feat: support backwards invocation

This commit is contained in:
Yeuoly 2024-07-29 18:57:34 +08:00
parent f29b44acd8
commit d52476c1c9
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
4 changed files with 77 additions and 52 deletions

View File

@ -1,10 +1,20 @@
import time
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.inner_api import api from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.entities.request import RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTTS, RequestInvokeTextEmbedding, RequestInvokeTool from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from libs.helper import compact_generate_response from libs.helper import compact_generate_response
from models.account import Tenant from models.account import Tenant
from services.plugin.plugin_invoke_service import PluginInvokeService from services.plugin.plugin_invoke_service import PluginInvokeService
@ -18,6 +28,7 @@ class PluginInvokeLLMApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
pass pass
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -26,6 +37,7 @@ class PluginInvokeTextEmbeddingApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
pass pass
class PluginInvokeRerankApi(Resource): class PluginInvokeRerankApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -34,6 +46,7 @@ class PluginInvokeRerankApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
pass pass
class PluginInvokeTTSApi(Resource): class PluginInvokeTTSApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -42,6 +55,7 @@ class PluginInvokeTTSApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
pass pass
class PluginInvokeSpeech2TextApi(Resource): class PluginInvokeSpeech2TextApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -50,6 +64,7 @@ class PluginInvokeSpeech2TextApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
pass pass
class PluginInvokeModerationApi(Resource): class PluginInvokeModerationApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -58,23 +73,27 @@ class PluginInvokeModerationApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
pass pass
class PluginInvokeToolApi(Resource): class PluginInvokeToolApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_tenant @get_tenant
@plugin_data(payload_type=RequestInvokeTool) @plugin_data(payload_type=RequestInvokeTool)
def post(self, user_id: str, tenant_model: Tenant): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
parser = reqparse.RequestParser() def generator():
parser.add_argument('provider', type=dict, required=True, location='json') for i in range(10):
parser.add_argument('tool', type=dict, required=True, location='json') time.sleep(0.1)
parser.add_argument('parameters', type=dict, required=True, location='json') yield (
ToolInvokeMessage(
args = parser.parse_args() type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text='helloworld'),
response = PluginInvokeService.invoke_tool(
user_id, tenant_model, args['provider'], args['tool'], args['parameters']
) )
return compact_generate_response(response) .model_dump_json()
.encode()
+ b'\n\n'
)
return compact_generate_response(generator())
class PluginInvokeNodeApi(Resource): class PluginInvokeNodeApi(Resource):

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Callable, Generator
from datetime import timedelta from datetime import timedelta
from typing import Optional, Union from typing import Optional, Union
@ -91,7 +91,7 @@ class RateLimit:
class RateLimitGenerator: class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str): def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str):
self.rate_limit = rate_limit self.rate_limit = rate_limit
if callable(generator): if callable(generator):
self.generator = generator() self.generator = generator()

View File

@ -90,6 +90,12 @@ class ApiProviderAuthType(Enum):
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
class ToolInvokeMessage(BaseModel): class ToolInvokeMessage(BaseModel):
class TextMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict
class MessageType(Enum): class MessageType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
@ -103,7 +109,7 @@ class ToolInvokeMessage(BaseModel):
""" """
plain text, image url or link url plain text, image url or link url
""" """
message: Optional[Union[str, bytes, dict]] = None message: JsonMessage | TextMessage
meta: Optional[dict[str, Any]] = None meta: Optional[dict[str, Any]] = None
save_as: str = '' save_as: str = ''

View File

@ -36,8 +36,7 @@ def email(email):
if re.match(pattern, email) is not None: if re.match(pattern, email) is not None:
return email return email
error = ('{email} is not a valid email.' error = '{email} is not a valid email.'.format(email=email)
.format(email=email))
raise ValueError(error) raise ValueError(error)
@ -49,10 +48,10 @@ def uuid_value(value):
uuid_obj = uuid.UUID(value) uuid_obj = uuid.UUID(value)
return str(uuid_obj) return str(uuid_obj)
except ValueError: except ValueError:
error = ('{value} is not a valid uuid.' error = '{value} is not a valid uuid.'.format(value=value)
.format(value=value))
raise ValueError(error) raise ValueError(error)
def alphanumeric(value: str): def alphanumeric(value: str):
# check if the value is alphanumeric and underlined # check if the value is alphanumeric and underlined
if re.match(r'^[a-zA-Z0-9_]+$', value): if re.match(r'^[a-zA-Z0-9_]+$', value):
@ -60,6 +59,7 @@ def alphanumeric(value: str):
raise ValueError(f'{value} is not a valid alphanumeric value') raise ValueError(f'{value} is not a valid alphanumeric value')
def timestamp_value(timestamp): def timestamp_value(timestamp):
try: try:
int_timestamp = int(timestamp) int_timestamp = int(timestamp)
@ -67,8 +67,7 @@ def timestamp_value(timestamp):
raise ValueError raise ValueError
return int_timestamp return int_timestamp
except ValueError: except ValueError:
error = ('{timestamp} is not a valid timestamp.' error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp)
.format(timestamp=timestamp))
raise ValueError(error) raise ValueError(error)
@ -82,8 +81,9 @@ class str_len:
def __call__(self, value): def __call__(self, value):
length = len(value) length = len(value)
if length > self.max_length: if length > self.max_length:
error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format(
.format(arg=self.argument, val=value, length=self.max_length)) arg=self.argument, val=value, length=self.max_length
)
raise ValueError(error) raise ValueError(error)
return value return value
@ -91,6 +91,7 @@ class str_len:
class float_range: class float_range:
"""Restrict input to an float in a range (inclusive)""" """Restrict input to an float in a range (inclusive)"""
def __init__(self, low, high, argument='argument'): def __init__(self, low, high, argument='argument'):
self.low = low self.low = low
self.high = high self.high = high
@ -99,8 +100,9 @@ class float_range:
def __call__(self, value): def __call__(self, value):
value = _get_float(value) value = _get_float(value)
if value < self.low or value > self.high: if value < self.low or value > self.high:
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format(
.format(arg=self.argument, val=value, lo=self.low, hi=self.high)) arg=self.argument, val=value, lo=self.low, hi=self.high
)
raise ValueError(error) raise ValueError(error)
return value return value
@ -115,8 +117,9 @@ class datetime_string:
try: try:
datetime.strptime(value, self.format) datetime.strptime(value, self.format)
except ValueError: except ValueError:
error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format(
.format(arg=self.argument, val=value, format=self.format)) arg=self.argument, val=value, format=self.format
)
raise ValueError(error) raise ValueError(error)
return value return value
@ -128,18 +131,18 @@ def _get_float(value):
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValueError('{} is not a valid float'.format(value)) raise ValueError('{} is not a valid float'.format(value))
def timezone(timezone_string): def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones(): if timezone_string and timezone_string in available_timezones():
return timezone_string return timezone_string
error = ('{timezone_string} is not a valid timezone.' error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string)
.format(timezone_string=timezone_string))
raise ValueError(error) raise ValueError(error)
def generate_string(n): def generate_string(n):
letters_digits = string.ascii_letters + string.digits letters_digits = string.ascii_letters + string.digits
result = "" result = ''
for i in range(n): for i in range(n):
result += random.choice(letters_digits) result += random.choice(letters_digits)
@ -149,8 +152,8 @@ def generate_string(n):
def get_remote_ip(request) -> str: def get_remote_ip(request) -> str:
if request.headers.get('CF-Connecting-IP'): if request.headers.get('CF-Connecting-IP'):
return request.headers.get('Cf-Connecting-Ip') return request.headers.get('Cf-Connecting-Ip')
elif request.headers.getlist("X-Forwarded-For"): elif request.headers.getlist('X-Forwarded-For'):
return request.headers.getlist("X-Forwarded-For")[0] return request.headers.getlist('X-Forwarded-For')[0]
else: else:
return request.remote_addr return request.remote_addr
@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest() return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json') return Response(response=json.dumps(response), status=200, mimetype='application/json')
else: else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200, def generate() -> Generator:
mimetype='text/event-stream') for data in response:
if isinstance(data, dict):
yield json.dumps(data).encode()
if isinstance(data, str):
yield data.encode()
else:
yield data
return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream')
class TokenManager: class TokenManager:
@classmethod @classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type) old_token = cls._get_current_token_for_account(account.id, token_type)
@ -182,21 +190,13 @@ class TokenManager:
cls.revoke_token(old_token, token_type) cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4()) token = str(uuid.uuid4())
token_data = { token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type}
'account_id': account.id,
'email': account.email,
'token_type': token_type
}
if additional_data: if additional_data:
token_data.update(additional_data) token_data.update(additional_data)
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
token_key = cls._get_token_key(token, token_type) token_key = cls._get_token_key(token, token_type)
redis_client.setex( redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
token_key,
expiry_hours * 60 * 60,
json.dumps(token_data)
)
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token return token
@ -215,7 +215,7 @@ class TokenManager:
key = cls._get_token_key(token, token_type) key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key) token_data_json = redis_client.get(key)
if token_data_json is None: if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}") logging.warning(f'{token_type} token {token} not found with key {key}')
return None return None
token_data = json.loads(token_data_json) token_data = json.loads(token_data_json)
return token_data return token_data
@ -243,7 +243,7 @@ class RateLimiter:
self.time_window = time_window self.time_window = time_window
def _get_key(self, email: str) -> str: def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}" return f'{self.prefix}:{email}'
def is_rate_limited(self, email: str) -> bool: def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email) key = self._get_key(email)