diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 55fd8825de..64075ed231 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,6 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from datetime import datetime, timezone from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -23,6 +22,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, @@ -31,18 +31,15 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, - ToolRuntimeVariablePool, ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought -from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -59,11 +56,9 @@ class BaseAgentRunner(AppRunner): queue_manager: AppQueueManager, message: Message, user_id: str, + model_instance: ModelInstance, memory: Optional[TokenBufferMemory] = None, prompt_messages: Optional[list[PromptMessage]] = None, - variables_pool: Optional[ToolRuntimeVariablePool] = None, - db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None, ) -> None: """ Agent runner @@ -93,8 +88,6 @@ class BaseAgentRunner(AppRunner): self.user_id = user_id self.memory = memory self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) - self.variables_pool = variables_pool - self.db_variables_pool = db_variables self.model_instance = model_instance # init callback @@ -162,11 +155,10 @@ class BaseAgentRunner(AppRunner): agent_tool=tool, invoke_from=self.application_generate_entity.invoke_from, ) - tool_entity.load_variables(self.variables_pool) - + assert tool_entity.entity.description message_tool = PromptMessageTool( name=tool.tool_name, - description=tool_entity.description.llm, + description=tool_entity.entity.description.llm, parameters={ "type": "object", "properties": {}, @@ -201,9 +193,11 @@ class BaseAgentRunner(AppRunner): """ convert dataset retriever tool to prompt message tool """ + assert tool.entity.description + prompt_tool = PromptMessageTool( - name=tool.identity.name, - description=tool.description.llm, + name=tool.entity.identity.name, + description=tool.entity.description.llm, parameters={ "type": "object", "properties": {}, @@ -232,7 +226,7 @@ class BaseAgentRunner(AppRunner): tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools if self.app_config.agent else []: + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -249,7 +243,7 @@ class BaseAgentRunner(AppRunner): # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + tool_instances[dataset_tool.entity.identity.name] = dataset_tool return tool_instances, prompt_messages_tools @@ -328,25 +322,29 @@ class BaseAgentRunner(AppRunner): def save_agent_thought( self, agent_thought: MessageAgentThought, - tool_name: str, - tool_input: Union[str, dict], - thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], - answer: str, + tool_name: str | None, + tool_input: Union[str, dict, None], + thought: str | None, + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], + answer: str | None, messages_ids: list[str], - llm_usage: LLMUsage = None, - ) -> MessageAgentThought: + llm_usage: LLMUsage | None = None, + ): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + updated_agent_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not updated_agent_thought: + raise ValueError("agent thought not found") if thought is not None: - agent_thought.thought = thought + updated_agent_thought.thought = thought if tool_name is not None: - agent_thought.tool = tool_name + updated_agent_thought.tool = tool_name if tool_input is not None: if isinstance(tool_input, dict): @@ -355,7 +353,7 @@ class BaseAgentRunner(AppRunner): except Exception as e: tool_input = json.dumps(tool_input) - agent_thought.tool_input = tool_input + updated_agent_thought.tool_input = tool_input if observation is not None: if isinstance(observation, dict): @@ -364,27 +362,27 @@ class BaseAgentRunner(AppRunner): except Exception as e: observation = json.dumps(observation) - agent_thought.observation = observation + updated_agent_thought.observation = observation if answer is not None: - agent_thought.answer = answer + updated_agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: - agent_thought.message_files = json.dumps(messages_ids) + updated_agent_thought.message_files = json.dumps(messages_ids) if llm_usage: - agent_thought.message_token = llm_usage.prompt_tokens - agent_thought.message_price_unit = llm_usage.prompt_price_unit - agent_thought.message_unit_price = llm_usage.prompt_unit_price - agent_thought.answer_token = llm_usage.completion_tokens - agent_thought.answer_price_unit = llm_usage.completion_price_unit - agent_thought.answer_unit_price = llm_usage.completion_unit_price - agent_thought.tokens = llm_usage.total_tokens - agent_thought.total_price = llm_usage.total_price + updated_agent_thought.message_token = llm_usage.prompt_tokens + updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit + updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price + updated_agent_thought.answer_token = llm_usage.completion_tokens + updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit + updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price + updated_agent_thought.tokens = llm_usage.total_tokens + updated_agent_thought.total_price = llm_usage.total_price # check if tool labels is not empty - labels = agent_thought.tool_labels or {} - tools = agent_thought.tool.split(";") if agent_thought.tool else [] + labels = updated_agent_thought.tool_labels or {} + tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] for tool in tools: if not tool: continue @@ -395,7 +393,7 @@ class BaseAgentRunner(AppRunner): else: labels[tool] = {"en_US": tool, "zh_Hans": tool} - agent_thought.tool_labels_str = json.dumps(labels) + updated_agent_thought.tool_labels_str = json.dumps(labels) if tool_invoke_meta is not None: if isinstance(tool_invoke_meta, dict): @@ -404,28 +402,11 @@ class BaseAgentRunner(AppRunner): except Exception as e: tool_invoke_meta = json.dumps(tool_invoke_meta) - agent_thought.tool_meta_str = tool_invoke_meta + updated_agent_thought.tool_meta_str = tool_invoke_meta db.session.commit() db.session.close() - def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): - """ - convert tool variables to db variables - """ - db_variables = ( - db.session.query(ToolConversationVariables) - .filter( - ToolConversationVariables.conversation_id == self.message.conversation_id, - ) - .first() - ) - - db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) - db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) - db.session.commit() - db.session.close() - def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize agent history @@ -515,6 +496,7 @@ class BaseAgentRunner(AppRunner): files = message.message_files if files: + assert message.app_model_config file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: @@ -525,7 +507,7 @@ class BaseAgentRunner(AppRunner): if not file_objs: return UserPromptMessage(content=message.query) else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] + prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)] for file_obj in file_objs: prompt_message_contents.append(file_obj.prompt_message_content) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0d74b1e5eb..1e62b4308d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from typing import Optional, Union from core.agent.base_agent_runner import BaseAgentRunner @@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) @@ -26,11 +27,11 @@ from models.model import Message class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] = None - _agent_scratchpad: list[AgentScratchpadUnit] = None - _instruction: str = None - _query: str = None - _prompt_messages_tools: list[PromptMessage] = None + _historic_prompt_messages: list[PromptMessage] + _agent_scratchpad: list[AgentScratchpadUnit] + _instruction: str + _query: str + _prompt_messages_tools: Sequence[PromptMessageTool] def run( self, @@ -41,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ Run Cot agent application """ + app_generate_entity = self.application_generate_entity self._repack_app_generate_entity(app_generate_entity) self._init_react_state(query) @@ -53,9 +55,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config + assert app_config.agent # init instruction inputs = inputs or {} + assert app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) @@ -63,13 +67,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # convert tools into ModelRuntime Tool format - tool_instances, self._prompt_messages_tools = self._init_prompt_tools() + tool_instances, prompt_messages_tools = self._init_prompt_tools() + self._prompt_messages_tools = prompt_messages_tools function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -115,10 +120,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): callbacks=[], ) - # check llm result - if not chunks: - raise ValueError("failed to invoke llm") - usage_dict = {} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( @@ -139,11 +140,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk # detect action + assert scratchpad.agent_response is not None scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action = action else: + assert scratchpad.agent_response is not None scratchpad.agent_response += chunk + assert scratchpad.thought is not None scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, @@ -152,6 +156,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) + assert scratchpad.thought is not None scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) @@ -168,7 +173,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_invoke_meta={}, thought=scratchpad.thought, observation="", - answer=scratchpad.agent_response, + answer=scratchpad.agent_response or "", messages_ids=[], llm_usage=usage_dict["usage"], ) @@ -248,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): messages_ids=[], ) - self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -266,7 +270,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): def _handle_invoke_action( self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], + tool_instances: Mapping[str, Tool], message_file_ids: list[str], trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, ToolInvokeMeta]: @@ -307,15 +311,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) - # publish message file self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER ) # add message file ids - message_file_ids.append(message_file_id) + message_file_ids.append(message_file_id.id) return tool_invoke_response, tool_invoke_meta @@ -369,18 +370,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): return message def _organize_historic_prompt_messages( - self, current_session_messages: list[PromptMessage] = None + self, current_session_messages: list[PromptMessage] | None = None ) -> list[PromptMessage]: """ organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit = None + current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: + assert isinstance(message.content, str) current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content or "I am thinking about how to help you", @@ -400,6 +402,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): pass elif isinstance(message, ToolPromptMessage): if current_scratchpad: + assert isinstance(message.content, str) current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index bdec6b7ed1..095f8775ae 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -4,6 +4,7 @@ from core.agent.cot_agent_runner import CotAgentRunner from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -16,6 +17,9 @@ class CotChatAgentRunner(CotAgentRunner): """ Organize system prompt """ + assert self.app_config.agent + assert self.app_config.agent.prompt + prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt @@ -27,12 +31,12 @@ class CotChatAgentRunner(CotAgentRunner): return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] for file_obj in self.files: prompt_message_contents.append(file_obj.prompt_message_content) @@ -57,8 +61,10 @@ class CotChatAgentRunner(CotAgentRunner): assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): + assert isinstance(assistant_message.content, str) assistant_message.content += f"Final Answer: {unit.agent_response}" else: + assert isinstance(assistant_message.content, str) assistant_message.content += f"Thought: {unit.thought}\n\n" if unit.action_str: assistant_message.content += f"Action: {unit.action_str}\n\n" diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 13164e0bfc..991c542846 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Generator from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom @@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageContentType, SystemPromptMessage, TextPromptMessageContent, @@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner): # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() + assert app_config.agent + iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # continue to run until there is not any tool call function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): current_llm_usage = None - if self.stream_tool_call: + if isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: if is_first_chunk: @@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): yield chunk else: - result: LLMResult = chunks + result = chunks # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True @@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) - # publish message file self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER ) # add message file ids - message_file_ids.append(message_file_id) + message_file_ids.append(message_file_id.id) tool_response = { "tool_call_id": tool_call_id, @@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): iteration_step += 1 - self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): return True return False - def extract_tool_calls( - self, llm_result_chunk: LLMResultChunk - ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: """ Extract tool calls from llm result chunk @@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): return tool_calls - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: """ Extract blocking tool calls from llm result @@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): return tool_calls def _init_system_message( - self, prompt_template: str, prompt_messages: list[PromptMessage] = None + self, prompt_template: str, prompt_messages: list[PromptMessage] ) -> list[PromptMessage]: """ Initialize system message @@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)] for file_obj in self.files: prompt_message_contents.append(file_obj.prompt_message_content) diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 45b1bf0093..3a9262f54e 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -16,10 +16,8 @@ from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError -from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought -from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -174,14 +172,6 @@ class AgentChatAppRunner(AppRunner): agent_entity = app_config.agent - # load tool variables - tool_conversation_variables = self._load_tool_variables( - conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id - ) - - # convert db variables to tool variables - tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) - # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, @@ -234,8 +224,6 @@ class AgentChatAppRunner(AppRunner): user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, - variables_pool=tool_variables, - db_variables=tool_conversation_variables, model_instance=model_instance, ) @@ -253,50 +241,6 @@ class AgentChatAppRunner(AppRunner): agent=True, ) - def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: - """ - load tool variables from database - """ - tool_variables: ToolConversationVariables = ( - db.session.query(ToolConversationVariables) - .filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id, - ) - .first() - ) - - if tool_variables: - # save tool variables to session, so that we can update it later - db.session.add(tool_variables) - else: - # create new tool variables - tool_variables = ToolConversationVariables( - conversation_id=conversation_id, - user_id=user_id, - tenant_id=tenant_id, - variables_str="[]", - ) - db.session.add(tool_variables) - db.session.commit() - - return tool_variables - - def _convert_db_variables_to_tool_variables( - self, db_variables: ToolConversationVariables - ) -> ToolRuntimeVariablePool: - """ - convert db variables to tool variables - """ - return ToolRuntimeVariablePool( - **{ - "conversation_id": db_variables.conversation_id, - "user_id": db_variables.user_id, - "tenant_id": db_variables.tenant_id, - "pool": db_variables.variables, - } - ) - def _get_usage_of_all_agent_thoughts( self, model_config: ModelConfigWithCredentialsEntity, message: Message ) -> LLMUsage: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 990efd36c6..28f01e1a19 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,7 +1,7 @@ import logging import os from collections.abc import Callable, Generator, Sequence -from typing import IO, Optional, Union, cast +from typing import IO, Literal, Optional, Union, cast, overload from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration @@ -97,6 +97,42 @@ class ModelInstance: return None + @overload + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: Literal[True] = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: ... + + @overload + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: Literal[False] = False, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: ... + def invoke_llm( self, prompt_messages: list[PromptMessage], diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 49f9bf68ea..548db51a2a 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,72 +1,34 @@ from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional -from pydantic import BaseModel, ConfigDict, Field, field_validator -from pydantic_core.core_schema import ValidationInfo - -from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( - ToolDescription, - ToolIdentity, - ToolInvokeFrom, + ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType, - ToolRuntimeImageVariable, - ToolRuntimeVariable, - ToolRuntimeVariablePool, ) -from core.tools.tool_file_manager import ToolFileManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter if TYPE_CHECKING: from core.file.file_obj import FileVar -class Tool(BaseModel, ABC): - identity: ToolIdentity - parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - is_team_authorization: bool = False +class Tool(ABC): + """ + The base class of a tool + """ - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) + entity: ToolEntity + runtime: ToolRuntime - @field_validator("parameters", mode="before") - @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: - return v or [] + def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + self.entity = entity + self.runtime = runtime - class Runtime(BaseModel): - """ - Meta data of a tool call processing - """ - - def __init__(self, **data: Any): - super().__init__(**data) - if not self.runtime_parameters: - self.runtime_parameters = {} - - tenant_id: Optional[str] = None - tool_id: Optional[str] = None - invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None - credentials: Optional[dict[str, Any]] = None - runtime_parameters: dict[str, Any] = Field(default_factory=dict) - - runtime: Optional[Runtime] = None - variables: Optional[ToolRuntimeVariablePool] = None - - def __init__(self, **data: Any): - super().__init__(**data) - - class VariableKey(Enum): - IMAGE = "image" - - def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": """ fork a new tool with meta data @@ -74,10 +36,8 @@ class Tool(BaseModel, ABC): :return: the new tool """ return self.__class__( - identity=self.identity.model_copy() if self.identity else None, - parameters=self.parameters.copy() if self.parameters else None, - description=self.description.model_copy() if self.description else None, - runtime=Tool.Runtime(**runtime), + entity=self.entity.model_copy(), + runtime=runtime, ) @abstractmethod @@ -88,112 +48,6 @@ class Tool(BaseModel, ABC): :return: the tool provider type """ - def load_variables(self, variables: ToolRuntimeVariablePool): - """ - load variables from database - - :param conversation_id: the conversation id - """ - self.variables = variables - - def set_image_variable(self, variable_name: str, image_key: str) -> None: - """ - set an image variable - """ - if not self.variables: - return - - self.variables.set_file(self.identity.name, variable_name, image_key) - - def set_text_variable(self, variable_name: str, text: str) -> None: - """ - set a text variable - """ - if not self.variables: - return - - self.variables.set_text(self.identity.name, variable_name, text) - - def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: - """ - get a variable - - :param name: the name of the variable - :return: the variable - """ - if not self.variables: - return None - - if isinstance(name, Enum): - name = name.value - - for variable in self.variables.pool: - if variable.name == name: - return variable - - return None - - def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: - """ - get the default image variable - - :return: the image variable - """ - if not self.variables: - return None - - return self.get_variable(self.VariableKey.IMAGE) - - def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: - """ - get a variable file - - :param name: the name of the variable - :return: the variable file - """ - variable = self.get_variable(name) - if not variable: - return None - - if not isinstance(variable, ToolRuntimeImageVariable): - return None - - message_file_id = variable.value - # get file binary - file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) - if not file_binary: - return None - - return file_binary[0] - - def list_variables(self) -> list[ToolRuntimeVariable]: - """ - list all variables - - :return: the variables - """ - if not self.variables: - return [] - - return self.variables.pool - - def list_default_image_variables(self) -> list[ToolRuntimeVariable]: - """ - list all image variables - - :return: the image variables - """ - if not self.variables: - return [] - - result = [] - - for variable in self.variables.pool: - if variable.name.startswith(self.VariableKey.IMAGE.value): - result.append(variable) - - return result - def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -227,7 +81,7 @@ class Tool(BaseModel, ABC): """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials result = deepcopy(tool_parameters) - for parameter in self.parameters or []: + for parameter in self.entity.parameters: if parameter.name in tool_parameters: result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( tool_parameters[parameter.name], parameter.type @@ -241,15 +95,6 @@ class Tool(BaseModel, ABC): ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: - """ - validate the credentials - - :param credentials: the credentials - :param parameters: the parameters - """ - pass - def get_runtime_parameters(self) -> list[ToolParameter]: """ get the runtime parameters @@ -258,7 +103,7 @@ class Tool(BaseModel, ABC): :return: the runtime parameters """ - return self.parameters or [] + return self.entity.parameters def get_all_runtime_parameters(self) -> list[ToolParameter]: """ @@ -266,7 +111,7 @@ class Tool(BaseModel, ABC): :return: all runtime parameters """ - parameters = self.parameters or [] + parameters = self.entity.parameters parameters = parameters.copy() user_parameters = self.get_runtime_parameters() or [] user_parameters = user_parameters.copy() @@ -274,20 +119,16 @@ class Tool(BaseModel, ABC): # override parameters for parameter in user_parameters: # check if parameter in tool parameters - found = False for tool_parameter in parameters: if tool_parameter.name == parameter.name: - found = True + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description break - - if found: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description else: # add new parameter parameters.append(parameter) diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index 7960ed5f84..c71885e48d 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -1,23 +1,22 @@ from abc import ABC, abstractmethod from typing import Any -from pydantic import BaseModel, ConfigDict, Field - from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( - ToolProviderIdentity, + ToolProviderEntity, ToolProviderType, ) from core.tools.errors import ToolProviderCredentialValidationError -class ToolProviderController(BaseModel, ABC): - identity: ToolProviderIdentity - tools: list[Tool] = Field(default_factory=list) - credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) +class ToolProviderController(ABC): + entity: ToolProviderEntity + tools: list[Tool] - model_config = ConfigDict(validate_assignment=True) + def __init__(self, entity: ToolProviderEntity) -> None: + self.entity = entity + self.tools = [] def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ @@ -25,7 +24,7 @@ class ToolProviderController(BaseModel, ABC): :return: the credentials schema """ - return self.credentials_schema.copy() + return self.entity.credentials_schema.copy() @abstractmethod def get_tool(self, tool_name: str) -> Tool: @@ -51,7 +50,7 @@ class ToolProviderController(BaseModel, ABC): :param credentials: the credentials of the tool """ - credentials_schema = self.credentials_schema + credentials_schema = self.entity.credentials_schema if credentials_schema is None: return @@ -62,7 +61,7 @@ class ToolProviderController(BaseModel, ABC): for credential_name in credentials: if credential_name not in credentials_need_to_validate: raise ToolProviderCredentialValidationError( - f"credential {credential_name} not found in provider {self.identity.name}" + f"credential {credential_name} not found in provider {self.entity.identity.name}" ) # check type diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py new file mode 100644 index 0000000000..d4b2ef6104 --- /dev/null +++ b/api/core/tools/__base/tool_runtime.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from openai import BaseModel +from pydantic import Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom + + +class ToolRuntime(BaseModel): + """ + Meta data of a tool call processing + """ + + tenant_id: str + tool_id: Optional[str] = None + invoke_from: Optional[InvokeFrom] = None + tool_invoke_from: Optional[ToolInvokeFrom] = None + credentials: Optional[dict[str, Any]] = None + runtime_parameters: dict[str, Any] = Field(default_factory=dict) + + +class FakeToolRuntime(ToolRuntime): + """ + Fake tool runtime for testing + """ + + def __init__(self): + super().__init__( + tenant_id="fake_tenant_id", + tool_id="fake_tool_id", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credentials={}, + runtime_parameters={}, + ) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 7d1775b7f5..4ebd82f8e7 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -2,13 +2,12 @@ from abc import abstractmethod from os import listdir, path from typing import Any -from pydantic import Field - from core.entities.provider_entities import ProviderConfig from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -17,10 +16,10 @@ from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): - tools: list[BuiltinTool] = Field(default_factory=list) + tools: list[BuiltinTool] def __init__(self, **data: Any) -> None: - if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: + if self.provider_type == ToolProviderType.API: super().__init__(**data) return @@ -37,10 +36,12 @@ class BuiltinToolProviderController(ToolProviderController): for credential_name in provider_yaml["credentials_for_provider"]: provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - super().__init__(**{ - 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {}, - }) + super().__init__( + entity=ToolProviderEntity( + identity=provider_yaml["identity"], + credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {}, + ), + ) def _get_builtin_tools(self) -> list[BuiltinTool]: """ @@ -51,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController): if self.tools: return self.tools - provider = self.identity.name + provider = self.entity.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools") # get all the yaml files in the tool path tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) @@ -62,30 +63,36 @@ class BuiltinToolProviderController(ToolProviderController): tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) # get tool class, import the module - assistant_tool_class = load_single_subclass_from_source( + assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source( module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}", script_path=path.join( - path.dirname(path.realpath(__file__)), - "builtin_tool", "providers", provider, "tools", f"{tool_name}.py" + path.dirname(path.realpath(__file__)), + "builtin_tool", + "providers", + provider, + "tools", + f"{tool_name}.py", ), parent_type=BuiltinTool, ) tool["identity"]["provider"] = provider - tools.append(assistant_tool_class(**tool)) + tools.append(assistant_tool_class( + entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""), + )) self.tools = tools return tools - + def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - if not self.credentials_schema: + if not self.entity.credentials_schema: return {} - return self.credentials_schema.copy() + return self.entity.credentials_schema.copy() def get_tools(self) -> list[BuiltinTool]: """ @@ -94,12 +101,12 @@ class BuiltinToolProviderController(ToolProviderController): :return: list of tools """ return self._get_builtin_tools() - + def get_tool(self, tool_name: str) -> BuiltinTool | None: """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) @property def need_credentials(self) -> bool: @@ -108,7 +115,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.credentials_schema is not None and len(self.credentials_schema) != 0 + return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: @@ -133,8 +140,8 @@ class BuiltinToolProviderController(ToolProviderController): """ returns the labels of the provider """ - return self.identity.tags or [] - + return self.entity.identity.tags or [] + def validate_credentials(self, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/tools/builtin_tool/providers/qrcode/qrcode.py b/api/core/tools/builtin_tool/providers/qrcode/qrcode.py index 542ee7b63e..e792382ee3 100644 --- a/api/core/tools/builtin_tool/providers/qrcode/qrcode.py +++ b/api/core/tools/builtin_tool/providers/qrcode/qrcode.py @@ -1,13 +1,8 @@ from typing import Any from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool -from core.tools.errors import ToolProviderCredentialValidationError class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - try: - QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + pass diff --git a/api/core/tools/builtin_tool/providers/time/time.py b/api/core/tools/builtin_tool/providers/time/time.py index 234ca9d9d6..d70fc22dfc 100644 --- a/api/core/tools/builtin_tool/providers/time/time.py +++ b/api/core/tools/builtin_tool/providers/time/time.py @@ -1,16 +1,8 @@ from typing import Any from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool -from core.tools.errors import ToolProviderCredentialValidationError class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - try: - CurrentTimeTool().invoke( - user_id="", - tool_parameters={}, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + pass diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 243d99dee3..fe77f9ac77 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -32,9 +32,9 @@ class BuiltinTool(Tool): # invoke model return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id or "", + tenant_id=self.runtime.tenant_id, tool_type="builtin", - tool_name=self.identity.name, + tool_name=self.entity.identity.name, prompt_messages=prompt_messages, ) @@ -79,6 +79,7 @@ class BuiltinTool(Tool): stop=[], ) + assert isinstance(summary.message.content, str) return summary.message.content lines = content.split("\n") diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 7ebaa6c5c6..32eda1d9bc 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -7,6 +7,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolProviderEntity, + ToolProviderIdentity, ToolProviderType, ) from extensions.ext_database import db @@ -18,6 +20,11 @@ class ApiToolProviderController(ToolProviderController): tenant_id: str tools: list[ApiTool] = Field(default_factory=list) + def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None: + super().__init__(entity) + self.provider_id = provider_id + self.tenant_id = tenant_id + @staticmethod def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { @@ -64,25 +71,23 @@ class ApiToolProviderController(ToolProviderController): } elif auth_type == ApiProviderAuthType.NONE: pass - else: - raise ValueError(f"invalid auth type {auth_type}") user = db_provider.user user_name = user.name if user else "" return ApiToolProviderController( - **{ - "identity": { - "author": user_name, - "name": db_provider.name, - "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": credentials_schema, - "provider_id": db_provider.id or "", - "tenant_id": db_provider.tenant_id or "", - }, + entity=ToolProviderEntity( + identity=ToolProviderIdentity( + author=user_name, + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=credentials_schema, + ), + provider_id=db_provider.id or "", + tenant_id=db_provider.tenant_id or "", ) @property @@ -103,7 +108,7 @@ class ApiToolProviderController(ToolProviderController): "author": tool_bundle.author, "name": tool_bundle.operation_id, "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, - "icon": self.identity.icon, + "icon": self.entity.identity.icon, "provider": self.provider_id, }, "description": { @@ -141,7 +146,7 @@ class ApiToolProviderController(ToolProviderController): # get tenant api providers db_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .all() ) @@ -149,7 +154,6 @@ class ApiToolProviderController(ToolProviderController): for db_provider in db_providers: for tool in db_provider.tools: assistant_tool = self._parse_tool_bundle(tool) - assistant_tool.is_team_authorization = True tools.append(assistant_tool) self.tools = tools @@ -166,7 +170,7 @@ class ApiToolProviderController(ToolProviderController): self.get_tools(self.tenant_id) for tool in self.tools: - if tool.identity.name == tool_name: + if tool.entity.identity.name == tool_name: return tool raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 9a728bb684..e36c97a2de 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -8,8 +8,9 @@ import httpx from core.helper import ssrf_proxy from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError API_TOOL_DEFAULT_TIMEOUT = ( @@ -25,7 +26,11 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": + def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime): + super().__init__(entity, runtime) + self.api_bundle = api_bundle + + def fork_tool_runtime(self, runtime: ToolRuntime): """ fork a new tool with meta data @@ -33,11 +38,9 @@ class ApiTool(Tool): :return: the new tool """ return self.__class__( - identity=self.identity.model_copy(), - parameters=self.parameters.copy() if self.parameters else [], - description=self.description.model_copy() if self.description else None, + entity=self.entity, api_bundle=self.api_bundle.model_copy(), - runtime=Tool.Runtime(**runtime), + runtime=runtime, ) def validate_credentials( @@ -62,7 +65,7 @@ class ApiTool(Tool): def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: if self.runtime == None: raise ToolProviderCredentialValidationError("runtime not initialized") - + headers = {} credentials = self.runtime.credentials or {} diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index cc84f6eaad..80334a274e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,10 +1,11 @@ import base64 from enum import Enum -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union -from pydantic import BaseModel, Field, field_serializer, field_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope +from core.entities.provider_entities import ProviderConfig from core.tools.entities.common_entities import I18nObject @@ -122,14 +123,14 @@ class ToolInvokeMessage(BaseModel): """ if not isinstance(value, dict | list | str | int | float | bool): raise ValueError("Only basic types and lists are allowed.") - + # if stream is true, the value must be a string - if values.get('stream'): + if values.get("stream"): if not isinstance(value, str): raise ValueError("When 'stream' is True, 'variable_value' must be a string.") return value - + @field_validator("variable_name", mode="before") @classmethod def transform_variable_name(cls, value) -> str: @@ -158,22 +159,20 @@ class ToolInvokeMessage(BaseModel): meta: dict[str, Any] | None = None save_as: str = "" - @field_validator('message', mode='before') + @field_validator("message", mode="before") @classmethod def decode_blob_message(cls, v): - if isinstance(v, dict) and 'blob' in v: + if isinstance(v, dict) and "blob" in v: try: - v['blob'] = base64.b64decode(v['blob']) + v["blob"] = base64.b64decode(v["blob"]) except Exception: pass return v - @field_serializer('message') + @field_serializer("message") def serialize_message(self, v): if isinstance(v, self.BlobMessage): - return { - 'blob': base64.b64encode(v.blob).decode('utf-8') - } + return {"blob": base64.b64encode(v.blob).decode("utf-8")} return v @@ -252,9 +251,9 @@ class ToolParameter(BaseModel): option_objs = [] return cls( name=name, - label=I18nObject(en_US='', zh_Hans=''), + label=I18nObject(en_US="", zh_Hans=""), placeholder=None, - human_description=I18nObject(en_US='', zh_Hans=''), + human_description=I18nObject(en_US="", zh_Hans=""), type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, @@ -275,6 +274,11 @@ class ToolProviderIdentity(BaseModel): ) +class ToolProviderEntity(BaseModel): + identity: ToolProviderIdentity + credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) + + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") @@ -288,131 +292,6 @@ class ToolIdentity(BaseModel): icon: Optional[str] = None -class ToolRuntimeVariableType(Enum): - TEXT = "text" - IMAGE = "image" - - -class ToolRuntimeVariable(BaseModel): - type: ToolRuntimeVariableType = Field(..., description="The type of the variable") - name: str = Field(..., description="The name of the variable") - position: int = Field(..., description="The position of the variable") - tool_name: str = Field(..., description="The name of the tool") - - -class ToolRuntimeTextVariable(ToolRuntimeVariable): - value: str = Field(..., description="The value of the variable") - - -class ToolRuntimeImageVariable(ToolRuntimeVariable): - value: str = Field(..., description="The path of the image") - - -class ToolRuntimeVariablePool(BaseModel): - conversation_id: str = Field(..., description="The conversation id") - user_id: str = Field(..., description="The user id") - tenant_id: str = Field(..., description="The tenant id of assistant") - - pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") - - def __init__(self, **data: Any): - pool = data.get("pool", []) - # convert pool into correct type - for index, variable in enumerate(pool): - if variable["type"] == ToolRuntimeVariableType.TEXT.value: - pool[index] = ToolRuntimeTextVariable(**variable) - elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: - pool[index] = ToolRuntimeImageVariable(**variable) - super().__init__(**data) - - def dict(self) -> dict: - return { - "conversation_id": self.conversation_id, - "user_id": self.user_id, - "tenant_id": self.tenant_id, - "pool": [variable.model_dump() for variable in self.pool], - } - - def set_text(self, tool_name: str, name: str, value: str) -> None: - """ - set a text variable - """ - for variable in self.pool: - if variable.name == name: - if variable.type == ToolRuntimeVariableType.TEXT: - variable = cast(ToolRuntimeTextVariable, variable) - variable.value = value - return - - variable = ToolRuntimeTextVariable( - type=ToolRuntimeVariableType.TEXT, - name=name, - position=len(self.pool), - tool_name=tool_name, - value=value, - ) - - self.pool.append(variable) - - def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None: - """ - set an image variable - - :param tool_name: the name of the tool - :param value: the id of the file - """ - # check how many image variables are there - image_variable_count = 0 - for variable in self.pool: - if variable.type == ToolRuntimeVariableType.IMAGE: - image_variable_count += 1 - - if name is None: - name = f"file_{image_variable_count}" - - for variable in self.pool: - if variable.name == name: - if variable.type == ToolRuntimeVariableType.IMAGE: - variable = cast(ToolRuntimeImageVariable, variable) - variable.value = value - return - - variable = ToolRuntimeImageVariable( - type=ToolRuntimeVariableType.IMAGE, - name=name, - position=len(self.pool), - tool_name=tool_name, - value=value, - ) - - self.pool.append(variable) - - -class ModelToolPropertyKey(Enum): - IMAGE_PARAMETER_NAME = "image_parameter_name" - - -class ModelToolConfiguration(BaseModel): - """ - Model tool configuration - """ - - type: str = Field(..., description="The type of the model tool") - model: str = Field(..., description="The model") - label: I18nObject = Field(..., description="The label of the model tool") - properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") - - -class ModelToolProviderConfiguration(BaseModel): - """ - Model tool provider configuration - """ - - provider: str = Field(..., description="The provider of the model tool") - models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") - label: I18nObject = Field(..., description="The label of the model tool") - - class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration @@ -471,3 +350,17 @@ class ToolInvokeFrom(Enum): WORKFLOW = "workflow" AGENT = "agent" + + +class ToolEntity(BaseModel): + identity: ToolIdentity + parameters: list[ToolParameter] = Field(default_factory=list) + description: Optional[ToolDescription] = None + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + return v or [] diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index cb2f6a899c..d8889917f0 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -65,7 +65,7 @@ class ToolEngine: # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) + agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) messages = ToolEngine._invoke(tool, tool_parameters, user_id) invocation_meta_dict: dict[str, ToolInvokeMeta] = {} @@ -99,7 +99,7 @@ class ToolEngine: # hit the callback handler agent_tool_callback.on_tool_end( - tool_name=tool.identity.name, + tool_name=tool.entity.identity.name, tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, @@ -112,7 +112,7 @@ class ToolEngine: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: - error_response = f"there is not a tool named {tool.identity.name}" + error_response = f"there is not a tool named {tool.entity.identity.name}" agent_tool_callback.on_tool_error(e) except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" @@ -145,7 +145,7 @@ class ToolEngine: """ try: # hit the callback handler - workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) + workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 @@ -158,7 +158,7 @@ class ToolEngine: # hit the callback handler workflow_tool_callback.on_tool_end( - tool_name=tool.identity.name, + tool_name=tool.entity.identity.name, tool_inputs=tool_parameters, tool_outputs=response, ) @@ -177,13 +177,13 @@ class ToolEngine: """ try: # hit the callback handler - callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) + callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) response = tool.invoke(user_id, tool_parameters) # hit the callback handler callback.on_tool_end( - tool_name=tool.identity.name, + tool_name=tool.entity.identity.name, tool_inputs=tool_parameters, tool_outputs=response, ) @@ -208,11 +208,11 @@ class ToolEngine: time_cost=0.0, error=None, tool_config={ - "tool_name": tool.identity.name, - "tool_provider": tool.identity.provider, + "tool_name": tool.entity.identity.name, + "tool_provider": tool.entity.identity.provider, "tool_provider_type": tool.tool_provider_type().value, "tool_parameters": deepcopy(tool.runtime.runtime_parameters), - "tool_icon": tool.identity.icon, + "tool_icon": tool.entity.identity.icon, }, ) try: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0cfcb6d9b9..c37ee730c8 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,6 +6,8 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Union, cast +from core.tools.__base.tool_runtime import ToolRuntime + if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity @@ -105,12 +107,12 @@ class ToolManager: return cast( BuiltinTool, builtin_tool.fork_tool_runtime( - runtime={ - "tenant_id": tenant_id, - "credentials": {}, - "invoke_from": invoke_from, - "tool_invoke_from": tool_invoke_from, - } + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ), ) @@ -134,7 +136,7 @@ class ToolManager: tenant_id=tenant_id, config=controller.get_credentials_schema(), provider_type=controller.provider_type.value, - provider_identity=controller.identity.name, + provider_identity=controller.entity.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) @@ -142,13 +144,13 @@ class ToolManager: return cast( BuiltinTool, builtin_tool.fork_tool_runtime( - runtime={ - "tenant_id": tenant_id, - "credentials": decrypted_credentials, - "runtime_parameters": {}, - "invoke_from": invoke_from, - "tool_invoke_from": tool_invoke_from, - } + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ), ) @@ -163,19 +165,19 @@ class ToolManager: tenant_id=tenant_id, config=api_provider.get_credentials_schema(), provider_type=api_provider.provider_type.value, - provider_identity=api_provider.identity.name, + provider_identity=api_provider.entity.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( - runtime={ - "tenant_id": tenant_id, - "credentials": decrypted_credentials, - "invoke_from": invoke_from, - "tool_invoke_from": tool_invoke_from, - } + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ), ) elif provider_type == ToolProviderType.WORKFLOW: @@ -193,12 +195,12 @@ class ToolManager: return cast( WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime={ - "tenant_id": tenant_id, - "credentials": {}, - "invoke_from": invoke_from, - "tool_invoke_from": tool_invoke_from, - } + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ), ) elif provider_type == ToolProviderType.APP: @@ -336,7 +338,7 @@ class ToolManager: "providers", provider, "_assets", - provider_controller.identity.icon, + provider_controller.entity.identity.icon, ) # check if the icon exists if not path.exists(absolute_path): @@ -389,9 +391,9 @@ class ToolManager: parent_type=BuiltinToolProviderController, ) provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.identity.name] = provider + cls._builtin_providers[provider.entity.identity.name] = provider for tool in provider.get_tools(): - cls._builtin_tools_labels[tool.identity.name] = tool.identity.label + cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label yield provider except Exception as e: @@ -466,11 +468,11 @@ class ToolManager: user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.identity.name), + db_provider=find_db_builtin_provider(provider.entity.identity.name), decrypt_credentials=False, ) - result_providers[provider.identity.name] = user_provider + result_providers[provider.entity.identity.name] = user_provider # get db api providers @@ -589,7 +591,7 @@ class ToolManager: tenant_id=tenant_id, config=controller.get_credentials_schema(), provider_type=controller.provider_type.value, - provider_identity=controller.identity.name, + provider_identity=controller.entity.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 0ab2b0021a..9f685a89b6 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -59,12 +59,11 @@ class ProviderConfigEncrypter(BaseModel): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: if len(data[field_name]) > 6: - data[field_name] = \ - data[field_name][:2] + \ - '*' * (len(data[field_name]) - 4) + \ - data[field_name][-2:] + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) else: - data[field_name] = '*' * len(data[field_name]) + data[field_name] = "*" * len(data[field_name]) return data @@ -75,9 +74,9 @@ class ProviderConfigEncrypter(BaseModel): return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_type}.{self.provider_identity}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() if cached_credentials: @@ -98,14 +97,14 @@ class ProviderConfigEncrypter(BaseModel): def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_type}.{self.provider_identity}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() -class ToolParameterConfigurationManager(BaseModel): +class ToolParameterConfigurationManager: """ Tool parameter configuration manager """ @@ -116,6 +115,15 @@ class ToolParameterConfigurationManager(BaseModel): provider_type: ToolProviderType identity_id: str + def __init__( + self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str + ) -> None: + self.tenant_id = tenant_id + self.tool_runtime = tool_runtime + self.provider_name = provider_name + self.provider_type = provider_type + self.identity_id = identity_id + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: """ deep copy parameters @@ -127,7 +135,7 @@ class ToolParameterConfigurationManager(BaseModel): merge parameters """ # get tool parameters - tool_parameters = self.tool_runtime.parameters or [] + tool_parameters = self.tool_runtime.entity.parameters or [] # get tool runtime parameters runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] # override parameters @@ -203,8 +211,8 @@ class ToolParameterConfigurationManager(BaseModel): """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type.value}.{self.provider_name}', - tool_name=self.tool_runtime.identity.name, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, cache_type=ToolParameterCacheType.PARAMETER, identity_id=self.identity_id, ) @@ -236,8 +244,8 @@ class ToolParameterConfigurationManager(BaseModel): def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type.value}.{self.provider_name}', - tool_name=self.tool_runtime.identity.name, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, cache_type=ToolParameterCacheType.PARAMETER, identity_id=self.identity_id, ) diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index 7b61222722..136491005c 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -6,9 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ToolDescription, + ToolEntity, ToolIdentity, ToolInvokeMessage, ToolParameter, @@ -20,11 +22,15 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas class DatasetRetrieverTool(Tool): retrieval_tool: DatasetRetrieverBaseTool + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: + super().__init__(entity, runtime) + self.retrieval_tool = retrieval_tool + @staticmethod def get_dataset_tools( tenant_id: str, dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, + retrieve_config: DatasetRetrieveConfigEntity | None, return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, @@ -54,7 +60,7 @@ class DatasetRetrieverTool(Tool): ) if retrieval_tools is None or len(retrieval_tools) == 0: return [] - + # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -63,13 +69,14 @@ class DatasetRetrieverTool(Tool): for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( retrieval_tool=retrieval_tool, - identity=ToolIdentity( - provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + entity=ToolEntity( + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), + parameters=[], + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), ), - parameters=[], - is_team_authorization=True, - description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), - runtime=DatasetRetrieverTool.Runtime(), + runtime=ToolRuntime(tenant_id=tenant_id), ) tools.append(tool) @@ -99,7 +106,7 @@ class DatasetRetrieverTool(Tool): """ query = tool_parameters.get("query") if not query: - yield self.create_text_message(text='please input query') + yield self.create_text_message(text="please input query") else: # invoke dataset retriever tool result = self.retrieval_tool._run(query=query) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index cab5f84506..2d0d33ffd9 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -6,9 +6,11 @@ from pydantic import Field from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ToolDescription, + ToolEntity, ToolIdentity, ToolParameter, ToolParameterOption, @@ -63,7 +65,7 @@ class WorkflowToolProviderController(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ get db provider tool @@ -140,19 +142,23 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("variable not found") return WorkflowTool( - identity=ToolIdentity( - author=user.name if user else "", - name=db_provider.name, - label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), - provider=self.provider_id, - icon=db_provider.icon, + entity=ToolEntity( + identity=ToolIdentity( + author=user.name if user else "", + name=db_provider.name, + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), + provider=self.provider_id, + icon=db_provider.icon, + ), + description=ToolDescription( + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + llm=db_provider.description, + ), + parameters=workflow_tool_parameters, ), - description=ToolDescription( - human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), - llm=db_provider.description, + runtime=ToolRuntime( + tenant_id=db_provider.tenant_id, ), - parameters=workflow_tool_parameters, - is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ "app": app, @@ -201,7 +207,7 @@ class WorkflowToolProviderController(ToolProviderController): return None for tool in self.tools: - if tool.identity.name == tool_name: + if tool.entity.identity.name == tool_name: return tool return None diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 72aae2796c..e1fc5140d0 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,12 +1,12 @@ import json import logging from collections.abc import Generator -from copy import deepcopy from typing import Any, Optional, Union from core.file.file_obj import FileTransferMethod, FileVar from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType from extensions.ext_database import db from models.account import Account from models.model import App, EndUser @@ -28,6 +28,26 @@ class WorkflowTool(Tool): Workflow tool. """ + def __init__( + self, + workflow_app_id: str, + version: str, + workflow_entities: dict[str, Any], + workflow_call_depth: int, + entity: ToolEntity, + runtime: ToolRuntime, + label: str = "Workflow", + thread_pool_id: Optional[str] = None, + ): + self.workflow_app_id = workflow_app_id + self.version = version + self.workflow_entities = workflow_entities + self.workflow_call_depth = workflow_call_depth + self.thread_pool_id = thread_pool_id + self.label = label + + super().__init__(entity=entity, runtime=runtime) + def tool_provider_type(self) -> ToolProviderType: """ get the tool provider type @@ -94,7 +114,7 @@ class WorkflowTool(Tool): return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with meta data @@ -102,10 +122,8 @@ class WorkflowTool(Tool): :return: the new tool """ return self.__class__( - identity=deepcopy(self.identity), - parameters=deepcopy(self.parameters), - description=deepcopy(self.description), - runtime=Tool.Runtime(**runtime), + entity=self.entity.model_copy(), + runtime=runtime, workflow_app_id=self.workflow_app_id, workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, diff --git a/api/core/tools/workflow_as_tool/workflow_tool_provider.py b/api/core/tools/workflow_as_tool/workflow_tool_provider.py deleted file mode 100644 index cab5f84506..0000000000 --- a/api/core/tools/workflow_as_tool/workflow_tool_provider.py +++ /dev/null @@ -1,207 +0,0 @@ -from collections.abc import Mapping -from typing import Optional - -from pydantic import Field - -from core.app.app_config.entities import VariableEntity, VariableEntityType -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ToolDescription, - ToolIdentity, - ToolParameter, - ToolParameterOption, - ToolProviderType, -) -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from core.tools.workflow_as_tool.tool import WorkflowTool -from extensions.ext_database import db -from models.model import App, AppMode -from models.tools import WorkflowToolProvider -from models.workflow import Workflow - -VARIABLE_TO_PARAMETER_TYPE_MAPPING = { - VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, - VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, - VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, - VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, -} - - -class WorkflowToolProviderController(ToolProviderController): - provider_id: str - tools: list[WorkflowTool] = Field(default_factory=list) - - @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": - app = db_provider.app - - if not app: - raise ValueError("app not found") - - controller = WorkflowToolProviderController( - **{ - "identity": { - "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", - "name": db_provider.label, - "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": {}, - "provider_id": db_provider.id or "", - } - ) - - # init tools - - controller.tools = [controller._get_db_provider_tool(db_provider, app)] - - return controller - - @property - def provider_type(self) -> ToolProviderType: - return ToolProviderType.WORKFLOW - - def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: - """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool - """ - workflow: Workflow | None = db.session.query(Workflow).filter( - Workflow.app_id == db_provider.app_id, - Workflow.version == db_provider.version - ).first() - - if not workflow: - raise ValueError("workflow not found") - - # fetch start node - graph: Mapping = workflow.graph_dict - features_dict: Mapping = workflow.features_dict - features = WorkflowAppConfigManager.convert_features( - config_dict=features_dict, - app_mode=AppMode.WORKFLOW - ) - - parameters = db_provider.parameter_configurations - variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) - - def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: - return next(filter(lambda x: x.variable == variable_name, variables), None) - - user = db_provider.user - - workflow_tool_parameters = [] - for parameter in parameters: - variable = fetch_workflow_variable(parameter.name) - if variable: - parameter_type = None - options = [] - if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: - raise ValueError(f"unsupported variable type {variable.type}") - parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] - - if variable.type == VariableEntityType.SELECT and variable.options: - options = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) - for option in variable.options - ] - - workflow_tool_parameters.append( - ToolParameter( - name=parameter.name, - label=I18nObject(en_US=variable.label, zh_Hans=variable.label), - human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), - type=parameter_type, - form=parameter.form, - llm_description=parameter.description, - required=variable.required, - options=options, - default=variable.default, - ) - ) - elif features.file_upload: - workflow_tool_parameters.append( - ToolParameter( - name=parameter.name, - label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), - human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), - type=ToolParameter.ToolParameterType.FILE, - llm_description=parameter.description, - required=False, - form=parameter.form, - ) - ) - else: - raise ValueError("variable not found") - - return WorkflowTool( - identity=ToolIdentity( - author=user.name if user else "", - name=db_provider.name, - label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), - provider=self.provider_id, - icon=db_provider.icon, - ), - description=ToolDescription( - human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), - llm=db_provider.description, - ), - parameters=workflow_tool_parameters, - is_team_authorization=True, - workflow_app_id=app.id, - workflow_entities={ - "app": app, - "workflow": workflow, - }, - version=db_provider.version, - workflow_call_depth=0, - label=db_provider.label, - ) - - def get_tools(self, tenant_id: str) -> list[WorkflowTool]: - """ - fetch tools from database - - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools - """ - if self.tools is not None: - return self.tools - - db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, - ).first() - - if not db_providers: - return [] - - app = db_providers.app - if not app: - raise ValueError("can not read app of workflow") - - self.tools = [self._get_db_provider_tool(db_providers, app)] - - return self.tools - - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: - """ - get tool by name - - :param tool_name: the name of the tool - :return: the tool - """ - if self.tools is None: - return None - - for tool in self.tools: - if tool.identity.name == tool_name: - return tool - - return None diff --git a/api/models/model.py b/api/models/model.py index f4e7686849..26b42963cc 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1304,7 +1304,7 @@ class MessageChain(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class MessageAgentThought(db.Model): +class MessageAgentThought(Base): __tablename__ = "message_agent_thoughts" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 11aa3ba529..e39c2b8a5b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -5,6 +5,7 @@ from httpx import get from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject @@ -160,7 +161,7 @@ class ApiToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name + provider_identity=provider_controller.entity.identity.name ) encrypted_credentials = tool_configuration.encrypt(credentials) @@ -222,6 +223,7 @@ class ApiToolManageService: return [ ToolTransformService.tool_to_user_tool( tool_bundle, + tenant_id=tenant_id, labels=labels, ) for tool_bundle in provider.tools @@ -291,7 +293,7 @@ class ApiToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name + provider_identity=provider_controller.entity.identity.name ) original_credentials = tool_configuration.decrypt(provider.credentials) @@ -410,7 +412,7 @@ class ApiToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name + provider_identity=provider_controller.entity.identity.name ) decrypted_credentials = tool_configuration.decrypt(credentials) # check if the credential has changed, save the original credential @@ -424,10 +426,10 @@ class ApiToolManageService: # get tool tool = provider_controller.get_tool(tool_name) tool = tool.fork_tool_runtime( - runtime={ - "credentials": credentials, - "tenant_id": tenant_id, - } + runtime=ToolRuntime( + credentials=credentials, + tenant_id=tenant_id, + ) ) result = tool.validate_credentials(credentials, parameters) except Exception as e: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6db8718b6b..83b363bb58 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -32,7 +32,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name, + provider_identity=provider_controller.entity.identity.name, ) # check if user has added the provider builtin_provider: BuiltinToolProvider | None = ( @@ -71,7 +71,7 @@ class BuiltinToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) + return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()]) @staticmethod def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): @@ -97,7 +97,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name, + provider_identity=provider_controller.entity.identity.name, ) # get original credentials if exists @@ -159,7 +159,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name, + provider_identity=provider_controller.entity.identity.name, ) credentials = tool_configuration.decrypt(provider_obj.credentials) credentials = tool_configuration.mask_tool_credentials(credentials) @@ -191,7 +191,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name, + provider_identity=provider_controller.entity.identity.name, ) tool_configuration.delete_tool_credentials_cache() @@ -241,7 +241,7 @@ class BuiltinToolManageService: # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, - db_provider=find_provider(provider_controller.identity.name), + db_provider=find_provider(provider_controller.entity.identity.name), decrypt_credentials=True, ) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b7488621c6..d4f132e902 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -4,6 +4,7 @@ from typing import Optional, Union from configs import dify_config from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.entities.api_entities import UserTool, UserToolProvider @@ -69,19 +70,19 @@ class ToolTransformService: convert provider controller to user provider """ result = UserToolProvider( - id=provider_controller.identity.name, - author=provider_controller.identity.author, - name=provider_controller.identity.name, + id=provider_controller.entity.identity.name, + author=provider_controller.entity.identity.author, + name=provider_controller.entity.identity.name, description=I18nObject( - en_US=provider_controller.identity.description.en_US, - zh_Hans=provider_controller.identity.description.zh_Hans, - pt_BR=provider_controller.identity.description.pt_BR, + en_US=provider_controller.entity.identity.description.en_US, + zh_Hans=provider_controller.entity.identity.description.zh_Hans, + pt_BR=provider_controller.entity.identity.description.pt_BR, ), - icon=provider_controller.identity.icon, + icon=provider_controller.entity.identity.icon, label=I18nObject( - en_US=provider_controller.identity.label.en_US, - zh_Hans=provider_controller.identity.label.zh_Hans, - pt_BR=provider_controller.identity.label.pt_BR, + en_US=provider_controller.entity.identity.label.en_US, + zh_Hans=provider_controller.entity.identity.label.zh_Hans, + pt_BR=provider_controller.entity.identity.label.pt_BR, ), type=ToolProviderType.BUILT_IN, masked_credentials={}, @@ -111,7 +112,7 @@ class ToolTransformService: tenant_id=db_provider.tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name + provider_identity=provider_controller.entity.identity.name, ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt(data=credentials) @@ -155,16 +156,16 @@ class ToolTransformService: """ return UserToolProvider( id=provider_controller.provider_id, - author=provider_controller.identity.author, - name=provider_controller.identity.name, + author=provider_controller.entity.identity.author, + name=provider_controller.entity.identity.name, description=I18nObject( - en_US=provider_controller.identity.description.en_US, - zh_Hans=provider_controller.identity.description.zh_Hans, + en_US=provider_controller.entity.identity.description.en_US, + zh_Hans=provider_controller.entity.identity.description.zh_Hans, ), - icon=provider_controller.identity.icon, + icon=provider_controller.entity.identity.icon, label=I18nObject( - en_US=provider_controller.identity.label.en_US, - zh_Hans=provider_controller.identity.label.zh_Hans, + en_US=provider_controller.entity.identity.label.en_US, + zh_Hans=provider_controller.entity.identity.label.zh_Hans, ), type=ToolProviderType.WORKFLOW, masked_credentials={}, @@ -189,7 +190,7 @@ class ToolTransformService: user = db_provider.user if not user: raise ValueError("user not found") - + username = user.name except Exception as e: logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") @@ -222,7 +223,7 @@ class ToolTransformService: tenant_id=db_provider.tenant_id, config=provider_controller.get_credentials_schema(), provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.identity.name + provider_identity=provider_controller.entity.identity.name, ) # decrypt the credentials and mask the credentials @@ -236,8 +237,8 @@ class ToolTransformService: @staticmethod def tool_to_user_tool( tool: Union[ApiToolBundle, WorkflowTool, Tool], + tenant_id: str, credentials: dict | None = None, - tenant_id: str | None = None, labels: list[str] | None = None, ) -> UserTool: """ @@ -246,14 +247,14 @@ class ToolTransformService: if isinstance(tool, Tool): # fork tool runtime tool = tool.fork_tool_runtime( - runtime={ - "credentials": credentials, - "tenant_id": tenant_id, - } + runtime=ToolRuntime( + credentials=credentials, + tenant_id=tenant_id, + ) ) # get tool parameters - parameters = tool.parameters or [] + parameters = tool.entity.parameters or [] # get tool runtime parameters runtime_parameters = tool.get_runtime_parameters() or [] # override parameters @@ -270,10 +271,10 @@ class ToolTransformService: current_parameters.append(runtime_parameter) return UserTool( - author=tool.identity.author, - name=tool.identity.name, - label=tool.identity.label, - description=tool.description.human if tool.description else I18nObject(en_US=''), + author=tool.entity.identity.author, + name=tool.entity.identity.name, + label=tool.entity.identity.label, + description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""), parameters=current_parameters, labels=labels or [], ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 60e26aa282..58bf7946bf 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -211,7 +211,9 @@ class WorkflowToolManageService: ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) + tool=tool.get_tools(user_id, tenant_id)[0], + labels=labels.get(tool.provider_id, []), + tenant_id=tenant_id, ) ] result.append(user_tool_provider) @@ -248,7 +250,7 @@ class WorkflowToolManageService: .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) - return cls._get_workflow_tool(db_tool) + return cls._get_workflow_tool(tenant_id, db_tool) @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: @@ -264,10 +266,10 @@ class WorkflowToolManageService: .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() ) - return cls._get_workflow_tool(db_tool) + return cls._get_workflow_tool(tenant_id, db_tool) @classmethod - def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None): + def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. :db_tool: the database tool @@ -298,7 +300,9 @@ class WorkflowToolManageService: "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + tool=tool.get_tools(db_tool.tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool), + tenant_id=tenant_id, ), "synced": workflow.version == db_tool.version, "privacy_policy": db_tool.privacy_policy, @@ -326,6 +330,8 @@ class WorkflowToolManageService: return [ ToolTransformService.tool_to_user_tool( - tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + tool=tool.get_tools(db_tool.tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool), + tenant_id=tenant_id, ) ] diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index e4798e02c3..b945d3fef7 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -1,5 +1,9 @@ from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity from tests.integration_tests.tools.__mock.http import setup_http_mock tool_bundle = { @@ -29,7 +33,13 @@ parameters = { def test_api_tool(setup_http_mock): - tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) + tool = ApiTool( + entity=ToolEntity( + identity=ToolIdentity(provider="", author="", name="", label=I18nObject()), + ), + api_bundle=ApiToolBundle(**tool_bundle), + runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}), + ) headers = tool.assembling_request(parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)