2016-04-04 20 views
2

Merhaba VGG modelini tensorflow'dan ayarlamak istiyorum. İki sorum var.Tensorflow modelinden ağırlık alın

Ağdan ağırlıklar nasıl alınır? Trainable_variables benim için boş liste döndürür.

Şu anki modeli kullandım: https://github.com/ry/tensorflow-vgg16. Yüklemeyi ağırlıkla ilgili buluyorum ancak import_graph_def nedeniyle bu benim için çalışmıyor. Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf 
import PIL.Image 
import numpy as np 

with open("../vgg16.tfmodel", mode='rb') as f: 
    fileContent = f.read() 

graph_def = tf.GraphDef() 
graph_def.ParseFromString(fileContent) 

images = tf.placeholder("float", [None, 224, 224, 3]) 

tf.import_graph_def(graph_def, input_map={ "images": images }) 
print("graph loaded from disk") 

graph = tf.get_default_graph() 

cat = np.asarray(PIL.Image.open('../cat224.jpg')) 
print(cat.shape) 
init = tf.initialize_all_variables() 

with tf.Session(graph=graph) as sess: 
    print(tf.trainable_variables()) 
    sess.run(init) 
+0

Aynı anda birden fazla soru sormaktan kaçınmaya çalışın. Her ikisine de bir cevap bulamıyorsanız ve bunların her ikisi de değerli sorular olduğunu düşünüyorsanız, ikisini de ayrı yayınlarda sormaya çekinmeyin. – Giewev

cevap

4

Bu pretrained VGG-16 modeltf.constant() ops gibi model parametrelerinin hepsi kodlar. (Örneğin, tf.constant()here numaralı çağrılara bakın.) Sonuç olarak, model parametreleri tf.trainable_variables()'da görünmeyecek ve model önemli bir ameliyat olmadan değiştirilemez: sabit düğümleri tf.Variable ile başlayan nesnelerle değiştirmeniz gerekir. eğitime devam etmek için aynı değer. Genel olarak, yeniden bir grafik almak için içe aktarırken, tf.train.import_meta_graph() işlevi kullanılmalıdır, çünkü bu işlev ek meta verileri (değişkenlerin koleksiyonları dahil) yükler. tf.import_graph_def() işlevi alt düzeydedir ve bu koleksiyonları doldurmaz.