メインコンテンツへスキップ
GitHub source

function pr_curve

pr_curve(
    y_true: 'Iterable[T] | None' = None,
    y_probas: 'Iterable[numbers.Number] | None' = None,
    labels: 'list[str] | None' = None,
    classes_to_plot: 'list[T] | None' = None,
    interp_size: 'int' = 21,
    title: 'str' = 'Precision-Recall Curve',
    split_table: 'bool' = False
) → CustomChart
Precision-Recall (PR) 曲線 を作成します。 PR曲線 は、不均衡な データセット に対する分類器の評価に特に有用です。PR曲線 下の面積(AUC)が大きいことは、高い適合率(低い偽陽性率)と高い再現率(低い偽陰性率)の両方を意味します。この曲線は、さまざまな閾値レベルにおける偽陽性と偽陰性のバランスに関する洞察を提供し、モデル の性能評価を支援します。 引数:
  • y_true: 正解のバイナリラベル。形状は (num_samples,) である必要があります。
  • y_probas: 各クラスの予測スコアまたは確率。これらは確率推定値、信頼スコア、または閾値処理前の決定値です。形状は (num_samples, num_classes) である必要があります。
  • labels: プロットの解釈を容易にするために、y_true の数値に置き換えるクラス名のリスト(任意)。例えば、labels = ['dog', 'cat', 'owl'] とすると、プロット内で 0 は ‘dog’、1 は ‘cat’、2 は ‘owl’ に置き換えられます。指定されない場合は、y_true の数値がそのまま使用されます。
  • classes_to_plot: プロットに含める y_true 内のユニークなクラス値のリスト(任意)。指定されない場合、y_true に含まれるすべてのユニークなクラスがプロットされます。
  • interp_size: 再現率の値を補間するポイントの数。再現率は [0, 1] の範囲で一様に分布した interp_size 個のポイントに固定され、適合率はそれに応じて補間されます。
  • title: プロットのタイトル。デフォルトは “Precision-Recall Curve” です。
  • split_table: W&B UI 上で テーブル を別のセクションに分割するかどうか。True の場合、テーブル は “Custom Chart Tables” という名前のセクションに表示されます。デフォルトは False です。
戻り値:
  • CustomChart: W&B に ログ 記録可能な カスタムチャート オブジェクト。チャートを ログ 記録するには、wandb.log() に渡します。
例外:
  • wandb.Error: NumPy、pandas、または scikit-learn がインストールされていない場合。
例:
import wandb

# スパム検出(二値分類)の例
y_true = [0, 1, 1, 0, 1]  # 0 = スパムではない, 1 = スパム
y_probas = [
    [0.9, 0.1],  # 1番目のサンプルの予測確率(スパムではない)
    [0.2, 0.8],  # 2番目のサンプルの予測確率(スパム)、以下同様
    [0.1, 0.9],
    [0.8, 0.2],
    [0.3, 0.7],
]

labels = ["not spam", "spam"]  # 読みやすさのための任意のクラス名

with wandb.init(project="spam-detection") as run:
    pr_curve = wandb.plot.pr_curve(
         y_true=y_true,
         y_probas=y_probas,
         labels=labels,
         title="Precision-Recall Curve for Spam Detection",
    )
    run.log({"pr-curve": pr_curve})