feat: support backwards invocation

This commit is contained in:
Yeuoly 2024-07-29 18:57:34 +08:00
parent f29b44acd8
commit d52476c1c9
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
4 changed files with 77 additions and 52 deletions

View File

@ -1,10 +1,20 @@
import time
from flask_restful import Resource, reqparse
from controllers.console.setup import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.entities.request import RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTTS, RequestInvokeTextEmbedding, RequestInvokeTool
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from libs.helper import compact_generate_response
from models.account import Tenant
from services.plugin.plugin_invoke_service import PluginInvokeService
@ -18,6 +28,7 @@ class PluginInvokeLLMApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
pass
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@ -26,6 +37,7 @@ class PluginInvokeTextEmbeddingApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
pass
class PluginInvokeRerankApi(Resource):
@setup_required
@plugin_inner_api_only
@ -34,6 +46,7 @@ class PluginInvokeRerankApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
pass
class PluginInvokeTTSApi(Resource):
@setup_required
@plugin_inner_api_only
@ -42,6 +55,7 @@ class PluginInvokeTTSApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
pass
class PluginInvokeSpeech2TextApi(Resource):
@setup_required
@plugin_inner_api_only
@ -50,6 +64,7 @@ class PluginInvokeSpeech2TextApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
pass
class PluginInvokeModerationApi(Resource):
@setup_required
@plugin_inner_api_only
@ -58,23 +73,27 @@ class PluginInvokeModerationApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
pass
class PluginInvokeToolApi(Resource):
@setup_required
@plugin_inner_api_only
@get_tenant
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_id: str, tenant_model: Tenant):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=dict, required=True, location='json')
parser.add_argument('tool', type=dict, required=True, location='json')
parser.add_argument('parameters', type=dict, required=True, location='json')
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
for i in range(10):
time.sleep(0.1)
yield (
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text='helloworld'),
)
.model_dump_json()
.encode()
+ b'\n\n'
)
args = parser.parse_args()
response = PluginInvokeService.invoke_tool(
user_id, tenant_model, args['provider'], args['tool'], args['parameters']
)
return compact_generate_response(response)
return compact_generate_response(generator())
class PluginInvokeNodeApi(Resource):

View File

@ -1,7 +1,7 @@
import logging
import time
import uuid
from collections.abc import Generator
from collections.abc import Callable, Generator
from datetime import timedelta
from typing import Optional, Union
@ -91,7 +91,7 @@ class RateLimit:
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()

View File

@ -90,6 +90,12 @@ class ApiProviderAuthType(Enum):
raise ValueError(f'invalid mode value {value}')
class ToolInvokeMessage(BaseModel):
class TextMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
@ -103,7 +109,7 @@ class ToolInvokeMessage(BaseModel):
"""
plain text, image url or link url
"""
message: Optional[Union[str, bytes, dict]] = None
message: JsonMessage | TextMessage
meta: Optional[dict[str, Any]] = None
save_as: str = ''

View File

@ -36,8 +36,7 @@ def email(email):
if re.match(pattern, email) is not None:
return email
error = ('{email} is not a valid email.'
.format(email=email))
error = '{email} is not a valid email.'.format(email=email)
raise ValueError(error)
@ -49,10 +48,10 @@ def uuid_value(value):
uuid_obj = uuid.UUID(value)
return str(uuid_obj)
except ValueError:
error = ('{value} is not a valid uuid.'
.format(value=value))
error = '{value} is not a valid uuid.'.format(value=value)
raise ValueError(error)
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r'^[a-zA-Z0-9_]+$', value):
@ -60,6 +59,7 @@ def alphanumeric(value: str):
raise ValueError(f'{value} is not a valid alphanumeric value')
def timestamp_value(timestamp):
try:
int_timestamp = int(timestamp)
@ -67,13 +67,12 @@ def timestamp_value(timestamp):
raise ValueError
return int_timestamp
except ValueError:
error = ('{timestamp} is not a valid timestamp.'
.format(timestamp=timestamp))
error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp)
raise ValueError(error)
class str_len:
""" Restrict input to an integer in a range (inclusive) """
"""Restrict input to an integer in a range (inclusive)"""
def __init__(self, max_length, argument='argument'):
self.max_length = max_length
@ -82,15 +81,17 @@ class str_len:
def __call__(self, value):
length = len(value)
if length > self.max_length:
error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
.format(arg=self.argument, val=value, length=self.max_length))
error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format(
arg=self.argument, val=value, length=self.max_length
)
raise ValueError(error)
return value
class float_range:
""" Restrict input to an float in a range (inclusive) """
"""Restrict input to an float in a range (inclusive)"""
def __init__(self, low, high, argument='argument'):
self.low = low
self.high = high
@ -99,8 +100,9 @@ class float_range:
def __call__(self, value):
value = _get_float(value)
if value < self.low or value > self.high:
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
.format(arg=self.argument, val=value, lo=self.low, hi=self.high))
error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format(
arg=self.argument, val=value, lo=self.low, hi=self.high
)
raise ValueError(error)
return value
@ -115,8 +117,9 @@ class datetime_string:
try:
datetime.strptime(value, self.format)
except ValueError:
error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
.format(arg=self.argument, val=value, format=self.format))
error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format(
arg=self.argument, val=value, format=self.format
)
raise ValueError(error)
return value
@ -128,18 +131,18 @@ def _get_float(value):
except (TypeError, ValueError):
raise ValueError('{} is not a valid float'.format(value))
def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string
error = ('{timezone_string} is not a valid timezone.'
.format(timezone_string=timezone_string))
error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string)
raise ValueError(error)
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
result = ''
for i in range(n):
result += random.choice(letters_digits)
@ -149,8 +152,8 @@ def generate_string(n):
def get_remote_ip(request) -> str:
if request.headers.get('CF-Connecting-IP'):
return request.headers.get('Cf-Connecting-Ip')
elif request.headers.getlist("X-Forwarded-For"):
return request.headers.getlist("X-Forwarded-For")[0]
elif request.headers.getlist('X-Forwarded-For'):
return request.headers.getlist('X-Forwarded-For')[0]
else:
return request.remote_addr
@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
def generate() -> Generator:
for data in response:
if isinstance(data, dict):
yield json.dumps(data).encode()
if isinstance(data, str):
yield data.encode()
else:
yield data
return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream')
class TokenManager:
@classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
@ -182,21 +190,13 @@ class TokenManager:
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {
'account_id': account.id,
'email': account.email,
'token_type': token_type
}
token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type}
if additional_data:
token_data.update(additional_data)
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
token_key = cls._get_token_key(token, token_type)
redis_client.setex(
token_key,
expiry_hours * 60 * 60,
json.dumps(token_data)
)
redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token
@ -215,7 +215,7 @@ class TokenManager:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
logging.warning(f'{token_type} token {token} not found with key {key}')
return None
token_data = json.loads(token_data_json)
return token_data
@ -243,7 +243,7 @@ class RateLimiter:
self.time_window = time_window
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
return f'{self.prefix}:{email}'
def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email)