From 56b7853afea4f9a71fdffbbacb294d7d41feb86d Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 30 Sep 2024 23:22:03 +0800 Subject: [PATCH] feat: compat tool provider credentials to updated data --- api/core/tools/entities/tool_entities.py | 17 ++++ api/core/tools/tool_manager.py | 47 +++++++--- api/models/tools.py | 1 + .../tools/api_tools_manage_service.py | 4 +- .../tools/builtin_tools_manage_service.py | 85 +++++++++++-------- api/services/tools/tools_transform_service.py | 2 +- .../tools/workflow_tools_manage_service.py | 6 +- 7 files changed, 111 insertions(+), 51 deletions(-) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b62707b8f7..e6ef1df79f 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,4 +1,5 @@ import base64 +import re from enum import Enum from typing import Any, Optional, Union @@ -377,3 +378,19 @@ class ToolInvokeFrom(Enum): WORKFLOW = "workflow" AGENT = "agent" + + +class ToolProviderID: + organization: str + plugin_name: str + provider_name: str + + def __str__(self) -> str: + return f"{self.organization}/{self.plugin_name}/{self.provider_name}" + + def __init__(self, value: str) -> None: + # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name + if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value): + raise ValueError("Invalid plugin id") + + self.organization, self.plugin_name, self.provider_name = value.split("/") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index fadc649e1f..54ae3a4117 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -29,7 +29,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeFrom, + ToolParameter, + ToolProviderID, + ToolProviderType, +) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager @@ -143,18 +149,30 @@ class ToolManager: ), ) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + # get credentials + builtin_provider: BuiltinToolProvider | None = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_id) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .first() ) - .first() - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + else: + builtin_provider: BuiltinToolProvider | None = ( + db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .first() + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials @@ -505,6 +523,13 @@ class ToolManager: db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() ) + # rewrite db_builtin_providers + for db_provider in db_builtin_providers: + try: + ToolProviderID(db_provider.provider) + except Exception: + db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}" + find_db_builtin_provider = lambda provider: next( (x for x in db_builtin_providers if x.provider == provider), None ) diff --git a/api/models/tools.py b/api/models/tools.py index 1e99749b24..b0d4ea3399 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -222,6 +222,7 @@ class ToolModelInvoke(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) +@deprecated class ToolConversationVariables(Base): """ store the conversation variables from tool invoke diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 82bfc8c8e5..8db7a5cbc4 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -221,7 +221,7 @@ class ApiToolManageService: labels = ToolLabelManager.get_tool_labels(controller) return [ - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tool_bundle, tenant_id=tenant_id, labels=labels, @@ -465,7 +465,7 @@ class ApiToolManageService: for tool in tools: user_provider.tools.append( - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 542add9336..de29727005 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -7,6 +7,7 @@ from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import ToolProviderID from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager @@ -40,14 +41,7 @@ class BuiltinToolManageService: provider_identity=provider_controller.entity.identity.name, ) # check if user has added the provider - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ) - .first() - ) + builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) credentials = {} if builtin_provider is not None: @@ -58,7 +52,7 @@ class BuiltinToolManageService: result = [] for tool in tools: result.append( - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, credentials=credentials, tenant_id=tenant_id, @@ -86,14 +80,7 @@ class BuiltinToolManageService: update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() - ) + provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) try: # get provider @@ -149,14 +136,7 @@ class BuiltinToolManageService: """ get builtin tool provider credentials """ - provider_obj: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ) - .first() - ) + provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) if provider_obj is None: return {} @@ -177,14 +157,7 @@ class BuiltinToolManageService: """ delete tool provider """ - provider_obj: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() - ) + provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) if provider_obj is None: raise ValueError(f"you have not added provider {provider_name}") @@ -227,6 +200,13 @@ class BuiltinToolManageService: db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] ) + # rewrite db_providers + for db_provider in db_providers: + try: + ToolProviderID(db_provider.provider) + except Exception: + db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}" + # find provider find_provider = lambda provider: next( filter(lambda db_provider: db_provider.provider == provider, db_providers), None @@ -258,7 +238,7 @@ class BuiltinToolManageService: tools = provider_controller.get_tools() for tool in tools: user_builtin_provider.tools.append( - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, credentials=user_builtin_provider.original_credentials, @@ -271,3 +251,40 @@ class BuiltinToolManageService: raise e return BuiltinToolProviderSort.sort(result) + + @staticmethod + def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: + try: + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + if provider_id_entity.organization != "langgenius": + return None + + provider_obj = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name), + ) + .first() + ) + + if provider_obj is None: + return None + + try: + ToolProviderID(provider_obj.provider) + except Exception: + provider_obj.provider = f"langgenius/{provider_obj.provider}/{provider_obj.provider}" + + return provider_obj + except Exception: + # it's an old provider without organization + return ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name), + ) + .first() + ) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 3d1f361088..6f07fa6dd5 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -223,7 +223,7 @@ class ToolTransformService: return result @staticmethod - def tool_to_user_tool( + def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, credentials: dict | None = None, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 3178fe7999..87ddaf3e67 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -210,7 +210,7 @@ class WorkflowToolManageService: ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tool=tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []), tenant_id=tenant_id, @@ -299,7 +299,7 @@ class WorkflowToolManageService: "icon": json.loads(db_tool.icon), "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), - "tool": ToolTransformService.tool_to_user_tool( + "tool": ToolTransformService.convert_tool_entity_to_api_entity( tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id, @@ -329,7 +329,7 @@ class WorkflowToolManageService: tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ - ToolTransformService.tool_to_user_tool( + ToolTransformService.convert_tool_entity_to_api_entity( tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id,