Merge 15160646fd
into a30945312a
This commit is contained in:
commit
cd1c2769d9
@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
def run(
|
||||||
|
self, message: Message, query: str, **kwargs: Any
|
||||||
|
) -> Union[Generator[LLMResultChunk, None, None], LLMResult]:
|
||||||
"""
|
"""
|
||||||
Run FunctionCall agent application
|
Run FunctionCall agent application
|
||||||
"""
|
"""
|
||||||
@ -72,6 +74,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
|
|
||||||
model_instance = self.model_instance
|
model_instance = self.model_instance
|
||||||
|
|
||||||
|
final_prompt_messages = None
|
||||||
|
final_system_fingerprint = None
|
||||||
|
|
||||||
|
def response_generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
nonlocal \
|
||||||
|
function_call_state, \
|
||||||
|
function_call_state, \
|
||||||
|
iteration_step, \
|
||||||
|
prompt_messages_tools, \
|
||||||
|
final_answer, \
|
||||||
|
final_prompt_messages, \
|
||||||
|
final_system_fingerprint
|
||||||
|
|
||||||
while function_call_state and iteration_step <= max_iteration_steps:
|
while function_call_state and iteration_step <= max_iteration_steps:
|
||||||
function_call_state = False
|
function_call_state = False
|
||||||
|
|
||||||
@ -114,7 +129,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if is_first_chunk:
|
if is_first_chunk:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
)
|
)
|
||||||
is_first_chunk = False
|
is_first_chunk = False
|
||||||
# check if there is any tool call
|
# check if there is any tool call
|
||||||
@ -130,6 +146,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
# ensure ascii to avoid encoding error
|
# ensure ascii to avoid encoding error
|
||||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||||
|
|
||||||
|
final_prompt_messages = chunk.prompt_messages
|
||||||
|
final_system_fingerprint = chunk.system_fingerprint
|
||||||
if chunk.delta.message and chunk.delta.message.content:
|
if chunk.delta.message and chunk.delta.message.content:
|
||||||
if isinstance(chunk.delta.message.content, list):
|
if isinstance(chunk.delta.message.content, list):
|
||||||
for content in chunk.delta.message.content:
|
for content in chunk.delta.message.content:
|
||||||
@ -161,6 +179,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
increase_usage(llm_usage, result.usage)
|
increase_usage(llm_usage, result.usage)
|
||||||
current_llm_usage = result.usage
|
current_llm_usage = result.usage
|
||||||
|
|
||||||
|
final_prompt_messages = result.prompt_messages
|
||||||
|
final_system_fingerprint = result.system_fingerprint
|
||||||
if result.message and result.message.content:
|
if result.message and result.message.content:
|
||||||
if isinstance(result.message.content, list):
|
if isinstance(result.message.content, list):
|
||||||
for content in result.message.content:
|
for content in result.message.content:
|
||||||
@ -219,7 +239,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
final_answer += response + "\n"
|
final_answer += response
|
||||||
|
|
||||||
# call tools
|
# call tools
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
@ -230,7 +250,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
"tool_call_id": tool_call_id,
|
"tool_call_id": tool_call_id,
|
||||||
"tool_call_name": tool_call_name,
|
"tool_call_name": tool_call_name,
|
||||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
"meta": ToolInvokeMeta.error_instance(
|
||||||
|
f"there is not a tool named {tool_call_name}"
|
||||||
|
).to_dict(),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# invoke tool
|
# invoke tool
|
||||||
@ -314,6 +336,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
PublishFrom.APPLICATION_MANAGER,
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chunk_generator = response_generator()
|
||||||
|
|
||||||
|
if app_generate_entity.stream:
|
||||||
|
return chunk_generator
|
||||||
|
else:
|
||||||
|
list(chunk_generator)
|
||||||
|
return LLMResult(
|
||||||
|
model=model_instance.model,
|
||||||
|
prompt_messages=final_prompt_messages or [],
|
||||||
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
|
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||||
|
system_fingerprint=final_system_fingerprint or "",
|
||||||
|
)
|
||||||
|
|
||||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if there is any tool call in llm result chunk
|
Check if there is any tool call in llm result chunk
|
||||||
|
@ -82,9 +82,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
:param invoke_from: invoke from source
|
:param invoke_from: invoke from source
|
||||||
:param stream: is stream
|
:param stream: is stream
|
||||||
"""
|
"""
|
||||||
if not streaming:
|
|
||||||
raise ValueError("Agent Chat App does not support blocking mode")
|
|
||||||
|
|
||||||
if not args.get("query"):
|
if not args.get("query"):
|
||||||
raise ValueError("query is required")
|
raise ValueError("query is required")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user