메인 콘텐츠로 건너뛰기
GitHub source

function confusion_matrix

confusion_matrix(
    probs: 'Sequence[Sequence[float]] | None' = None,
    y_true: 'Sequence[T] | None' = None,
    preds: 'Sequence[T] | None' = None,
    class_names: 'Sequence[str] | None' = None,
    title: 'str' = 'Confusion Matrix Curve',
    split_table: 'bool' = False
) → CustomChart
일련의 확률값 또는 예측값을 사용하여 오차 행렬 (confusion matrix)을 생성합니다. Args:
  • probs: 각 클래스에 대한 예측 확률 시퀀스입니다. 시퀀스의 형태는 (N, K)여야 하며, 여기서 N은 샘플 수, K는 클래스 수입니다. 이 인자를 제공하는 경우 preds는 제공하지 않아야 합니다.
  • y_true: 실제 라벨의 시퀀스입니다.
  • preds: 예측된 클래스 라벨의 시퀀스입니다. 이 인자를 제공하는 경우 probs는 제공하지 않아야 합니다.
  • class_names: 클래스 이름의 시퀀스입니다. 제공되지 않으면 클래스 이름은 “Class_1”, “Class_2” 등으로 정의됩니다.
  • title: 오차 행렬 차트의 제목입니다.
  • split_table: 테이블을 W&B UI의 별도 섹션으로 분리할지 여부입니다. True인 경우 테이블은 “Custom Chart Tables”라는 섹션에 표시됩니다. 기본값은 False입니다.
Returns:
  • CustomChart: W&B에 로그를 남길 수 있는 커스텀 차트 오브젝트입니다. 차트를 로그하려면 wandb.log()에 전달하세요.
Raises:
  • ValueError: probspreds가 모두 제공되거나, 예측값과 실제 라벨의 수가 일치하지 않는 경우 발생합니다. 또한 고유한 예측 클래스 수가 클래스 이름의 수를 초과하거나 고유한 실제 라벨 수가 클래스 이름의 수를 초과하는 경우에도 발생합니다.
  • wandb.Error: numpy가 설치되어 있지 않은 경우 발생합니다.
Examples: 야생 동물 분류를 위한 무작위 확률값으로 오차 행렬을 로그하는 예시:
import numpy as np
import wandb

# 야생 동물에 대한 클래스 이름 정의
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]

# 무작위 실제 라벨 생성 (10개 샘플에 대해 0에서 3 사이의 값)
wildlife_y_true = np.random.randint(0, 4, size=10)

# 각 클래스에 대한 무작위 확률값 생성 (10개 샘플 x 4개 클래스)
wildlife_probs = np.random.rand(10, 4)
wildlife_probs = np.exp(wildlife_probs) / np.sum(
    np.exp(wildlife_probs),
    axis=1,
    keepdims=True,
)

# W&B run 초기화 및 오차 행렬 로그
with wandb.init(project="wildlife_classification") as run:
    confusion_matrix = wandb.plot.confusion_matrix(
         probs=wildlife_probs,
         y_true=wildlife_y_true,
         class_names=wildlife_class_names,
         title="Wildlife Classification Confusion Matrix",
    )
    run.log({"wildlife_confusion_matrix": confusion_matrix})
이 예시에서는 무작위 확률값을 사용하여 오차 행렬을 생성합니다. 85%의 정확도를 가진 시뮬레이션된 모델 예측값으로 오차 행렬을 로그하는 예시:
import numpy as np
import wandb

# 야생 동물에 대한 클래스 이름 정의
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]

# 200개의 동물 이미지에 대한 실제 라벨 시뮬레이션 (불균형 분포)
wildlife_y_true = np.random.choice(
    [0, 1, 2, 3],
    size=200,
    p=[0.2, 0.3, 0.25, 0.25],
)

# 85% 정확도를 가진 모델 예측값 시뮬레이션
wildlife_preds = [
    y_t
    if np.random.rand() < 0.85
    else np.random.choice([x for x in range(4) if x != y_t])
    for y_t in wildlife_y_true
]

# W&B run 초기화 및 오차 행렬 로그
with wandb.init(project="wildlife_classification") as run:
    confusion_matrix = wandb.plot.confusion_matrix(
         preds=wildlife_preds,
         y_true=wildlife_y_true,
         class_names=wildlife_class_names,
         title="Simulated Wildlife Classification Confusion Matrix",
    )
    run.log({"wildlife_confusion_matrix": confusion_matrix})
이 예시에서는 오차 행렬을 생성하기 위해 85% 정확도로 시뮬레이션된 예측값을 사용합니다.