메인 콘텐츠로 건너뛰기
GitHub source

function pr_curve

pr_curve(
    y_true: 'Iterable[T] | None' = None,
    y_probas: 'Iterable[numbers.Number] | None' = None,
    labels: 'list[str] | None' = None,
    classes_to_plot: 'list[T] | None' = None,
    interp_size: 'int' = 21,
    title: 'str' = 'Precision-Recall Curve',
    split_table: 'bool' = False
) → CustomChart
PR 곡선 (Precision-Recall curve)을 생성합니다. PR 곡선은 특히 불균형한 데이터셋에서 분류 모델을 평가하는 데 매우 유용합니다. PR 곡선 아래의 면적(AUC)이 넓을수록 높은 정밀도(낮은 거짓 양성 비율)와 높은 재현율(낮은 거짓 음성 비율)을 의미합니다. 이 곡선은 다양한 임계값 수준에서 거짓 양성과 거짓 음성 사이의 균형에 대한 통찰력을 제공하여 모델의 성능을 평가하는 데 도움을 줍니다. Args:
  • y_true: 실제 바이너리 레이블. 형태는 (num_samples,)여야 합니다.
  • y_probas: 각 클래스에 대해 예측된 점수 또는 확률. 이는 확률 추정치, 신뢰 점수 또는 임계값이 적용되지 않은 결정 값일 수 있습니다. 형태는 (num_samples, num_classes)여야 합니다.
  • labels: 차트 해석을 쉽게 하기 위해 y_true의 숫자 값을 대체할 클래스 이름의 선택적 리스트입니다. 예를 들어, labels = ['dog', 'cat', 'owl']로 설정하면 차트에서 0은 ‘dog’, 1은 ‘cat’, 2는 ‘owl’로 표시됩니다. 제공되지 않으면 y_true의 숫자 값이 그대로 사용됩니다.
  • classes_to_plot: 차트에 포함할 y_true의 고유 클래스 값 선택적 리스트입니다. 지정하지 않으면 y_true에 있는 모든 고유 클래스가 그려집니다.
  • interp_size: 재현율(recall) 값을 보간할 포인트의 수입니다. 재현율 값은 [0, 1] 범위 내에서 균등하게 분포된 interp_size개의 포인트로 고정되며, 정밀도(precision)는 이에 따라 보간됩니다.
  • title: 플롯의 제목. 기본값은 “Precision-Recall Curve”입니다.
  • split_table: W&B UI에서 테이블을 별도의 섹션으로 분리할지 여부입니다. True인 경우, 테이블은 “Custom Chart Tables”라는 섹션에 표시됩니다. 기본값은 False입니다.
Returns:
  • CustomChart: W&B에 로그를 남길 수 있는 커스텀 차트 오브젝트입니다. 차트를 로그하려면 wandb.log()에 전달하세요.
Raises:
  • wandb.Error: NumPy, pandas 또는 scikit-learn이 설치되어 있지 않은 경우 발생합니다.
Example:
import wandb

# 스팸 감지 예시 (이진 분류)
y_true = [0, 1, 1, 0, 1]  # 0 = 스팸 아님, 1 = 스팸
y_probas = [
    [0.9, 0.1],  # 첫 번째 샘플에 대한 예측 확률 (스팸 아님)
    [0.2, 0.8],  # 두 번째 샘플 (스팸) 등등
    [0.1, 0.9],
    [0.8, 0.2],
    [0.3, 0.7],
]

labels = ["not spam", "spam"]  # 가독성을 위한 선택적 클래스 이름

with wandb.init(project="spam-detection") as run:
    pr_curve = wandb.plot.pr_curve(
         y_true=y_true,
         y_probas=y_probas,
         labels=labels,
         title="Precision-Recall Curve for Spam Detection",
    )
    run.log({"pr-curve": pr_curve})