前回は、MNIST(エムニスト)の手書きの数字の「画像とラベル」のデータベースを使って、数字を画像表示したりした。
今回は、画像データにたいして数字の分類をしてみよう。
テストデータの準備
MNISTのデータセットは、すでにトレーニングとテストのセットに分けられている。はじめの60,000がトレーニングデータで、それ以降がテストデータだ。
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
前回は、数字の5をsome_digit_imageに代入したので、それを使おう。
数字画像の分類
まずは数字が5かどうかの判定をしてみよう。
どうやってプログラミングするの?
ここでは、クラス分類をするためにScikit-LearnのSGD(stochastic gradient descent)というライブラリを使ってみよう。日本語だと、確率的勾配降下というらしいが、わかりずらいネーミング。SGDClassifierは、トレーニング中にランダムにシャッフルするようだが、同じ結果を再現させるには、random_stateを記述しよう。
fit()メソッドによってトレーニングしよう。fit(X,y)では、Xがトレーニングデータで、yがターゲットバリューとなる。
from sklearn.linear_model import SGDClassifier
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
print("y_train_5: {}\ny_test_5: {}\ny_train_5.shape: {}".format(y_train_5, y_test_5, y_train_5.shape))
sgd_clf = SGDClassifier(random_state = 42, max_iter = 500)
sgd_clf.fit(X_train, y_train_5)
上記コードを実行するとつぎのようなエラーがでるだろう。
525 max_iter=max_iter)
526 else:
--> 527 raise ValueError(
528 "The number of classes has to be greater than one;"
529 " got %d class" % n_classes)
ValueError: The number of classes has to be greater than one; got 1 class
何がいけないの?
print(np.unique(y_train_5))
とすると[False]しかなく、文字どおり1つのクラスしかないことが確認できる。
どういうこと?
これは、labelが文字列なので、このようにinteger(unit8)にキャストしないといけないよ。前回のコードも含めて全体を表示するとこんな感じになる。
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
import numpy as np
mnist = fetch_openml('mnist_784')
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
sgd_clf = SGDClassifier(random_state = 42. max_iter = 500)
sgd_clf.fit(X_train, y_train_5)
では、ホントに画像から数字を認識しているかチェックしてみよう。
some_digit = X[0]
sgd_clf.predict([some_digit])
正しく判定しているようだ。
まとめ
MNISTのデータセットをトレーニングとテストのセットに分け、画像が数字の5かどうかの判定を実施した。
すでにクラス分類をするためにScikit-LearnのSGDというライブラリがあるので、それを利用してトレーニングと予測をしてみた。
次回は、結果についてもうすこし詳しく見てみよう。
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html
https://scikit-learn.org/stable/modules/generated
/sklearn.linear_model.SGDClassifier.html#sklearn.linear_model.SGDClassifier.fit
Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition by Aurélien Géron Published by O’Reilly Media, Inc., 2019, Ch3
コメント