import numpy as np
import wandb
# 야생 동물에 대한 클래스 이름 정의
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
# 200개의 동물 이미지에 대한 실제 라벨 시뮬레이션 (불균형 분포)
wildlife_y_true = np.random.choice(
[0, 1, 2, 3],
size=200,
p=[0.2, 0.3, 0.25, 0.25],
)
# 85% 정확도를 가진 모델 예측값 시뮬레이션
wildlife_preds = [
y_t
if np.random.rand() < 0.85
else np.random.choice([x for x in range(4) if x != y_t])
for y_t in wildlife_y_true
]
# W&B run 초기화 및 오차 행렬 로그
with wandb.init(project="wildlife_classification") as run:
confusion_matrix = wandb.plot.confusion_matrix(
preds=wildlife_preds,
y_true=wildlife_y_true,
class_names=wildlife_class_names,
title="Simulated Wildlife Classification Confusion Matrix",
)
run.log({"wildlife_confusion_matrix": confusion_matrix})