文章目录
- BaseTool 源码分析
- 核心属性
- 以 `TavilySearchResults(BaseTool)` 为例
- name
- description
- args_schema
- response_format
- 查询选项属性
- 需要子类实现的抽象方法
- 以 `TavilySearchResults(BaseTool)` 为例
- 核心方法
- `arun()`:`run()`的异步执行版本
- `invoke()`和`ainvoke()`
BaseTool 源码分析
BaseTool 是 LangChain 框架中定义 tools 的模板类
核心属性
-
name
:表示 tool 唯一名称的字符串(用于识别) -
description
:对如何 / 何时 / 为何使用该 tool 的描述,帮助模型决定什么时候调用该 tool -
args_schema
:验证工具输入参数的 Pydantic model 或 schema -
return_direct
:如果为True
,则立即返回 tool 的输出 -
responcse_format
:定义 tool 的响应格式
以 TavilySearchResults(BaseTool)
为例
name
name: str = "tavily_search_results_json"
description
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
args_schema
class TavilyInput(BaseModel):
"""Input for the Tavily tool."""
query: str = Field(description="search query to look up")
# 输入将遵循 TavilyInput 类中定义的架构规则
# 同时args_schema的值必须是BaseModel派生类
args_schema: Type[BaseModel] = TavilyInput
- 按照
TavilyInput
的规则,如果输入没有提供query
值,将抛出一个验证错误 -
Field
函数用于向字段添加元数据(描述)
response_format
response_format: Literal["content_and_artifact"] = "content_and_artifact"
- 使用 Literal 来确保某些值被限制为特定文字
@_LiteralSpecialForm
@_tp_cache(typed=True)
def Literal(self, *parameters):
"""Special typing form to define literal types (a.k.a. value types).
This form can be used to indicate to type checkers that the corresponding
variable or function parameter has a value equivalent to the provided
literal (or one of several literals):
def validate_simple(data: Any) -> Literal[True]: # always returns True
...
MODE = Literal['r', 'rb', 'w', 'wb']
def open_helper(file: str, mode: MODE) -> str:
...
open_helper('/some/path', 'r') # Passes type check
open_helper('/other/path', 'typo') # Error in type checker
Literal[...] cannot be subclassed. At runtime, an arbitrary value
is allowed as type argument to Literal[...], but type checkers may
impose restrictions.
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
parameters = _flatten_literal_params(parameters)
try:
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
except TypeError: # unhashable parameters
pass
return _LiteralGenericAlias(self, parameters)
查询选项属性
-
**max_results
:返回的最大结果数量,默认为 5。 -
**search_depth
:查询的深度,可以是"basic"
或"advanced"
,默认是"advanced"
。 -
**include_domains
:一个包含在结果中的域名列表(默认为空,即包含所有域名)。 -
exclude_domains
:一个排除在结果之外的域名列表。 -
include_answer
:是否在结果中包含简短答案,默认值为False
。 -
include_raw_content
:是否返回 HTML 原始内容的解析结果(默认关闭)。 -
include_images
:是否在结果中包含相关图片,默认值为False
。
需要子类实现的抽象方法
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing.
"""
以 TavilySearchResults(BaseTool)
为例
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
-
api_wrapper
是一个TavilySearchAPIWrapper
实例,用于封装 API 调用的细节
class TavilySearchAPIWrapper(BaseModel):
"""Wrapper for Tavily Search API."""
tavily_api_key: SecretStr
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
def raw_results(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> Dict:
params = {
"api_key": self.tavily_api_key.get_secret_value(),
"query": query,
"max_results": max_results,
"search_depth": search_depth,
"include_domains": include_domains,
"exclude_domains": exclude_domains,
"include_answer": include_answer,
"include_raw_content": include_raw_content,
"include_images": include_images,
}
response = requests.post(
# type: ignore
f"{TAVILY_API_URL}/search",
json=params,
)
response.raise_for_status()
return response.json()
def results(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
"""Run query through Tavily Search and return metadata.
Args:
query: The query to search for.
max_results: The maximum number of results to return.
search_depth: The depth of the search. Can be "basic" or "advanced".
include_domains: A list of domains to include in the search.
exclude_domains: A list of domains to exclude from the search.
include_answer: Whether to include the answer in the results.
include_raw_content: Whether to include the raw content in the results.
include_images: Whether to include images in the results.
Returns:
query: The query that was searched for.
follow_up_questions: A list of follow up questions.
response_time: The response time of the query.
answer: The answer to the query.
images: A list of images.
results: A list of dictionaries containing the results:
title: The title of the result.
url: The url of the result.
content: The content of the result.
score: The score of the result.
raw_content: The raw content of the result.
"""
raw_search_results = self.raw_results(
query,
max_results=max_results,
search_depth=search_depth,
include_domains=include_domains,
exclude_domains=exclude_domains,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_images=include_images,
)
return self.clean_results(raw_search_results["results"])
async def raw_results_async(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> Dict:
"""Get results from the Tavily Search API asynchronously."""
# Function to perform the API call
async def fetch() -> str:
params = {
"api_key": self.tavily_api_key.get_secret_value(),
"query": query,
"max_results": max_results,
"search_depth": search_depth,
"include_domains": include_domains,
"exclude_domains": exclude_domains,
"include_answer": include_answer,
"include_raw_content": include_raw_content,
"include_images": include_images,
}
async with aiohttp.ClientSession() as session:
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
if res.status == 200:
data = await res.text()
return data
else:
raise Exception(f"Error {res.status}: {res.reason}")
results_json_str = await fetch()
return json.loads(results_json_str)
async def results_async(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
results_json = await self.raw_results_async(
query=query,
max_results=max_results,
search_depth=search_depth,
include_domains=include_domains,
exclude_domains=exclude_domains,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_images=include_images,
)
return self.clean_results(results_json["results"])
def clean_results(self, results: List[Dict]) -> List[Dict]:
"""Clean results from Tavily Search API."""
clean_results = []
for result in results:
clean_results.append(
{
"url": result["url"],
"content": result["content"],
}
)
return clean_results
-
raw_results()
:同步调用 API。 -
raw_results_async()
:异步调用 API。 -
clean_results()
:清理和格式化查询结果。
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
"""Use the tool."""
# TODO: remove try/except, should be handled by BaseTool
try:
raw_results = self.api_wrapper.raw_results(
query,
self.max_results,
self.search_depth,
self.include_domains,
self.exclude_domains,
self.include_answer,
self.include_raw_content,
self.include_images,
)
except Exception as e:
return repr(e), {}
return self.api_wrapper.clean_results(raw_results["results"]), raw_results
- 传入查询参数,调用 TavilySearchAPIWrapper 来获取结果。
- 如果查询失败,则返回错误信息。
核心方法
arun()
:run()
的异步执行版本
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously.
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing.
"""
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
- 若具有
run_manager
参数,则转换为同步版本,然后使用默认执行器异步运行self._run
方法 -
run_in_executor
是一个异步执行器,它允许你在不同的执行器中运行同步代码,而不会阻塞当前的事件循环
invoke()
和ainvoke()
def invoke(
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return self.run(tool_input, **kwargs)
async def ainvoke(
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return await self.arun(tool_input, **kwargs)
- 充当执行工具逻辑的入口点
- 准备输入参数,并在内部调用
run()
或arun()