技術

MNISTを使用した数字認識

みやもと

みやもと

2020/08/05

皆さん、おはこんばんにちは!
梅雨が明け気温が上がり、全く外出をしなくなったみやもとです。

今回は、「MNIST」というデータセットを使用して数字の画像認識を行っていきたいと思います。

1.MNIST

MNISTとは

MNIST(Mixed National Institute of Standards and Technology database)とは、手書き数字画像60,000枚と、テスト画像10,000枚を集めた、画像データセットです。さらに、手書きの数字「0〜9」に正解ラベルが与えられるデータセットでもあり、画像分類問題で人気の高いデータセットです。

MNISTは深層学習(ディープラーニング)の手法の1つであるニューラルネットワークを学ぶ上でも便利ですし、機械学習の入門のデータセットとしてもよく使われています。

手軽に入手できる点も含めて、人工知能(AI)の勉強を始める入り口としてMNISTは人気の高いデータセットです。


機械学習で便利な画像データセット「MNIST」を丁寧に解説!より引用

サンプルコード実行

参考にしている本から一部コードを抜粋してきましたが、MNISTを使用した数字認識は下記のようにコードを書きます。重みやバイアスは、すでに作成されたものを使用していますが、それ以外の処理は下記の数行で数字認識を行うことができます。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
from PIL import Image

def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0

y = predict(network, x[0])
p = np.argmax(y)

img = x[0]
img = img.reshape(28, 28)

img_show(img)
print(p)

上記のコードを動かしたのがこの動画になります。
実際のコードは正答率を出すものですが少し修正して読み込まれた画像の表示と認識された数字を出力しています。修正前のコードを動かすと正答率は約70%程で正答率はあまり高くはありません。

まとめ

少しずつAI分野の理解が深まったように感じます。
しかし、まだまだ表面部分しか触れていないので、今後は数字認識で使用しているソフトマックス関数やシグモイド関数などを理解して機能を実装できるように勉強を続けていきます。

参考

ゼロから作る Deep Learning – deep-learning-from-scratch