This commit is contained in:
Obada Khalili 2025-03-21 15:04:12 +08:00 committed by GitHub
commit cd1c2769d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 239 additions and 206 deletions

View File

@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: def run(
self, message: Message, query: str, **kwargs: Any
) -> Union[Generator[LLMResultChunk, None, None], LLMResult]:
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
@ -72,55 +74,98 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_instance = self.model_instance model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps: final_prompt_messages = None
function_call_state = False final_system_fingerprint = None
if iteration_step == max_iteration_steps: def response_generator() -> Generator[LLMResultChunk, None, None]:
# the last iteration, remove all tools nonlocal \
prompt_messages_tools = [] function_call_state, \
function_call_state, \
iteration_step, \
prompt_messages_tools, \
final_answer, \
final_prompt_messages, \
final_system_fingerprint
message_file_ids: list[str] = [] while function_call_state and iteration_step <= max_iteration_steps:
agent_thought = self.create_agent_thought( function_call_state = False
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens if iteration_step == max_iteration_steps:
prompt_messages = self._organize_prompt_messages() # the last iteration, remove all tools
self.recalc_llm_max_tokens(self.model_config, prompt_messages) prompt_messages_tools = []
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
tool_calls: list[tuple[str, str, dict[str, Any]]] = [] message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# save full response # recalc llm max tokens
response = "" prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
# save tool call names and inputs tool_calls: list[tuple[str, str, dict[str, Any]]] = []
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None # save full response
response = ""
if isinstance(chunks, Generator): # save tool call names and inputs
is_first_chunk = True tool_call_names = ""
for chunk in chunks: tool_call_inputs = ""
if is_first_chunk:
self.queue_manager.publish( current_llm_usage = None
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) if isinstance(chunks, Generator):
is_first_chunk = False is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id),
PublishFrom.APPLICATION_MANAGER,
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
final_prompt_messages = chunk.prompt_messages
final_system_fingerprint = chunk.system_fingerprint
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result = chunks
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps(
@ -130,189 +175,180 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content: if result.usage:
if isinstance(chunk.delta.message.content, list): increase_usage(llm_usage, result.usage)
for content in chunk.delta.message.content: current_llm_usage = result.usage
final_prompt_messages = result.prompt_messages
final_system_fingerprint = result.system_fingerprint
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data response += content.data
else: else:
response += str(chunk.delta.message.content) response += str(result.message.content)
if chunk.delta.usage: if not result.message.content:
increase_usage(llm_usage, chunk.delta.usage) result.message.content = ""
current_llm_usage = chunk.delta.usage
yield chunk self.queue_manager.publish(
else: QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
result = chunks )
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage: yield LLMResultChunk(
increase_usage(llm_usage, result.usage) model=model_instance.model,
current_llm_usage = result.usage prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
if result.message and result.message.content: delta=LLMResultChunkDelta(
if isinstance(result.message.content, list): index=0,
for content in result.message.content: message=result.message,
response += content.data usage=result.usage,
else:
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
),
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
), ),
) )
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message) assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
# save thought assistant_message.tool_calls = [
self.save_agent_thought( AssistantPromptMessage.ToolCall(
agent_thought=agent_thought, id=tool_call[0],
tool_name=tool_call_names, type="function",
tool_input=tool_call_inputs, function=AssistantPromptMessage.ToolCall.ToolCallFunction(
thought=response, name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
tool_invoke_meta=None, ),
observation=None, )
answer=response, for tool_call in tool_calls
messages_ids=[], ]
llm_usage=current_llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else: else:
# invoke tool assistant_message.content = response
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = { self._current_thoughts.append(assistant_message)
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response) # save thought
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name="", tool_name=tool_call_names,
tool_input="", tool_input=tool_call_inputs,
thought="", thought=response,
tool_invoke_meta={ tool_invoke_meta=None,
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses observation=None,
}, answer=response,
observation={ messages_ids=[],
tool_response["tool_call_name"]: tool_response["tool_response"] llm_usage=current_llm_usage,
for tool_response in tool_responses
},
answer="",
messages_ids=message_file_ids,
) )
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) )
# update prompt tool final_answer += response
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1 # call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(
f"there is not a tool named {tool_call_name}"
).to_dict(),
}
else:
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
# publish end event tool_response = {
self.queue_manager.publish( "tool_call_id": tool_call_id,
QueueMessageEndEvent( "tool_call_name": tool_call_name,
llm_result=LLMResult( "tool_response": tool_invoke_response,
model=model_instance.model, "meta": tool_invoke_meta.to_dict(),
prompt_messages=prompt_messages, }
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(), tool_responses.append(tool_response)
system_fingerprint="", if tool_response["tool_response"] is not None:
) self._current_thoughts.append(
), ToolPromptMessage(
PublishFrom.APPLICATION_MANAGER, content=str(tool_response["tool_response"]),
) tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
chunk_generator = response_generator()
if app_generate_entity.stream:
return chunk_generator
else:
list(chunk_generator)
return LLMResult(
model=model_instance.model,
prompt_messages=final_prompt_messages or [],
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint=final_system_fingerprint or "",
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
""" """

View File

@ -82,9 +82,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get("query"): if not args.get("query"):
raise ValueError("query is required") raise ValueError("query is required")