fix: invoke tool streamingly

This commit is contained in:
Yeuoly 2024-08-30 18:11:38 +08:00
parent cf4e9f317e
commit 886a160115
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
16 changed files with 149 additions and 92 deletions

View File

@ -4,8 +4,8 @@ from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
from models.provider import ProviderQuotaType
@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None
scope: AppSelectorScope | ModelConfigScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None

View File

@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):

View File

@ -1,10 +1,11 @@
from typing import Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool.tool import ToolParameter
@ -14,7 +15,7 @@ class UserTool(BaseModel):
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = None
labels: list[str] = Field(default_factory=list)
UserToolProviderTypeLiteral = Optional[Literal[
'builtin', 'api', 'workflow'
@ -32,8 +33,8 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
labels: list[str] = None
tools: list[UserTool] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
def to_dict(self) -> dict:
# -------------

View File

@ -25,7 +25,7 @@ class ToolLabelEnum(Enum):
UTILITIES = 'utilities'
OTHER = 'other'
class ToolProviderType(Enum):
class ToolProviderType(str, Enum):
"""
Enum class for tool provider
"""
@ -181,7 +181,7 @@ class ToolParameter(BaseModel):
if options:
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
else:
option_objs = None
option_objs = []
return cls(
name=name,
label=I18nObject(en_US='', zh_Hans=''),

View File

@ -1,21 +1,23 @@
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ProviderConfig,
ToolCredentialsOption,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiToolProviderController(ToolProviderController):
provider_id: str
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController):
required=True,
type=ProviderConfig.Type.SELECT,
options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
],
default='none',
help=I18nObject(
@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController):
zh_Hans='api key header 的前缀'
),
options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
)
}
@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController):
},
'credentials_schema': credentials_schema,
'provider_id': db_provider.id or '',
'tenant_id': db_provider.tenant_id or '',
})
@property
@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController):
return self.tools
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
def get_tools(self, tenant_id: str) -> list[ApiTool]:
"""
fetch tools from database
@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
tools: list[Tool] = []
tools: list[ApiTool] = []
# get tenant api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController):
:return: the tool
"""
if self.tools is None:
self.get_tools()
self.get_tools(self.tenant_id)
for tool in self.tools:
if tool.identity.name == tool_name:

View File

@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController):
super().__init__(**{
'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
})
def _get_builtin_tools(self) -> list[BuiltinTool]:

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.tool_entities import (
@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC):
tools: list[Tool] = Field(default_factory=list)
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
model_config = ConfigDict(validate_assignment=True)
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider

View File

@ -206,7 +206,16 @@ class Tool(BaseModel, ABC):
tool_parameters=tool_parameters,
)
return result
if isinstance(result, ToolInvokeMessage):
def single_generator():
yield result
return single_generator()
elif isinstance(result, list):
def generator():
yield from result
return generator()
else:
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
@ -223,7 +232,7 @@ class Tool(BaseModel, ABC):
return result
@abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

View File

@ -116,7 +116,12 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@ -135,7 +140,12 @@ class ToolManager:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=api_provider.get_credentials_schema(),
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
@ -513,7 +523,12 @@ class ToolManager:
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)

View File

@ -1,23 +1,25 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ProviderConfig,
ToolParameter,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
class ToolConfigurationManager(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
config: Mapping[str, BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
"""
@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel):
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel):
return cached_credentials
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
fields = self.config
for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT:
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()

View File

@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser:
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle

View File

@ -1,6 +1,6 @@
from collections.abc import Generator, Sequence
from os import path
from typing import Any, cast
from typing import Any, Iterable, cast
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -158,14 +158,17 @@ class ToolNode(BaseNode):
tenant_id=self.tenant_id,
conversation_id=None,
)
result = list(messages)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
json = self._extract_tool_response_json(messages)
files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(result)
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
@ -215,7 +218,7 @@ class ToolNode(BaseNode):
return result
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
@ -230,7 +233,7 @@ class ToolNode(BaseNode):
return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON:

View File

@ -7,7 +7,7 @@ from typing import Optional
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from configs import dify_config
from core.file.tool_file_parser import ToolFileParser
@ -495,14 +495,14 @@ class InstalledApp(db.Model):
return tenant
class Conversation(db.Model):
class Conversation(Base):
__tablename__ = 'conversations'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='conversation_pkey'),
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True)
@ -526,8 +526,8 @@ class Conversation(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@ -660,10 +660,10 @@ class Message(Base):
model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON)
query = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
inputs: Mapped[str] = mapped_column(db.JSON)
query: Mapped[str] = mapped_column(db.Text, nullable=False)
message: Mapped[str] = mapped_column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
@ -944,7 +944,7 @@ class MessageFile(Base):
db.Index('message_file_created_by_idx', 'created_by')
)
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
@ -956,7 +956,7 @@ class MessageFile(Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class MessageAnnotation(db.Model):
class MessageAnnotation(Base):
__tablename__ = 'message_annotations'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='message_annotation_pkey'),
@ -967,7 +967,7 @@ class MessageAnnotation(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(StringUUID, nullable=True)
question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False)

View File

@ -77,10 +77,10 @@ class PublishedAppTool(db.Model):
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App:
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
class ApiToolProvider(db.Model):
class ApiToolProvider(Base):
"""
The table stores the api providers.
"""
@ -290,7 +290,7 @@ class ToolFile(Base):
db.Index('tool_file_conversation_id_idx', 'conversation_id'),
)
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id

View File

@ -3,6 +3,7 @@ import logging
from httpx import get
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject
@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ProviderConfig,
ToolCredentialsOption,
)
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
@ -45,8 +44,8 @@ class ApiToolManageService:
required=True,
default="none",
options=[
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
@ -79,15 +78,14 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
"""
convert schema to tool bundles
:return: the list of tool bundles, description
"""
try:
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
return tool_bundles
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@ -111,7 +109,7 @@ class ApiToolManageService:
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@ -158,7 +156,13 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
@ -195,21 +199,21 @@ class ApiToolManageService:
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
"""
list api tool provider tools
"""
provider: ApiToolProvider = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider}")
raise ValueError(f"you have not added provider {provider_name}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
@ -243,7 +247,7 @@ class ApiToolManageService:
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@ -282,7 +286,12 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
@ -310,7 +319,7 @@ class ApiToolManageService:
"""
delete tool provider
"""
provider: ApiToolProvider = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@ -360,7 +369,7 @@ class ApiToolManageService:
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = (
db_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@ -396,7 +405,12 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@ -444,7 +458,7 @@ class ApiToolManageService:
# add icon
ToolTransformService.repack_provider(user_provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
tools = provider_controller.get_tools(tenant_id=tenant_id)
for tool in tools:
user_provider.tools.append(

View File

@ -3,12 +3,12 @@ import logging
from typing import Optional, Union
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ProviderConfig,
ToolParameter,
ToolProviderType,
)
@ -106,7 +106,10 @@ class ToolTransformService:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@ -143,7 +146,7 @@ class ToolTransformService:
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController, labels: list[str] = None
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
):
"""
convert provider controller to user provider
@ -174,7 +177,7 @@ class ToolTransformService:
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] = None,
labels: list[str] | None = None,
) -> UserToolProvider:
"""
convert provider controller to user provider
@ -209,7 +212,10 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
# decrypt the credentials and mask the credentials
@ -223,9 +229,9 @@ class ToolTransformService:
@staticmethod
def tool_to_user_tool(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tenant_id: str = None,
labels: list[str] = None,
credentials: dict | None = None,
tenant_id: str | None = None,
labels: list[str] | None = None,
) -> UserTool:
"""
convert tool to user tool