diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 31c3a996e1..1f6756a4ee 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ node_id: str - inputs: dict + inputs: Mapping single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 011c51aad4..c12134a97d 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -219,7 +219,7 @@ class ModelInstance: input_type=input_type, ) - def get_text_embedding_num_tokens(self, texts: list[str]) -> int: + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -235,7 +235,7 @@ class ModelInstance: model=self.model, credentials=self.credentials, texts=texts, - ) + )[0] # TODO: fix this, this is only for temporary compatibility with old def invoke_rerank( self, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 6da5db3883..4ff1c9032a 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -52,7 +52,7 @@ class TextEmbeddingModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: """ Get number of tokens for given prompt messages diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index ef91fa6046..17ce71d01a 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -71,7 +71,15 @@ class PluginModelProviderEntity(BaseModel): declaration: ProviderEntity = Field(description="The declaration of the model provider.") -class PluginNumTokensResponse(BaseModel): +class PluginTextEmbeddingNumTokensResponse(BaseModel): + """ + Response for number of tokens. + """ + + num_tokens: list[int] = Field(description="The number of tokens.") + + +class PluginLLMNumTokensResponse(BaseModel): """ Response for number of tokens. """ diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index 5968c2cf1f..ff90b6ea21 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -17,6 +17,14 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError +from core.plugin.manager.exc import ( + PluginDaemonBadRequestError, + PluginDaemonInternalServerError, + PluginDaemonNotFoundError, + PluginDaemonUnauthorizedError, + PluginPermissionDeniedError, + PluginUniqueIdentifierError, +) plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY @@ -190,17 +198,32 @@ class BasePluginManager: """ args = args or {} - if error_type == PluginDaemonInnerError.__name__: - raise PluginDaemonInnerError(code=-500, message=message) - elif error_type == InvokeRateLimitError.__name__: - raise InvokeRateLimitError(description=args.get("description")) - elif error_type == InvokeAuthorizationError.__name__: - raise InvokeAuthorizationError(description=args.get("description")) - elif error_type == InvokeBadRequestError.__name__: - raise InvokeBadRequestError(description=args.get("description")) - elif error_type == InvokeConnectionError.__name__: - raise InvokeConnectionError(description=args.get("description")) - elif error_type == InvokeServerUnavailableError.__name__: - raise InvokeServerUnavailableError(description=args.get("description")) - else: - raise ValueError(f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}") + match error_type: + case PluginDaemonInnerError.__name__: + raise PluginDaemonInnerError(code=-500, message=message) + case InvokeRateLimitError.__name__: + raise InvokeRateLimitError(description=args.get("description")) + case InvokeAuthorizationError.__name__: + raise InvokeAuthorizationError(description=args.get("description")) + case InvokeBadRequestError.__name__: + raise InvokeBadRequestError(description=args.get("description")) + case InvokeConnectionError.__name__: + raise InvokeConnectionError(description=args.get("description")) + case InvokeServerUnavailableError.__name__: + raise InvokeServerUnavailableError(description=args.get("description")) + case PluginDaemonInternalServerError.__name__: + raise PluginDaemonInternalServerError(description=message) + case PluginDaemonBadRequestError.__name__: + raise PluginDaemonBadRequestError(description=message) + case PluginDaemonNotFoundError.__name__: + raise PluginDaemonNotFoundError(description=message) + case PluginUniqueIdentifierError.__name__: + raise PluginUniqueIdentifierError(description=message) + case PluginDaemonUnauthorizedError.__name__: + raise PluginDaemonUnauthorizedError(description=message) + case PluginPermissionDeniedError.__name__: + raise PluginPermissionDeniedError(description=message) + case _: + raise ValueError( + f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}" + ) diff --git a/api/core/plugin/manager/exc.py b/api/core/plugin/manager/exc.py new file mode 100644 index 0000000000..8c2d78b526 --- /dev/null +++ b/api/core/plugin/manager/exc.py @@ -0,0 +1,33 @@ +class PluginDaemonError(Exception): + """Base class for all plugin daemon errors.""" + + def __init__(self, description: str) -> None: + self.description = description + + +class PluginDaemonInternalServerError(PluginDaemonError): + description: str = "Internal Server Error" + + +class PluginDaemonBadRequestError(PluginDaemonError): + description: str = "Bad Request" + + +class PluginDaemonNotFoundError(PluginDaemonError): + description: str = "Not Found" + + +class PluginUniqueIdentifierError(PluginDaemonError): + description: str = "Unique Identifier Error" + + +class PluginNotFoundError(PluginDaemonError): + description: str = "Plugin Not Found" + + +class PluginDaemonUnauthorizedError(PluginDaemonError): + description: str = "Unauthorized" + + +class PluginPermissionDeniedError(PluginDaemonError): + description: str = "Permission Denied" diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/manager/model.py index 4188148812..7842d624a3 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/manager/model.py @@ -11,10 +11,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, + PluginLLMNumTokensResponse, PluginModelProviderEntity, PluginModelSchemaEntity, - PluginNumTokensResponse, PluginStringResultResponse, + PluginTextEmbeddingNumTokensResponse, PluginVoicesResponse, ) from core.plugin.manager.base import BasePluginManager @@ -201,7 +202,7 @@ class PluginModelManager(BasePluginManager): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", - type=PluginNumTokensResponse, + type=PluginLLMNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -277,14 +278,14 @@ class PluginModelManager(BasePluginManager): model: str, credentials: dict, texts: list[str], - ) -> int: + ) -> list[int]: """ Get number of tokens for text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", - type=PluginNumTokensResponse, + type=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -306,7 +307,7 @@ class PluginModelManager(BasePluginManager): for resp in response: return resp.num_tokens - return 0 + return [] def invoke_rerank( self,