pytorchとtimmを使えば、簡単に画像分類モデルを作ることができます。
# resnet18dを作る
model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)
# モデルリストを確認する
timm.list_models()
# 確認には以下も便利
[m for m in timm.list_models() if "efficientnet" in m]
timmには既に作成済みかつ学習済みのモデルが入っているので、手っ取り早いです。

CIFAR10
pytorchのtorchvisionからCIFAR10を引っ張ってきましょう。
import torch
import torchvision
from torchvision.datasets import CIFAR10
train_data = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = True)
valid_data = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = False)
torchvision:pytorchの補助的なライブラリ。
torchvisionのdatasetsにあるCIFAR10をインポートします。
CIFAR10から画像をダウンロードする際に以下の引数を設定してください。
download:Trueにしたらダウンロードできる。
transform:画像を変換する方法。ToTensorにするとpytorchのtensor型になる。
train:Trueで学習用データ、Falseで評価用データをダウンロード。
pytorchはテンソル(tensor)型としてデータを処理します。
pytorchでモデルを作るために必要なことは以下の通りです。
②Datasetを作る⇒今回はCIFAR10で作成済
③DataLoaderを作る
④モデルを定義する
⑤学習と検証
詳しくは以下の記事を参照してください。
DataLoaderで画像データを出力するシステムを作りましょう。
train_loader = torch.utils.data.DataLoader(train_data, batch_size = 128, shuffle = True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = 256, shuffle = False)
shuffle:データをシャッフルするか
画像データは容量が重いので、バッチサイズを指定して小分けして処理します。
batch = next(iter(train_loader))
len(batch)
バッチ処理されたデータは”iter”で取り出せます。
出力されるデータは2つあり、1つ目が画像で2つ目が画像のラベルです。
print(batch[0].shape)
print(batch[1].shape)
shapeでサイズを見ると、
ラベル:128(バッチサイズ)
になっているかと思います。
画像を1枚取り出してみましょう。↓
import matplotlib.pyplot as plt
image = batch[0][0]
image = image.numpy().transpose(1, 2, 0)
plt.imshow(image)
plt.show()

データはtensor型になっているので、numpyに変換しましょう。
画像を表示するには(縦, 横, 色)の順にする必要があるので、transposeで入れ替えました。
犬っぽいですね。
シャッフルされているので同じ画像じゃなくても大丈夫です。
各バッチの2つ目にはラベルが入っているので見てみます。
print(batch[1][0])
犬なら5が出力されます。
つまり犬の画像を渡したら5が出力されるモデルを作るというわけですね。
plt.figure(figsize = (10, 10))
for i in range(16):
image = batch[0][i]
image = image.numpy().transpose(1, 2, 0)
label = batch[1][i]
plt.subplot(4, 4, i + 1)
plt.imshow(image)
plt.title(label)
plt.tight_layout()
plt.show()

画像を複数出力しました。
犬や飛行機や鳥などバラバラで、それぞれ違うラベルを持っていますね。
CIFAR10には番号0~9の10クラスがあります。
モデル作成
早速モデルを作りましょう。
!pip install timm
import timm
“timm”をダウンロードしてインポートします。
この中にいくつもの優秀なアルゴリズムが入っているので、基本的に何もしなくてOKです!
timm.list_models()
上記コードを実行すると使えるアルゴリズムのリストが出てきます。
もし”efficientnet”など特定のカテゴリで検索したいなら、以下のように書きましょう。
[m for m in timm.list_models() if "efficientnet" in m]
有名なのはefficientnetあたりですが、計算が重いので、今回はresnet18dを使いましょう!
model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)
“create_model”でアルゴリズム名を渡すとモデルを作ることができます。
num_classes:ラベル数。今回は10クラス。
timmに入っているモデルは膨大な画像で事前学習されています。
“pretrained”をTrueにするとその学習結果を引き継げるので性能が良いです。
最適化手法, 損失関数
損失関数:どんな指標が良くなったら良いモデルと言えるか
モデルを学習させるには、この2つの設定がマストです。
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()
最適化手法はたくさんあるのでどれが良いかは状況によります。
今回は無難に優秀な”Adam”にしました。
損失関数は予測する問題によって変わります。
複数クラスの分類をする際は”CrossEntropyLoss”を使いましょう。
モデル学習
model.train() #必須
train_loss = 0
for batch in train_loader:
optimizer.zero_grad() #必須
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
loss.backward() #必須
optimizer.step() #必須
train_loss += loss.item()
model.eval() #必須
valid_loss = 0
with torch.no_grad(): #必須
for batch in valid_loader:
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
valid_loss += loss.item()
print(train_loss / len(train_loader))
print(valid_loss / len(valid_loader))
やることは大きく2つに分かれます。
②評価(eval)
①学習
model.train() #必須
train_loss = 0
for batch in train_loader:
optimizer.zero_grad() #必須
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
loss.backward() #必須
optimizer.step() #必須
train_loss += loss.item()
trainで学習モードにします。
for文でデータローダーにアクセスし、バッチを繰り返し取り出しましょう。
“optimizer.zero_grad”は必須です。リセットみたいなもの。
次にバッチから画像とラベルを取り出します。画像はモデルに入れましょう。
各クラス(10クラス)の予測値がでてくるので、これと正解ラベルを損失関数に入れます。
“backward”で損失をモデルに伝え、”step”で改善する感じです。
モデルの成長を見たいので適当な変数(“train_loss”とか)として足していきましょう。
②評価
model.eval() #必須
valid_loss = 0
with torch.no_grad(): #必須
for batch in valid_loader:
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
valid_loss += loss.item()
評価時はevalを実行します。
“torch.no_grad”は忘れがちですが必須です。
optimizer関連とlossのbackwardは不要。
EPOCHS
上記①②のサイクルを複数繰り返してモデルを成長させます。
history = {"train" : [], "valid" : []}
for epoch in range(5):
print("EPOCH", epoch)
model.train() #必須
train_loss = 0
for batch in train_loader:
optimizer.zero_grad() #必須
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
loss.backward() #必須
optimizer.step() #必須
train_loss += loss.item()
model.eval() #必須
valid_loss = 0
with torch.no_grad(): #必須
for batch in valid_loader:
image = batch[0] #(batch_size, channel, size, size)
label = batch[1] #(batch_size)
preds = model(image) #(batch_size, num_class)
loss = criterion(preds, label) #必須
valid_loss += loss.item()
train_loss /= len(train_loader)
valid_loss /= len(valid_loader)
print(train_loss)
print(valid_loss)
history["train"].append(train_loss)
history["valid"].append(valid_loss)
for文で学習サイクルを繰り返しました。
長くなるので今回は5epochにしています。
plt.plot(history["train"], "red", label = "train")
plt.plot(history["valid"], "blue", label = "valid")
plt.legend()
plt.show()

hisotryにlossの履歴を残しているので可視化しました。
赤が学習データで青が評価データです。
学習データは繰り返せば繰り返すほど性能がよくなりますが、過学習すると汎化性能が落ちます。
後半は評価データでの性能が落ちている(lossが増えている)ことがわかりますね。
予測結果
学習したモデルを使って予測してみましょう。
model.eval()
OOF = []
labels = []
with torch.no_grad():
for batch in valid_loader:
image = batch[0]
label = batch[1]
preds = model(image)
OOF.append(preds.numpy())
labels.append(label)
評価時とほとんど同じことをしています。
OOFが予測結果でlabelsが正解データです。リストとして保存しておきました。
import numpy as np
OOF = np.concatenate(OOF)
labels = np.concatenate(labels)
print(OOF.shape, labels.shape)
numpyをインポートしてリストを結合しましょう。
データサイズを見て行方向が10000で同じであるか見ておきます。
print(labels[:10])
print("=" * 50)
print(OOF.argmax(axis = 1)[:10])

予測結果の列数はラベル数(10個)になっており、最も値の大きい列が予測値です。
argmaxで列方向(axis = 1)の最大の位置を取り出しました。
正解ラベルと比較するとあっていそうですね。
from sklearn.metrics import accuracy_score
acc = accuracy_score(labels, OOF.argmax(axis = 1))
print(acc)
“accucary_score”をインポートして正解率を計算しましょう。
だいたい8割強かと思います。
まとめ:画像分類に挑戦しよう
今回はpytorchとtimmを使って画像分類をする方法について解説しました。
性能を上げたいならAugmentationを使ったりschedulerを使ったり色々試すといいですよ。
画像サイズが大きくなったり、精度を高めようとするならGPUが必要になります。
その際はクラウドでGoogle Colabortoryを使うか、PCを自作するといいですよ。
>>【無料説明会あり】キカガクのAI人材育成コースで勉強する

コメント