refactor: tool
This commit is contained in:
parent
3c1d32e3ac
commit
91cb80f795
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -23,6 +22,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -31,18 +31,15 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,11 +56,9 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
@ -93,8 +88,6 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@ -162,11 +155,10 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
assert tool_entity.entity.description
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -201,9 +193,11 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -232,7 +226,7 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_instances = {}
|
||||
prompt_messages_tools = []
|
||||
|
||||
for tool in self.app_config.agent.tools if self.app_config.agent else []:
|
||||
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@ -249,7 +243,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@ -328,25 +322,29 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str | None,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
updated_agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
updated_agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
@ -355,7 +353,7 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
if isinstance(observation, dict):
|
||||
@ -364,27 +362,27 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
updated_agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
updated_agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -395,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
@ -404,28 +402,11 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@ -515,6 +496,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
files = message.message_files
|
||||
if files:
|
||||
assert message.app_model_config
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
@ -525,7 +507,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -26,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@ -41,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@ -53,9 +55,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
assert app_config.prompt_template.simple_prompt_template
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
@ -63,13 +67,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -115,10 +120,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
@ -139,11 +140,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
@ -152,6 +156,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
@ -168,7 +173,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
@ -248,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -266,7 +270,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@ -307,15 +311,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
message_file_ids.append(message_file_id.id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
@ -369,18 +370,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@ -400,6 +402,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
|
@ -4,6 +4,7 @@ from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@ -16,6 +17,9 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
@ -27,12 +31,12 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
@ -57,8 +61,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
message_file_ids.append(message_file_id.id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage]
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
|
@ -16,10 +16,8 @@ from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -174,14 +172,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
agent_entity = app_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
@ -234,8 +224,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -253,50 +241,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import IO, Optional, Union, cast
|
||||
from typing import IO, Literal, Optional, Union, cast, overload
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
@ -97,6 +97,42 @@ class ModelInstance:
|
||||
|
||||
return None
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
|
@ -1,72 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeFrom,
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
ToolRuntimeImageVariable,
|
||||
ToolRuntimeVariable,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
is_team_authorization: bool = False
|
||||
class Tool(ABC):
|
||||
"""
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if not self.runtime_parameters:
|
||||
self.runtime_parameters = {}
|
||||
|
||||
tenant_id: Optional[str] = None
|
||||
tool_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: Optional[dict[str, Any]] = None
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
runtime: Optional[Runtime] = None
|
||||
variables: Optional[ToolRuntimeVariablePool] = None
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
class VariableKey(Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -74,10 +36,8 @@ class Tool(BaseModel, ABC):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -88,112 +48,6 @@ class Tool(BaseModel, ABC):
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
|
||||
:param conversation_id: the conversation id
|
||||
"""
|
||||
self.variables = variables
|
||||
|
||||
def set_image_variable(self, variable_name: str, image_key: str) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||
|
||||
def set_text_variable(self, variable_name: str, text: str) -> None:
|
||||
"""
|
||||
set a text variable
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
self.variables.set_text(self.identity.name, variable_name, text)
|
||||
|
||||
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
|
||||
"""
|
||||
get a variable
|
||||
|
||||
:param name: the name of the variable
|
||||
:return: the variable
|
||||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
if isinstance(name, Enum):
|
||||
name = name.value
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name == name:
|
||||
return variable
|
||||
|
||||
return None
|
||||
|
||||
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
|
||||
"""
|
||||
get the default image variable
|
||||
|
||||
:return: the image variable
|
||||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
return self.get_variable(self.VariableKey.IMAGE)
|
||||
|
||||
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
||||
"""
|
||||
get a variable file
|
||||
|
||||
:param name: the name of the variable
|
||||
:return: the variable file
|
||||
"""
|
||||
variable = self.get_variable(name)
|
||||
if not variable:
|
||||
return None
|
||||
|
||||
if not isinstance(variable, ToolRuntimeImageVariable):
|
||||
return None
|
||||
|
||||
message_file_id = variable.value
|
||||
# get file binary
|
||||
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
|
||||
if not file_binary:
|
||||
return None
|
||||
|
||||
return file_binary[0]
|
||||
|
||||
def list_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all variables
|
||||
|
||||
:return: the variables
|
||||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
return self.variables.pool
|
||||
|
||||
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all image variables
|
||||
|
||||
:return: the image variables
|
||||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
result = []
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name.startswith(self.VariableKey.IMAGE.value):
|
||||
result.append(variable)
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
@ -227,7 +81,7 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.parameters or []:
|
||||
for parameter in self.entity.parameters:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
|
||||
tool_parameters[parameter.name], parameter.type
|
||||
@ -241,15 +95,6 @@ class Tool(BaseModel, ABC):
|
||||
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials
|
||||
|
||||
:param credentials: the credentials
|
||||
:param parameters: the parameters
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
@ -258,7 +103,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.parameters or []
|
||||
return self.entity.parameters
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
@ -266,7 +111,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
:return: all runtime parameters
|
||||
"""
|
||||
parameters = self.parameters or []
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
@ -274,20 +119,16 @@ class Tool(BaseModel, ABC):
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
@ -1,23 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderIdentity,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(BaseModel, ABC):
|
||||
identity: ToolProviderIdentity
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
tools: list[Tool]
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True)
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
self.tools = []
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
@ -25,7 +24,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return self.credentials_schema.copy()
|
||||
return self.entity.credentials_schema.copy()
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
@ -51,7 +50,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = self.credentials_schema
|
||||
credentials_schema = self.entity.credentials_schema
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
@ -62,7 +61,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.identity.name}"
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
|
36
api/core/tools/__base/tool_runtime.py
Normal file
36
api/core/tools/__base/tool_runtime.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: Optional[dict[str, Any]] = None
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeToolRuntime(ToolRuntime):
|
||||
"""
|
||||
Fake tool runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
tool_id="fake_tool_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
@ -2,13 +2,12 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
@ -17,10 +16,10 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool] = Field(default_factory=list)
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
|
||||
if self.provider_type == ToolProviderType.API:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
@ -37,10 +36,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for credential_name in provider_yaml["credentials_for_provider"]:
|
||||
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
|
||||
|
||||
super().__init__(**{
|
||||
'identity': provider_yaml['identity'],
|
||||
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
|
||||
})
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
|
||||
),
|
||||
)
|
||||
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@ -51,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
if self.tools:
|
||||
return self.tools
|
||||
|
||||
provider = self.identity.name
|
||||
provider = self.entity.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
@ -62,30 +63,36 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py"
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider,
|
||||
"tools",
|
||||
f"{tool_name}.py",
|
||||
),
|
||||
parent_type=BuiltinTool,
|
||||
)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
tools.append(assistant_tool_class(
|
||||
entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""),
|
||||
))
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.credentials_schema:
|
||||
if not self.entity.credentials_schema:
|
||||
return {}
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
return self.entity.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@ -94,12 +101,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
:return: list of tools
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@ -108,7 +115,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.credentials_schema is not None and len(self.credentials_schema) != 0
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@ -133,8 +140,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return self.identity.tags or []
|
||||
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
@ -1,13 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class QRCodeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
pass
|
||||
|
@ -1,16 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CurrentTimeTool().invoke(
|
||||
user_id="",
|
||||
tool_parameters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
pass
|
||||
|
@ -32,9 +32,9 @@ class BuiltinTool(Tool):
|
||||
# invoke model
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -79,6 +79,7 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
assert isinstance(summary.message.content, str)
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
|
@ -7,6 +7,8 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
@ -18,6 +20,11 @@ class ApiToolProviderController(ToolProviderController):
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = {
|
||||
@ -64,25 +71,23 @@ class ApiToolProviderController(ToolProviderController):
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user = db_provider.user
|
||||
user_name = user.name if user else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
"tenant_id": db_provider.tenant_id or "",
|
||||
},
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
)
|
||||
|
||||
@property
|
||||
@ -103,7 +108,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"icon": self.entity.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
"description": {
|
||||
@ -141,7 +146,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -149,7 +154,6 @@ class ApiToolProviderController(ToolProviderController):
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
@ -166,7 +170,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
|
@ -8,8 +8,9 @@ import httpx
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
@ -25,7 +26,11 @@ class ApiTool(Tool):
|
||||
Api tool
|
||||
"""
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime):
|
||||
super().__init__(entity, runtime)
|
||||
self.api_bundle = api_bundle
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime):
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -33,11 +38,9 @@ class ApiTool(Tool):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy(),
|
||||
parameters=self.parameters.copy() if self.parameters else [],
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
entity=self.entity,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
@ -62,7 +65,7 @@ class ApiTool(Tool):
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.runtime == None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
|
||||
headers = {}
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
@ -122,14 +123,14 @@ class ToolInvokeMessage(BaseModel):
|
||||
"""
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get('stream'):
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value) -> str:
|
||||
@ -158,22 +159,20 @@ class ToolInvokeMessage(BaseModel):
|
||||
meta: dict[str, Any] | None = None
|
||||
save_as: str = ""
|
||||
|
||||
@field_validator('message', mode='before')
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and 'blob' in v:
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
v['blob'] = base64.b64decode(v['blob'])
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer('message')
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {
|
||||
'blob': base64.b64encode(v.blob).decode('utf-8')
|
||||
}
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
@ -252,9 +251,9 @@ class ToolParameter(BaseModel):
|
||||
option_objs = []
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=type,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
@ -275,6 +274,11 @@ class ToolProviderIdentity(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
@ -288,131 +292,6 @@ class ToolIdentity(BaseModel):
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class ToolRuntimeVariableType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
class ToolRuntimeVariable(BaseModel):
|
||||
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
|
||||
name: str = Field(..., description="The name of the variable")
|
||||
position: int = Field(..., description="The position of the variable")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
|
||||
|
||||
class ToolRuntimeTextVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The value of the variable")
|
||||
|
||||
|
||||
class ToolRuntimeImageVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The path of the image")
|
||||
|
||||
|
||||
class ToolRuntimeVariablePool(BaseModel):
|
||||
conversation_id: str = Field(..., description="The conversation id")
|
||||
user_id: str = Field(..., description="The user id")
|
||||
tenant_id: str = Field(..., description="The tenant id of assistant")
|
||||
|
||||
pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
pool = data.get("pool", [])
|
||||
# convert pool into correct type
|
||||
for index, variable in enumerate(pool):
|
||||
if variable["type"] == ToolRuntimeVariableType.TEXT.value:
|
||||
pool[index] = ToolRuntimeTextVariable(**variable)
|
||||
elif variable["type"] == ToolRuntimeVariableType.IMAGE.value:
|
||||
pool[index] = ToolRuntimeImageVariable(**variable)
|
||||
super().__init__(**data)
|
||||
|
||||
def dict(self) -> dict:
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"user_id": self.user_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"pool": [variable.model_dump() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
"""
|
||||
set a text variable
|
||||
"""
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.TEXT:
|
||||
variable = cast(ToolRuntimeTextVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeTextVariable(
|
||||
type=ToolRuntimeVariableType.TEXT,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:param value: the id of the file
|
||||
"""
|
||||
# check how many image variables are there
|
||||
image_variable_count = 0
|
||||
for variable in self.pool:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
image_variable_count += 1
|
||||
|
||||
if name is None:
|
||||
name = f"file_{image_variable_count}"
|
||||
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
variable = cast(ToolRuntimeImageVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeImageVariable(
|
||||
type=ToolRuntimeVariableType.IMAGE,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
|
||||
class ModelToolPropertyKey(Enum):
|
||||
IMAGE_PARAMETER_NAME = "image_parameter_name"
|
||||
|
||||
|
||||
class ModelToolConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool configuration
|
||||
"""
|
||||
|
||||
type: str = Field(..., description="The type of the model tool")
|
||||
model: str = Field(..., description="The model")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
|
||||
|
||||
|
||||
class ModelToolProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool provider configuration
|
||||
"""
|
||||
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
@ -471,3 +350,17 @@ class ToolInvokeFrom(Enum):
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
@ -65,7 +65,7 @@ class ToolEngine:
|
||||
# invoke the tool
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
@ -99,7 +99,7 @@ class ToolEngine:
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text,
|
||||
message_id=message.id,
|
||||
@ -112,7 +112,7 @@ class ToolEngine:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
||||
error_response = f"there is not a tool named {tool.identity.name}"
|
||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolParameterValidationError as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
@ -145,7 +145,7 @@ class ToolEngine:
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
@ -158,7 +158,7 @@ class ToolEngine:
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
@ -177,13 +177,13 @@ class ToolEngine:
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||
callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
@ -208,11 +208,11 @@ class ToolEngine:
|
||||
time_cost=0.0,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_name": tool.identity.name,
|
||||
"tool_provider": tool.identity.provider,
|
||||
"tool_name": tool.entity.identity.name,
|
||||
"tool_provider": tool.entity.identity.provider,
|
||||
"tool_provider_type": tool.tool_provider_type().value,
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
||||
"tool_icon": tool.identity.icon,
|
||||
"tool_icon": tool.entity.identity.icon,
|
||||
},
|
||||
)
|
||||
try:
|
||||
|
@ -6,6 +6,8 @@ from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
@ -105,12 +107,12 @@ class ToolManager:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": {},
|
||||
"invoke_from": invoke_from,
|
||||
"tool_invoke_from": tool_invoke_from,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -134,7 +136,7 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.identity.name,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
@ -142,13 +144,13 @@ class ToolManager:
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": decrypted_credentials,
|
||||
"runtime_parameters": {},
|
||||
"invoke_from": invoke_from,
|
||||
"tool_invoke_from": tool_invoke_from,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -163,19 +165,19 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
config=api_provider.get_credentials_schema(),
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.identity.name,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
ApiTool,
|
||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": decrypted_credentials,
|
||||
"invoke_from": invoke_from,
|
||||
"tool_invoke_from": tool_invoke_from,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
@ -193,12 +195,12 @@ class ToolManager:
|
||||
return cast(
|
||||
WorkflowTool,
|
||||
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": {},
|
||||
"invoke_from": invoke_from,
|
||||
"tool_invoke_from": tool_invoke_from,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
@ -336,7 +338,7 @@ class ToolManager:
|
||||
"providers",
|
||||
provider,
|
||||
"_assets",
|
||||
provider_controller.identity.icon,
|
||||
provider_controller.entity.identity.icon,
|
||||
)
|
||||
# check if the icon exists
|
||||
if not path.exists(absolute_path):
|
||||
@ -389,9 +391,9 @@ class ToolManager:
|
||||
parent_type=BuiltinToolProviderController,
|
||||
)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
cls._builtin_providers[provider.entity.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
|
||||
yield provider
|
||||
|
||||
except Exception as e:
|
||||
@ -466,11 +468,11 @@ class ToolManager:
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
result_providers[provider.entity.identity.name] = user_provider
|
||||
|
||||
# get db api providers
|
||||
|
||||
@ -589,7 +591,7 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.identity.name,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
@ -59,12 +59,11 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = \
|
||||
data[field_name][:2] + \
|
||||
'*' * (len(data[field_name]) - 4) + \
|
||||
data[field_name][-2:]
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = '*' * len(data[field_name])
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
@ -75,9 +74,9 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_type}.{self.provider_identity}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
@ -98,14 +97,14 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_type}.{self.provider_identity}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager(BaseModel):
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
@ -116,6 +115,15 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
@ -127,7 +135,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.parameters or []
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
@ -203,8 +211,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type.value}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
@ -236,8 +244,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type.value}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
|
@ -6,9 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
@ -20,11 +22,15 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
@ -54,7 +60,7 @@ class DatasetRetrieverTool(Tool):
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
@ -63,13 +69,14 @@ class DatasetRetrieverTool(Tool):
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrieval_tool=retrieval_tool,
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
),
|
||||
parameters=[],
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime(),
|
||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -99,7 +106,7 @@ class DatasetRetrieverTool(Tool):
|
||||
"""
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
yield self.create_text_message(text='please input query')
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
|
@ -6,9 +6,11 @@ from pydantic import Field
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolParameterOption,
|
||||
@ -63,7 +65,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
@ -140,19 +142,23 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
raise ValueError("variable not found")
|
||||
|
||||
return WorkflowTool(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
llm=db_provider.description,
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
is_team_authorization=True,
|
||||
workflow_app_id=app.id,
|
||||
workflow_entities={
|
||||
"app": app,
|
||||
@ -201,7 +207,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
return None
|
||||
|
@ -1,12 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
@ -28,6 +28,26 @@ class WorkflowTool(Tool):
|
||||
Workflow tool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_app_id: str,
|
||||
version: str,
|
||||
workflow_entities: dict[str, Any],
|
||||
workflow_call_depth: int,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
label: str = "Workflow",
|
||||
thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
self.workflow_app_id = workflow_app_id
|
||||
self.version = version
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.label = label
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
@ -94,7 +114,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return user
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -102,10 +122,8 @@ class WorkflowTool(Tool):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=deepcopy(self.identity),
|
||||
parameters=deepcopy(self.parameters),
|
||||
description=deepcopy(self.description),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
workflow_entities=self.workflow_entities,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
|
@ -1,207 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolParameterOption,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tools: list[WorkflowTool] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
app = db_provider.app
|
||||
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||
"name": db_provider.label,
|
||||
"label": {"en_US": db_provider.label, "zh_Hans": db_provider.label},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": {},
|
||||
"provider_id": db_provider.id or "",
|
||||
}
|
||||
)
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
|
||||
return controller
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == db_provider.app_id,
|
||||
Workflow.version == db_provider.version
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found")
|
||||
|
||||
# fetch start node
|
||||
graph: Mapping = workflow.graph_dict
|
||||
features_dict: Mapping = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(
|
||||
config_dict=features_dict,
|
||||
app_mode=AppMode.WORKFLOW
|
||||
)
|
||||
|
||||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
variable = fetch_workflow_variable(parameter.name)
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = []
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f"unsupported variable type {variable.type}")
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in variable.options
|
||||
]
|
||||
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(en_US=variable.label, zh_Hans=variable.label),
|
||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||
type=parameter_type,
|
||||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
default=variable.default,
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||
type=ToolParameter.ToolParameterType.FILE,
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
return WorkflowTool(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
is_team_authorization=True,
|
||||
workflow_app_id=app.id,
|
||||
workflow_entities={
|
||||
"app": app,
|
||||
"workflow": workflow,
|
||||
},
|
||||
version=db_provider.version,
|
||||
workflow_call_depth=0,
|
||||
label=db_provider.label,
|
||||
)
|
||||
|
||||
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
).first()
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
app = db_providers.app
|
||||
if not app:
|
||||
raise ValueError("can not read app of workflow")
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
return None
|
@ -1304,7 +1304,7 @@ class MessageChain(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(db.Model):
|
||||
class MessageAgentThought(Base):
|
||||
__tablename__ = "message_agent_thoughts"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
|
@ -5,6 +5,7 @@ from httpx import get
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -160,7 +161,7 @@ class ApiToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
provider_identity=provider_controller.entity.identity.name
|
||||
)
|
||||
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
@ -222,6 +223,7 @@ class ApiToolManageService:
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool_bundle,
|
||||
tenant_id=tenant_id,
|
||||
labels=labels,
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
@ -291,7 +293,7 @@ class ApiToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
provider_identity=provider_controller.entity.identity.name
|
||||
)
|
||||
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
@ -410,7 +412,7 @@ class ApiToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
provider_identity=provider_controller.entity.identity.name
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
@ -424,10 +426,10 @@ class ApiToolManageService:
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
|
@ -32,7 +32,7 @@ class BuiltinToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
@ -71,7 +71,7 @@ class BuiltinToolManageService:
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
@ -97,7 +97,7 @@ class BuiltinToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# get original credentials if exists
|
||||
@ -159,7 +159,7 @@ class BuiltinToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
credentials = tool_configuration.decrypt(provider_obj.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
@ -191,7 +191,7 @@ class BuiltinToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
@ -241,7 +241,7 @@ class BuiltinToolManageService:
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
@ -69,19 +70,19 @@ class ToolTransformService:
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
pt_BR=provider_controller.identity.description.pt_BR,
|
||||
en_US=provider_controller.entity.identity.description.en_US,
|
||||
zh_Hans=provider_controller.entity.identity.description.zh_Hans,
|
||||
pt_BR=provider_controller.entity.identity.description.pt_BR,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
pt_BR=provider_controller.identity.label.pt_BR,
|
||||
en_US=provider_controller.entity.identity.label.en_US,
|
||||
zh_Hans=provider_controller.entity.identity.label.zh_Hans,
|
||||
pt_BR=provider_controller.entity.identity.label.pt_BR,
|
||||
),
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
@ -111,7 +112,7 @@ class ToolTransformService:
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
@ -155,16 +156,16 @@ class ToolTransformService:
|
||||
"""
|
||||
return UserToolProvider(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
en_US=provider_controller.entity.identity.description.en_US,
|
||||
zh_Hans=provider_controller.entity.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
en_US=provider_controller.entity.identity.label.en_US,
|
||||
zh_Hans=provider_controller.entity.identity.label.zh_Hans,
|
||||
),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
@ -189,7 +190,7 @@ class ToolTransformService:
|
||||
user = db_provider.user
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
|
||||
username = user.name
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
@ -222,7 +223,7 @@ class ToolTransformService:
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
@ -236,8 +237,8 @@ class ToolTransformService:
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
tenant_id: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserTool:
|
||||
"""
|
||||
@ -246,14 +247,14 @@ class ToolTransformService:
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
@ -270,10 +271,10 @@ class ToolTransformService:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human if tool.description else I18nObject(en_US=''),
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
parameters=current_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
@ -211,7 +211,9 @@ class WorkflowToolManageService:
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
|
||||
tool=tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
@ -248,7 +250,7 @@ class WorkflowToolManageService:
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(db_tool)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
@ -264,10 +266,10 @@ class WorkflowToolManageService:
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(db_tool)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None):
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
@ -298,7 +300,9 @@ class WorkflowToolManageService:
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
"synced": workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
@ -326,6 +330,8 @@ class WorkflowToolManageService:
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
|
@ -1,5 +1,9 @@
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from tests.integration_tests.tools.__mock.http import setup_http_mock
|
||||
|
||||
tool_bundle = {
|
||||
@ -29,7 +33,13 @@ parameters = {
|
||||
|
||||
|
||||
def test_api_tool(setup_http_mock):
|
||||
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"}))
|
||||
tool = ApiTool(
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(provider="", author="", name="", label=I18nObject()),
|
||||
),
|
||||
api_bundle=ApiToolBundle(**tool_bundle),
|
||||
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
|
||||
)
|
||||
headers = tool.assembling_request(parameters)
|
||||
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user