fix: invoke tool streamingly
This commit is contained in:
parent
cf4e9f317e
commit
886a160115
@ -4,8 +4,8 @@ from typing import Optional, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
scope: AppSelectorScope | ModelConfigScope | None
|
||||
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
|
@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
@ -14,7 +15,7 @@ class UserTool(BaseModel):
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
'builtin', 'api', 'workflow'
|
||||
@ -32,8 +33,8 @@ class UserToolProvider(BaseModel):
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = None
|
||||
labels: list[str] = None
|
||||
tools: list[UserTool] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
|
@ -25,7 +25,7 @@ class ToolLabelEnum(Enum):
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
|
||||
class ToolProviderType(Enum):
|
||||
class ToolProviderType(str, Enum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
@ -181,7 +181,7 @@ class ToolParameter(BaseModel):
|
||||
if options:
|
||||
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
|
||||
else:
|
||||
option_objs = None
|
||||
option_objs = []
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
|
@ -1,21 +1,23 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ProviderConfig,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||
@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController):
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
],
|
||||
default='none',
|
||||
help=I18nObject(
|
||||
@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController):
|
||||
zh_Hans='api key header 的前缀'
|
||||
),
|
||||
options=[
|
||||
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
|
||||
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
|
||||
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
|
||||
ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
|
||||
ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
|
||||
ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
|
||||
]
|
||||
)
|
||||
}
|
||||
@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
},
|
||||
'credentials_schema': credentials_schema,
|
||||
'provider_id': db_provider.id or '',
|
||||
'tenant_id': db_provider.tenant_id or '',
|
||||
})
|
||||
|
||||
@property
|
||||
@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
|
||||
def get_tools(self, tenant_id: str) -> list[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
tools: list[Tool] = []
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
|
@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
super().__init__(**{
|
||||
'identity': provider_yaml['identity'],
|
||||
'credentials_schema': provider_yaml.get('credentials_for_provider', None),
|
||||
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
|
||||
})
|
||||
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC):
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True)
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
@ -206,7 +206,16 @@ class Tool(BaseModel, ABC):
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
return result
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
def single_generator():
|
||||
yield result
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
def generator():
|
||||
yield from result
|
||||
return generator()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
@ -223,7 +232,7 @@ class Tool(BaseModel, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
|
@ -116,7 +116,12 @@ class ToolManager:
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = cls.get_builtin_provider(provider_id)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.identity.name
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
@ -135,7 +140,12 @@ class ToolManager:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=api_provider.get_credentials_schema(),
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.identity.name
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||
@ -513,7 +523,12 @@ class ToolManager:
|
||||
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.identity.name
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
@ -1,23 +1,25 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ProviderConfig,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class ToolConfigurationManager(BaseModel):
|
||||
tenant_id: str
|
||||
provider_controller: ToolProviderController
|
||||
config: Mapping[str, BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel):
|
||||
credentials = self._deep_copy(credentials)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = self.config
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||
credentials[field_name] = encrypted
|
||||
@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel):
|
||||
credentials = self._deep_copy(credentials)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = self.config
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = \
|
||||
@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
identity_id=f'{self.provider_type}.{self.provider_identity}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel):
|
||||
return cached_credentials
|
||||
credentials = self._deep_copy(credentials)
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
fields = self.config
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ProviderConfig.Type.SECRET_INPUT:
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
try:
|
||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||
@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
identity_id=f'{self.provider_type}.{self.provider_identity}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cache.delete()
|
||||
|
@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser:
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
from typing import Any, Iterable, cast
|
||||
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@ -158,14 +158,17 @@ class ToolNode(BaseNode):
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
result = list(messages)
|
||||
|
||||
# extract plain text and files
|
||||
files = self._extract_tool_response_binary(messages)
|
||||
plain_text = self._extract_tool_response_text(messages)
|
||||
json = self._extract_tool_response_json(messages)
|
||||
files = self._extract_tool_response_binary(result)
|
||||
plain_text = self._extract_tool_response_text(result)
|
||||
json = self._extract_tool_response_json(result)
|
||||
|
||||
return plain_text, files, json
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
|
||||
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@ -215,7 +218,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
|
||||
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
@ -230,7 +233,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
|
||||
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
|
||||
result: list[dict] = []
|
||||
for message in tool_response:
|
||||
if message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
|
@ -7,7 +7,7 @@ from typing import Optional
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
@ -495,14 +495,14 @@ class InstalledApp(db.Model):
|
||||
return tenant
|
||||
|
||||
|
||||
class Conversation(db.Model):
|
||||
class Conversation(Base):
|
||||
__tablename__ = 'conversations'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='conversation_pkey'),
|
||||
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
app_model_config_id = db.Column(StringUUID, nullable=True)
|
||||
model_provider = db.Column(db.String(255), nullable=True)
|
||||
@ -526,8 +526,8 @@ class Conversation(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
||||
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
|
||||
messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
||||
message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
|
||||
|
||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
|
||||
@ -660,10 +660,10 @@ class Message(Base):
|
||||
model_provider = db.Column(db.String(255), nullable=True)
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
override_model_configs = db.Column(db.Text)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
inputs = db.Column(db.JSON)
|
||||
query = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
inputs: Mapped[str] = mapped_column(db.JSON)
|
||||
query: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
message: Mapped[str] = mapped_column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
@ -944,7 +944,7 @@ class MessageFile(Base):
|
||||
db.Index('message_file_created_by_idx', 'created_by')
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
@ -956,7 +956,7 @@ class MessageFile(Base):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
class MessageAnnotation(Base):
|
||||
__tablename__ = 'message_annotations'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='message_annotation_pkey'),
|
||||
@ -967,7 +967,7 @@ class MessageAnnotation(db.Model):
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
message_id = db.Column(StringUUID, nullable=True)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
|
@ -77,10 +77,10 @@ class PublishedAppTool(db.Model):
|
||||
return I18nObject(**json.loads(self.description))
|
||||
|
||||
@property
|
||||
def app(self) -> App:
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
class ApiToolProvider(db.Model):
|
||||
class ApiToolProvider(Base):
|
||||
"""
|
||||
The table stores the api providers.
|
||||
"""
|
||||
@ -290,7 +290,7 @@ class ToolFile(Base):
|
||||
db.Index('tool_file_conversation_id_idx', 'conversation_id'),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from httpx import get
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ProviderConfig,
|
||||
ToolCredentialsOption,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
@ -45,8 +44,8 @@ class ApiToolManageService:
|
||||
required=True,
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
@ -79,15 +78,14 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@ -111,7 +109,7 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -158,7 +156,13 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
)
|
||||
|
||||
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
|
||||
@ -195,21 +199,21 @@ class ApiToolManageService:
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
@ -243,7 +247,7 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -282,7 +286,12 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
)
|
||||
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
@ -310,7 +319,7 @@ class ApiToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -360,7 +369,7 @@ class ApiToolManageService:
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider: ApiToolProvider = (
|
||||
db_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -396,7 +405,12 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
@ -444,7 +458,7 @@ class ApiToolManageService:
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
tools = provider_controller.get_tools(tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(
|
||||
|
@ -3,12 +3,12 @@ import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ProviderConfig,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
@ -106,7 +106,10 @@ class ToolTransformService:
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
@ -143,7 +146,7 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] = None
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
@ -174,7 +177,7 @@ class ToolTransformService:
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
@ -209,7 +212,10 @@ class ToolTransformService:
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.identity.name
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
@ -223,9 +229,9 @@ class ToolTransformService:
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tenant_id: str = None,
|
||||
labels: list[str] = None,
|
||||
credentials: dict | None = None,
|
||||
tenant_id: str | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserTool:
|
||||
"""
|
||||
convert tool to user tool
|
||||
|
Loading…
Reference in New Issue
Block a user