メインコンテンツへスキップ
Try in Colab PyTorch Lightning は、PyTorch コードを整理し、分散トレーニングや 16-bit 精度などの高度な機能を簡単に追加するための軽量なラッパーを提供します。W&B は、ML 実験をログに記録するための軽量なラッパーを提供します。これら 2 つを自分自身で組み合わせる必要はありません。W&B は、WandbLogger を介して PyTorch Lightning ライブラリに直接組み込まれています。

Lightning とのインテグレーション

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

# すべてのモデルをログに記録するように設定してインスタンス化
wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)
wandb.log() の使用: WandbLogger は Trainer の global_step を使用して W&B にログを記録します。コード内で直接 wandb.log を追加で呼び出す場合は、wandb.log()step 引数を 使用しないでください代わりに、他のメトリクスと同様に Trainer の global_step をログに記録してください:
wandb.log({"accuracy":0.99, "trainer/global_step": step})
Interactive dashboards

サインアップと APIキー の作成

APIキー は、お使いのマシンを W&B に対して認証します。ユーザープロファイルから APIキー を生成できます。
For a more streamlined approach, create an API key by going directly to User Settings. Copy the newly created API key immediately and save it in a secure location such as a password manager.
  1. 右上隅にあるユーザープロファイルアイコンをクリックします。
  2. User Settings を選択し、API Keys セクションまでスクロールします。

wandb ライブラリのインストールとログイン

wandb ライブラリをローカルにインストールしてログインするには:
  1. WANDB_API_KEY 環境変数 に APIキー を設定します。
    export WANDB_API_KEY=<your_api_key>
    
  2. wandb ライブラリをインストールしてログインします。
    pip install wandb
    
    wandb login
    

PyTorch Lightning の WandbLogger を使用する

PyTorch Lightning には、メトリクス、モデルの重み、メディアなどをログに記録するための複数の WandbLogger クラスがあります。 Lightning と統合するには、WandbLogger をインスタンス化し、Lightning の Trainer または Fabric に渡します。
trainer = Trainer(logger=wandb_logger)

一般的なロガー引数

以下は WandbLogger で最もよく使用されるパラメータの一部です。すべてのロガー引数の詳細については、PyTorch Lightning のドキュメントを確認してください。
パラメータ説明
projectログを記録する W&B の Projects を定義します
nameW&B の Runs に名前を付けます
log_modellog_model="all" の場合はすべてのモデルを、log_model=True の場合はトレーニング終了時にログを記録します
save_dirデータが保存されるパス

ハイパーパラメータのログ記録

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        # ハイパーパラメータを保存(W&Bによって自動的にログ記録されます)
        self.save_hyperparameters()

追加のコンフィグパラメータのログ記録

# 単一のパラメータを追加
wandb_logger.experiment.config["key"] = value

# 複数のパラメータを追加
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# wandb モジュールを直接使用
wandb.config["key"] = value
wandb.config.update()

勾配、パラメータのヒストグラム、モデル構造のログ記録

モデルオブジェクトを wandblogger.watch() に渡すことで、トレーニング中にモデルの勾配とパラメータを監視できます。詳細は PyTorch Lightning の WandbLogger ドキュメントを参照してください。

メトリクスのログ記録

WandbLogger を使用している場合、training_stepvalidation_step メソッドなどの LightningModule 内で self.log('my_metric_name', metric_vale) を呼び出すことで、メトリクスを W&B にログ記録できます。以下のコードスニペットは、メトリクスと LightningModule のハイパーパラメータをログに記録するための LightningModule の定義方法を示しています。この例では、torchmetrics ライブラリを使用してメトリクスを計算しています。
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """モデルパラメータを定義するために使用されるメソッド"""
        super().__init__()

        # mnist 画像は (1, 28, 28) (チャンネル, 幅, 高さ)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # ハイパーパラメータを self.hparams に保存(W&Bによって自動ログ記録)
        self.save_hyperparameters()

    def forward(self, x):
        """推論に使用されるメソッド input -> output"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # 3 x (linear + relu) を実行
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """単一バッチから損失を返す必要がある"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # 損失とメトリクスをログ記録
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """メトリクスのログ記録に使用"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # 損失とメトリクスをログ記録
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """モデルのオプティマイザーを定義"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """train/valid/test ステップが類似しているための便利な関数"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

メトリクスの最小値/最大値のログ記録

W&B の define_metric 関数を使用すると、W&B サマリーメトリクスに、そのメトリクスの最小値、最大値、平均値、または最良値のどれを表示するかを定義できます。define_metric が使用されない場合は、最後にログに記録された値がサマリーメトリクスに表示されます。詳細は define_metricリファレンスドキュメント および ガイド を参照してください。 W&B サマリーメトリクスで最大検証精度を追跡するように W&B に指示するには、トレーニングの開始時に一度だけ wandb.define_metric を呼び出します。
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            # 検証精度の最大値を追跡するように定義
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # 損失とメトリクスをログ記録
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

モデルのチェックポイント作成

モデルのチェックポイントを W&B Artifacts として保存するには、 Lightning の ModelCheckpoint コールバックを使用し、WandbLoggerlog_model 引数を設定します。
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
latest および best のエイリアスが自動的に設定され、W&B Artifact からモデルチェックポイントを簡単に取得できます。
# 参照はアーティファクトパネルで取得可能
# "VERSION" はバージョン(例: "v2")またはエイリアス("latest" または "best")を指定可能
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
# チェックポイントをローカルにダウンロード(まだキャッシュされていない場合)
wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# チェックポイントのロード
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
ログに記録されたモデルチェックポイントは W&B Artifacts UI で表示可能であり、完全なモデルリネージが含まれます(UI でのモデルチェックポイントの例は こちら を参照してください)。 最高のモデルチェックポイントをブックマークし、チーム全体で一元管理するには、それらを W&B Model Registry にリンクできます。 ここでは、タスクごとに最適なモデルを整理し、モデルのライフサイクルを管理し、ML ライフサイクル全体を通じた容易な追跡と監査を促進し、Webhook やジョブを使用してダウンストリームのアクションを 自動化 することができます。

画像、テキスト、その他のログ記録

WandbLogger には、メディアをログに記録するための log_imagelog_textlog_table メソッドがあります。 また、wandb.log または trainer.logger.experiment.log を直接呼び出して、オーディオ、分子、点群、3D オブジェクトなどの他のメディアタイプをログに記録することもできます。
# tensor、numpy 配列、または PIL 画像を使用
wandb_logger.log_image(key="samples", images=[img1, img2])

# キャプションの追加
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# ファイルパスを使用
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# trainer の .log を使用
trainer.logger.experiment.log(
    {"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
    step=current_trainer_global_step,
)
Lightning のコールバックシステムを使用して、WandbLogger を介して W&B にログを記録するタイミングを制御できます。この例では、検証画像と予測のサンプルをログに記録します。
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# または
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """検証バッチが終了したときに呼び出されます。"""

        # `outputs` は `LightningModule.validation_step` から取得されます
        # この場合はモデルの予測に対応します

        # 最初のバッチから 20 個のサンプル画像の予測をログ記録します
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # オプション 1: `WandbLogger.log_image` で画像をログ記録
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # オプション 2: 画像と予測を W&B Table としてログ記録
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Lightning と W&B で複数の GPU を使用する

PyTorch Lightning は、DDP インターフェースを通じてマルチ GPU サポートを提供しています。ただし、PyTorch Lightning の設計上、GPU のインスタンス化方法に注意する必要があります。 Lightning は、トレーニングループ内の各 GPU(またはランク)が、同じ初期条件でまったく同じようにインスタンス化される必要があると想定しています。しかし、ランク 0 のプロセスのみが wandb.run オブジェクトにアクセスでき、ランクが 0 以外のプロセスの場合は wandb.run = None となります。これにより、ランクが 0 以外のプロセスが失敗する可能性があります。このような状況では、ランク 0 のプロセスが、すでにクラッシュしたランク 0 以外のプロセスの参加を待機するため、デッドロック に陥る可能性があります。 このため、トレーニングコードのセットアップ方法に注意してください。推奨されるセットアップ方法は、コードを wandb.run オブジェクトに依存しないようにすることです。
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # すべての乱数シードを同じ値に設定します。
    # これは分散トレーニング設定において重要です。
    # 各ランクは独自の初期重みセットを取得します。
    # それらが一致しない場合、勾配も一致せず、
    # トレーニングが収束しない可能性があります。
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

Colab ノートブック付きのビデオチュートリアル で手順を確認できます。

よくある質問

W&B は Lightning とどのように統合されますか?

コアとなる統合は Lightning loggers API に基づいており、これによりフレームワークに依存しない方法でログコードの大部分を記述できます。LoggerLightning Trainer に渡され、その API の豊富な フックとコールバックシステム に基づいてトリガーされます。これにより、研究コードをエンジニアリングやログコードから適切に分離し続けることができます。

追加のコードなしでインテグレーションは何をログに記録しますか?

モデルのチェックポイントを W&B に保存し、そこで表示したり、将来の Runs で使用するためにダウンロードしたりできます。また、GPU 使用率やネットワーク I/O などの システムメトリクス、ハードウェアや OS 情報などの環境情報、コードの状態(git のコミットや diff パッチ、ノートブックの内容、セッション履歴を含む)、および標準出力に印刷されたすべての内容をキャプチャします。

トレーニングセットアップで wandb.run を使用する必要がある場合はどうすればよいですか?

アクセスする必要がある変数のスコープを自分自身で拡張する必要があります。言い換えれば、すべてのプロセスで初期条件が同じであることを確認してください。
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
初期条件が同じであれば、os.environ["WANDB_DIR"] を使用してモデルのチェックポイントディレクトリをセットアップできます。これにより、ランクが 0 以外のプロセスでも wandb.run.dir にアクセスできるようになります。