2013-09-26 15 views
5

Belirli bir koşulu karşılamayan satırların yerini tutam matrisler için sıfırlarla değiştirmenin en iyi yolunun ne olduğunu merak ediyorum. Örneğin, (I düz gösterim için diziler kullanılır):Belirli koşullara uyan scipy.sparse matris satırlarını sıfırlar olarak ayarlayın

a = np.array([[0,0,0,1,1], 
       [1,2,0,0,0], 
       [6,7,4,1,0], # sum > 10 
       [0,1,1,0,1], 
       [7,3,2,2,8], # sum > 10 
       [0,1,0,1,2]]) 

I değiştirmek isteyen bir miktarı toplamı sıfır bir sıra ile daha büyük 10 her satır değiştirmek isteyen

bir [2] ve [ 4] sıfırlarla, yani benim çıkış aşağıdaki gibi görünmelidir:

row_sum = a.sum(axis=1) 
to_keep = row_sum >= 10 
a[to_keep] = np.zeros(a.shape[1]) 
:

array([[0, 0, 0, 1, 1], 
     [1, 2, 0, 0, 0], 
     [0, 0, 0, 0, 0], 
     [0, 1, 1, 0, 1], 
     [0, 0, 0, 0, 0], 
     [0, 1, 0, 1, 2]]) 

Bu yoğun matrisler için yalındır oldukça olduğunu Ancak 0

, ben çalıştığınızda:

s = sparse.csr_matrix(a) 
s[to_keep, :] = np.zeros(a.shape[1]) 

bu hatayı alıyorum:

raise NotImplementedError("Fancy indexing in assignment not " 
NotImplementedError: Fancy indexing in assignment not supported for csr matrices. 

Dolayısıyla, ben seyrek matrisler için farklı bir çözüm gerekir. Ben bu geldi:

Bu bizim sıfıra kimlik matrisi içinde diyagonal 2. ve 4. unsurları kurarsanız, önceden çoğaltılmış matrisin satırları sıfıra ayarlanır gerçeğine dayanır
def zero_out_unfit_rows(s_mat, limit_row_sum): 
    row_sum = s_mat.sum(axis=1).T.A[0] 
    to_keep = row_sum <= limit_row_sum 
    to_keep = to_keep.astype('int8') 
    temp_diag = get_sparse_diag_mat(to_keep) 
    return temp_diag * s_mat 

def get_sparse_diag_mat(my_diag): 
    N = len(my_diag) 
    my_diags = my_diag[np.newaxis, :] 
    return sparse.dia_matrix((my_diags, [0]), shape=(N,N)) 

. Bununla birlikte, daha iyi, daha scipinic bir çözüm olduğunu hissediyorum. Daha iyi bir çözüm var mı?

cevap

4

Çok fazla scithonic olup olmadığından emin değilsiniz, ancak seyrek matrislerdeki pek çok işlem doğrudan guts erişerek daha iyi yapılır. Davanız için kişisel olarak şunu yapabilirim:

a = np.array([[0,0,0,1,1], 
       [1,2,0,0,0], 
       [6,7,4,1,0], # sum > 10 
       [0,1,1,0,1], 
       [7,3,2,2,8], # sum > 10 
       [0,1,0,1,2]]) 
sps_a = sps.csr_matrix(a) 

# get sum of each row: 
row_sum = np.add.reduceat(sps_a.data, sps_a.indptr[:-1]) 

# set values to zero 
row_mask = row_sum > 10 
nnz_per_row = np.diff(sps_a.indptr) 
sps_a.data[np.repeat(row_mask, nnz_per_row)] = 0 
# ask scipy.sparse to remove the zeroed entries 
sps_a.eliminate_zeros() 

>>> sps_a.toarray() 
array([[0, 0, 0, 1, 1], 
     [1, 2, 0, 0, 0], 
     [0, 0, 0, 0, 0], 
     [0, 1, 1, 0, 1], 
     [0, 0, 0, 0, 0], 
     [0, 1, 0, 1, 2]]) 
>>> sps_a.nnz # it does remove the entries, not simply set them to zero 
10 
İlgili konular