メインコンテンツへスキップ
Try in Colab このチュートリアルでは、 MONAI を使用してマルチラベル 3D 脳腫瘍セグメンテーションタスクのトレーニングワークフローを構築し、 W&B の 実験管理 および データ可視化 機能を使用する方法を解説します。このチュートリアルには以下の機能が含まれています。
  1. W&B Run を初期化し、再現性のために Run に関連付けられたすべての設定(configs)を同期する。
  2. MONAI transform API:
    1. 辞書形式データのための MONAI Transforms。
    2. MONAI transforms API に従って新しい transform を定義する方法。
    3. データ拡張のために強度(intensity)をランダムに調整する方法。
  3. データの読み込みと可視化:
    1. メタデータを含む Nifti 画像の読み込み、画像のリストの読み込みとスタック。
    2. トレーニングと検証を加速するための IO と transform のキャッシュ。
    3. wandb.Table と W&B 上のインタラクティブなセグメンテーションオーバーレイを使用したデータの可視化。
  4. 3D SegResNet モデル のトレーニング
    1. MONAI の networkslossesmetrics API の使用。
    2. PyTorch トレーニングループを使用した 3D SegResNet モデル のトレーニング。
    3. W&B を使用したトレーニング 実験 の追跡。
    4. モデルのチェックポイントを W&B 上の モデルアーティファクト として ログ および バージョン管理 する。
  5. wandb.Table と W&B 上のインタラクティブなセグメンテーションオーバーレイを使用して、検証 データセット に対する 予測 を可視化し比較する。

セットアップとインストール

まず、最新バージョンの MONAI と W&B をインストールします。
!python -c "import monai" || pip install -q -U "monai[nibabel, tqdm]"
!python -c "import wandb" || pip install -q -U wandb
import os

import numpy as np
from tqdm.auto import tqdm
import wandb

from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism

import torch
次に、Colab インスタンスを W&B で認証します。
wandb.login()

W&B Run の初期化

新しい W&B Run を開始して 実験 の追跡を開始します。適切な設定システム(config system)を使用することは、再現可能な 機械学習 のための推奨されるベストプラクティスです。 W&B を使用して、すべての 実験 の ハイパーパラメーター を追跡できます。
with wandb.init(project="monai-brain-tumor-segmentation") as run:

    config = run.config
    config.seed = 0
    config.roi_size = [224, 224, 144]
    config.batch_size = 1
    config.num_workers = 4
    config.max_train_images_visualized = 20
    config.max_val_images_visualized = 20
    config.dice_loss_smoothen_numerator = 0
    config.dice_loss_smoothen_denominator = 1e-5
    config.dice_loss_squared_prediction = True
    config.dice_loss_target_onehot = False
    config.dice_loss_apply_sigmoid = True
    config.initial_learning_rate = 1e-4
    config.weight_decay = 1e-5
    config.max_train_epochs = 50
    config.validation_intervals = 1
    config.dataset_dir = "./dataset/"
    config.checkpoint_dir = "./checkpoints"
    config.inference_roi_size = (128, 128, 64)
    config.max_prediction_images_visualized = 20
また、決定論的な トレーニング を有効または無効にするために、各モジュールの乱数シードを設定する必要があります。
set_determinism(seed=config.seed)

# ディレクトリの作成
os.makedirs(config.dataset_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

データの読み込みと変換

ここでは、 monai.transforms API を使用して、マルチクラスのラベルを one-hot 形式のマルチラベルセグメンテーションタスクに変換するカスタム transform を作成します。
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    BraTS クラスに基づいてラベルをマルチチャンネルに変換します:
    ラベル 1 は腫瘍周囲の浮腫 (peritumoral edema)
    ラベル 2 は GD 増強腫瘍 (GD-enhancing tumor)
    ラベル 3 は壊死および非増強腫瘍コア (necrotic and non-enhancing tumor core)
    可能なクラスは TC (Tumor core), WT (Whole tumor) および ET (Enhancing tumor) です。

    リファレンス: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # ラベル 2 とラベル 3 をマージして TC を構築
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # ラベル 1, 2, 3 をマージして WT を構築
            result.append(
                torch.logical_or(
                    torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # ラベル 2 は ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d
次に、トレーニング用と検証用の データセット に対して、それぞれ transform を設定します。
train_transform = Compose(
    [
        # 4つの Nifti 画像をロードしてスタックする
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

データセット

この 実験 で使用される データセット は http://medicaldecathlon.com/ から提供されています。これは、マルチモーダル・マルチサイトの MRI データ(FLAIR, T1w, T1gd, T2w)を使用して、神経膠腫(Gliomas)、壊死/活動性腫瘍、および浮腫をセグメント化します。 データセット は 750 個の 4D ボリューム(トレーニング用 484 + テスト用 266)で構成されています。 DecathlonDataset を使用して、 データセット を自動的にダウンロードおよび解凍します。これは MONAI の CacheDataset を継承しており、メモリサイズに応じて cache_num=N を設定してトレーニング用に N 個のアイテムをキャッシュしたり、デフォルトの 引数 を使用して検証用のすべてのアイテムをキャッシュしたりできます。
train_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)
val_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)
注: train_datasettrain_transform を適用する代わりに、トレーニングと検証の両方の データセット に val_transform を適用します。これは、 トレーニング を開始する前に、両方の分割(split)からサンプルを可視化するためです。

データセット の可視化

W&B は画像、ビデオ、オーディオなどをサポートしています。リッチメディアを ログ に記録して、結果を探索したり、 Run 、 モデル 、 データセット を視覚的に比較したりできます。セグメンテーションマスクオーバーレイシステム を使用して、データのボリュームを可視化します。セグメンテーションマスクを テーブル (Tables) に ログ 記録するには、テーブルの各行に対して wandb.Image オブジェクトを提供する必要があります。 疑似コードの例を以下に示します。
table = wandb.Table(columns=["ID", "Image"])

for id, img, label in zip(ids, images, labels):
    mask_img = wandb.Image(
        img,
        masks={
            "prediction": {"mask_data": label, "class_labels": class_labels}
            # ...
        },
    )

    table.add_data(id, img)

run.log({"Table": table})
次に、サンプル画像、ラベル、 wandb.Table オブジェクト、および関連する メタデータ を受け取り、W&B ダッシュボード に ログ 記録されるテーブルの行に入力する簡単なユーティリティ関数を作成します。
def log_data_samples_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            ground_truth_wandb_images = []
            for channel_idx in range(num_channels):
                ground_truth_wandb_images.append(
                    masks = {
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx] * 3,
                            "class_labels": {0: "background", 3: "Enhancing Tumor"},
                        },
                    }
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks=masks,
                    )
                )
            table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)
            progress_bar.update(1)
    return table
次に、 wandb.Table オブジェクトと、データ可視化を取り込むための列(columns)を定義します。
table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0",
        "Image-Channel-1",
        "Image-Channel-2",
        "Image-Channel-3",
    ]
)
その後、 train_datasetval_dataset をそれぞれループして、データサンプルの可視化を生成し、 ダッシュボード に ログ 記録するテーブルの行を埋めます。
# train_dataset の可視化を生成
max_samples = (
    min(config.max_train_images_visualized, len(train_dataset))
    if config.max_train_images_visualized > 0
    else len(train_dataset)
)
progress_bar = tqdm(
    enumerate(train_dataset[:max_samples]),
    total=max_samples,
    desc="Generating Train Dataset Visualizations:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="train",
        data_idx=data_idx,
        table=table,
    )

# val_dataset の可視化を生成
max_samples = (
    min(config.max_val_images_visualized, len(val_dataset))
    if config.max_val_images_visualized > 0
    else len(val_dataset)
)
progress_bar = tqdm(
    enumerate(val_dataset[:max_samples]),
    total=max_samples,
    desc="Generating Validation Dataset Visualizations:",
)
for data_idx, sample in progress_bar:
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="val",
        data_idx=data_idx,
        table=table,
    )

# テーブルをダッシュボードにログ記録
run.log({"Tumor-Segmentation-Data": table})
データは W&B ダッシュボード にインタラクティブなテーブル形式で表示されます。データボリュームの特定のスライスの各チャンネルが、それぞれのセグメンテーションマスクとオーバーレイされて各行に表示されているのが確認できます。 Weave クエリ を記述してテーブルのデータをフィルタリングし、特定の行に焦点を当てることもできます。
Logged table data
画像を開いて、インタラクティブなオーバーレイを使用して各セグメンテーションマスクをどのように操作できるかを確認してください。
Segmentation maps
注: データセット のラベルは、クラス間で重複しないマスクで構成されています。オーバーレイでは、ラベルが個別のマスクとして ログ 記録されます。

データの読み込み

データセット からデータを読み込むための PyTorch DataLoader を作成します。 DataLoader を作成する前に、 train_datasettransformtrain_transform に設定して、トレーニング用にデータを前処理および変換します。
# トレーニングデータセットに train_transforms を適用
train_dataset.transform = train_transform

# train_loader の作成
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# val_loader の作成
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

モデル、損失関数、オプティマイザーの作成

このチュートリアルでは、論文 3D MRI brain tumor segmentation using auto-encoder regularization に基づいた SegResNet モデル を作成します。 SegResNet モデル は monai.networks API の一部として PyTorch モジュールとして実装されており、オプティマイザー や学習率スケジューラも含まれています。
device = torch.device("cuda:0")

# モデルの作成
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

# オプティマイザーの作成
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# 学習率スケジューラの作成
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)
損失関数を monai.losses API を使用してマルチラベル DiceLoss として定義し、対応する Dice メトリクス を monai.metrics API を使用して定義します。
loss_function = DiceLoss(
    smooth_nr=config.dice_loss_smoothen_numerator,
    smooth_dr=config.dice_loss_smoothen_denominator,
    squared_pred=config.dice_loss_squared_prediction,
    to_onehot_y=config.dice_loss_target_onehot,
    sigmoid=config.dice_loss_apply_sigmoid,
)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# トレーニングを加速するために自動混合精度を使用
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
混合精度推論のための小さなユーティリティを定義します。これは、トレーニング プロセス の検証ステップや、トレーニング後に モデル を実行する際に役立ちます。
def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

トレーニングと検証

トレーニングの前に、トレーニングと検証の 実験 を追跡するために run.log() で後ほど ログ 記録する メトリクス プロパティを定義します。
run.define_metric("epoch/epoch_step")
run.define_metric("epoch/*", step_metric="epoch/epoch_step")
run.define_metric("batch/batch_step")
run.define_metric("batch/*", step_metric="batch/batch_step")
run.define_metric("validation/validation_step")
run.define_metric("validation/*", step_metric="validation/validation_step")

batch_step = 0
validation_step = 0
metric_values = []
metric_values_tumor_core = []
metric_values_whole_tumor = []
metric_values_enhanced_tumor = []

標準的な PyTorch トレーニングループの実行

with wandb.init(
    project="monai-brain-tumor-segmentation",
    config=config,
    job_type="train",
    reinit=True,
) as run:

    # W&B Artifact オブジェクトの定義
    artifact = wandb.Artifact(
        name=f"{run.id}-checkpoint", type="model"
    )

    epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

    for epoch in epoch_progress_bar:
        model.train()
        epoch_loss = 0

        total_batch_steps = len(train_dataset) // train_loader.batch_size
        batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)
        
        # トレーニングステップ
        for batch_data in batch_progress_bar:
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
            batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
            ## バッチごとのトレーニング損失を W&B にログ記録
            run.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
            batch_step += 1

        lr_scheduler.step()
        epoch_loss /= total_batch_steps
        ## エポックごとのトレーニング損失と学習率を W&B にログ記録
        run.log(
            {
                "epoch/epoch_step": epoch,
                "epoch/mean_train_loss": epoch_loss,
                "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
            }
        )
        epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:")

        # 検証およびモデルチェックポイント作成ステップ
        if (epoch + 1) % config.validation_intervals == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = inference(model, val_inputs)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    dice_metric(y_pred=val_outputs, y=val_labels)
                    dice_metric_batch(y_pred=val_outputs, y=val_labels)

                metric_values.append(dice_metric.aggregate().item())
                metric_batch = dice_metric_batch.aggregate()
                metric_values_tumor_core.append(metric_batch[0].item())
                metric_values_whole_tumor.append(metric_batch[1].item())
                metric_values_enhanced_tumor.append(metric_batch[2].item())
                dice_metric.reset()
                dice_metric_batch.reset()

                checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
                torch.save(model.state_dict(), checkpoint_path)
                
                # W&B Artifacts を使用してモデルのチェックポイントをログ記録し、バージョン管理する
                artifact.add_file(local_path=checkpoint_path)
                run.log_artifact(artifact, aliases=[f"epoch_{epoch}"])

                # 検証メトリクスを W&B ダッシュボードにログ記録
                run.log(
                    {
                        "validation/validation_step": validation_step,
                        "validation/mean_dice": metric_values[-1],
                        "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                        "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                        "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                    }
                )
                validation_step += 1


    # アーティファクトのログ記録が完了するのを待機
    artifact.wait()
コードに wandb.log を組み込むことで、トレーニングおよび検証 プロセス に関連するすべての メトリクス を追跡できるだけでなく、すべてのシステム メトリクス (この場合は CPU と GPU )も W&B ダッシュボード 上で追跡できるようになります。
Training and validation tracking
W&B の Run ダッシュボード の artifacts タブに移動すると、トレーニング中に ログ 記録された モデル チェックポイント アーティファクト のさまざまな バージョン にアクセスできます。
Model checkpoints logging

推論

アーティファクトインターフェースを使用して、どの バージョン の アーティファクト が最適な モデル チェックポイントであるか(この場合は、エポックごとの平均トレーニング損失が最小のものなど)を選択できます。また、 アーティファクト の完全な リネージ (系統)を探索し、必要な バージョン を使用することもできます。
Model artifact tracking
エポックごとの平均トレーニング損失が最も低い モデルアーティファクト の バージョン を取得し、チェックポイントの状態辞書を モデル にロードします。
run = wandb.init(
    project="monai-brain-tumor-segmentation",
    job_type="inference",
    reinit=True,
)
model_artifact = run.use_artifact(
    "geekyrakshit/monai-brain-tumor-segmentation/d5ex6n4a-checkpoint:v49",
    type="model",
)
model_artifact_dir = model_artifact.download()
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

予測 の可視化と 正解 ラベルとの比較

学習済み モデル の 予測 を可視化し、インタラクティブなセグメンテーションマスクオーバーレイを使用して対応する 正解 セグメンテーションマスクと比較するための別のユーティリティ関数を作成します。
def log_predictions_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    predicted_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            wandb_images = []
            for channel_idx in range(num_channels):
                wandb_images += [
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Tumor-Core": {
                                "mask_data": sample_label[0, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Tumor Core"},
                            },
                            "prediction/Tumor-Core": {
                                "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Tumor Core"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Whole-Tumor": {
                                "mask_data": sample_label[1, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Whole Tumor"},
                            },
                            "prediction/Whole-Tumor": {
                                "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Whole Tumor"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Enhancing-Tumor": {
                                "mask_data": sample_label[2, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Enhancing Tumor"},
                            },
                            "prediction/Enhancing-Tumor": {
                                "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Enhancing Tumor"},
                            },
                        },
                    ),
                ]
            table.add_data(split, data_idx, slice_idx, *wandb_images)
            progress_bar.update(1)
    return table
予測 結果を予測テーブルに ログ 記録します。
run = wandb.init(
    project="monai-brain-tumor-segmentation",
    job_type="inference",
    reinit=True,
)
# 予測テーブルの作成
prediction_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0/Tumor-Core",
        "Image-Channel-1/Tumor-Core",
        "Image-Channel-2/Tumor-Core",
        "Image-Channel-3/Tumor-Core",
        "Image-Channel-0/Whole-Tumor",
        "Image-Channel-1/Whole-Tumor",
        "Image-Channel-2/Whole-Tumor",
        "Image-Channel-3/Whole-Tumor",
        "Image-Channel-0/Enhancing-Tumor",
        "Image-Channel-1/Enhancing-Tumor",
        "Image-Channel-2/Enhancing-Tumor",
        "Image-Channel-3/Enhancing-Tumor",
    ]
)

# 推論と可視化の実行
with torch.no_grad():
    config.max_prediction_images_visualized
    max_samples = (
        min(config.max_prediction_images_visualized, len(val_dataset))
        if config.max_prediction_images_visualized > 0
        else len(val_dataset)
    )
    progress_bar = tqdm(
        enumerate(val_dataset[:max_samples]),
        total=max_samples,
        desc="Generating Predictions:",
    )
    for data_idx, sample in progress_bar:
        val_input = sample["image"].unsqueeze(0).to(device)
        val_output = inference(model, val_input)
        val_output = post_trans(val_output[0])
        prediction_table = log_predictions_into_tables(
            sample_image=sample["image"].cpu().numpy(),
            sample_label=sample["label"].cpu().numpy(),
            predicted_label=val_output.cpu().numpy(),
            data_idx=data_idx,
            split="validation",
            table=prediction_table,
        )

    run.log({"Predictions/Tumor-Segmentation-Data": prediction_table})


# 実験の終了
run.finish()
インタラクティブなセグメンテーションマスクオーバーレイを使用して、各クラスの予測されたセグメンテーションマスクと 正解 ラベルを分析および比較します。
Predictions and ground-truth

謝辞およびその他のリソース