2016-01-26 10 views
16

'dan tüm değişkenleri nasıl alabilirim tf.initialize_all_variables()'u kullanan ana başlatmadan sonra LSTM'yi başlatmam gereken bir kurulum var. Yani Ben SADECE başlatabilmesi Tensorflow: rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

yüzden

  • rnn_cell.MultiRNNCell

    • rnn_cell.BasicLSTM: Ben tf.initialize_variables([var_list])

      çağırmak için her ikisi için tüm iç eğitilebilir değişkenleri toplamak için yolu var mı istiyorum bu parametreler?

      Bunun temel sebebi, daha önce bazı eğitimli değerleri yeniden başlatmak istemediğimden kaynaklanmaktadır.

  • cevap

    17

    Sorununuzu çözmenin en kolay yolu, değişken kapsamı kullanmaktır. Bir kapsamdaki değişkenlerin isimleri, ismi ile önekine eklenir.

    cell = rnn_cell.BasicLSTMCell(num_nodes) 
    
    with tf.variable_scope("LSTM") as vs: 
        # Execute the LSTM cell here in any way, for example: 
        for i in range(num_steps): 
        output[i], state = cell(input_data[i], state) 
    
        # Retrieve just the LSTM variables. 
        lstm_variables = [v for v in tf.all_variables() 
            if v.name.startswith(vs.name)] 
    
    # [..] 
    # Initialize the LSTM variables. 
    tf.initialize_variables(lstm_variables) 
    

    O MultiRNNCell ile aynı şekilde çalışır olacaktır: Burada kısa bir snippet'tir.

    DÜZENLEME: Ayrıca tf.get_collection() kullanabilirsiniz tf.all_variables()

    +0

    Bu mükemmel, teşekkürler. Tf.trainable_variables() 'nin kapsamına saygı duyduğunun farkında değildim, ama sanırım geçmişte bu mantıklı! – bge0

    +1

    "tf.trainable_variables()" yerine tf.all_variables() 'yi eklemek daha iyi bir seçim olacaktır. Temel olarak, eğitilebilir değişkenlere sahip olmayan, ancak yine de başlatılması gereken optimizer gibi şeyler olduğu için. – bge0

    +1

    Teşekkürler, haklısınız. Kodu güncelledim. –

    11

    için tf.trainable_variables değiştirdi:

    cell = rnn_cell.BasicLSTMCell(num_nodes) 
    with tf.variable_scope("LSTM") as vs: 
        # Execute the LSTM cell here in any way, for example: 
        for i in range(num_steps): 
        output[i], state = cell(input_data[i], state) 
    
        lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 
    

    son satırı liste anlamada eşdeğer olduğunu

    Not (kısmen Rafal cevabı kopyalanmış) Rafal'ın kodu.

    Temel olarak, tensorflow, tf.all_variables() veya tf.get_collection(tf.GraphKeys.VARIABLES) tarafından getirilebilen genel bir değişkenler koleksiyonunu depolar. tf.get_collection() işlevinde scope (kapsam adı) belirtirseniz, kapsamları belirtilen kapsamın altında olan koleksiyonda yalnızca tensörleri (bu durumda değişkenler) getireceksiniz.

    EDIT: Yalnızca eğitilebilir değişkenler almak için tf.GraphKeys.TRAINABLE_VARIABLES'u kullanabilirsiniz. Fakat vanilya BasicLSTMCell herhangi bir eğitilebilir olmayan değişkeni başlatmadığından, her ikisi de işlevsel olarak eşdeğer olacaktır. Varsayılan grafik koleksiyonlarının tam listesi için, this'u kontrol edin.

    +0

    'daki tüm diğer değişkenlerle tutarlı olması için farklı şekilde ayarlanmasını diliyorum. Bu, Rafal'ın çözümünden daha iyi bir yoldur :-) –

    +1

    Yukarıdaki gibi, belki de daha iyi kullanmalıyım tf.get_collection (..., scope = vs.name + "/") çünkü "LSTM2" adlı başka bir alan da olabilir. – Albert