2016-07-15 11 views
6

numpy için yani n = 120 ve N = 100000 ve einsum döner aşağıdaki hata:Alternatifler Ben genellikle <code>einsum</code> kullanın einsum

ValueError: iterator is too large

3 iç içe döngüler yapmanın alternatif sınırların ötesine, yani herhangi bir alternatif var mı diye merak ediyorum. numpy GPU ya da bir şey erişimi sürece yavaş olacaktır bu nedenle bu hesaplama, en az ~ N × N = 173 milyar işlem (simetri dikkate almayan) için gerekli olacaktır

cevap

4

not edin. ~ 3 GHz CPU'lu modern bir bilgisayarda, tüm hesaplamaların SIMD/paralel hızlanma olmadığı varsayılarak tamamlanması yaklaşık 60 saniye sürmesi bekleniyor.

#!/usr/bin/env python3 

import numpy 
import time 

numpy.random.seed(0) 

n = 120 
N = 1000 
X = numpy.random.random((N, n)) 

start_time = time.time() 

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X) 

end_time = time.time() 

print('check:', M3[2,4,6], '= 125.401852515?') 
print('check:', M3[4,2,6], '= 125.401852515?') 
print('check:', M3[6,4,2], '= 125.401852515?') 
print('check:', numpy.sum(M3), '= 218028826.631?') 
print('total time =', end_time - start_time) 

Bu yaklaşık 8 saniye sürer:


test için biz doğruluğunu ve performansını denetlemek için kullanır en N = 1000 ile başlayalım. Bu başlangıç ​​noktasıdır.

en alternatif olarak 3 iç içe döngü ile başlayalım:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l]) 
# ~27 seconds 

Bu kabaca yarım dakika sürer, hiç iyi değil! Bunun bir nedeni, aslında dört iç içe geçmiş döngü olmasıdır: numpy.sum da bir döngü olarak düşünülebilir.

Biz toplamı bu 4 döngü kaldırmak için bir nokta ürüne dönüştürülebilir unutmayın:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l] 
# 14 seconds 

Şimdi çok daha iyi ama hala yavaş.

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     M3[j,k] = X[:,j] * X[:,k] @ X 
# ~0.5 seconds 

Huh: Ama biz nokta ürünü bir döngü kaldırmak için bir matris çarpma içine değiştirilebilir unutmayın? Şimdi bu, einsum'dan bile çok daha verimli! Cevabın gerçekten doğru olup olmadığını da kontrol edebiliriz.

daha ileri gidebilir miyiz? Evet! Biz tarafından k döngü ortadan kaldırabilir:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = numpy.repeat(X[:,j], n).reshape((N, n)) 
    M3[j] = (Y * X).T @ X 
# ~0.3 seconds 

Biz de kaçınmakla (X her satır için yani a * [b,c] == [a*b, a*c]) yayını kullanabilirsiniz numpy.repeat (teşekkürler @Divakar):

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = X[:,j].reshape((N, 1)) 
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j] 
    M3[j] = (Y * X).T @ X 
# ~0.16 seconds 

biz ölçeklerseniz Bu N = 100000 program (ama anlamak kod gerçekten zor hale getirebilir) öylesine j çok fazla yardımcı olmayabilir ortadan kaldırarak, teorik sınırı içinde olduğu 16 saniye, sürmesi bekleniyor. Bunu nihai çözüm olarak kabul edebiliriz.


Not: Python 2 kullanıyorsanız, a @ ba.dot(b) eşdeğerdir.

+0

Harika cevap, teşekkürler! –

+0

Gerçekten harika bir fikir. Burada biraz yayın ekleyebilirsem, “Y” yi oluşturmaktan ve doğrudan yinelemeli çıktıdan kaçınabiliriz: '(X [:, Yok, j] * X) .T @ X'. Bu bize biraz daha fazla performans artışı sağlamalı. – Divakar

+0

@Divakar: Teşekkürler! Güncellenmiş. – kennytm