blog

リア向け。

Python(Numpy)での多重ループの書き方

Pythonで大量のデータを使おうとすると遅い。
Numpyを使ってもforで回してたら意味ない。遅い。
numpyのmeshgridを書けばfor文いらない。速い。

試しに3重ループになる行列積を書いてみる。

#!/usr/bin/python2
#vim:fileencoding=utf8

import numpy as np

def main():
    arr34 = np.random.randint(0,100,12).reshape((3,4))
    arr45 = np.random.randint(0,100,20).reshape((4,5))

    # 行列の掛け算(行列の場合のみ使えるが一般的な多重ループには無理)
    ref1 =  np.mat(arr34)*np.mat(arr45)

    # 典型的なダメな例。forループを使う。
    ref2 = np.zeros_like(ref1)
    for i in range(arr34.shape[0]):
        for j in range(arr45.shape[1]):
            for k in range(arr45.shape[0]):
                ref2[i,j] += arr34[i,k]*arr45[k,j]

    # これもダメな例。forループを使う。
    from itertools import product
    ref3 = np.zeros_like(ref1)
    for i,j in product(range(arr34.shape[0]),range(arr45.shape[1])):
        ref3[i,j] = np.sum(arr34[i,:]*arr45[:,j])

    # 多重ループの書き方としてはこれがベストか?ただし状況によってはメモリの使用量が心配
    idx345 = np.meshgrid(np.arange(3),np.arange(4),np.arange(5))
    tmp = np.sum(arr34[idx345[0],idx345[1]]*arr45[idx345[1],idx345[2]],axis=0)

    assert np.all(ref1==ref2)
    assert np.all(ref1==ref3)
    assert np.all(tmp==ref1)

if __name__ == "__main__":
    main()