diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index cc2a49a41f..b3ebf81bf6 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,10 +1,20 @@ +import time from flask_restful import Resource, reqparse from controllers.console.setup import setup_required from controllers.inner_api import api from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only -from core.plugin.entities.request import RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTTS, RequestInvokeTextEmbedding, RequestInvokeTool +from core.plugin.entities.request import ( + RequestInvokeLLM, + RequestInvokeModeration, + RequestInvokeRerank, + RequestInvokeSpeech2Text, + RequestInvokeTextEmbedding, + RequestInvokeTool, + RequestInvokeTTS, +) +from core.tools.entities.tool_entities import ToolInvokeMessage from libs.helper import compact_generate_response from models.account import Tenant from services.plugin.plugin_invoke_service import PluginInvokeService @@ -18,6 +28,7 @@ class PluginInvokeLLMApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): pass + class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only @@ -26,6 +37,7 @@ class PluginInvokeTextEmbeddingApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): pass + class PluginInvokeRerankApi(Resource): @setup_required @plugin_inner_api_only @@ -34,6 +46,7 @@ class PluginInvokeRerankApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank): pass + class PluginInvokeTTSApi(Resource): @setup_required @plugin_inner_api_only @@ -42,6 +55,7 @@ class PluginInvokeTTSApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS): pass + class PluginInvokeSpeech2TextApi(Resource): @setup_required @plugin_inner_api_only @@ -50,6 +64,7 @@ class PluginInvokeSpeech2TextApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): pass + class PluginInvokeModerationApi(Resource): @setup_required @plugin_inner_api_only @@ -58,23 +73,27 @@ class PluginInvokeModerationApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration): pass + class PluginInvokeToolApi(Resource): @setup_required @plugin_inner_api_only @get_tenant @plugin_data(payload_type=RequestInvokeTool) - def post(self, user_id: str, tenant_model: Tenant): - parser = reqparse.RequestParser() - parser.add_argument('provider', type=dict, required=True, location='json') - parser.add_argument('tool', type=dict, required=True, location='json') - parser.add_argument('parameters', type=dict, required=True, location='json') + def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool): + def generator(): + for i in range(10): + time.sleep(0.1) + yield ( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage(text='helloworld'), + ) + .model_dump_json() + .encode() + + b'\n\n' + ) - args = parser.parse_args() - - response = PluginInvokeService.invoke_tool( - user_id, tenant_model, args['provider'], args['tool'], args['parameters'] - ) - return compact_generate_response(response) + return compact_generate_response(generator()) class PluginInvokeNodeApi(Resource): diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f11e8021f0..570e3c003f 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -1,7 +1,7 @@ import logging import time import uuid -from collections.abc import Generator +from collections.abc import Callable, Generator from datetime import timedelta from typing import Optional, Union @@ -91,7 +91,7 @@ class RateLimit: class RateLimitGenerator: - def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str): + def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str): self.rate_limit = rate_limit if callable(generator): self.generator = generator() diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5213f35a4b..1d39e7fb00 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -90,6 +90,12 @@ class ApiProviderAuthType(Enum): raise ValueError(f'invalid mode value {value}') class ToolInvokeMessage(BaseModel): + class TextMessage(BaseModel): + text: str + + class JsonMessage(BaseModel): + json_object: dict + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -103,7 +109,7 @@ class ToolInvokeMessage(BaseModel): """ plain text, image url or link url """ - message: Optional[Union[str, bytes, dict]] = None + message: JsonMessage | TextMessage meta: Optional[dict[str, Any]] = None save_as: str = '' diff --git a/api/libs/helper.py b/api/libs/helper.py index 15cd65dd6a..c169d6ba17 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -36,8 +36,7 @@ def email(email): if re.match(pattern, email) is not None: return email - error = ('{email} is not a valid email.' - .format(email=email)) + error = '{email} is not a valid email.'.format(email=email) raise ValueError(error) @@ -49,10 +48,10 @@ def uuid_value(value): uuid_obj = uuid.UUID(value) return str(uuid_obj) except ValueError: - error = ('{value} is not a valid uuid.' - .format(value=value)) + error = '{value} is not a valid uuid.'.format(value=value) raise ValueError(error) + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r'^[a-zA-Z0-9_]+$', value): @@ -60,6 +59,7 @@ def alphanumeric(value: str): raise ValueError(f'{value} is not a valid alphanumeric value') + def timestamp_value(timestamp): try: int_timestamp = int(timestamp) @@ -67,13 +67,12 @@ def timestamp_value(timestamp): raise ValueError return int_timestamp except ValueError: - error = ('{timestamp} is not a valid timestamp.' - .format(timestamp=timestamp)) + error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp) raise ValueError(error) class str_len: - """ Restrict input to an integer in a range (inclusive) """ + """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument='argument'): self.max_length = max_length @@ -82,15 +81,17 @@ class str_len: def __call__(self, value): length = len(value) if length > self.max_length: - error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' - .format(arg=self.argument, val=value, length=self.max_length)) + error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format( + arg=self.argument, val=value, length=self.max_length + ) raise ValueError(error) return value class float_range: - """ Restrict input to an float in a range (inclusive) """ + """Restrict input to an float in a range (inclusive)""" + def __init__(self, low, high, argument='argument'): self.low = low self.high = high @@ -99,8 +100,9 @@ class float_range: def __call__(self, value): value = _get_float(value) if value < self.low or value > self.high: - error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' - .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format( + arg=self.argument, val=value, lo=self.low, hi=self.high + ) raise ValueError(error) return value @@ -115,8 +117,9 @@ class datetime_string: try: datetime.strptime(value, self.format) except ValueError: - error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' - .format(arg=self.argument, val=value, format=self.format)) + error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format( + arg=self.argument, val=value, format=self.format + ) raise ValueError(error) return value @@ -128,18 +131,18 @@ def _get_float(value): except (TypeError, ValueError): raise ValueError('{} is not a valid float'.format(value)) + def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string - error = ('{timezone_string} is not a valid timezone.' - .format(timezone_string=timezone_string)) + error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string) raise ValueError(error) def generate_string(n): letters_digits = string.ascii_letters + string.digits - result = "" + result = '' for i in range(n): result += random.choice(letters_digits) @@ -149,8 +152,8 @@ def generate_string(n): def get_remote_ip(request) -> str: if request.headers.get('CF-Connecting-IP'): return request.headers.get('Cf-Connecting-Ip') - elif request.headers.getlist("X-Forwarded-For"): - return request.headers.getlist("X-Forwarded-For")[0] + elif request.headers.getlist('X-Forwarded-For'): + return request.headers.getlist('X-Forwarded-For')[0] else: return request.remote_addr @@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str: return sha256(hash_text.encode()).hexdigest() -def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: +def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') else: - def generate() -> Generator: - yield from response - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') + def generate() -> Generator: + for data in response: + if isinstance(data, dict): + yield json.dumps(data).encode() + if isinstance(data, str): + yield data.encode() + else: + yield data + + return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') class TokenManager: - @classmethod def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: old_token = cls._get_current_token_for_account(account.id, token_type) @@ -182,21 +190,13 @@ class TokenManager: cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = { - 'account_id': account.id, - 'email': account.email, - 'token_type': token_type - } + token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type} if additional_data: token_data.update(additional_data) expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] token_key = cls._get_token_key(token, token_type) - redis_client.setex( - token_key, - expiry_hours * 60 * 60, - json.dumps(token_data) - ) + redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @@ -215,7 +215,7 @@ class TokenManager: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: - logging.warning(f"{token_type} token {token} not found with key {key}") + logging.warning(f'{token_type} token {token} not found with key {key}') return None token_data = json.loads(token_data_json) return token_data @@ -243,7 +243,7 @@ class RateLimiter: self.time_window = time_window def _get_key(self, email: str) -> str: - return f"{self.prefix}:{email}" + return f'{self.prefix}:{email}' def is_rate_limited(self, email: str) -> bool: key = self._get_key(email)