import numpy as np
import wandb
# 3つの疾患を持つ医学的診断分類問題をシミュレート
n_samples = 200
n_classes = 3
# 真のラベル: 各サンプルに "Diabetes"(糖尿病)、"Hypertension"(高血圧)、
# または "Heart Disease"(心臓病)を割り当て
disease_labels = ["Diabetes", "Hypertension", "Heart Disease"]
# 0: Diabetes, 1: Hypertension, 2: Heart Disease
y_true = np.random.choice([0, 1, 2], size=n_samples)
# 予測確率: 予測をシミュレートし、各サンプルの合計が1になるようにする
y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples)
# プロットするクラスを指定(3つの疾患すべてをプロット)
classes_to_plot = [0, 1, 2]
# W&B run を初期化し、疾患分類の ROC 曲線プロットをログする
with wandb.init(project="medical_diagnosis") as run:
roc_plot = wandb.plot.roc_curve(
y_true=y_true,
y_probas=y_probas,
labels=disease_labels,
classes_to_plot=classes_to_plot,
title="ROC Curve for Disease Classification",
)
run.log({"roc-curve": roc_plot})