メインコンテンツへスキップ
これはインタラクティブなノートブックです。ローカルで実行するか、以下のリンクを使用できます:

マルチエージェントシステムのための Structured Outputs

OpenAI は、強い言葉でのプロンプトを使わずに、モデルが常に提供された JSON スキーマに準拠したレスポンスを生成することを保証する Structured Outputs をリリースしました。Structured Outputs を使用すると、不適切な形式のレスポンスを検証したり、再試行したりする必要がなくなります。 新しいパラメータ strict: true を使用することで、レスポンスが提供されたスキーマに従うことを保証できます。 マルチエージェントシステムで Structured Outputs を使用すると、エージェント間で一貫性があり、処理しやすいデータが保証されるため、コミュニケーションが強化されます。また、明示的な拒否が可能になることで安全性が向上し、再試行や検証の必要がなくなることでパフォーマンスが向上します。これにより、相互作用が簡素化され、システム全体の効率が高まります。 このチュートリアルでは、マルチエージェントシステムで Structured Outputs を活用し、Weave でそれらをトレースする方法を紹介します。
ソース: このクックブックは、OpenAI の Structured Outputs のサンプルコードに基づいており、Weave を使用して可視化を向上させるための修正が加えられています。

依存関係のインストール

このチュートリアルには以下のライブラリが必要です:
  • マルチエージェントシステムを作成するための OpenAI
  • LLM ワークフローを追跡し、プロンプティング戦略を評価するための Weave
!pip install -qU openai weave wandb
python
%%capture
# OpenAI のバグを修正するための暫定的な回避策:
# TypeError: Client.__init__() got an unexpected keyword argument 'proxies'
# 詳細は https://community.openai.com/t/error-with-openai-1-56-0-client-init-got-an-unexpected-keyword-argument-proxies/1040332/15 を参照
!pip install "httpx<0.28"
wandb.login() で簡単にログインできるように、環境変数に WANDB_API_KEY を設定します(これは Colab の secret として提供する必要があります)。 ログを記録したい W&B のプロジェクトを name_of_wandb_project に設定します。 注意: name_of_wandb_project は、トレースをログに記録するチームを指定するために {team_name}/{project_name} の形式にすることもできます。 次に、weave.init() を呼び出して Weave クライアントを取得します。 OpenAI API を使用するため、OpenAI API キーも必要です。OpenAI プラットフォームで サインアップ して、独自の API キーを取得できます(これも Colab の secret として提供する必要があります)。
import base64
import json
import os
from io import BytesIO, StringIO

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
from google.colab import userdata
from openai import OpenAI

import weave
python
os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")
os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")

wandb.login()
name_of_wandb_project = "multi-agent-structured-output"
weave.init(name_of_wandb_project)

client = OpenAI()
MODEL = "gpt-4o-2024-08-06"

エージェントのセットアップ

今回取り組むユースケースはデータ分析タスクです。 まず、4 つのエージェントシステムをセットアップしましょう:
  • Triaging agent(振り分けエージェント):どのアージェントを呼び出すかを決定します。
  • Data pre-processing Agent(データ前処理エージェント):クリーニングなどを行い、分析のためのデータを準備します。
  • Data Analysis Agent(データ分析エージェント):データに対して分析を実行します。
  • Data Visualization Agent(データ可視化エージェント):分析結果を可視化してインサイトを抽出します。 まず、これらの各エージェントのシステムプロンプトを定義することから始めます。
triaging_system_prompt = """あなたは Triaging Agent です。あなたの役割は、ユーザーのクエリを評価し、関連するエージェントにルーティングすることです。利用可能なエージェントは以下の通りです:
- Data Processing Agent: データのクリーニング、変換、集計を行います。
- Analysis Agent: 統計分析、相関分析、回帰分析を行います。
- Visualization Agent: 棒グラフ、折れ線グラフ、円グラフを作成します。

send_query_to_agents ツールを使用して、ユーザーのクエリを関連するエージェントに転送してください。また、必要に応じて speak_to_user ツールを使用してユーザーから詳細情報を取得してください。"""

processing_system_prompt = """あなたは Data Processing Agent です。あなたの役割は、以下のツールを使用してデータのクリーニング、変換、集計を行うことです:
- clean_data
- transform_data
- aggregate_data"""

analysis_system_prompt = """あなたは Analysis Agent です。あなたの役割は、以下のツールを使用して統計分析、相関分析、回帰分析を行うことです:
- stat_analysis
- correlation_analysis
- regression_analysis"""

visualization_system_prompt = """あなたは Visualization Agent です。あなたの役割は、以下のツールを使用して棒グラフ、折れ線グラフ、円グラフを作成することです:
- create_bar_chart
- create_line_chart
- create_pie_chart"""
次に、各エージェントのツールを定義します。 Triaging agent を除き、各エージェントにはその役割に特化したツールが装備されます: Data pre-processing agent: 1. Clean data, 2. Transform data, 3. Aggregate data Data analysis agent: 1. Statistical analysis, 2. Correlation analysis, 3. Regression Analysis Data visualization agent: 1. Create bar chart, 2. Create line chart, 3. Create pie chart
triage_tools = [
    {
        "type": "function",
        "function": {
            "name": "send_query_to_agents",
            "description": "能力に基づいてユーザーのクエリを関連するエージェントに送信します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "agents": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "クエリを送信するエージェント名の配列。",
                    },
                    "query": {
                        "type": "string",
                        "description": "送信するユーザークエリ。",
                    },
                },
                "required": ["agents", "query"],
            },
        },
        "strict": True,
    }
]

preprocess_tools = [
    {
        "type": "function",
        "function": {
            "name": "clean_data",
            "description": "重複を削除し、欠損値を処理することで、提供されたデータをクリーニングします。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "クリーニングするデータセット。JSON や CSV などの適切な形式である必要があります。",
                    }
                },
                "required": ["data"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "transform_data",
            "description": "指定されたルールに基づいてデータを変換します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "変換するデータ。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "rules": {
                        "type": "string",
                        "description": "適用する変換ルール。構造化された形式で指定します。",
                    },
                },
                "required": ["data", "rules"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "aggregate_data",
            "description": "指定された列と操作でデータを集計します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "集計するデータ。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "group_by": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "グループ化する列。",
                    },
                    "operations": {
                        "type": "string",
                        "description": "実行する集計操作。構造化された形式で指定します。",
                    },
                },
                "required": ["data", "group_by", "operations"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
]

analysis_tools = [
    {
        "type": "function",
        "function": {
            "name": "stat_analysis",
            "description": "指定されたデータセットに対して統計分析を実行します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "分析するデータセット。JSON や CSV などの適切な形式である必要があります。",
                    }
                },
                "required": ["data"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "correlation_analysis",
            "description": "データセット内の変数間の相関係数を計算します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "分析するデータセット。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "variables": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "相関を計算する変数のリスト。",
                    },
                },
                "required": ["data", "variables"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "regression_analysis",
            "description": "データセットに対して回帰分析を実行します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "分析するデータセット。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "dependent_var": {
                        "type": "string",
                        "description": "回帰の従属変数。",
                    },
                    "independent_vars": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "独立変数のリスト。",
                    },
                },
                "required": ["data", "dependent_var", "independent_vars"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
]

visualization_tools = [
    {
        "type": "function",
        "function": {
            "name": "create_bar_chart",
            "description": "提供されたデータから棒グラフを作成します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "棒グラフ用のデータ。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "x": {"type": "string", "description": "X 軸の列。"},
                    "y": {"type": "string", "description": "Y 軸の列。"},
                },
                "required": ["data", "x", "y"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "create_line_chart",
            "description": "提供されたデータから折れ線グラフを作成します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "折れ線グラフ用のデータ。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "x": {"type": "string", "description": "X 軸の列。"},
                    "y": {"type": "string", "description": "Y 軸の列。"},
                },
                "required": ["data", "x", "y"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
    {
        "type": "function",
        "function": {
            "name": "create_pie_chart",
            "description": "提供されたデータから円グラフを作成します。",
            "parameters": {
                "type": "object",
                "properties": {
                    "data": {
                        "type": "string",
                        "description": "円グラフ用のデータ。JSON や CSV などの適切な形式である必要があります。",
                    },
                    "labels": {
                        "type": "string",
                        "description": "ラベル用の列。",
                    },
                    "values": {
                        "type": "string",
                        "description": "値用の列。",
                    },
                },
                "required": ["data", "labels", "values"],
                "additionalProperties": False,
            },
        },
        "strict": True,
    },
]

Weave を使用したマルチエージェントの追跡の有効化

以下の処理を行うコードロジックを書く必要があります:
  • ユーザーのクエリをマルチエージェントシステムに渡す
  • マルチエージェントシステムの内部動作を処理する
  • ツール呼び出しを実行する
# クエリの例

user_query = """
以下にいくつかのデータがあります。まず重複を削除してから、データの統計を分析し、折れ線グラフをプロットしてください。

house_size (m3), house_price ($)
90, 100
80, 90
100, 120
90, 100
"""
ユーザーのクエリから、呼び出す必要があるツールは clean_datastart_analysis、および use_line_chart であることが推測できます。 まず、ツール呼び出しを実行する実行関数を定義します。 Python 関数を @weave.op() でデコレートすることで、言語モデルの入力、出力、およびトレースを記録し、デバッグできます。 マルチエージェントシステムを作成すると多くの関数が登場しますが、それらの上に @weave.op() を追加するだけで十分です。
@weave.op()
def clean_data(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")
    df_deduplicated = df.drop_duplicates()
    return df_deduplicated

@weave.op()
def stat_analysis(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")
    return df.describe()

@weave.op()
def plot_line_chart(data):
    data_io = StringIO(data)
    df = pd.read_csv(data_io, sep=",")

    x = df.iloc[:, 0]
    y = df.iloc[:, 1]

    coefficients = np.polyfit(x, y, 1)
    polynomial = np.poly1d(coefficients)
    y_fit = polynomial(x)

    plt.figure(figsize=(10, 6))
    plt.plot(x, y, "o", label="Data Points")
    plt.plot(x, y_fit, "-", label="Best Fit Line")
    plt.title("Line Chart with Best Fit Line")
    plt.xlabel(df.columns[0])
    plt.ylabel(df.columns[1])
    plt.legend()
    plt.grid(True)

    # 表示する前にプロットを BytesIO バッファに保存
    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)

    # プロットを表示
    plt.show()

    # データ URL 用に画像を base64 でエンコード
    image_data = buf.getvalue()
    base64_encoded_data = base64.b64encode(image_data)
    base64_string = base64_encoded_data.decode("utf-8")
    data_url = f"data:image/png;base64,{base64_string}"

    return data_url

# ツールを実行する関数を定義
@weave.op()
def execute_tool(tool_calls, messages):
    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        tool_arguments = json.loads(tool_call.function.arguments)

        if tool_name == "clean_data":
            # データクリーニングをシミュレート
            cleaned_df = clean_data(tool_arguments["data"])
            cleaned_data = {"cleaned_data": cleaned_df.to_dict()}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(cleaned_data)}
            )
            print("Cleaned data: ", cleaned_df)
        elif tool_name == "transform_data":
            # データ変換をシミュレート
            transformed_data = {"transformed_data": "sample_transformed_data"}
            messages.append(
                {
                    "role": "tool",
                    "name": tool_name,
                    "content": json.dumps(transformed_data),
                }
            )
        elif tool_name == "aggregate_data":
            # データ集計をシミュレート
            aggregated_data = {"aggregated_data": "sample_aggregated_data"}
            messages.append(
                {
                    "role": "tool",
                    "name": tool_name,
                    "content": json.dumps(aggregated_data),
                }
            )
        elif tool_name == "stat_analysis":
            # 統計分析をシミュレート
            stats_df = stat_analysis(tool_arguments["data"])
            stats = {"stats": stats_df.to_dict()}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(stats)}
            )
            print("Statistical Analysis: ", stats_df)
        elif tool_name == "correlation_analysis":
            # 相関分析をシミュレート
            correlations = {"correlations": "sample_correlations"}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(correlations)}
            )
        elif tool_name == "regression_analysis":
            # 回帰分析をシミュレート
            regression_results = {"regression_results": "sample_regression_results"}
            messages.append(
                {
                    "role": "tool",
                    "name": tool_name,
                    "content": json.dumps(regression_results),
                }
            )
        elif tool_name == "create_bar_chart":
            # 棒グラフ作成をシミュレート
            bar_chart = {"bar_chart": "sample_bar_chart"}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(bar_chart)}
            )
        elif tool_name == "create_line_chart":
            # 折れ線グラフ作成をシミュレート
            line_chart = {"line_chart": plot_line_chart(tool_arguments["data"])}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(line_chart)}
            )
        elif tool_name == "create_pie_chart":
            # 円グラフ作成をシミュレート
            pie_chart = {"pie_chart": "sample_pie_chart"}
            messages.append(
                {"role": "tool", "name": tool_name, "content": json.dumps(pie_chart)}
            )
    return messages
次に、各サブエージェントのツールハンドラーを作成します。これらは、モデルに渡される独自のアシスタントプロンプトとツールセットを持っています。出力はその後、ツール呼び出しを実行する実行関数に渡されます。
# 各エージェントの処理を処理する関数を定義
@weave.op()
def handle_data_processing_agent(query, conversation_messages):
    messages = [{"role": "system", "content": processing_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=preprocess_tools,
    )

    conversation_messages.append(
        [tool_call.function for tool_call in response.choices[0].message.tool_calls]
    )
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)

@weave.op()
def handle_analysis_agent(query, conversation_messages):
    messages = [{"role": "system", "content": analysis_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=analysis_tools,
    )

    conversation_messages.append(
        [tool_call.function for tool_call in response.choices[0].message.tool_calls]
    )
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)

@weave.op()
def handle_visualization_agent(query, conversation_messages):
    messages = [{"role": "system", "content": visualization_system_prompt}]
    messages.append({"role": "user", "content": query})

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=visualization_tools,
    )

    conversation_messages.append(
        [tool_call.function for tool_call in response.choices[0].message.tool_calls]
    )
    execute_tool(response.choices[0].message.tool_calls, conversation_messages)
最後に、ユーザークエリの処理全体を管理するツールを作成します。この関数はユーザークエリを受け取り、モデルからレスポンスを取得し、実行のために他のエージェントに渡す処理を行います。
# ユーザー入力と振り分けを処理する関数
@weave.op()
def handle_user_message(user_query, conversation_messages=None):
    if conversation_messages is None:
        conversation_messages = []
    user_message = {"role": "user", "content": user_query}
    conversation_messages.append(user_message)

    messages = [{"role": "system", "content": triaging_system_prompt}]
    messages.extend(conversation_messages)

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0,
        tools=triage_tools,
    )

    conversation_messages.append(
        [tool_call.function for tool_call in response.choices[0].message.tool_calls]
    )

    for tool_call in response.choices[0].message.tool_calls:
        if tool_call.function.name == "send_query_to_agents":
            agents = json.loads(tool_call.function.arguments)["agents"]
            query = json.loads(tool_call.function.arguments)["query"]
            for agent in agents:
                if agent == "Data Processing Agent":
                    handle_data_processing_agent(query, conversation_messages)
                elif agent == "Analysis Agent":
                    handle_analysis_agent(query, conversation_messages)
                elif agent == "Visualization Agent":
                    handle_visualization_agent(query, conversation_messages)

    outputs = extract_tool_contents(conversation_messages)

    return outputs

functions = [
    "clean_data",
    "transform_data",
    "stat_analysis",
    "aggregate_data",
    "correlation_analysis",
    "regression_analysis",
    "create_bar_chart",
    "create_line_chart",
    "create_pie_chart",
]

@weave.op()
def extract_tool_contents(data):
    contents = {}
    contents["all"] = data
    for element in data:
        if (
            isinstance(element, dict)
            and element.get("role") == "tool"
            and element.get("name") in functions
        ):
            name = element["name"]
            content_str = element["content"]
            try:
                content_json = json.loads(content_str)
                if "chart" not in element.get("name"):
                    contents[name] = [content_json]
                else:
                    first_key = next(iter(content_json))
                    second_level = content_json[first_key]
                    if isinstance(second_level, dict):
                        second_key = next(iter(second_level))
                        contents[name] = second_level[second_key]
                    else:
                        contents[name] = second_level
            except json.JSONDecodeError:
                print(f"Error decoding JSON for {name}")
                contents[name] = None

    return contents

Weave でのマルチエージェントシステムの実行と可視化

最後に、ユーザーの入力を使用して主要な handle_user_message 関数を実行し、結果を観察します。
handle_user_message(user_query)
Weave の URL をクリックすると、以下のように実行がトレースされていることが確認できます。トレースページでは、入力と出力を確認できます。分かりやすくするために、各出力をクリックしたときに表示される結果のスクリーンショットを図に追加しています。Weave は OpenAI の API とのインテグレーションを提供しており、コストを自動的に計算できます。そのため、右端にコストとレイテンシも表示されていることが確認できます。 1-1.png 行をクリックすると、マルチエージェントシステム内で実行された中間プロセスを確認できます。例えば、analysis_agent の入力と出力を見ると、それが Structured Output 形式であることがわかります。OpenAI の Structured Output はエージェント間の連携を容易にしますが、システムが複雑になるにつれて、これらのやり取りがどのような形式で行われているかを把握するのが難しくなります。Weave を使用すると、これらの中間プロセスとその入出力を、まるで手元で見ているかのように理解することができます。
3.png
Weave でのトレースがどのように処理されるか、ぜひ詳しく見てみてください!

まとめ

このチュートリアルでは、Structured Output と、入力、最終出力、および中間出力形式を追跡するための Weave を使用して、マルチエージェントシステムを便利に開発する方法を学びました。