メインコンテンツへスキップ
Stable Baselines 3 (SB3) は、 PyTorch による信頼性の高い強化学習アルゴリズムの実装セットです。W&B の SB3 インテグレーションは以下の機能を提供します。
  • 損失やエピソードごとのリターンなどのメトリクスを記録します。
  • エージェントがゲームをプレイしている動画をアップロードします。
  • トレーニング済み モデル を保存します。
  • モデル の ハイパーパラメーター を ログ に記録します。
  • モデル の 勾配 ヒストグラムを ログ に記録します。
SB3 トレーニングの Run 例 をご確認ください。

SB3 Experiments のログを記録する

from wandb.integration.sb3 import WandbCallback

model.learn(..., callback=WandbCallback())
Stable Baselines 3 training with W&B

WandbCallback の引数

引数使用方法
verboseSB3 出力の冗長性
model_save_pathモデル が保存されるフォルダーへのパス。デフォルト 値 は `None` で、 モデル は ログ に記録されません
model_save_freqモデル を保存する頻度
gradient_save_freq勾配 を ログ に記録する頻度。デフォルト 値 は 0 で、 勾配 は ログ に記録されません

基本的な例

W&B の SB3 インテグレーションは、TensorBoard から出力される ログ を使用して メトリクス を記録します。
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback


config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 25000,
    "env_name": "CartPole-v1",
}
run = wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # SB3 の TensorBoard メトリクスを自動アップロード
    monitor_gym=True,  # エージェントがゲームをプレイしている動画を自動アップロード
    save_code=True,  # オプション
)


def make_env():
    env = gym.make(config["env_name"])
    env = Monitor(env)  # リターンなどの統計情報を記録
    return env


env = DummyVecEnv([make_env])
env = VecVideoRecorder(
    env,
    f"videos/{run.id}",
    record_video_trigger=lambda x: x % 2000 == 0,
    video_length=200,
)
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        gradient_save_freq=100,
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()