個別のノートブックを用意していないデータセットを整える段取りを記録している。
20世紀前半からある古くて小さなデータセットであるが、線形モデルでは完全に識別できないことはわかっているので、多種の学習機のプロトタイプを手軽に試すデータセットとして重宝されてきた。学習課題はアヤメの花の品種を、花びらの長さなどの指標で識別することである。
import os
import csv
import numpy as np
! cat data/iris/iris.data | head -n 5
! cat data/iris/iris.data | tail -n 5
5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa 4.6,3.1,1.5,0.2,Iris-setosa 5.0,3.6,1.4,0.2,Iris-setosa 6.3,2.5,5.0,1.9,Iris-virginica 6.5,3.0,5.2,2.0,Iris-virginica 6.2,3.4,5.4,2.3,Iris-virginica 5.9,3.0,5.1,1.8,Iris-virginica
ここで、先頭の5行はデータであるのに対して、末尾の5行はデータのない空の行が一つある。__このからの行を削除し__、新たにiris_rev.data
として保存しておくこと。修正を加えたファイルの中身を覗いてみると以下のとおりである。
! cat data/iris/iris_rev.data | head -n 5
! cat data/iris/iris_rev.data | tail -n 5
5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa 4.6,3.1,1.5,0.2,Iris-setosa 5.0,3.6,1.4,0.2,Iris-setosa 6.7,3.0,5.2,2.3,Iris-virginica 6.3,2.5,5.0,1.9,Iris-virginica 6.5,3.0,5.2,2.0,Iris-virginica 6.2,3.4,5.4,2.3,Iris-virginica 5.9,3.0,5.1,1.8,Iris-virginica
! wc -l data/iris/iris_rev.data
150 data/iris/iris_rev.data
データがきちんと150点あることは上記からわかったので、安心して進むことができる。
NUM_DATA = 150
NUM_TRAIN = 100 # Set manually.
NUM_TEST = NUM_DATA - NUM_TRAIN
NUM_FEATURES = 4
NUM_CLASSES = 3
NUM_LABELS = 1
LABEL_DICT = {"Iris-setosa": 0,
"Iris-versicolor": 1,
"Iris-virginica": 2}
toread = os.path.join("data", "iris", "iris_rev.data")
data_X = np.zeros((NUM_DATA,NUM_FEATURES), dtype=np.float32)
data_y = np.zeros((NUM_DATA,1), dtype=np.int8)
with open(toread, newline="") as f_table:
f_reader = csv.reader(f_table, delimiter=",")
i = 0
for line in f_reader:
data_X[i,:] = np.array(line[0:-1], dtype=data_X.dtype)
data_y[i,:] = np.array(LABEL_DICT[line[-1]], dtype=data_y.dtype)
i += 1
訓練データを読み込んだのだが、検証データとともに一つの階層型ファイルにまとめるために、__PyTables__というパッケージを利用する。
import tables
# Open file connection, writing new file to disk.
myh5 = tables.open_file("data/iris/data.h5",
mode="w",
title="Iris data")
print(myh5) # currently empty.
data/iris/data.h5 (File) 'Iris data' Last modif.: 'Tue Aug 28 15:26:37 2018' Object Tree: / (RootGroup) 'Iris data'
myh5.create_group(myh5.root, "train", "Training data")
myh5.create_group(myh5.root, "test", "Testing data")
print(myh5)
data/iris/data.h5 (File) 'Iris data' Last modif.: 'Tue Aug 28 15:26:37 2018' Object Tree: / (RootGroup) 'Iris data' /test (Group) 'Testing data' /train (Group) 'Training data'
# Training data arrays.
a = tables.Int8Atom()
myh5.create_earray(myh5.root.train,
name="labels",
atom=a,
shape=(0,NUM_LABELS),
title="Label values")
a = tables.Float32Atom()
myh5.create_earray(myh5.root.train,
name="inputs",
atom=a,
shape=(0,NUM_FEATURES),
title="Input images")
# Testing data arrays.
a = tables.Int8Atom()
myh5.create_earray(myh5.root.test,
name="labels",
atom=a,
shape=(0,NUM_LABELS),
title="Label values")
a = tables.Float32Atom()
myh5.create_earray(myh5.root.test,
name="inputs",
atom=a,
shape=(0,NUM_FEATURES),
title="Input images")
print(myh5)
data/iris/data.h5 (File) 'Iris data' Last modif.: 'Tue Aug 28 15:26:37 2018' Object Tree: / (RootGroup) 'Iris data' /test (Group) 'Testing data' /test/inputs (EArray(0, 4)) 'Input images' /test/labels (EArray(0, 1)) 'Label values' /train (Group) 'Training data' /train/inputs (EArray(0, 4)) 'Input images' /train/labels (EArray(0, 1)) 'Label values'
訓練データと検証データに分ける前に、クラス分布が偏らないようにシャッフルしておく。
shufidx = np.random.choice(a=NUM_DATA, size=NUM_DATA, replace=False)
idx_tr = shufidx[0:NUM_TRAIN]
idx_te = shufidx[NUM_TRAIN:]
# Training data
for i in idx_tr:
myh5.root.train.inputs.append([data_X[i,:]])
myh5.root.train.labels.append([data_y[i,:]])
print(myh5)
data/iris/data.h5 (File) 'Iris data' Last modif.: 'Tue Aug 28 15:26:37 2018' Object Tree: / (RootGroup) 'Iris data' /test (Group) 'Testing data' /test/inputs (EArray(0, 4)) 'Input images' /test/labels (EArray(0, 1)) 'Label values' /train (Group) 'Training data' /train/inputs (EArray(100, 4)) 'Input images' /train/labels (EArray(100, 1)) 'Label values'
# Testing data
for i in idx_te:
myh5.root.test.inputs.append([data_X[i,:]])
myh5.root.test.labels.append([data_y[i,:]])
print(myh5)
data/iris/data.h5 (File) 'Iris data' Last modif.: 'Tue Aug 28 15:26:37 2018' Object Tree: / (RootGroup) 'Iris data' /test (Group) 'Testing data' /test/inputs (EArray(50, 4)) 'Input images' /test/labels (EArray(50, 1)) 'Label values' /train (Group) 'Training data' /train/inputs (EArray(100, 4)) 'Input images' /train/labels (EArray(100, 1)) 'Label values'
ファイルとの接続をここで打ち切る。
myh5.close()