Machine Learning 6: 10分でできる手書きの数字の読み方(2)

IT

前回は、MNIST(エムニスト)の手書きの数字の「画像とラベル」のデータベースを使って、数字を画像表示したりした。

今回は、画像データにたいして数字の分類をしてみよう。

テストデータの準備

MNISTのデータセットは、すでにトレーニングとテストのセットに分けられている。はじめの60,000がトレーニングデータで、それ以降がテストデータだ。

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

前回は、数字の5をsome_digit_imageに代入したので、それを使おう。

数字画像の分類

まずは数字が5かどうかの判定をしてみよう。

どうやってプログラミングするの?

ここでは、クラス分類をするためにScikit-LearnのSGD(stochastic gradient descent)というライブラリを使ってみよう。日本語だと、確率的勾配降下というらしいが、わかりずらいネーミング。SGDClassifierは、トレーニング中にランダムにシャッフルするようだが、同じ結果を再現させるには、random_stateを記述しよう。

fit()メソッドによってトレーニングしよう。fit(X,y)では、Xがトレーニングデータで、yがターゲットバリューとなる。

from sklearn.linear_model import SGDClassifier

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

print("y_train_5: {}\ny_test_5: {}\ny_train_5.shape: {}".format(y_train_5, y_test_5, y_train_5.shape))

sgd_clf = SGDClassifier(random_state = 42, max_iter = 500)

sgd_clf.fit(X_train, y_train_5)

上記コードを実行するとつぎのようなエラーがでるだろう。

    525                              max_iter=max_iter)
    526         else:
--> 527             raise ValueError(
    528                 "The number of classes has to be greater than one;"
    529                 " got %d class" % n_classes)

ValueError: The number of classes has to be greater than one; got 1 class

何がいけないの?

print(np.unique(y_train_5))

とすると[False]しかなく、文字どおり1つのクラスしかないことが確認できる。

どういうこと?

これは、labelが文字列なので、このようにinteger(unit8)にキャストしないといけないよ。前回のコードも含めて全体を表示するとこんな感じになる。

from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
import numpy as np

mnist = fetch_openml('mnist_784')
X, y = mnist["data"], mnist["target"]

y = y.astype(np.uint8)

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)


sgd_clf = SGDClassifier(random_state = 42. max_iter = 500)

sgd_clf.fit(X_train, y_train_5)

では、ホントに画像から数字を認識しているかチェックしてみよう。

some_digit = X[0]
sgd_clf.predict([some_digit])

正しく判定しているようだ。

まとめ

MNISTのデータセットをトレーニングとテストのセットに分け、画像が数字の5かどうかの判定を実施した。

すでにクラス分類をするためにScikit-LearnのSGDというライブラリがあるので、それを利用してトレーニングと予測をしてみた。

次回は、結果についてもうすこし詳しく見てみよう。

参考:
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html

https://scikit-learn.org/stable/modules/generated
/sklearn.linear_model.SGDClassifier.html#sklearn.linear_model.SGDClassifier.fit

Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition by Aurélien Géron Published by O’Reilly Media, Inc., 2019, Ch3

コメント

タイトルとURLをコピーしました