GPT结合业务探索之多步推理

参考

先是看到了这个文章: https://mp.weixin.qq.com/s/noM8Hy1atF8asuj1Bagpdg
AutoGPT 等自动推理,自动执行的项目火热出现
之前看过飞书的 AI 的介绍视频,里面有创建日程的场景

结合我们的场景

类型场景我们也有, 定会议室, 创建日程

代码实践

核心是应用 langchain 中的 Agent, Tools, Chain 来完成整个步骤
大致关系如下:

Agent

from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from langchain import OpenAI, LLMChain
from typing import List, Union
from langchain.schema import AgentAction, AgentFinish
import re
from langchain.agents import tool
from dateTool import dateCalculator
from dbSearch import employeeSearch
from schedule import getAvailabelSchedule
from session import getPgMemory
from langchain.callbacks import get_openai_callback
from dotenv import load_dotenv
# 引入环境变量
load_dotenv()
llm = OpenAI(temperature=0, verbose=False) # type: ignore

# 根据 uuid 创建一个临时存储, session 模块
history = getPgMemory("xx")
history.clear()

# 错误情况走的 Tool
@tool
def error_log(s) -> str:
"""just output anything received"""
return s

# 组装 Tool, 一个用来推断会议日期, 一个是获取参会人信息, 一个是汇总信息
tools = [
Tool(
name = "Date Calculator",
func = lambda string: dateCalculator(string,history),
description="用于日期相关计算"
),
Tool(
name = "Search",
func=lambda string: employeeSearch(string,history),
description="用于查询员工信息"
# description="用于查询模板,模块和模块配置信息"
),
Tool(
name = "Schedule",
func=lambda string: getAvailabelSchedule(history),
description="useful to get Availabel Schedule, 需要先Date Calculator计算出预约的时间, 在用Search分析出参会人的工号和姓名后才能调用",
return_direct=True
),
Tool(
name = "Error",
func= error_log,
description="当遇到错误,需要终止程序时调用",
return_direct=True
)
]

# Agent 逐步推断用的模板, 根据每一步要做的事情,选择合适的 Tool
template = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:

{tools}
Use the following format(response in Chinese):

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}].
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Arg"s

Question: {input}
{agent_scratchpad}"""

# 自定义的提示模板.会有Thought/Action/Observation 几个部分
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]

def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
return self.template.format(**kwargs)

prompt = CustomPromptTemplate(
template=template,
tools=tools,
# This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically
# This includes the `intermediate_steps` variable because that is needed
input_variables=["input", "intermediate_steps"]
)
# LLM chain consisting of the LLM and a prompt
llm_chain = LLMChain(llm=llm, prompt=prompt,verbose=False)

# 对输出做一些控制, 截取 gpt 返回的内容,并切割
class CustomOutputParser(AgentOutputParser):

def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# Check if agent should finishs
if not llm_output.strip("\n").strip(" ").startswith("Thought:"):
llm_output ="\nThought:"+llm_output

if "Final Answer:" in llm_output:
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
log=llm_output,
)
# Parse out the action and action input
regex = r"Thought\s*\d*\s*:(.*?)\nAction\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
raise ValueError(f"Could not parse LLM output: `{llm_output}`")
thought = match.group(1).strip()
history.add_user_message(thought)
action = match.group(2).strip()
action_input = match.group(3)
# Return the action and action input
if action == "Error":
# 这是错误走的通道
error= history.messages[-2]
# 方便获取最后的错误内容
action_input =f"{action}:{error.content}";
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)

output_parser = CustomOutputParser()

agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:","\n\tObservation:"],
)
# verbose 能控制是否 debug 信息
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, max_iterations=10, verbose=True)

input_str = "给我约一个下周五的会议,参会人有我和张金海参会, 我的工号是 076533"
# input_str = "明天给我和王四大建一个早上的会 我的工号是 076533,创建会议等于'拉会','拉'等于创建的意思"
# input_str = "模板 id 为 72 的模板 中, 模块名字包含晋升的模块 id, 看下这个模块的启用状态"
history.add_user_message(input_str)

# 这是输出 token 计算相关信息
with get_openai_callback() as cb:
result = agent_executor.run(input_str)
print(f"Total Tokens: {cb.total_tokens}")
print(f"Prompt Tokens: {cb.prompt_tokens}")
print(f"Completion Tokens: {cb.completion_tokens}")
print(f"Total Cost (USD): ${cb.total_cost}")

print("Final Answer: ",result)

# 清除这次推理的记录
history.clear()

日期计算 Tool

import datetime
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI

dateLLM = ChatOpenAI(temperature=0,model_name="gpt-3.5-turbo", verbose=False) # type: ignore
system_message_prompt = SystemMessagePromptTemplate.from_template("你将作为一个时间计算器")
human_template = """当前时间是北京时间 {time}. 一周中的第一天是周一.每周是 7 天.请一步一步详细分析计算过程的解释(不要任何代码)并给出结果日期时间(北京时间). in format: mm/dd/yyyy
问题是:{question}"""
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

# 组装 Prompt
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
dateChain = LLMChain(llm=dateLLM, prompt=chat_prompt,verbose=True)

today = datetime.datetime.now()
from langchain.memory import RedisChatMessageHistory # type: ignore
def dateCalculator(date:str,history:RedisChatMessageHistory):
# if '/' not in date or 'mm' in date:
# history.add_ai_message("是的冯绍峰格式不对")
# return "出错了"
weekday = today.weekday()
weekday_str = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'][weekday]
# 推断出要计算的日期,有格式限制
result = dateChain.run(time=f"{today.strftime('%m/%d/%Y, %H:%M:%S')} {weekday_str}",question=date)
# 计算结果计入记录
history.add_ai_message(result)

return f"{result}"

查询员工信息 Tool

这里是连了数据库, 通过自然语言转换为 sql

import os
from langchain import SQLDatabase, SQLDatabaseChain
from langchain.chat_models import ChatOpenAI

# TODO: 只能给只读权限, 不要让他修改数据.账号权限控制
db = SQLDatabase.from_uri(os.environ['TALENT_DB_DIALECT'],include_tables=['employees'])
# talent_position

from langchain.prompts.prompt import PromptTemplate

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}
结果中必须包含每个人的工号(workcode)信息,并解释返回字段数据的含义, 如果没有找到任何结果则只返回Error
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)


llm = ChatOpenAI(temperature=0) # type: ignore
db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT,verbose=True,top_k=3)
from langchain.memory import RedisChatMessageHistory # type: ignore
def employeeSearch(s:str,history:RedisChatMessageHistory):
# 获取他的想法,用想法去转 SQL 比较准.
thought = history.messages[-1]
# top_k 参数好像不好使,所以提示语里强制最多 3 条
result = db_chain.run(f"{thought.content}{s},最多返回3个")
# 空结果处理
if(result.find('Error')>=0):
history.add_ai_message(f"{thought}:没查询到结果")
return f"发生错误"
# 把结果加入记录
history.add_ai_message(result)
return f"{result}"

日程查询 Tool

  • 将日期和查询 Tool 得到的结果汇总处理.
  • 解析出需要的信息
  • 根据业务逻辑去判断空闲时间,或者日程冲突. 返回结果或者错误处理

from langchain import LLMChain, PromptTemplate
from functools import reduce
from langchain.chat_models import ChatOpenAI

# 示例
example="""
Q: 我的工号是 076533, 小王的工号是 098773, 我和小王要在 04/22/2023 那天 9:30 开个会.
A: {
"participant":["076533","098773"],
"meeting_date":"04/22/2023",
"meeting_start_time":"9:30"
}
Q: 我的工号是 076533, 小王的工号是 098773, 我和小王要在 04/22/2023 那天开个会.
A: {
"participant":["076533","098773"],
"meeting_date":"04/22/2023",
"meeting_start_time":""
}
"""
template = """
请从以下对话中抽取出 参会人工号, 参会日期, 参会时间. 并以 JSON 格式返回.如果没有明确说明开始时间,则meeting_start_time为"":
{example}
Q: {chat_history}.
A: """

prompt = PromptTemplate(
input_variables=["chat_history","example"],
template=template,
)
def concatStr(x,y):
return f"{x}\n{y}"

scheduleLLM = ChatOpenAI(temperature=0,model_name="gpt-3.5-turbo", verbose=False) # type: ignore
dateChain = LLMChain(llm=scheduleLLM, prompt=prompt,verbose=True)
from langchain.memory import RedisChatMessageHistory # type: ignore

def getAvailabelSchedule(history:RedisChatMessageHistory):
# 把之前的每个工具的问答都集合起来作为参考提示
chat_history = reduce(concatStr,map(lambda x:x.content,history.messages)) # type: ignore
json_str = dateChain.run(chat_history=chat_history,example=example)
# result = json.loads(json_str)
# TODO: 这里根据获取的 json 来判断和业务系统通信, 获取空闲时间, 或者报错说冲突

# 记录最终结果
history.add_ai_message(json_str)

return f"{json_str}"

Error Tool

当在推断过程中返回 类似出错了 这种意思的时候, agent 会选择 Error tool. 此时推断直接结束.
Tool 定义中的 return_direct=True 代表这个 Tool 返回的结果就直接结束了, 不会让程序再次判断给哪个 Tool 去处理

Session 模块

因为没有找到现成的能记录下各个 Tool 里的历史记录和 agent 的记录. 所以独立找 redis 来自己记录了

import os
from langchain.memory import RedisChatMessageHistory # type: ignore

def getPgMemory(sid:str):
return RedisChatMessageHistory(session_id=sid,url=os.environ['REDIS_HUI'])

示例解析过程(debug)

这个例子中, 日期推算错误,但格式是对的, 这里只需要针对每个 Tool 去做优化迭代就好了. 整个场景上是跑通的

> Entering new AgentExecutor chain...

Thought: 我需要先计算出下周五的日期
Action: Date Calculator
Action Input: 下周五

> Entering new LLMChain chain...
Prompt after formatting:
System: 你将作为一个时间计算器
Human: 当前时间是北京时间 04/21/2023, 14:47:38 星期五. 一周中的第一天是周一.每周是 7 天.请一步一步详细分析计算过程的解释(不要任何代码)并给出结果日期时间(北京时间). in format: mm/dd/yyyy
问题是:下周五

> Finished chain.


Observation:首先确定当前是星期五,那么下周五就是本周的第六天再加上一周的时间,即第13天。因为每周是7天,所以我们需要把13除以7得到商和余数,商表示经过了几个完整的周,余数表示在这个周的第几天。

13 ÷ 7 = 1...6

因此下周五是第二周的第六天,也就是从当前时间往后推8天的那个星期五。

当前日期是04/21/2023,那么下周五的日期就是04/29/2023。
Thought:下一步我需要查询参会人的工号和姓名
Action: Search
Action Input: 张金海, 076533

> Entering new SQLDatabaseChain chain...
下一步我需要查询参会人的工号和姓名张金海, 076533,最多返回3个
SQLQuery:SELECT workcode, name FROM employees WHERE name = '张金海' OR workcode = '076533' LIMIT 3
SQLResult: [('076533', '国杰'), ('164733', '张金海')]
Answer:返回了两个符合条件的结果,第一个结果的工号为'076533',姓名为'国杰',第二个结果的工号为'164733',姓名为'张金海'。工号和姓名的含义分别为员工的工号和姓名。
> Finished chain.


Observation:返回了两个符合条件的结果,第一个结果的工号为'076533',姓名为'国杰',第二个结果的工号为'164733',姓名为'张金海'。工号和姓名的含义分别为员工的工号和姓名。
Thought:下一步我需要查询可用的会议时间
Action: Schedule
Action Input: 04/29/2023

> Entering new LLMChain chain...
Prompt after formatting:

请从以下对话中抽取出 参会人工号, 参会日期, 参会时间. 并以 JSON 格式返回.如果没有明确说明开始时间,则meeting_start_time为"":

Q: 我的工号是 076533, 小王的工号是 098773, 我和小王要在 04/22/2023 那天 9:30 开个会.
A: {
"participant":["076533","098773"],
"meeting_date":"04/22/2023",
"meeting_start_time":"9:30"
}
Q: 我的工号是 076533, 小王的工号是 098773, 我和小王要在 04/22/2023 那天开个会.
A: {
"participant":["076533","098773"],
"meeting_date":"04/22/2023",
"meeting_start_time":""
}

Q: 给我约一个下周五的会议,参会人有我和张金海, 我的工号是076533
我需要先计算出下周五的日期
首先确定当前是星期五,那么下周五就是本周的第六天再加上一周的时间,即第13天。因为每周是7天,所以我们需要把13除以7得到商和余数,商表示经过了几个完整的周,余数表示在这个周的第几天。

13 ÷ 7 = 1...6

因此下周五是第二周的第六天,也就是从当前时间往后推8天的那个星期五。

当前日期是04/21/2023,那么下周五的日期就是04/29/2023。
下一步我需要查询参会人的工号和姓名
返回了两个符合条件的结果,第一个结果的工号为'076533',姓名为'国杰',第二个结果的工号为'164733',姓名为'张金海'。工号和姓名的含义分别为员工的工号和姓名。
下一步我需要查询可用的会议时间.
A:

> Finished chain.


Observation:{
"participant":["076533","164733"],
"meeting_date":"04/29/2023",
"meeting_start_time":""
}

> Finished chain.
Total Tokens: 7538
Prompt Tokens: 7107
Completion Tokens: 431
Total Cost (USD): $0.056440000000000004
Final Answer: {
"participant":["076533","164733"],
"meeting_date":"04/29/2023",
"meeting_start_time":""
}