import numpy as np
import wandb
# 세 가지 질병에 대한 의료 진단 분류 문제 시뮬레이션
n_samples = 200
n_classes = 3
# 실제 라벨: 각 샘플에 "Diabetes", "Hypertension", "Heart Disease" 할당
# 각 샘플에 할당
disease_labels = ["Diabetes", "Hypertension", "Heart Disease"]
# 0: Diabetes, 1: Hypertension, 2: Heart Disease
y_true = np.random.choice([0, 1, 2], size=n_samples)
# 예측 확률: 각 샘플의 확률 합이 1이 되도록 시뮬레이션
# 각 샘플에 대해
y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples)
# 플롯할 클래스 지정 (세 가지 질병 모두 플롯)
classes_to_plot = [0, 1, 2]
# W&B run을 초기화하고 질병 분류를 위한 ROC 커브 플롯을 로그합니다
with wandb.init(project="medical_diagnosis") as run:
roc_plot = wandb.plot.roc_curve(
y_true=y_true,
y_probas=y_probas,
labels=disease_labels,
classes_to_plot=classes_to_plot,
title="ROC Curve for Disease Classification",
)
run.log({"roc-curve": roc_plot})