This commit is contained in:
Obada Khalili 2025-03-21 15:04:12 +08:00 committed by GitHub
commit cd1c2769d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 239 additions and 206 deletions

View File

@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: def run(
self, message: Message, query: str, **kwargs: Any
) -> Union[Generator[LLMResultChunk, None, None], LLMResult]:
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
@ -72,6 +74,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_instance = self.model_instance model_instance = self.model_instance
final_prompt_messages = None
final_system_fingerprint = None
def response_generator() -> Generator[LLMResultChunk, None, None]:
nonlocal \
function_call_state, \
function_call_state, \
iteration_step, \
prompt_messages_tools, \
final_answer, \
final_prompt_messages, \
final_system_fingerprint
while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False function_call_state = False
@ -114,7 +129,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought.id),
PublishFrom.APPLICATION_MANAGER,
) )
is_first_chunk = False is_first_chunk = False
# check if there is any tool call # check if there is any tool call
@ -130,6 +146,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
final_prompt_messages = chunk.prompt_messages
final_system_fingerprint = chunk.system_fingerprint
if chunk.delta.message and chunk.delta.message.content: if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list): if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content: for content in chunk.delta.message.content:
@ -161,6 +179,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
increase_usage(llm_usage, result.usage) increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage current_llm_usage = result.usage
final_prompt_messages = result.prompt_messages
final_system_fingerprint = result.system_fingerprint
if result.message and result.message.content: if result.message and result.message.content:
if isinstance(result.message.content, list): if isinstance(result.message.content, list):
for content in result.message.content: for content in result.message.content:
@ -219,7 +239,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) )
final_answer += response + "\n" final_answer += response
# call tools # call tools
tool_responses = [] tool_responses = []
@ -230,7 +250,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}", "tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), "meta": ToolInvokeMeta.error_instance(
f"there is not a tool named {tool_call_name}"
).to_dict(),
} }
else: else:
# invoke tool # invoke tool
@ -314,6 +336,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
PublishFrom.APPLICATION_MANAGER, PublishFrom.APPLICATION_MANAGER,
) )
chunk_generator = response_generator()
if app_generate_entity.stream:
return chunk_generator
else:
list(chunk_generator)
return LLMResult(
model=model_instance.model,
prompt_messages=final_prompt_messages or [],
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint=final_system_fingerprint or "",
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
""" """
Check if there is any tool call in llm result chunk Check if there is any tool call in llm result chunk

View File

@ -82,9 +82,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get("query"): if not args.get("query"):
raise ValueError("query is required") raise ValueError("query is required")