用 Pydantic 自动生成 LLM Tool Schema

定义 tool 参数后, 不引入其他库, 仅用 Pydantic 自动生成符合 OpenAI 规范的 Tool Schema. 想法很简单, 把 Pydantic 的 model_json_schema 生成的 JSON Schema 处理成 OpenAI 规范即可.

好处是 (1) 不用引入或依赖其他乱七八糟的库; (2) 不用手动额外维护一套工具描述; (3) 能利用 Pydantic 的一些功能, 从 JSON string load 之后自动校验参数, 自动转换类型等.

基础示例

比如

class GetWeatherArgs(BaseModel):
    """Retrieves current weather for the given location."""
    location: str = Field(description="City and country e.g. Bogotá, Colombia")
    units: Literal["celsius", "fahrenheit"] = Field(description="Units the temperature will be returned in.")


def get_weather(args: GetWeatherArgs):
    """实际的工具处理逻辑"""
    pass


get_weather_tool = create_tool_from_pydantic(GetWeatherArgs)
print(json.dumps(get_weather_tool, ensure_ascii=False, indent=2))
{
  "type": "function",
  "function": {
    "name": "get_weather",
    "description": "Retrieves current weather for the given location.",
    "parameters": {
      "type": "object",
      "properties": {
        "location": {
          "type": "string",
          "description": "City and country e.g. Bogotá, Colombia"
        },
        "units": {
          "type": "string",
          "description": "Units the temperature will be returned in.",
          "enum": [
            "celsius",
            "fahrenheit"
          ]
        }
      },
      "required": [
        "location",
        "units"
      ]
    }
  }
}

完整代码

import datetime
import json
import re
import textwrap
from enum import StrEnum
from typing import Type, Literal, Optional, List, Any

import pydantic
from pydantic import BaseModel, Field


def _clean_text(text: str) -> str:
    """清理多行字符串的缩进和行尾空格。"""
    return textwrap.dedent(text).strip()


def _process_property(prop_schema: dict, defs: dict) -> dict:
    """递归地处理单个属性的 Schema,将其转换为 Tool 参数格式。"""
    # 1. 处理 Optional[T],在 Pydantic v2 中表现为 anyOf 包含 'null'
    if 'anyOf' in prop_schema:
        # 找到非 null 的那个 schema 定义
        non_null_schema = next((s for s in prop_schema['anyOf'] if s.get('type') != 'null'), None)
        if non_null_schema:
            # 递归处理,但保留外层的 description
            processed_schema = _process_property(non_null_schema, defs)
            if 'description' in prop_schema:
                processed_schema['description'] = _clean_text(prop_schema['description'])
            return processed_schema
        else:  # 理论上不应该只有 null
            return {}

    # 2. 处理嵌套对象 ($ref)
    if '$ref' in prop_schema:
        ref_name = prop_schema['$ref'].split('/')[-1]
        nested_schema = defs.get(ref_name)
        if nested_schema:
            # 对于嵌套对象,我们再次调用主转换函数
            return pydantic_to_tool_schema(nested_schema, defs)

    # 3. 处理基本类型和数组
    result = {}
    prop_type = prop_schema.get('type')

    if prop_type:
        result['type'] = prop_type
    if 'description' in prop_schema:
        result['description'] = _clean_text(prop_schema['description'])
    if 'enum' in prop_schema:
        result['enum'] = prop_schema['enum']

    # 3a. 处理数组 (List[T])
    if prop_type == 'array' and 'items' in prop_schema:
        # 递归处理数组元素的类型
        result['items'] = _process_property(prop_schema['items'], defs)

    return result


def pydantic_to_tool_schema(schema: dict, defs: dict = None) -> dict:
    """将 Pydantic 的 JSON Schema 转换为 Tool 的 parameters 部分。"""
    if defs is None:
        defs = schema.get('$defs', {})

    tool_params = {
        "type": "object",
        "properties": {},
        "required": schema.get("required", []),
    }

    # 顶层描述 (来自类的 docstring)
    if 'description' in schema:
        tool_params['description'] = _clean_text(schema['description'])

    properties = schema.get("properties", {})
    for name, prop_schema in properties.items():
        tool_params["properties"][name] = _process_property(prop_schema, defs)

    return tool_params


def create_tool_from_pydantic(pydantic_model: Type[BaseModel]) -> dict:
    """
    根据 Pydantic 模型自动创建一个符合 OpenAI 规范的 Tool 定义。

    - 自动从模型类名推断函数名 (例如 GetWeatherArgs -> get_weather)。
    - 自动使用模型的 docstring 作为工具的描述。
    """
    # 1. 从模型类名推断函数名
    model_name = pydantic_model.__name__
    class_name = model_name.removesuffix('Args')
    # 将驼峰命名 (CamelCase) 转换为下划线命名 (snake_case)
    function_name = re.sub(r'(?<!^)(?=[A-Z])', '_', class_name).lower()

    # 2. 生成 Pydantic Schema 并转换为 Tool Schema
    pydantic_schema = pydantic_model.model_json_schema()
    tool_schema = pydantic_to_tool_schema(pydantic_schema)

    description = tool_schema.pop("description", "")  # 描述移动到外层

    # 3. 构建并返回完整的 Tool 定义
    return {
        "type": "function",
        "function": {
            "name": function_name,
            "description": description,
            "parameters": tool_schema,
        },
    }


class GetWeatherArgs(BaseModel):
    """Retrieves current weather for the given location."""
    location: str = Field(description="City and country e.g. Bogotá, Colombia")
    units: Literal["celsius", "fahrenheit"] = Field(description="Units the temperature will be returned in.")


def get_weather(args: GetWeatherArgs):
    """实际的工具处理逻辑"""
    pass


get_weather_tool = create_tool_from_pydantic(GetWeatherArgs)
print(json.dumps(get_weather_tool, ensure_ascii=False, indent=2))

复杂点的例子

可以定义嵌套模型, 枚举类型, 添加自定义校验逻辑等. 下面的 SearchFilesArgs 模型演示了如何处理文件搜索场景, 它包含了对文件类型 (FileType 枚举) 和创建时间 (嵌套的 TimeRange 模型) 的筛选.

我们还定义了一个 LLMProofBaseModel 基类, 能自动处理来自 LLM 的 'null' 字符串输入. 嵌套的 TimeRange 模型中的校验器 check_dates 也展示了如何在数据模型层面封装业务规则.

# --- 接上一段代码 ---

class LLMProofBaseModel(BaseModel):
    """自动将所有字段中值为字符串 'null' 的输入转换为 None"""
    @pydantic.field_validator('*', mode='before')
    @classmethod
    def _clean_null_str(cls, v: Any) -> Any:
        if isinstance(v, str) and v.lower() == 'null':
            return None
        return v


class TimeRange(LLMProofBaseModel):
    """这个 docstring 不会用到"""
    start_date: Optional[datetime.date] = Field(None, description="开始日期 (YYYY-MM-DD)")
    end_date: Optional[datetime.date] = Field(None, description="结束日期 (YYYY-MM-DD)")
    random_field: Optional[str] = Field(None, description='演示用')

    @pydantic.model_validator(mode='after')
    def check_dates(self) -> 'TimeRange':
        if self.start_date and self.end_date and self.start_date > self.end_date:
            # 抛出错误或者其他处理方式
            self.end_date = self.start_date
        return self


class FileType(StrEnum):
    PDF = "pdf"
    PPT = "ppt"


class SearchFilesArgs(LLMProofBaseModel):
    """
    搜索文件

    多行示例
    - xx
    - yy
    """
    query: str = Field(description="根据用户问题提炼出的核心搜索查询语句")
    file_types: Optional[List[Literal[*FileType]]] = Field(None, description="文件类型")
    time_range: Optional[TimeRange] = Field(None, description="文件创建时间范围")


search_file_tool = create_tool_from_pydantic(SearchFilesArgs)

tools = [
    get_weather_tool,
    search_file_tool,
]
print(json.dumps(tools, ensure_ascii=False, indent=2))

args1 = GetWeatherArgs.model_validate({"location": "Bogotá, Colombia", "units": "celsius"})
args2 = SearchFilesArgs.model_validate(
    {
        "query": "年报", "file_types": ["pdf"],
        "time_range": {"start_date": "2025-01-01", "end_date": "2024-01-01", "random_field": "null"},
    }
)
[
  {
    "type": "function",
    "function": {
      "name": "get_weather",
      "description": "Retrieves current weather for the given location.",
      "parameters": {
        "type": "object",
        "properties": {
          "location": {
            "type": "string",
            "description": "City and country e.g. Bogotá, Colombia"
          },
          "units": {
            "type": "string",
            "description": "Units the temperature will be returned in.",
            "enum": [
              "celsius",
              "fahrenheit"
            ]
          }
        },
        "required": [
          "location",
          "units"
        ]
      }
    }
  },
  {
    "type": "function",
    "function": {
      "name": "search_files",
      "description": "搜索文件\n\n多行示例\n- xx\n- yy",
      "parameters": {
        "type": "object",
        "properties": {
          "query": {
            "type": "string",
            "description": "根据用户问题提炼出的核心搜索查询语句"
          },
          "file_types": {
            "type": "array",
            "items": {
              "type": "string",
              "enum": [
                "pdf",
                "ppt"
              ]
            },
            "description": "文件类型"
          },
          "time_range": {
            "type": "object",
            "properties": {
              "start_date": {
                "type": "string",
                "description": "开始日期 (YYYY-MM-DD)"
              },
              "end_date": {
                "type": "string",
                "description": "结束日期 (YYYY-MM-DD)"
              },
              "random_field": {
                "type": "string",
                "description": "演示用"
              }
            },
            "required": [],
            "description": "文件创建时间范围"
          }
        },
        "required": [
          "query"
        ]
      }
    }
  }
]