メインコンテンツへスキップ
Try in Colab PyTorch Lightning を使用して画像分類パイプラインを構築します。コードの可読性と再現性を高めるために、この スタイルガイド に従います。これに関する詳しい解説は こちら で確認できます。

PyTorch Lightning と W&B のセットアップ

このチュートリアルでは、PyTorch Lightning と W&B が必要です。
pip install lightning -q
pip install wandb -qU
import lightning.pytorch as pl

# お気に入りの機械学習トラッキングツール
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb
次に、wandb アカウントにログインする必要があります。
wandb.login()

DataModule - 私たちが求めるデータパイプライン

DataModule は、データ関連のフックを LightningModule から切り離す方法であり、データセットに依存しないモデルを開発できるようにします。 これにより、データパイプラインを共有可能で再利用可能な一つのクラスに整理できます。DataModule は、PyTorch におけるデータプロセッシングの 5 つのステップをカプセル化します:
  • ダウンロード / トークン化 / 処理
  • クリーニング、および(必要に応じて)ディスクへの保存
  • Dataset へのロード
  • 変換(回転、トークン化など)の適用
  • DataLoader へのラップ
DataModule についての詳細は こちら をご覧ください。CIFAR-10 データセット用の DataModule を構築してみましょう。
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.num_classes = 10
    
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # データローダーで使用するトレーニング/検証用データセットを割り当て
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # データローダーで使用するテスト用データセットを割り当て
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

Callbacks

コールバックは、プロジェクト間で再利用可能な自己完結型のプログラムです。PyTorch Lightning には、よく使用されるいくつかの 組み込みコールバック が用意されています。 PyTorch Lightning のコールバックについての詳細は こちら をご覧ください。

組み込みコールバック

このチュートリアルでは、Early StoppingModel Checkpoint の組み込みコールバックを使用します。これらは Trainer に渡すことができます。

カスタムコールバック

Keras のカスタムコールバックに慣れているなら、PyTorch のパイプラインでも同じことができるのは非常に魅力的です。 今回は画像分類を行っているので、いくつかの画像サンプルに対するモデルの予測を可視化できると便利です。これをコールバックの形式にすることで、早い段階でのモデルのデバッグに役立ちます。
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # テンソルを CPU に移動
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # モデルの予測を取得
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # 画像を wandb Image としてログ
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })
        

LightningModule - システムの定義

LightningModule はモデルではなく「システム」を定義します。ここでは、すべての研究コードを一つのクラスにグループ化し、自己完結型にすることをシステムと呼びます。LightningModule は、PyTorch のコードを以下の 5 つのセクションに整理します:
  • 計算(__init__
  • トレーニングループ(training_step
  • 検証ループ(validation_step
  • テストループ(test_step
  • オプティマイザー(configure_optimizers
これにより、簡単に共有可能なデータセットに依存しないモデルを構築できます。CIFAR-10 分類用のシステムを構築しましょう。
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # ハイパーパラメーターのログ記録
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    # conv ブロックから Linear レイヤーに入る出力テンソルのサイズを返す
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # conv ブロックからの特徴テンソルを返す
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # 推論時に使用
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # トレーニングメトリクス
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 検証メトリクス
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 検証メトリクス
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

トレーニングと評価

DataModule を使ってデータパイプラインを、LightningModule を使ってモデルアーキテクチャーとトレーニングループを整理しました。あとは PyTorch Lightning の Trainer がすべてを自動化してくれます。 Trainer が自動化するもの:
  • エポックとバッチのイテレーション
  • optimizer.step()backwardzero_grad() の呼び出し
  • .eval() の呼び出し、勾配の有効化/無効化
  • 重みの保存とロード
  • W&B へのログ記録
  • マルチ GPU トレーニングのサポート
  • TPU のサポート
  • 16-bit トレーニングのサポート
dm = CIFAR10DataModule(batch_size=32)
# データローダーにアクセスするために prepare_data と setup を呼び出す必要があります。
dm.prepare_data()
dm.setup()

# カスタム ImagePredictionLogger コールバックが画像予測をログするために必要なサンプル
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)

# wandb logger の初期化
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# コールバックの初期化
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# trainer の初期化
trainer = pl.Trainer(max_epochs=2,
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# モデルのトレーニング
trainer.fit(model, dm)

# ホールドアウトされたテストセットでモデルを評価 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# wandb run を終了
run.finish()

最後に

私は TensorFlow/Keras エコシステムの出身で、PyTorch は洗練されたフレームワークであるものの、少し敷居が高いと感じていました(あくまで個人的な経験です)。しかし、PyTorch Lightning を試してみると、PyTorch から遠ざかっていた理由のほとんどが解消されていることに気づきました。私の感動を簡単にまとめます:
  • 以前:従来の PyTorch のモデル定義はバラバラになりがちでした。モデルは model.py スクリプトにあり、トレーニングループは train.py ファイルにあるといった具合です。パイプラインを理解するために何度もファイルを行き来する必要がありました。
  • 現在:LightningModule がシステムとして機能し、モデルが training_stepvalidation_step などと一緒に定義されます。モジュール化され、共有しやすくなりました。
  • 以前:TensorFlow/Keras の最大の魅力は入力データパイプラインでした。データセットカタログが豊富で成長し続けています。PyTorch のデータパイプラインは最大の懸念点でした。通常の PyTorch コードでは、データのダウンロード/クリーニング/準備が多くのファイルに分散していることがよくあります。
  • 現在:DataModule がデータパイプラインを一つの共有・再利用可能なクラスにまとめます。これは単に train_dataloaderval_dataloadertest_dataloader と、それに対応する変換やデータ処理/ダウンロードステップの集まりにすぎません。
  • 以前:Keras では、model.fit でトレーニング、model.predict で推論を実行できました。model.evaluate はテストデータに対するシンプルで使い勝手の良い評価を提供していました。PyTorch ではそうはいかず、通常は別々の train.pytest.py ファイルが必要でした。
  • 現在:LightningModule を導入すれば、Trainer がすべてを自動化します。trainer.fittrainer.test を呼び出すだけで、モデルのトレーニングと評価が完了します。
  • 以前:TensorFlow は TPU が大好きですが、PyTorch は…
  • 現在:PyTorch Lightning を使えば、同じモデルを複数の GPU や TPU でトレーニングするのが非常に簡単です。
  • 以前:私はコールバックの大ファンで、カスタムコールバックを書くのが好きです。Early Stopping のような些細なことでさえ、従来の PyTorch では議論の的になることがありました。
  • 現在:PyTorch Lightning では Early Stopping や Model Checkpointing を使うのは非常に簡単です。カスタムコールバックを書くこともできます。

🎨 結論とリソース

このレポートがお役に立てば幸いです。ぜひコードを動かして、お好みのデータセットで画像分類器をトレーニングしてみてください。 PyTorch Lightning についてさらに詳しく学ぶためのリソースをいくつか紹介します: