scikit-learnにおけるデータセットの形式:Bunchクラスの扱い方【Python】

scikit-learnのサンプルデータはBunchクラスのオブジェクトとして取得できますが、このBunchクラスとは何なんでしょうか?ここではBunchクラスのオブジェクトからデータを取得する方法について説明していきます。

開発環境

  • Python 3.7.9
  • scikit-learn 0.23.2
目次

Bunchクラスとは

BunchクラスはPython標準ライブラリの辞書(dictクラス)を継承して定義されたクラスで、scikit-learnにおけるデータセットのパッケージに用いられます。dictクラスに属性参照の機能が追加されていることが特徴で、Bunchオブジェクトではそのキーを属性のように扱って値を取得することができます。

例えば、key01というキーの値を取得する場合は

# 辞書(sample_dict)の場合
value = sample_dict['key01']

# Bunchオブジェクト(sample_bunch)の場合
value = sample_bunch.key01

とすることができ、より簡単にその値を取得することができます。

実際のBunchインスタンスを見てみる

Bunchクラス自体は辞書に属性参照の機能を追加しただけのものですが、scikit-learnで機械学習を扱うにあたって自分で一からBunchインスタンスを作成するようなことはほぼなく、load_diabetes関数やload_iris関数などを用いてサンプルデータをBunchインスタンスとして取得するような場合がほとんどだと思われます。

サンプルデータとして取得されるBunchインスタンスのキーはインスタンスごとに異なりますが、以下の5つのキーに関しては多くのBunchインスタンスで共通して含まれています。また、load_****関数でas_frame=Falseの場合は、dataやtargetはndarrayとして取得されますが、Trueの場合はDataFrameとして取得されます。

キー説明値のデータ型
DESCRデータセットに関する詳細情報str
feature_names特徴量の項目名list / ndarray
data特徴量のデータセット (説明変数)ndarray
(as_frame=Trueの場合はDataFrame)
target目標値のデータセット (目的変数)ndarray
(as_frame=Trueの場合はDataFrame / Series)
framedata / targetを組み合わせたDataFrame
(as_frame=Trueの場合のみ)
DataFrame

データの本体はdata(説明変数)とtarget(目的変数)にそれぞれndarrayとして格納されています。また、説明変数の特徴量の項目はfeature_namesに格納されています。その他のキーはそのBunchオブジェクトごとに異なるので、keysメソッドを用いて確認してください。

例えばload_diabetes関数で得られる「糖尿病の診療についてのデータセット」のBunchインスタンスのキーを表示してみましょう。

from sklearn import datasets
diabetes = datasets.load_diabetes()
print(diabetes.keys())
dict_keys(['data', 'target', 'frame', 'DESCR', 'feature_names', 'data_filename', 'target_filename'])

「data」「target」「frame」「DESCR」「feature_names」に加えて「data_filename」「target_filename」の7つのキーが含まれていることが分かりました。特徴量の項目一覧は「feature_names」に格納されています。

from sklearn import datasets
diabetes = datasets.load_diabetes()
print(diabetes.feature_names)
['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']

これが特徴量の一覧ですが、これだけではそれぞれが何を表すのか全く分かりませんよね?そのようなときのために多くのBunchインスタンスではDESCRに格納されているデータについての詳しい説明が格納されています。

from sklearn import datasets
diabetes = datasets.load_diabetes()
print(diabetes.DESCR)
        ...(略)...

**Data Set Characteristics:**

  :Number of Instances: 442

  :Number of Attributes: First 10 columns are numeric predictive values

  :Target: Column 11 is a quantitative measure of disease progression one year after baseline

  :Attribute Information:
      - age     age in years
      - sex
      - bmi     body mass index
      - bp      average blood pressure
      - s1      tc, T-Cells (a type of white blood cells)
      - s2      ldl, low-density lipoproteins
      - s3      hdl, high-density lipoproteins
      - s4      tch, thyroid stimulating hormone
      - s5      ltg, lamotrigine
      - s6      glu, blood sugar level

Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).

        ...(略)...

これでそのデータセットに何が格納されているかが分かります。

主なサンプルデータセット

scikit-learnで取得できる主なサンプルデータセットについては以下の記事にまとめているのでご覧ください。

スポンサーリンク

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

one × three =

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

目次