Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly 2024-11-22 18:19:27 +08:00
commit f69d5caa14
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
7 changed files with 89 additions and 24 deletions

View File

@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
node_id: str node_id: str
inputs: dict inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -219,7 +219,7 @@ class ModelInstance:
input_type=input_type, input_type=input_type,
) )
def get_text_embedding_num_tokens(self, texts: list[str]) -> int: def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
""" """
Get number of tokens for text embedding Get number of tokens for text embedding
@ -235,7 +235,7 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts, texts=texts,
) )[0] # TODO: fix this, this is only for temporary compatibility with old
def invoke_rerank( def invoke_rerank(
self, self,

View File

@ -52,7 +52,7 @@ class TextEmbeddingModel(AIModel):
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
""" """
Get number of tokens for given prompt messages Get number of tokens for given prompt messages

View File

@ -71,7 +71,15 @@ class PluginModelProviderEntity(BaseModel):
declaration: ProviderEntity = Field(description="The declaration of the model provider.") declaration: ProviderEntity = Field(description="The declaration of the model provider.")
class PluginNumTokensResponse(BaseModel): class PluginTextEmbeddingNumTokensResponse(BaseModel):
"""
Response for number of tokens.
"""
num_tokens: list[int] = Field(description="The number of tokens.")
class PluginLLMNumTokensResponse(BaseModel):
""" """
Response for number of tokens. Response for number of tokens.
""" """

View File

@ -17,6 +17,14 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError, InvokeServerUnavailableError,
) )
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
from core.plugin.manager.exc import (
PluginDaemonBadRequestError,
PluginDaemonInternalServerError,
PluginDaemonNotFoundError,
PluginDaemonUnauthorizedError,
PluginPermissionDeniedError,
PluginUniqueIdentifierError,
)
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL
plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
@ -190,17 +198,32 @@ class BasePluginManager:
""" """
args = args or {} args = args or {}
if error_type == PluginDaemonInnerError.__name__: match error_type:
raise PluginDaemonInnerError(code=-500, message=message) case PluginDaemonInnerError.__name__:
elif error_type == InvokeRateLimitError.__name__: raise PluginDaemonInnerError(code=-500, message=message)
raise InvokeRateLimitError(description=args.get("description")) case InvokeRateLimitError.__name__:
elif error_type == InvokeAuthorizationError.__name__: raise InvokeRateLimitError(description=args.get("description"))
raise InvokeAuthorizationError(description=args.get("description")) case InvokeAuthorizationError.__name__:
elif error_type == InvokeBadRequestError.__name__: raise InvokeAuthorizationError(description=args.get("description"))
raise InvokeBadRequestError(description=args.get("description")) case InvokeBadRequestError.__name__:
elif error_type == InvokeConnectionError.__name__: raise InvokeBadRequestError(description=args.get("description"))
raise InvokeConnectionError(description=args.get("description")) case InvokeConnectionError.__name__:
elif error_type == InvokeServerUnavailableError.__name__: raise InvokeConnectionError(description=args.get("description"))
raise InvokeServerUnavailableError(description=args.get("description")) case InvokeServerUnavailableError.__name__:
else: raise InvokeServerUnavailableError(description=args.get("description"))
raise ValueError(f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}") case PluginDaemonInternalServerError.__name__:
raise PluginDaemonInternalServerError(description=message)
case PluginDaemonBadRequestError.__name__:
raise PluginDaemonBadRequestError(description=message)
case PluginDaemonNotFoundError.__name__:
raise PluginDaemonNotFoundError(description=message)
case PluginUniqueIdentifierError.__name__:
raise PluginUniqueIdentifierError(description=message)
case PluginDaemonUnauthorizedError.__name__:
raise PluginDaemonUnauthorizedError(description=message)
case PluginPermissionDeniedError.__name__:
raise PluginPermissionDeniedError(description=message)
case _:
raise ValueError(
f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}"
)

View File

@ -0,0 +1,33 @@
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
def __init__(self, description: str) -> None:
self.description = description
class PluginDaemonInternalServerError(PluginDaemonError):
description: str = "Internal Server Error"
class PluginDaemonBadRequestError(PluginDaemonError):
description: str = "Bad Request"
class PluginDaemonNotFoundError(PluginDaemonError):
description: str = "Not Found"
class PluginUniqueIdentifierError(PluginDaemonError):
description: str = "Unique Identifier Error"
class PluginNotFoundError(PluginDaemonError):
description: str = "Plugin Not Found"
class PluginDaemonUnauthorizedError(PluginDaemonError):
description: str = "Unauthorized"
class PluginPermissionDeniedError(PluginDaemonError):
description: str = "Permission Denied"

View File

@ -11,10 +11,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse, PluginBasicBooleanResponse,
PluginDaemonInnerError, PluginDaemonInnerError,
PluginLLMNumTokensResponse,
PluginModelProviderEntity, PluginModelProviderEntity,
PluginModelSchemaEntity, PluginModelSchemaEntity,
PluginNumTokensResponse,
PluginStringResultResponse, PluginStringResultResponse,
PluginTextEmbeddingNumTokensResponse,
PluginVoicesResponse, PluginVoicesResponse,
) )
from core.plugin.manager.base import BasePluginManager from core.plugin.manager.base import BasePluginManager
@ -201,7 +202,7 @@ class PluginModelManager(BasePluginManager):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginNumTokensResponse, type=PluginLLMNumTokensResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@ -277,14 +278,14 @@ class PluginModelManager(BasePluginManager):
model: str, model: str,
credentials: dict, credentials: dict,
texts: list[str], texts: list[str],
) -> int: ) -> list[int]:
""" """
Get number of tokens for text embedding Get number of tokens for text embedding
""" """
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginNumTokensResponse, type=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@ -306,7 +307,7 @@ class PluginModelManager(BasePluginManager):
for resp in response: for resp in response:
return resp.num_tokens return resp.num_tokens
return 0 return []
def invoke_rerank( def invoke_rerank(
self, self,