2013-07-08 17 views
5

Numpy'yi kullanan bir Python işlevini nasıl hızlandırmaya çalışıyorum. lineprofiler'dan aldığım çıktı aşağıdadır ve bu, zamanın büyük çoğunluğunun ind_y, ind_x = np.where(seg_image == i) hattında harcandığını gösterir.Tam sayı segmentlerini ayıklamak için numpy.where hızlandırın mı?

seg_image, bir görüntünün bölümlendirilmesinin sonucu olan bir tamsayı dizisidir, böylece seg_image == i belirli bir bölümlenmiş nesneyi ayıklar. Ben bu nesnelerin birçoğunun içinden geçiyorum (aşağıdaki kodda test için sadece 5'e geçiyorum, ama aslında 20.000'in üzerinde döngü yapacağım) ve koşmak için çok uzun zaman harcıyor!

np.where çağrısının hızlandırılabileceği herhangi bir yol var mı? Ya da alternatif olarak, sondan bir önceki satırın (ki bu da zamanın iyi bir oranını da alır) hızlandırılabilir mi?

İdeal çözüm, kod dizisini döngüden ziyade tüm dizide çalıştırmak olabilir, ancak çalıştırmam gereken bazı işlevlerin yan etkileri olduğu için bunun mümkün olduğunu düşünmüyorum (Örneğin, bölümlenmiş bir nesnenin genişletilmesi, bir sonraki bölge ile 'çarpışmasını' ve böylece daha sonra yanlış sonuçlar vermesini sağlayabilir.

Herhangi bir fikri olan var mı? Eğer aynı zaman biraz yapabileceğini

Line #  Hits   Time Per Hit % Time Line Contents 
============================================================== 
    5           def correct_hot(hot_image, seg_image): 
    6   1  239810 239810.0  2.3  new_hot = hot_image.copy() 
    7   1  572966 572966.0  5.5  sign = np.zeros_like(hot_image) + 1 
    8   1  67565 67565.0  0.6  sign[:,:] = 1 
    9   1  1257867 1257867.0  12.1  sign[hot_image > 0] = -1 
    10           
    11   1   150 150.0  0.0  s_elem = np.ones((3, 3)) 
    12           
    13            #for i in xrange(1,seg_image.max()+1): 
    14   6   57  9.5  0.0  for i in range(1,6): 
    15   5  6092775 1218555.0  58.5   ind_y, ind_x = np.where(seg_image == i) 
    16           
    17             # Get the average HOT value of the object (really simple!) 
    18   5   2408 481.6  0.0   obj_avg = hot_image[ind_y, ind_x].mean() 
    19           
    20   5   333  66.6  0.0   miny = np.min(ind_y) 
    21             
    22   5   162  32.4  0.0   minx = np.min(ind_x) 
    23             
    24           
    25   5   369  73.8  0.0   new_ind_x = ind_x - minx + 3 
    26   5   113  22.6  0.0   new_ind_y = ind_y - miny + 3 
    27           
    28   5   211  42.2  0.0   maxy = np.max(new_ind_y) 
    29   5   143  28.6  0.0   maxx = np.max(new_ind_x) 
    30           
    31             # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above 
    32   5   217  43.4  0.0   obj = np.zeros((maxy+7, maxx+7)) 
    33           
    34   5   158  31.6  0.0   obj[new_ind_y, new_ind_x] = 1 
    35           
    36   5   2482 496.4  0.0   dilated = ndimage.binary_dilation(obj, s_elem) 
    37   5   1370 274.0  0.0   border = mahotas.borders(dilated) 
    38           
    39   5   122  24.4  0.0   border = np.logical_and(border, dilated) 
    40           
    41   5   355  71.0  0.0   border_ind_y, border_ind_x = np.where(border == 1) 
    42   5   136  27.2  0.0   border_ind_y = border_ind_y + miny - 3 
    43   5   123  24.6  0.0   border_ind_x = border_ind_x + minx - 3 
    44           
    45   5   645 129.0  0.0   border_avg = hot_image[border_ind_y, border_ind_x].mean() 
    46           
    47   5  2167729 433545.8  20.8   new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg))) 
    48   5  10179 2035.8  0.1   print obj_avg, border_avg 
    49           
    50   1   4  4.0  0.0  return new_hot 

cevap

4

DÜZENLEME Ben kayıt için altta özgün cevabım bırakmış, ama aslında öğle yemeğinde daha ayrıntılı şekilde koduna baktım ve ben np.where kullanarak büyük bir hata olduğunu düşünüyorum:

In [63]: a = np.random.randint(100, size=(1000, 1000)) 

In [64]: %timeit a == 42 
1000 loops, best of 3: 950 us per loop 

In [65]: %timeit np.where(a == 42) 
100 loops, best of 3: 7.55 ms per loop 

Puanların gerçek koordinatlarını almak için gereken sürenin 1/8'inde boole dizisini (indeksleme için kullanabilirsiniz) alabilirsiniz!

yapmanız özelliklerin kırpma elbette vardır

ancak ndimage çevreleyen dilimleri döndüren bir find_objects işlevi vardır ve çok hızlı olarak görünmektedir: Bu dilim dizilerini listesini döndürür

In [66]: %timeit ndimage.find_objects(a) 
100 loops, best of 3: 11.5 ms per loop 

, tek bir nesnesinin endekslerini bulmak için, nesnelerinizin tümünün 0% 30 daha fazla sürede eklenmesi.

Şu anda test edemez gibi kutunun dışında çalışmayabilir, ama şu gibi bir şey kodunuzu yeniden yapılandıracak:

def correct_hot_bis(hot_image, seg_image): 
    # Need this to not index out of bounds when computing border_avg 
    hot_image_padded = np.pad(hot_image, 3, mode='constant', 
           constant_values=0) 
    new_hot = hot_image.copy() 
    sign = np.ones_like(hot_image, dtype=np.int8) 
    sign[hot_image > 0] = -1 
    s_elem = np.ones((3, 3)) 

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)): 
     hot_image_view = hot_image[slice_] 
     seg_image_view = seg_image[slice_] 
     new_shape = tuple(dim+6 for dim in hot_image_view.shape) 
     new_slice = tuple(slice(dim.start, 
           dim.stop+6, 
           None) for dim in slice_) 
     indices = seg_image_view == j+1 

     obj_avg = hot_image_view[indices].mean() 

     obj = np.zeros(new_shape) 
     obj[3:-3, 3:-3][indices] = True 

     dilated = ndimage.binary_dilation(obj, s_elem) 
     border = mahotas.borders(dilated) 
     border &= dilated 

     border_avg = hot_image_padded[new_slice][border == 1].mean() 

     new_hot[slice_][indices] += (sign[slice_][indices] * 
            np.abs(obj_avg - border_avg)) 

    return new_hot 

Hala anlamaya gerekir çarpışmalar, ancak eş zamanlı olarak bir np.unique dayalı bir yaklaşım kullanarak tüm indekslerinin bir 2x hızlanmasından yaklaşık alabilir:

a = np.random.randint(100, size=(1000, 1000)) 

def get_pos(arr): 
    pos = [] 
    for j in xrange(100): 
     pos.append(np.where(arr == j)) 
    return pos 

def get_pos_bis(arr): 
    unq, flat_idx = np.unique(arr, return_inverse=True) 
    pos = np.argsort(flat_idx) 
    counts = np.bincount(flat_idx) 
    cum_counts = np.cumsum(counts) 
    multi_dim_idx = np.unravel_index(pos, arr.shape) 
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx)) 

In [33]: %timeit get_pos(a) 
1 loops, best of 3: 766 ms per loop 

In [34]: %timeit get_pos_bis(a) 
1 loops, best of 3: 388 ms per loop 

Not her için piksel o bject farklı bir sırayla döndürülür, böylece eşitliği değerlendirmek için her iki işlevin de getirilerini kıyaslayamazsınız. Ama ikisi de aynı şeyi geri vermeli.

+0

Bu harika, harika ve şaşırtıcı - teşekkürler! İlk koştuğumda aslında orijinal kodumdan daha yavaş olduğunu buldum, ancak kodunuzun bir kısmını değiştirdim, böylece tüm işleri (genişleme, sınırlar vb.) Büyük bir dizi yerine küçük bir dizide yaptı. new_shape'un nasıl hesaplandığını değiştirerek. Şimdi hızda büyük bir artış yaşadım. Çalıştığım resimlerden birinde, eski versiyon iki buçuk saat sürdü, yenisi 11 saniye sürdü! – robintw

+0

Oops! Evet, jenerasyon ifadesinin 'new_shape = tuple (loş_genim_view.shape'de loş için dim + 6)', 'new_shape = tuple (hot_image.shape'de loş için loş + 6) 'değil gibi görünüyor. Bu değişti mi? Lütfen, cevabımı çalışma kodunu yansıtacak şekilde düzenlemek için çekinmeyin. – Jaime

2

Bir şey onu iki kere hesaplamak gerekmez böylece seg_image == i sonucunu kurtarmaktır. 15 & 47 numaralı hatlarda hesaplıyorsunuz, seg_mask = seg_image == i ekleyebilir ve daha sonra bu sonucu yeniden kullanabilirsiniz (Bu parçanın profil oluşturma amacıyla ayrılması da iyi olabilir).

Biraz performans çıkarmak için yapabileceğiniz başka bazı küçük şeyler olsa da, temel sorun, M'nin parça sayısı ve N olduğu bir O (M * N) algoritması kullandığınızdır. görüntünün boyutu. Aynı şeyi başarmak için daha hızlı bir algoritma olup olmadığını kodunuzdan açıkça görmüyorum, ama denemek ve hızlandırmak için aradığım ilk yer.