2015-12-18 12 views
6

bunun aşağıdaki dosya .pb için, basit bir kod formu öğretici ve çıkış var:Bir TensorFlow grafiğini C++'daki bir protobuftan nasıl çalıştırabilirim?

mnist_softmax_train.py

x = tf.placeholder("float", shape=[None, 784], name='input_x') 
y_ = tf.placeholder("float", shape=[None, 10], name='input_y') 

W = tf.Variable(tf.zeros([784, 10]), name='W') 
b = tf.Variable(tf.zeros([10]), name='b') 
tf.initialize_all_variables().run() 
y = tf.nn.softmax(tf.matmul(x,W)+b, name='softmax') 

cross_entropy = -tf.reduce_sum(y_*tf.log(y)) 

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy, name='train_step') 
train_step.run(feed_dict={x:input_x, y_:input_y}) 
C

++ aynı grafik yük ve sahte verileri feed test için:

: Ancak

Tensor input_x(DT_FLOAT, TensorShape({10,784})); 
Tensor input_y(DT_FLOAT, TensorShape({10,10})); 
Tensor W(DT_FLOAT, TensorShape({784,10})); 
Tensor b(DT_FLOAT, TensorShape({10,10})); 
Tensor input_test_x(DT_FLOAT, TensorShape({1,784})); 

for(int i=0;i<10;i++){ 
    for(int j=0;j<10;j++) 
     input_x.matrix<float>()(i,i+j) = 1.0;  

    input_y.matrix<float>()(i,i) = 1.0; 
    input_test_x.matrix<float>()(0,i) = 1.0; 
} 

std::vector<std::pair<string, tensorflow::Tensor>> inputs = { 
    { "input_x", input_x }, 
    { "input_y", input_y }, 
    { "W", W }, 
    { "b", b }, 
    { "input_test_x", input_test_x }, 
}; 

std::vector<tensorflow::Tensor> outputs; 
status = session->Run(inputs, {}, {"train_step"}, &outputs); 

std::cout << outputs[0].DebugString() << "\n"; 

, bu hata ile başarısız

Grafik, Python'da doğru bir şekilde çalışıyor. C++ ile nasıl doğru bir şekilde çalıştırabilirim?

+1

Bunu çalıştırmayı denediğinizde hangi hatayı alırsınız? – mrry

+0

Üzgünüz, kodumu değiştirmiştim. Train_step grafiğini çalıştırırsam, "Geçersiz argüman: node train_step/update_W/ApplyGradientDescent öğesinin giriş 0'ı, beklenen float_ref ile uyumlu olmayan _recv_W_0: 0 süzgecinden geçirildi." –

+0

[C++ API'sı ile bir TensorFlow grafiği yükleme] (https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f) – nobar

cevap

5

Sorun, yalnızca çıkarımdan çok daha fazla iş gerçekleştiren "train_step" hedefini çalıştırıyor olmanızdır. Özellikle, degrade iniş aşamasının sonucuyla W ve b değişkenlerini güncellemeye çalışır. Hata mesajı

Invalid argument: Input 0 of node train_step/update_W/ApplyGradientDescent was passed float from _recv_W_0:0 incompatible with expected float_ref. 

... düğümlerden biri size ("train_step/update_W/ApplyGradientDescent") (tip float_ref ile) değişken girdi beklenen çalıştırmayı denedi ama değer çünkü (tip float ile) değişmez girişi var demektir . beslenen içinde

vardır iki olası çözümler (en azından): yalnızca belirli bir girdi için tahminleri ve verilen ağırlıklar bakınız "softmax:0" yerinegetirmek istiyorsanız

  1. Session::Run() numaralı çağrıya.

  2. C++ eğitim gerçekleştirmek istiyorsanız, W ve b besleme değil, ancak bunun yerine daha sonra bu değişkenlere değer atamak "train_step" yürütmeye devam etmektedir. Grafiği Python'da oluştururken tf.train.Saver oluşturmayı ve daha sonra bir denetim noktasından değerleri kaydetmek ve geri yüklemek için ürettiği işlemleri çağırmayı daha kolay bulabilirsiniz.

+0

Çok teşekkürler! Ama "W ve b beslemeyin" derken ne demek istiyorsun? C++ için girdi olarak W veya b için Tensor oluşturamıyorum? –

+0

Sorun, "W" ve "b" nin TensorFlow * değişkenleri * olması, onları beslemenin desteklenmediği anlamına gelir. Bunun yerine, değerlerini ayarlamak için, onlara bir değer atayan bir Assign op çalıştırmalısınız. Bunu yapmanın bir yolu, Wassign (tf.placeholder (tf.float32, name = "w_placeholder"), name = "w_assign") 'gibi grafiğinize (Python'da) bazı operasyonlar eklemek olabilir. '' session :: Run() '' 'w_assign '' kelimesini hedef olarak ve' 'w_placeholder: 0' 'olarak adlandırılacak ve tensörün besleneceği (ve benzer şekilde 'b' için). – mrry

+0

Tekrar teşekkürler, bu arada, C++ 'da daha sonra kullanabileceğim .pb dosyasında eğitim verilerini kaydetmenin doğru yöntemi nedir? Ben tf.train.write_graph() kullanıyorum ama protobuf'un eğitim verilerini tutmadığı görülüyor. –

İlgili konular