2016-03-28 33 views
2

Dizilerimi ArrayFire'da doyurmaya çalışıyorum. 0.75'ten büyük olan tüm değerlerin 1.0'a doyurulmasını ve 0.25'ten azının 0.0'a doyurulmasını istiyorum. Aşağıdaki ifadeleri kullanıyorum.Koşullu koşullu hata

a(a > 0.75) = 1.0; 
a(a < 0.25) = 0.0; 

İşte bir af :: dizi türüdür. Bir süre için çalışır, ancak 0.75'ten büyük değerler olmayan bir dizi alır almaz aşağıdaki özel durumu alırım. Ben af::print("", a > 0.75); ararsam

terminate called after throwing an instance of 'af::exception' 
    what(): ArrayFire Exception (Invalid input size:203): 
In function verifyDims 
In file src/api/c/data.cpp:36 
Invalid dimension for argument 1 
Expected: ndims >= 1 

In function af::array af::constant(T, const af::dim4&, af::dtype) [with T = double; af::dtype = af_dtype] 
In file src/api/cpp/data.cpp:28 

Ben çöküyor hemen önce şu çıktıyı almak.

[10 1 1 1] 
     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 

bir şekilde bu dizi tüm sıfırlar olduğunu görecek mi ve sonra boyut sıfır olduğunu söyleyerek (non 0,75'den büyük olması nedeniyle bu olmalıdır)? Bu yanlış yaptığım bir şey mi yoksa kodlarında bir hata mı?

Aşağıdaki kod bunu düzeltiyor gibi görünüyor, ancak bu çözümün biraz verimsiz olduğunu hissediyorum. Ben bir sinir ağına dik iniş yapıyorum tüm işlevini görmek istiyorum olanlarınız için

af::array bellow = a[levels - 1] < 0.25f; 
af::array above = a[levels - 1] > 0.75f; 

if(af::anyTrue<bool>(above)) 
    a[levels - 1](above) = 0.75f; 

if(af::anyTrue<bool>(bellow)) 
    a[levels - 1](bellow) = 0.25f; 

. a aslında ve tür dizisi dizidir. Soruyu basitleştirmek için bunu bıraktım.

void train(const float* in, const float* expected_out, float learning_rate) 
{ 
    std::unique_ptr<af::array[]> a(new af::array[levels]), 
      z(new af::array[levels]), d(new af::array[levels]); 

    af::array in_array(inputs, in); 
    af::array y(dims[levels - 1], expected_out); 

    z[0] = af::matmul(weights[0], in_array) + biases[0]; 
    a[0] = sigma(z[0]); 


    for(size_t i = 1; i < levels; i++) 
    { 
     z[i] = af::matmul(weights[i], a[i - 1]) + biases[i]; 
     a[i] = sigma(z[i]); 
    } 


    a[levels - 1](a[levels - 1] < 0.25f) = 0.0f; 
    a[levels - 1](a[levels - 1] > 0.75f) = 1.0f; 

    d[levels - 1] = (y - a[levels - 1]) * sigma_prime(z[levels - 1]); 
    for(size_t i = levels - 1; i-- > 0;) 
     d[i] = af::matmul(weights[i + 1].T(), d[i + 1]) * sigma_prime(z[i]); 

    for(size_t i = 0; i < levels; i++) 
    { 
     biases[i] += learning_rate * d[i]; 
     weights[i] += learning_rate * af::matmul(d[i], (i ? a[i - 1] : in_array).T()); 
    } 
} 
+0

Lütfen hata veren kod parçasını ekleyin. – hyde

+0

İlk kod parçası yaptım. – chasep255

+0

'a (a> 0.75) = 1.0; a (a <0.25) = 0.0; ' – chasep255

cevap

4

görmekte olduğunuz hata nedeniyle bu open bug about zero length arrays taşımaktadır (EDIT: v3.4.0 itibaren sabit). Bu, bir süredir düzgün bir şekilde düzeltmeye çalıştığımız yaygın bir sorundur.

İşte işiniz için işiniz bitti. Yapmaya çalıştığınız şeyi elde etmek için indekslemeye bile ihtiyacınız yok.

a[levels - 1] = af::min(0.75, af::max(0.25, a[levels - 1])); 

DÜZENLEME: 3.4 itibaren, arrayfire aynı işlevselliği elde etmek için aşağıdakileri yapabilirsiniz:

a[levels - 1] = af::clamp(a[levels - 1], 0.25, 0.75); 

Bu yöntem çok daha hızlı dava için indeksleme fazla. sözü


, sen endeksleme yerine af::min ve af::max kullanamaz bazı durumlar vardır. Bu gibi durumlarda, size geçici bir çözüm olarak böyle bir şey yapabileceğini:

af::array cond = arr < some_val; 
arr = arr * (1 - cond) + cond * other_val; 

Bu aynı zamanda endeksleme daha hızlı olmalıdır. Ancak, dizilerde NAN varsa ve bunları değiştirmeye çalışıyorsanız, aritmetik çalışmaz. Bu durumda aşağıdaki işlevlerden birine geri dönebilirsiniz.select (ek bellek kullanır) kullanarak

:

arr = af::select(af::isNaN(arr), arr, other_val)); 

yerine kullanma (yerinde değiştirir, hiçbir ek bellek kullanılır):

af::replace(arr, af::isNaN(arr) other_val)); 

Ancak bazı kıyaslama select ve replace olabileceğini gösterdi Bazı durumlarda indekslemeden daha yavaş (düzeltmeye çalışıyoruz). Yani, select/replace algoritmanızda yavaşlarsa, indeksleme için aşağıdaki çalışmayı kullanmayı deneyebilirsiniz.

af::array idx = af::where(af::isNaN(arr)); 
if (idx.elements()) arr(idx) = replace_val; 

Not içten af::where çağıran af::array bir mantıksal o indeksleme. Yani bu Aşağıdakilerden sıfır boyutlu diziler için başarısız değil yararı ile

arr(arr < some_val) = other_val; 

kadar etkilidir.

DÜZENLEME: Gelecek için ek geçici çözümler eklenmiştir.