【Python】pytorchで画像分類モデルを作る方法

Python
スポンサーリンク

pytorchとtimmを使えば、簡単に画像分類モデルを作ることができます。

# resnet18dを作る
model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)

# モデルリストを確認する
timm.list_models()

# 確認には以下も便利
[m for m in timm.list_models() if "efficientnet" in m]

timmには既に作成済みかつ学習済みのモデルが入っているので、手っ取り早いです。

当ブログでは、もし私が機械学習初心者だったらどう勉強するかも解説しています。
>>コスパよく機械学習を勉強するロードマップ

Dataset

import torch
import torchvision
from torchvision.datasets import CIFAR10

train = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = True)
test = CIFAR10(root = "data", download = True, transform = torchvision.transforms.ToTensor(), train = False)

torchvisionのdatasetsにあるCIFAR10をインポートします。
“data”フォルダができて、その中に画像データがダウンロードされます。
“train”が学習用、”test”が検証用データです。

import matplotlib.pyplot as plt
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(train[0][0].permute(1, 2, 0))
plt.title(train[1][1])
plt.subplot(1, 2, 2)
plt.imshow(test[0][0].permute(1, 2, 0))
plt.title(test[1][1])
plt.show()
画像サンプル

それぞれこのように32 x 32の画像が入っています。
写真の種類は数値ラベルとして保存されています。
カエルが6で、3は猫らしいです。

plt.figure()
for i in range(9):
    plt.subplot(3, 3, i + 1)
    plt.imshow(train[i][0].permute(1, 2, 0))
    plt.title(train[i][1])
plt.tight_layout()
plt.show()
画像サンプル9枚

このように写真の種類別にラベルが振り分けられています。
CIFAR10は0~9までの10種類の画像があるデータセットです。

DataLoader

train_loader = torch.utils.data.DataLoader(train, batch_size = 256, shuffle = True)
test_loader = torch.utils.data.DataLoader(test, batch_size = 256 * 2, shuffle = False)

画像データは容量が重いので、バッチサイズを指定して小分けして処理します。
“batch_size”を256にしたので、1度に256枚の画像を出力します。

batch = next(iter(train_loader))
print(len(batch))
print(batch[0].shape, batch[1].shape)

# ========== output ==========
# 2
# torch.Size([256, 3, 32, 32]) torch.Size([256])

1つ目のバッチを取り出しました。
(画像, ラベル)の形式で入っており、それぞれ256個のデータになっています。

モデル作成

#!pip install timm # timmをインストールしていない場合
import timm

timmは優秀な画像モデルを使うことができるライブラリです。

[m for m in timm.list_models() if "resnet" in m]

# ========== output ==========
# ['cspresnet50',
#  'cspresnet50d',
#  'cspresnet50w',
#  'eca_resnet33ts',
#  'ecaresnet26t',
# ...

“list_models”を使うと、どんなモデルが入っているかを見ることができます。
数が多いので、上記コードでは”resnet”が入っているものに絞っています。
有名なのはefficientnetあたりですが、計算が重いのでresnet18dを使うことにしましょう。

model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)

“create_model”でアルゴリズム名を渡すとモデルを作ることができます。
“pretrained”をTrueにすると事前学習されたモデルが使えるので、精度が高いです。
“num_classes”は学習するラベル数なので、今回は10ですね。

学習と検証

from timm.utils import AverageMeter
from tqdm import tqdm

# 最適化手法
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

# 損失関数
criterion = torch.nn.CrossEntropyLoss()

# ログ記録用の変数
history = {"train": [], "test": []}

# 学習回数
for epoch in range(5):
    print("\nEpoch:", epoch)

    # 学習
    model.train()
    train_loss = AverageMeter()
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        image = batch[0] #(batch_size, channel, size, size)
        label = batch[1] #(batch_size)
        preds = model(image) #(batch_size, num_class)
        loss = criterion(preds, label)
        loss.backward()
        optimizer.step()
        train_loss.update(val = loss.item(), n = len(image))

    # 検証
    model.eval()
    test_loss = AverageMeter()
    with torch.no_grad():
        for batch in tqdm(test_loader):
            image = batch[0] #(batch_size, channel, size, size)
            label = batch[1] #(batch_size)
            preds = model(image) #(batch_size, num_class)
            loss = criterion(preds, label)
            test_loss.update(val = loss.item(), n = len(image))

    # 誤差出力
    print(train_loss.avg)
    print(test_loss.avg)
    history["train"].append(train_loss.avg)
    history["test"].append(test_loss.avg)

最適化手法と損失関数を定義し、学習と検証を”epoch”の数だけ繰り返します。
pytorchの使い方自体の説明は、ここでは割愛します。
【関連記事】pytorchで機械学習モデルを作る方法

plt.plot(history["train"], label = "train")
plt.plot(history["test"], label = "test")
plt.legend()
plt.show()
学習経過

このように誤差(loss)が徐々に下がっています。

予測結果

model.eval()
preds = []
labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        image = batch[0]
        label = batch[1]
        # 最も値の大きい列番号
        preds += model(image).numpy().argmax(axis = 1).tolist()
        # 答え
        labels += label.numpy().tolist()

modelに画像データを入れたら予測を出力します。
ただし、返り値は(データ数, ラベル数)になっています。
順当にいけば最も値が大きい列番号が予測ラベルになるので、argmaxで取り出しました。

print(preds[:5])
print(labels[:5])

# ========== output ==========
# [5, 8, 8, 0, 6]
# [3, 8, 8, 0, 6]

予測結果はこの通り。だいたい正解していそうですね。

from sklearn.metrics import accuracy_score
print(accuracy_score(labels, preds))

# ========== output ==========
# 0.7434

正解率は74%でした。

Data Augmentation

image = train[0][0]
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(torchvision.transforms.RandomHorizontalFlip(p = 1)(image).permute(1, 2, 0))
plt.show()
水平反転した画像

画像モデルの学習ではData Augmentationが有効です。

例えばtorchvisionにあるRandomHorizontalFlipを使うと、水平方向に反転してくれます。
この画像は左右どちらを向いていてもカエルであることは変わりません。

なので、モデルとしてはどちらを向いていてもカエルであると学習してくれます。

model = timm.create_model("resnet18d", pretrained = True, num_classes = 10)

# 最適化手法
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

# 損失関数
criterion = torch.nn.CrossEntropyLoss()

# ログ記録用の変数
history = {"train": [], "test": []}

# 学習回数
for epoch in range(5):
    print("\nEpoch:", epoch)

    # 学習
    model.train()
    train_loss = AverageMeter()
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        image = batch[0] #(batch_size, channel, size, size)
        label = batch[1] #(batch_size)

        image = torchvision.transforms.RandomHorizontalFlip()(image)

        preds = model(image) #(batch_size, num_class)
        loss = criterion(preds, label)
        loss.backward()
        optimizer.step()
        train_loss.update(val = loss.item(), n = len(image))

    # 検証
    model.eval()
    test_loss = AverageMeter()
    with torch.no_grad():
        for batch in tqdm(test_loader):
            image = batch[0] #(batch_size, channel, size, size)
            label = batch[1] #(batch_size)
            preds = model(image) #(batch_size, num_class)
            loss = criterion(preds, label)
            test_loss.update(val = loss.item(), n = len(image))

    # 誤差出力
    print(train_loss.avg)
    print(test_loss.avg)
    history["train"].append(train_loss.avg)
    history["test"].append(test_loss.avg)

plt.plot(history["train"], label = "train")
plt.plot(history["test"], label = "test")
plt.legend()
plt.show()
水平反転も含めた学習経過

学習データでのみAugmentationを追加しましょう。
デフォルトでは50%の確立で反転してくれます。
ぱっと見ですが、先ほどよりもlossが下がり、0.8を切っていますね。

model.eval()
preds = []
labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        image = batch[0]
        label = batch[1]
        # 最も値の大きい列番号
        preds += model(image).numpy().argmax(axis = 1).tolist()
        # 答え
        labels += label.numpy().tolist()

print(accuracy_score(labels, preds))

# ========== output ==========
# 0.7525

このように正解率も上がっています。
他にも上下反転や回転、色の変更など様々な変換があるので、調べて試してみましょう。

まとめ

今回はpytorchとtimmを使って画像分類をする方法について解説しました。
画像サイズが大きくなったり、精度を高めようとするならGPUが必要になります。
その際はGoogle ColabortoryなどのクラウドGPUを使うか、PCを自作するといいですよ。

なんか適当に独学してるだけで、

どうやって勉強を進めたらいいかわからんな。。。

と悩んでいる人向けに、
もし私が初心者ならどう勉強するかを解説しているので、参考にどうぞ。
>>コスパよく初心者が機械学習を勉強する方法|ロードマップ

コメント

タイトルとURLをコピーしました