【Python】画像分類モデルを作る方法|pytorch, timm

Python
Image by Peri Priatna from Pixabay
スポンサーリンク

pytorchとtimmを使えば、簡単に画像分類モデルを作ることができます。

model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)

timmには既に作成済みかつ学習済みのモデルが入っているので、手っ取り早いです。

CIFAR10

pytorchのtorchvisionからCIFAR10を引っ張ってきましょう。

import torch
import torchvision
from torchvision.datasets import CIFAR10
train_data = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = True)
valid_data = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = False)
torch:pytorchのライブラリ。
torchvision:pytorchの補助的なライブラリ。

torchvisionのdatasetsにあるCIFAR10をインポートします。

CIFAR10から画像をダウンロードする際に以下の引数を設定してください。

root:画像を一時的にダウンロードする保存先。
download:Trueにしたらダウンロードできる。
transform:画像を変換する方法。ToTensorにするとpytorchのtensor型になる。
train:Trueで学習用データ、Falseで評価用データをダウンロード。

pytorchはテンソル(tensor)型としてデータを処理します。

pytorchでモデルを作るために必要なことは以下の通りです。

①データを用意する
②Datasetを作る⇒今回はCIFAR10で作成済
③DataLoaderを作る
④モデルを定義する
⑤学習と検証

詳しくは以下の記事を参照してください。

DataLoaderで画像データを出力するシステムを作りましょう。

train_loader = torch.utils.data.DataLoader(train_data, batch_size = 128, shuffle = True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = 256, shuffle = False)
batch_size:一度に何枚の画像を取り出すか
shuffle:データをシャッフルするか

画像データは容量が重いので、バッチサイズを指定して小分けして処理します。

batch = next(iter(train_loader))
len(batch)

バッチ処理されたデータは”iter”で取り出せます。

出力されるデータは2つあり、1つ目が画像で2つ目が画像のラベルです。

print(batch[0].shape)
print(batch[1].shape)

shapeでサイズを見ると、

画像:128, 3, 32, 32(バッチサイズ, 色, 縦, 横)
ラベル:128(バッチサイズ)

になっているかと思います。

画像を1枚取り出してみましょう。↓

import matplotlib.pyplot as plt
image = batch[0][0]
image = image.numpy().transpose(1, 2, 0)
plt.imshow(image)
plt.show()

データはtensor型になっているので、numpyに変換しましょう。

画像を表示するには(縦, 横, 色)の順にする必要があるので、transposeで入れ替えました。

犬っぽいですね。

シャッフルされているので同じ画像じゃなくても大丈夫です。

各バッチの2つ目にはラベルが入っているので見てみます。

print(batch[1][0])

犬なら5が出力されます。

つまり犬の画像を渡したら5が出力されるモデルを作るというわけですね。

plt.figure(figsize = (10, 10))
for i in range(16):
    image = batch[0][i]
    image = image.numpy().transpose(1, 2, 0)
    label = batch[1][i]
    plt.subplot(4, 4, i + 1)
    plt.imshow(image)
    plt.title(label)
plt.tight_layout()
plt.show()

画像を複数出力しました。

犬や飛行機や鳥などバラバラで、それぞれ違うラベルを持っていますね。

CIFAR10には番号0~9の10クラスがあります。

モデル作成

早速モデルを作りましょう。

!pip install timm
import timm

“timm”をダウンロードしてインポートします。

この中にいくつもの優秀なアルゴリズムが入っているので、基本的に何もしなくてOKです!

from pprint import pprint
pprint(timm.list_models(pretrained = True))

上記コードを実行すると使えるアルゴリズムのリストが出てきます。

有名なのはefficientnetあたりですが、計算が重いので、今回はresnet18dを使いましょう!

model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)

“create_model”でアルゴリズム名を渡すとモデルを作ることができます。

pretrained:事前学習をするか。基本True。
num_classes:ラベル数。今回は10クラス。

timmに入っているモデルは膨大な画像で事前学習されています。

“pretrained”をTrueにするとその学習結果を引き継げるので性能が良いです。

最適化手法, 損失関数

最適化手法:どんな手法でモデルを改善するか
損失関数:どんな指標が良くなったら良いモデルと言えるか

モデルを学習させるには、この2つの設定がマストです。

optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

最適化手法はたくさんあるのでどれが良いかは状況によります。

今回は無難に優秀な”Adam”にしました。

損失関数は予測する問題によって変わります。

複数クラスの分類をする際は”CrossEntropyLoss”を使いましょう。

モデル学習

model.train() #必須
train_loss = 0
for batch in train_loader:
    optimizer.zero_grad() #必須
    image = batch[0] #(batch_size, channel, size, size)
    label = batch[1] #(batch_size)
    preds = model(image) #(batch_size, num_class)
    loss = criterion(preds, label) #必須
    loss.backward() #必須
    optimizer.step() #必須
    train_loss += loss.item()

model.eval() #必須
valid_loss = 0
with torch.no_grad(): #必須
    for batch in valid_loader:
        image = batch[0] #(batch_size, channel, size, size)
        label = batch[1] #(batch_size)
        preds = model(image) #(batch_size, num_class)
        loss = criterion(preds, label) #必須
        valid_loss += loss.item()

print(train_loss / len(train_loader))
print(valid_loss / len(valid_loader))

やることは大きく2つに分かれます。

①学習(train)
②評価(eval)

①学習

model.train() #必須
train_loss = 0
for batch in train_loader:
    optimizer.zero_grad() #必須
    image = batch[0] #(batch_size, channel, size, size)
    label = batch[1] #(batch_size)
    preds = model(image) #(batch_size, num_class)
    loss = criterion(preds, label) #必須
    loss.backward() #必須
    optimizer.step() #必須
    train_loss += loss.item()

trainで学習モードにします。

for文でデータローダーにアクセスし、バッチを繰り返し取り出しましょう。

“optimizer.zero_grad”は必須です。リセットみたいなもの。

次にバッチから画像とラベルを取り出します。画像はモデルに入れましょう。

各クラス(10クラス)の予測値がでてくるので、これと正解ラベルを損失関数に入れます。

“backward”で損失をモデルに伝え、”step”で改善する感じです。

モデルの成長を見たいので適当な変数(“train_loss”とか)として足していきましょう。

②評価

model.eval() #必須
valid_loss = 0
with torch.no_grad(): #必須
    for batch in valid_loader:
        image = batch[0] #(batch_size, channel, size, size)
        label = batch[1] #(batch_size)
        preds = model(image) #(batch_size, num_class)
        loss = criterion(preds, label) #必須
        valid_loss += loss.item()

評価時はevalを実行します。

“torch.no_grad”は忘れがちですが必須です。

optimizer関連とlossのbackwardは不要。

EPOCHS

上記①②のサイクルを複数繰り返してモデルを成長させます。

history = {"train" : [], "valid" : []}
for epoch in range(5):
    print("EPOCH", epoch)
    model.train() #必須
    train_loss = 0
    for batch in train_loader:
        optimizer.zero_grad() #必須
        image = batch[0] #(batch_size, channel, size, size)
        label = batch[1] #(batch_size)
        preds = model(image) #(batch_size, num_class)
        loss = criterion(preds, label) #必須
        loss.backward() #必須
        optimizer.step() #必須
        train_loss += loss.item()

    model.eval() #必須
    valid_loss = 0
    with torch.no_grad(): #必須
        for batch in valid_loader:
            image = batch[0] #(batch_size, channel, size, size)
            label = batch[1] #(batch_size)
            preds = model(image) #(batch_size, num_class)
            loss = criterion(preds, label) #必須
            valid_loss += loss.item()

    train_loss /= len(train_loader)
    valid_loss /= len(valid_loader)
    print(train_loss)
    print(valid_loss)
    history["train"].append(train_loss)
    history["valid"].append(valid_loss)

for文で学習サイクルを繰り返しました。

長くなるので今回は5epochにしています。

plt.plot(history["train"], "red", label = "train")
plt.plot(history["valid"], "blue", label = "valid")
plt.legend()
plt.show()

hisotryにlossの履歴を残しているので可視化しました。

赤が学習データで青が評価データです。

学習データは繰り返せば繰り返すほど性能がよくなりますが、過学習すると汎化性能が落ちます。

後半は評価データでの性能が落ちている(lossが増えている)ことがわかりますね。

予測結果

学習したモデルを使って予測してみましょう。

model.eval()
OOF = []
labels = []
with torch.no_grad():
    for batch in valid_loader:
        image = batch[0]
        label = batch[1]
        preds = model(image)
        OOF.append(preds.numpy())
        labels.append(label)

評価時とほとんど同じことをしています。

OOFが予測結果でlabelsが正解データです。リストとして保存しておきました。

import numpy as np
OOF = np.concatenate(OOF)
labels = np.concatenate(labels)
print(OOF.shape, labels.shape)

numpyをインポートしてリストを結合しましょう。

データサイズを見て行方向が10000で同じであるか見ておきます。

print(labels[:10])
print("=" * 50)
print(OOF.argmax(axis = 1)[:10])

予測結果の列数はラベル数(10個)になっており、最も値の大きい列が予測値です。

argmaxで列方向(axis = 1)の最大の位置を取り出しました。

正解ラベルと比較するとあっていそうですね。

from sklearn.metrics import accuracy_score
acc = accuracy_score(labels, OOF.argmax(axis = 1))
print(acc)

“accucary_score”をインポートして正解率を計算しましょう。

だいたい8割強かと思います。

まとめ:画像分類に挑戦しよう

今回はpytorchとtimmを使って画像分類をする方法について解説しました。

意外と簡単に実装できたのではないでしょうか。

性能を上げたいならAugmentationを使ったりschedulerを使ったり色々試すといいですよ。

コメント

タイトルとURLをコピーしました