From 886a1601152b5fd887f6e5777efd0bd37b88f84c Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 30 Aug 2024 18:11:38 +0800 Subject: [PATCH] fix: invoke tool streamingly --- api/core/entities/provider_entities.py | 4 +- api/core/helper/tool_provider_cache.py | 1 + api/core/tools/entities/api_entities.py | 11 ++-- api/core/tools/entities/tool_entities.py | 4 +- api/core/tools/provider/api_tool_provider.py | 25 +++++---- .../tools/provider/builtin_tool_provider.py | 2 +- api/core/tools/provider/tool_provider.py | 4 +- api/core/tools/tool/tool.py | 13 ++++- api/core/tools/tool_manager.py | 21 ++++++-- api/core/tools/utils/configuration.py | 24 +++++---- api/core/tools/utils/parser.py | 11 ++-- api/core/workflow/nodes/tool/tool_node.py | 17 +++--- api/models/model.py | 24 ++++----- api/models/tools.py | 6 +-- .../tools/api_tools_manage_service.py | 52 ++++++++++++------- api/services/tools/tools_transform_service.py | 22 +++++--- 16 files changed, 149 insertions(+), 92 deletions(-) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index ae78d9ecf9..e0d1de151f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -4,8 +4,8 @@ from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope -from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType +from core.tools.entities.common_entities import I18nObject from models.provider import ProviderQuotaType @@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - scope: AppSelectorScope | ModelConfigScope | None + scope: AppSelectorScope | ModelConfigScope | None = None required: bool = False default: Optional[Union[int, str]] = None options: Optional[list[Option]] = None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6..2777367963 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + ENDPOINT = "endpoint" class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 71db8d8b2d..2aaca35060 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,10 +1,11 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool.tool import ToolParameter @@ -14,7 +15,7 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None - labels: list[str] = None + labels: list[str] = Field(default_factory=list) UserToolProviderTypeLiteral = Optional[Literal[ 'builtin', 'api', 'workflow' @@ -32,8 +33,8 @@ class UserToolProvider(BaseModel): original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] = None - labels: list[str] = None + tools: list[UserTool] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) def to_dict(self) -> dict: # ------------- diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 98efb92a0d..4b0961fb09 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -25,7 +25,7 @@ class ToolLabelEnum(Enum): UTILITIES = 'utilities' OTHER = 'other' -class ToolProviderType(Enum): +class ToolProviderType(str, Enum): """ Enum class for tool provider """ @@ -181,7 +181,7 @@ class ToolParameter(BaseModel): if options: option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] else: - option_objs = None + option_objs = [] return cls( name=name, label=I18nObject(en_US='', zh_Hans=''), diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index fc7fcb675a..880ddc4955 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,21 +1,23 @@ +from pydantic import Field + +from core.entities.provider_entities import ProviderConfig from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, - ProviderConfig, - ToolCredentialsOption, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool -from core.tools.tool.tool import Tool from extensions.ext_database import db from models.tools import ApiToolProvider class ApiToolProviderController(ToolProviderController): provider_id: str + tenant_id: str + tools: list[ApiTool] = Field(default_factory=list) @staticmethod def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': @@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController): required=True, type=ProviderConfig.Type.SELECT, options=[ - ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='无')), + ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) ], default='none', help=I18nObject( @@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController): zh_Hans='api key header 的前缀' ), options=[ - ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) + ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), + ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), + ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) ] ) } @@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController): }, 'credentials_schema': credentials_schema, 'provider_id': db_provider.id or '', + 'tenant_id': db_provider.tenant_id or '', }) @property @@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController): return self.tools - def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: + def get_tools(self, tenant_id: str) -> list[ApiTool]: """ fetch tools from database @@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController): if self.tools is not None: return self.tools - tools: list[Tool] = [] + tools: list[ApiTool] = [] # get tenant api providers db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( @@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController): :return: the tool """ if self.tools is None: - self.get_tools() + self.get_tools(self.tenant_id) for tool in self.tools: if tool.identity.name == tool_name: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 7ad8a5468b..8dd543b00a 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController): super().__init__(**{ 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', None), + 'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {}, }) def _get_builtin_tools(self) -> list[BuiltinTool]: diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ac770a2a60..057f3060ed 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from core.entities.provider_entities import ProviderConfig from core.tools.entities.tool_entities import ( @@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC): tools: list[Tool] = Field(default_factory=list) credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) + model_config = ConfigDict(validate_assignment=True) + def get_credentials_schema(self) -> dict[str, ProviderConfig]: """ returns the credentials schema of the provider diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 6005297118..6f21afdb35 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -206,7 +206,16 @@ class Tool(BaseModel, ABC): tool_parameters=tool_parameters, ) - return result + if isinstance(result, ToolInvokeMessage): + def single_generator(): + yield result + return single_generator() + elif isinstance(result, list): + def generator(): + yield from result + return generator() + else: + return result def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: """ @@ -223,7 +232,7 @@ class Tool(BaseModel, ABC): return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index efc2802016..56e97252f9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -116,7 +116,12 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials controller = cls.get_builtin_provider(provider_id) - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=controller.get_credentials_schema(), + provider_type=controller.provider_type.value, + provider_identity=controller.identity.name + ) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) @@ -135,7 +140,12 @@ class ToolManager: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) # decrypt the credentials - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=api_provider.get_credentials_schema(), + provider_type=api_provider.provider_type.value, + provider_identity=api_provider.identity.name + ) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ @@ -513,7 +523,12 @@ class ToolManager: provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=controller.get_credentials_schema(), + provider_type=controller.provider_type.value, + provider_identity=controller.identity.name + ) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 2fc0ba3bcd..5b65ce443f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,23 +1,25 @@ +from collections.abc import Mapping from copy import deepcopy from typing import Any from pydantic import BaseModel +from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.entities.tool_entities import ( - ProviderConfig, ToolParameter, ToolProviderType, ) -from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.tool import Tool class ToolConfigurationManager(BaseModel): tenant_id: str - provider_controller: ToolProviderController + config: Mapping[str, BasicProviderConfig] + provider_type: str + provider_identity: str def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]: """ @@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel): credentials = self._deep_copy(credentials) # get fields need to be decrypted - fields = self.provider_controller.get_credentials_schema() + fields = self.config for field_name, field in fields.items(): - if field.type == ProviderConfig.Type.SECRET_INPUT: + if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in credentials: encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypted @@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel): credentials = self._deep_copy(credentials) # get fields need to be decrypted - fields = self.provider_controller.get_credentials_schema() + fields = self.config for field_name, field in fields.items(): - if field.type == ProviderConfig.Type.SECRET_INPUT: + if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: credentials[field_name] = \ @@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel): """ cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', + identity_id=f'{self.provider_type}.{self.provider_identity}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cached_credentials = cache.get() @@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel): return cached_credentials credentials = self._deep_copy(credentials) # get fields need to be decrypted - fields = self.provider_controller.get_credentials_schema() + fields = self.config for field_name, field in fields.items(): - if field.type == ProviderConfig.Type.SECRET_INPUT: + if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in credentials: try: credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) @@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel): def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', + identity_id=f'{self.provider_type}.{self.provider_identity}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f711f7c9f3..882e276afe 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle @@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod - def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: + def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: + warning = warning or {} """ parse swagger to openapi @@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser: return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle @@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: + def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]: """ auto parse to tool bundle diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6ba7e7e09b..2d01bad1f4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,6 +1,6 @@ from collections.abc import Generator, Sequence from os import path -from typing import Any, cast +from typing import Any, Iterable, cast from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -158,14 +158,17 @@ class ToolNode(BaseNode): tenant_id=self.tenant_id, conversation_id=None, ) + + result = list(messages) + # extract plain text and files - files = self._extract_tool_response_binary(messages) - plain_text = self._extract_tool_response_text(messages) - json = self._extract_tool_response_json(messages) + files = self._extract_tool_response_binary(result) + plain_text = self._extract_tool_response_text(result) + json = self._extract_tool_response_json(result) return plain_text, files, json - def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]: + def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]: """ Extract tool response binary """ @@ -215,7 +218,7 @@ class ToolNode(BaseNode): return result - def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str: + def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str: """ Extract tool response text """ @@ -230,7 +233,7 @@ class ToolNode(BaseNode): return '\n'.join(result) - def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]: + def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]: result: list[dict] = [] for message in tool_response: if message.type == ToolInvokeMessage.MessageType.JSON: diff --git a/api/models/model.py b/api/models/model.py index 298bfbda12..74ba4a7fd5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,7 +7,7 @@ from typing import Optional from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -495,14 +495,14 @@ class InstalledApp(db.Model): return tenant -class Conversation(db.Model): +class Conversation(Base): __tablename__ = 'conversations' __table_args__ = ( db.PrimaryKeyConstraint('id', name='conversation_pkey'), db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -526,8 +526,8 @@ class Conversation(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all") + message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -660,10 +660,10 @@ class Message(Base): model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) - inputs = db.Column(db.JSON) - query = db.Column(db.Text, nullable=False) - message = db.Column(db.JSON, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) + inputs: Mapped[str] = mapped_column(db.JSON) + query: Mapped[str] = mapped_column(db.Text, nullable=False) + message: Mapped[str] = mapped_column(db.JSON, nullable=False) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) @@ -944,7 +944,7 @@ class MessageFile(Base): db.Index('message_file_created_by_idx', 'created_by') ) - id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) @@ -956,7 +956,7 @@ class MessageFile(Base): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) -class MessageAnnotation(db.Model): +class MessageAnnotation(Base): __tablename__ = 'message_annotations' __table_args__ = ( db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), @@ -967,7 +967,7 @@ class MessageAnnotation(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) diff --git a/api/models/tools.py b/api/models/tools.py index 1e7421622a..a87dfad079 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -77,10 +77,10 @@ class PublishedAppTool(db.Model): return I18nObject(**json.loads(self.description)) @property - def app(self) -> App: + def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() -class ApiToolProvider(db.Model): +class ApiToolProvider(Base): """ The table stores the api providers. """ @@ -290,7 +290,7 @@ class ToolFile(Base): db.Index('tool_file_conversation_id_idx', 'conversation_id'), ) - id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()')) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index d7538bd812..bfb9827ce2 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -3,6 +3,7 @@ import logging from httpx import get +from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject @@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, - ProviderConfig, - ToolCredentialsOption, ) from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.tool_label_manager import ToolLabelManager @@ -45,8 +44,8 @@ class ApiToolManageService: required=True, default="none", options=[ - ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), - ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), + ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), @@ -79,15 +78,14 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: + def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: """ convert schema to tool bundles :return: the list of tool bundles, description """ try: - tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) - return tool_bundles + return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -111,7 +109,7 @@ class ApiToolManageService: raise ValueError(f"invalid schema type {schema}") # check if the provider exists - provider: ApiToolProvider = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -158,7 +156,13 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.identity.name + ) + encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) db_provider.credentials_str = json.dumps(encrypted_credentials) @@ -195,21 +199,21 @@ class ApiToolManageService: return {"schema": schema} @staticmethod - def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]: """ list api tool provider tools """ - provider: ApiToolProvider = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, + ApiToolProvider.name == provider_name, ) .first() ) if provider is None: - raise ValueError(f"you have not added provider {provider}") + raise ValueError(f"you have not added provider {provider_name}") controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) @@ -243,7 +247,7 @@ class ApiToolManageService: raise ValueError(f"invalid schema type {schema}") # check if the provider exists - provider: ApiToolProvider = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -282,7 +286,12 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.identity.name + ) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) @@ -310,7 +319,7 @@ class ApiToolManageService: """ delete tool provider """ - provider: ApiToolProvider = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -360,7 +369,7 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider: ApiToolProvider = ( + db_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -396,7 +405,12 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager( + tenant_id=tenant_id, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.identity.name + ) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -444,7 +458,7 @@ class ApiToolManageService: # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) + tools = provider_controller.get_tools(tenant_id=tenant_id) for tool in tools: user_provider.tools.append( diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1848fb2a13..513a421966 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -3,12 +3,12 @@ import logging from typing import Optional, Union from configs import dify_config +from core.entities.provider_entities import ProviderConfig from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, - ProviderConfig, ToolParameter, ToolProviderType, ) @@ -106,7 +106,10 @@ class ToolTransformService: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, provider_controller=provider_controller + tenant_id=db_provider.tenant_id, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.identity.name ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -143,7 +146,7 @@ class ToolTransformService: @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: list[str] | None = None ): """ convert provider controller to user provider @@ -174,7 +177,7 @@ class ToolTransformService: provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None, + labels: list[str] | None = None, ) -> UserToolProvider: """ convert provider controller to user provider @@ -209,7 +212,10 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, provider_controller=provider_controller + tenant_id=db_provider.tenant_id, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.identity.name ) # decrypt the credentials and mask the credentials @@ -223,9 +229,9 @@ class ToolTransformService: @staticmethod def tool_to_user_tool( tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, - tenant_id: str = None, - labels: list[str] = None, + credentials: dict | None = None, + tenant_id: str | None = None, + labels: list[str] | None = None, ) -> UserTool: """ convert tool to user tool