not edin. ~ 3 GHz CPU'lu modern bir bilgisayarda, tüm hesaplamaların SIMD/paralel hızlanma olmadığı varsayılarak tamamlanması yaklaşık 60 saniye sürmesi bekleniyor.
#!/usr/bin/env python3
import numpy
import time
numpy.random.seed(0)
n = 120
N = 1000
X = numpy.random.random((N, n))
start_time = time.time()
M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
end_time = time.time()
print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)
Bu yaklaşık 8 saniye sürer:
test için biz doğruluğunu ve performansını denetlemek için kullanır en N = 1000 ile başlayalım. Bu başlangıç noktasıdır.
en alternatif olarak 3 iç içe döngü ile başlayalım:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds
Bu kabaca yarım dakika sürer, hiç iyi değil! Bunun bir nedeni, aslında dört iç içe geçmiş döngü olmasıdır: numpy.sum
da bir döngü olarak düşünülebilir.
Biz toplamı bu 4 döngü kaldırmak için bir nokta ürüne dönüştürülebilir unutmayın:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds
Şimdi çok daha iyi ama hala yavaş.
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds
Huh: Ama biz nokta ürünü bir döngü kaldırmak için bir matris çarpma içine değiştirilebilir unutmayın? Şimdi bu, einsum
'dan bile çok daha verimli! Cevabın gerçekten doğru olup olmadığını da kontrol edebiliriz.
daha ileri gidebilir miyiz? Evet! Biz tarafından k
döngü ortadan kaldırabilir:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = numpy.repeat(X[:,j], n).reshape((N, n))
M3[j] = (Y * X).T @ X
# ~0.3 seconds
Biz de kaçınmakla (X her satır için yani a * [b,c] == [a*b, a*c]
) yayını kullanabilirsiniz numpy.repeat
(teşekkürler @Divakar):
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = X[:,j].reshape((N, 1))
## or, equivalently:
# Y = X[:, numpy.newaxis, j]
M3[j] = (Y * X).T @ X
# ~0.16 seconds
biz ölçeklerseniz Bu N = 100000 program (ama anlamak kod gerçekten zor hale getirebilir) öylesine j
çok fazla yardımcı olmayabilir ortadan kaldırarak, teorik sınırı içinde olduğu 16 saniye, sürmesi bekleniyor. Bunu nihai çözüm olarak kabul edebiliriz.
Not: Python 2 kullanıyorsanız, a @ b
a.dot(b)
eşdeğerdir.
Harika cevap, teşekkürler! –
Gerçekten harika bir fikir. Burada biraz yayın ekleyebilirsem, “Y” yi oluşturmaktan ve doğrudan yinelemeli çıktıdan kaçınabiliriz: '(X [:, Yok, j] * X) .T @ X'. Bu bize biraz daha fazla performans artışı sağlamalı. – Divakar
@Divakar: Teşekkürler! Güncellenmiş. – kennytm