メインコンテンツへスキップ
Try in Colab

Keras コールバック

W&B には Keras 用の 3 つのコールバックがあり、wandb v0.13.4 から利用可能です。レガシーな WandbCallback については、ページ下部をご覧ください。
  • WandbMetricsLogger : このコールバックは 実験管理 に使用します。トレーニングと検証のメトリクス、およびシステムメトリクスを W&B にログ記録します。
  • WandbModelCheckpoint : モデルのチェックポイントを W&B アーティファクト にログ記録するために使用します。
  • WandbEvalCallback: このベースコールバックは、インタラクティブな可視化のためにモデルの予測を W&B テーブル にログ記録します。
これらの新しいコールバックは以下の特徴を持ちます:
  • Keras の設計哲学に準拠しています。
  • 1 つのコールバック (WandbCallback) ですべてを行うことによる認知負荷を軽減します。
  • Keras ユーザーがコールバックをサブクラス化して、特定のユースケースに合わせて簡単に修正できるようにします。

WandbMetricsLogger で実験を追跡する

Try in Colab WandbMetricsLogger は、on_epoch_endon_batch_end などのコールバックメソッドが引数として受け取る Keras の logs 辞書を自動的にログ記録します。 以下を追跡します:
  • model.compile で定義されたトレーニングおよび検証メトリクス。
  • システム (CPU/GPU/TPU) メトリクス。
  • 学習率(固定値、または学習率スケジューラーの両方)。
import wandb
from wandb.integration.keras import WandbMetricsLogger

# 新しい W&B Run を初期化
wandb.init(config={"bs": 12})

# model.fit に WandbMetricsLogger を渡す
model.fit(
    X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbMetricsLogger()]
)

WandbMetricsLogger リファレンス

パラメータ説明
log_freq(epoch, batch, または int): epoch の場合、各エポックの終了時にメトリクスをログ記録します。batch の場合、各バッチの終了時にログ記録します。int の場合、その指定されたバッチ数ごとにログ記録します。デフォルトは epoch です。
initial_global_step(int): 特定の initial_epoch からトレーニングを再開し、学習率スケジューラーを使用する場合に、学習率を正しくログ記録するためにこの引数を使用します。これは step_size * initial_step として計算できます。デフォルトは 0 です。

WandbModelCheckpoint を使用したモデルのチェックポイント作成

Try in Colab WandbModelCheckpoint コールバックを使用して、Keras モデル (SavedModel 形式) またはモデルの重みを定期的に保存し、モデルのバージョン管理のために wandb.Artifact として W&B にアップロードします。 このコールバックは tf.keras.callbacks.ModelCheckpoint から継承されているため、チェックポイント作成のロジックは親コールバックによって処理されます。 このコールバックは以下を保存します:
  • モニターに基づき最高のパフォーマンスを達成したモデル。
  • パフォーマンスに関わらず、毎エポック終了時のモデル。
  • エポック終了時、または一定数のトレーニングバッチ終了時のモデル。
  • モデルの重みのみ、またはモデル全体。
  • SavedModel 形式または .h5 形式のモデル。
このコールバックは WandbMetricsLogger と組み合わせて使用してください。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

# 新しい W&B Run を初期化
wandb.init(config={"bs": 12})

# model.fit に WandbModelCheckpoint を渡す
model.fit(
    X_train,
    y_train,
    validation_data=(X_test, y_test),
    callbacks=[
        WandbMetricsLogger(),
        WandbModelCheckpoint("models"),
    ],
)

WandbModelCheckpoint リファレンス

パラメータ説明
filepath(str): モデルファイルを保存するパス。
monitor(str): 監視するメトリクス名。
verbose(int): 詳細モード、0 または 1。モード 0 はサイレント、モード 1 はコールバックがアクションを実行したときにメッセージを表示します。
save_best_only(Boolean): save_best_only=True の場合、monitormode 属性で定義された内容に従って、最新のモデル、または最適とみなされるモデルのみを保存します。
save_weights_only(Boolean): True の場合、モデルの重みのみを保存します。
mode(auto, min, または max): val_acc の場合は maxval_loss の場合は min などに設定します。
save_freq(“epoch” または int): ‘epoch’ を使用すると、各エポックの後にモデルを保存します。整数を使用すると、そのバッチ数の終了時にモデルを保存します。val_accval_loss などの検証メトリクスを監視する場合、それらのメトリクスはエポックの終了時にしか利用できないため、save_freq は “epoch” に設定する必要があります。
options(str): save_weights_only が true の場合はオプションの tf.train.CheckpointOptions オブジェクト、false の場合はオプションの tf.saved_model.SaveOptions オブジェクト。
initial_value_threshold(float): 監視対象メトリクスの初期の「最良」値。

N エポックごとにチェックポイントをログ記録する

デフォルト (save_freq="epoch") では、コールバックは各エポックの後にチェックポイントを作成し、アーティファクトとしてアップロードします。特定のバッチ数ごとにチェックポイントを作成するには、save_freq に整数を設定します。N エポックごとにチェックポイントを作成するには、train データローダーのカーディナリティを計算し、それを save_freq に渡します:
WandbModelCheckpoint(
    filepath="models/",
    save_freq=int((trainloader.cardinality()*N).numpy())
)

TPU アーキテクチャーで効率的にチェックポイントをログ記録する

TPU でチェックポイントを作成する際、UnimplementedError: File system scheme '[local]' not implemented というエラーメッセージに遭遇することがあります。これは、モデルディレクトリ (filepath) がクラウドストレージのバケットパス (gs://bucket-name/...) を使用する必要があり、かつそのバケットが TPU サーバーからアクセス可能である必要があるために発生します。しかし、ローカルパスを使用してチェックポイントを作成し、それをアーティファクトとしてアップロードすることも可能です。
checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

WandbModelCheckpoint(
    filepath="models/,
    options=checkpoint_options,
)

WandbEvalCallback を使用したモデル予測の可視化

Try in Colab WandbEvalCallback は、主にモデル予測、副次的にデータセットの可視化を目的とした Keras コールバックを構築するための抽象基底クラスです。 この抽象コールバックはデータセットやタスクに依存しません。これを使用するには、この基底 WandbEvalCallback クラスを継承し、add_ground_truthadd_model_prediction メソッドを実装します。 WandbEvalCallback は、以下のメソッドを提供するユーティリティクラスです:
  • データおよび予測用の wandb.Table インスタンスを作成する。
  • データおよび予測テーブルを wandb.Artifact としてログ記録する。
  • on_train_begin 時にデータテーブルをログ記録する。
  • on_epoch_end 時に予測テーブルをログ記録する。
次の例では、画像分類タスクに WandbClfEvalCallback を使用しています。この例のコールバックは、検証データ (data_table) を W&B にログ記録し、推論を実行して、各エポックの終了時に予測 (pred_table) を W&B にログ記録します。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback


# モデル予測可視化用コールバックの実装
class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )


# ...

# 新しい W&B Run を初期化
wandb.init(config={"hyper": "parameter"})

# コールバックを Model.fit に追加
model.fit(
    X_train,
    y_train,
    validation_data=(X_test, y_test),
    callbacks=[
        WandbMetricsLogger(),
        WandbClfEvalCallback(
            validation_data=(X_test, y_test),
            data_table_columns=["idx", "image", "label"],
            pred_table_columns=["epoch", "idx", "image", "label", "pred"],
        ),
    ],
)
W&B の アーティファクトページ には、デフォルトで Workspace ページではなくテーブルログが含まれます。

WandbEvalCallback リファレンス

パラメータ説明
data_table_columns(list) data_table のカラム名のリスト
pred_table_columns(list) pred_table のカラム名のリスト

メモリフットプリントの詳細

on_train_begin メソッドが呼び出されたときに data_table を W&B にログ記録します。W&B アーティファクトとしてアップロードされると、このテーブルへの参照を取得でき、data_table_ref クラス変数を使用してアクセスできます。data_table_refself.data_table_ref[idx][n] のようにインデックス指定できる 2D リストで、idx は行番号、n は列番号です。使い方は上記の例を確認してください。

コールバックのカスタマイズ

より詳細な制御を行うために、on_train_begin または on_epoch_end メソッドをオーバーライドできます。N バッチごとにサンプルをログ記録したい場合は、on_train_batch_end メソッドを実装できます。
WandbEvalCallback を継承してモデル予測可視化用のコールバックを実装する際に、不明点や修正が必要な箇所がある場合は、issue を開いてお知らせください。

WandbCallback [レガシー]

W&B ライブラリの WandbCallback クラスを使用して、model.fit で追跡されるすべてのメトリクスと損失値を自動的に保存します。
import wandb
from wandb.integration.keras import WandbCallback

wandb.init(config={"hyper": "parameter"})

...  # Keras でモデルをセットアップするコード

# コールバックを model.fit に渡す
model.fit(
    X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()]
)
短い動画 Get Started with Keras and W&B in Less Than a Minute をご覧いただけます。 より詳細な動画については、Integrate W&B with Keras をご覧ください。また、Colab Jupyter Notebook も確認できます。
Fashion MNIST の例 や、それによって生成される W&B ダッシュボード を含むスクリプトについては、公式サンプルリポジトリ を参照してください。
WandbCallback クラスは、監視するメトリクスの指定、重みと勾配の追跡、トレーニングデータおよび検証データの予測のログ記録など、多種多様なログ設定オプションをサポートしています。 詳細については、keras.WandbCallback のリファレンスドキュメントを確認してください。 WandbCallback は:
  • Keras によって収集されたすべてのメトリクスの履歴データを自動的にログ記録します:損失(loss)および keras_model.compile() に渡されたすべてのメトリクス。
  • monitor および mode 属性で定義された「最良」のトレーニングステップに関連付けられた実行のサマリーメトリクスを設定します。これはデフォルトで val_loss が最小のエポックになります。WandbCallback はデフォルトで最良の epoch に関連付けられたモデルを保存します。
  • オプションで勾配とパラメータのヒストグラムをログ記録します。
  • オプションで W&B が可視化するためのトレーニングおよび検証データを保存します。

WandbCallback リファレンス

引数説明
monitor(str) 監視するメトリクス名。デフォルトは val_loss
mode(str) {auto, min, max} のいずれか。min - monitor が最小化されたときにモデルを保存。max - monitor が最大化されたときにモデルを保存。auto - いつモデルを保存するかを自動推測(デフォルト)。
save_modelTrue - monitor が過去のすべてのエポックを上回ったときにモデルを保存。False - モデルを保存しない。
save_graph(boolean) True の場合、モデルグラフを wandb に保存(デフォルトは True)。
save_weights_only(boolean) True の場合、モデルの重みのみを保存 (model.save_weights(filepath))。それ以外の場合はモデル全体を保存。
log_weights(boolean) True の場合、モデルレイヤーの重みのヒストグラムを保存。
log_gradients(boolean) True の場合、トレーニング勾配のヒストグラムをログ記録。
training_data(tuple) model.fit に渡されるものと同じ形式 (X,y)。勾配の計算に必要です。log_gradientsTrue の場合は必須です。
validation_data(tuple) model.fit に渡されるものと同じ形式 (X,y)。W&B が可視化するためのデータセット。このフィールドを設定すると、毎エポック、W&B は少数の予測を行い、後で可視化するために結果を保存します。
generator(generator) W&B が可視化するための検証データを返すジェネレーター。このジェネレーターはタプル (X,y) を返す必要があります。特定のデータ例を可視化するには、validate_data またはジェネレーターのいずれかを設定する必要があります。
validation_steps(int) validation_data がジェネレーターの場合、検証セット全体に対してジェネレーターを実行するステップ数。
labels(list) W&B でデータを可視化する場合、複数のクラスを持つ分類器を構築していれば、このラベルリストにより数値出力を理解しやすい文字列に変換します。バイナリ分類器の場合は、2 つのラベルのリスト [label for false, label for true] を渡すことができます。validate_datagenerator の両方が false の場合、これは何もしません。
predictions(int) 各エポックで可視化のために作成する予測数。最大は 100 です。
input_type(string) 可視化を助けるためのモデル入力のタイプ。(image, images, segmentation_mask) のいずれか。
output_type(string) 可視化を助けるためのモデル出力のタイプ。(image, images, segmentation_mask) のいずれか。
log_evaluation(boolean) True の場合、各エポックでの検証データとモデルの予測を含むテーブルを保存します。詳細は validation_indexesvalidation_row_processoroutput_row_processor を参照してください。
class_colors([float, float, float]) 入力または出力がセグメンテーションマスクの場合、各クラスの RGB タプル(範囲 0-1)を含む配列。
log_batch_frequency(integer) None の場合、コールバックは毎エポックログを記録します。整数に設定されている場合、コールバックは log_batch_frequency バッチごとにトレーニングメトリクスをログ記録します。
log_best_prefix(string) None の場合、余分なサマリーメトリクスを保存しません。文字列に設定されている場合、監視されているメトリクスとエポックの前にプレフィックスを付加し、結果をサマリーメトリクスとして保存します。
validation_indexes([wandb.data_types._TableLinkMixin]) 各検証例に関連付けるインデックスキーの順序付きリスト。log_evaluation が True で validation_indexes を指定した場合、検証データのテーブルは作成されません。代わりに、各予測を TableLinkMixin で表される行に関連付けます。行キーのリストを取得するには、Table.get_index() を使用してください。
validation_row_processor(Callable) 検証データに適用する関数で、通常はデータの可視化に使用されます。この関数は ndx (int) と row (dict) を受け取ります。モデルの入力が 1 つの場合、row["input"] にはその行の入力データが含まれます。それ以外の場合は、入力スロットの名前が含まれます。fit 関数が単一のターゲットを受け取る場合、row["target"] にはその行のターゲットデータが含まれます。それ以外の場合は、出力スロットの名前が含まれます。例えば、入力データが単一の配列で、データを画像として可視化する場合、プロセッサーとして lambda ndx, row: {"img": wandb.Image(row["input"])} を指定します。log_evaluation が False または validation_indexes が存在する場合は無視されます。
output_row_processor(Callable) validation_row_processor と同様ですが、モデルの出力に適用されます。row["output"] にはモデル出力の結果が含まれます。
infer_missing_processors(Boolean) validation_row_processoroutput_row_processor が欠けている場合に推論するかどうかを決定します。デフォルトは True です。labels を指定した場合、W&B は適切な場所で分類タイプのプロセッサーを推論しようとします。
log_evaluation_frequency(int) 評価結果をログ記録する頻度を決定します。デフォルトは 0 で、トレーニングの終了時にのみログ記録します。1 に設定すると毎エポック、2 に設定すると 1 エポックおきにログ記録します。log_evaluation が False の場合は効果がありません。

よくある質問

Keras のマルチプロセッシングを wandb で使用するにはどうすればよいですか?

use_multiprocessing=True を設定すると、以下のエラーが発生することがあります:
Error("You must call wandb.init() before wandb.config.batch_size")
これを回避するには:
  1. Sequence クラスの構築時に wandb.init(group='...') を追加します。
  2. mainif __name__ == "__main__": を使用していることを確認し、スクリプトの残りのロジックをその中に入れます。