ndarrayから条件を満たす要素を抽出する【Python】

ndarrayから条件を満たす要素を抽出する方法には、ブールインデックスを用いる方法とwhere関数を用いる方法の2通りがあります。ここではその方法を見ていきましょう。

開発環境

  • numpy 1.19.2
  • Python 3.7.9
目次

ブールインデックス参照を用いて抽出する

ブールインデックス参照とは?

ndarrayでは、同じ形状のndarrayでTrue / Falseのbool値のみを要素にしたndarrayを添え字に指定してやると、Trueの要素のみを抽出することができます。これをブールインデックス参照と呼び、ndarray以外にもpanda.DataFrameなどでも用いられ、Pythonで広く採用されている方法です。

ブールインデックス参照を用いた条件の指定方法

それでは、このブールインデックス参照を用いて1次元のndarray[1, 3, 2, 4, 7, 5, 6, 10, 8, 9]から5以上の要素を抽出する方法を見ていきましょう。ndarrayを条件式を当てはめると、その条件式がブロードキャストされて各要素ごとに条件を満たすか判定して、True/Falseから成る配列を得ることができます。

import numpy as np

a = np.array([1, 3, 2, 4, 7, 5, 6, 10, 8, 9])
print(a >= 5)
[False False False False  True  True  True  True  True  True]    

この結果をもとのndarrayに当てはめることでブールインデックス参照を行うことができ、Trueに対応する要素のみを抽出して1次元配列で返してくれます。

import numpy as np

a = np.array([1, 3, 2, 4, 7, 5, 6, 10, 8, 9])
print(a[[False, False, False, False, True, True, True, True, True, True]])
[ 7  5  6 10  8  9]

以上の手順をまとめると、ndarrayにおけるブールインデックス参照は次のようになります。

import numpy as np

a = np.array([1, 3, 2, 4, 7, 5, 6, 10, 8, 9])
print(a[a >= 5])
[ 7  5  6 10  8  9]

なお元の配列の形状がどのようなものであっても、この方法で得られる抽出結果は1次元配列になることに注意してください。

import numpy as np

b = np.array([[1, 2, 6, 4], [3, 8, 7, 6]])
print(b[b >= 5])
[6 8 7 6]

where関数を用いて抽出する

条件を満たす要素の抽出はwhere関数を用いても行うことができます。where関数に条件式を指定すると、その条件を満たす要素のインデックスを取得することができます。

import numpy as np

a = np.array([1, 3, 2, 4, 7, 5, 6, 10, 8, 9])
print(np.where(a >= 5))
(array([4, 5, 6, 7, 8, 9], dtype=int64),)

なお、本来where関数に指定できるのはTrue/Falseから成る配列(ブールインデックス)ですが、ここに条件式を指定することでそれぞれの要素に条件式を当てはめてTrue/Falseから成る配列を返してくれるので、条件式をそのまま指定することが可能になっています。

さらにここで取得した条件を満たす要素のインデックスをもとのndarrayに指定してやれば、その要素を取得することができます。

import numpy as np

a = np.array([1, 3, 2, 4, 7, 5, 6, 10, 8, 9])
print(a[np.where(a >= 5)])
[ 7  5  6 10  8  9]

ただし、where関数を用いた方法よりもブールインデックス参照を用いた方がシンプルに実装することができるので、単純に条件を満たす要素を抽出するだけならwhere関数を用いるメリットはあまりないかも知れません。where関数を用いるのは、条件を満たす要素を取得する場合よりも、むしろその要素のインデックスを取得するような場合などです。また、where関数を用いれば、条件を満たすか満たさないかによって値を書き換えることもできます。

スポンサーリンク

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

nine + ten =

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

目次