2016-04-03 26 views
5

BxWxHxD boyutunda bir tensörüm var mı? Tensörü işlemek için her WxH dilimindeki sadece maksimum elemanın tutulduğu ve diğer tüm değerlerin sıfır olduğu yeni bir BxWxHxD tensörüne sahip olmak istiyorum. Başka bir deyişle, bunu başarmanın en iyi yolunun bir şekilde WxH dilimleri boyunca bir 2B argmax alması, daha sonra tek bir sıcak BxWxHxD tensörüne dönüştürülebilen satırlar ve sütunlar için BxD indeks tensörlerinin elde edilmesi olduğunu düşünüyorum. maske olarak kullanılmak üzere. Bu işi nasıl yaparım?Tensorflow çok boyutlu argmax

cevap

1

Aşağıdaki işlevi başlangıç ​​noktası olarak kullanabilirsiniz. Her parti ve her kanal için maksimum elemanın indekslerini hesaplar. Elde edilen dizi formattadır (parti boyutu, 2, kanal sayısı).

def argmax_2d(tensor): 

    # input format: BxHxWxD 
    assert rank(tensor) == 4 

    # flatten the Tensor along the height and width axes 
    flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) 

    # argmax of the flat tensor 
    argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) 

    # convert indexes into 2D coordinates 
    argmax_x = argmax // tf.shape(tensor)[2] 
    argmax_y = argmax % tf.shape(tensor)[2] 

    # stack and return 2D coordinates 
    return tf.stack((argmax_x, argmax_y), axis=1) 

def rank(tensor): 

    # return the rank of a Tensor 
    return len(tensor.get_shape())