feat: support backwards invocation
This commit is contained in:
parent
f29b44acd8
commit
d52476c1c9
@ -1,10 +1,20 @@
|
|||||||
|
import time
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.inner_api import api
|
from controllers.inner_api import api
|
||||||
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
||||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||||
from core.plugin.entities.request import RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTTS, RequestInvokeTextEmbedding, RequestInvokeTool
|
from core.plugin.entities.request import (
|
||||||
|
RequestInvokeLLM,
|
||||||
|
RequestInvokeModeration,
|
||||||
|
RequestInvokeRerank,
|
||||||
|
RequestInvokeSpeech2Text,
|
||||||
|
RequestInvokeTextEmbedding,
|
||||||
|
RequestInvokeTool,
|
||||||
|
RequestInvokeTTS,
|
||||||
|
)
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from libs.helper import compact_generate_response
|
from libs.helper import compact_generate_response
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from services.plugin.plugin_invoke_service import PluginInvokeService
|
from services.plugin.plugin_invoke_service import PluginInvokeService
|
||||||
@ -18,6 +28,7 @@ class PluginInvokeLLMApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeTextEmbeddingApi(Resource):
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@ -26,6 +37,7 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeRerankApi(Resource):
|
class PluginInvokeRerankApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@ -34,6 +46,7 @@ class PluginInvokeRerankApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeTTSApi(Resource):
|
class PluginInvokeTTSApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@ -42,6 +55,7 @@ class PluginInvokeTTSApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeSpeech2TextApi(Resource):
|
class PluginInvokeSpeech2TextApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@ -50,6 +64,7 @@ class PluginInvokeSpeech2TextApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeModerationApi(Resource):
|
class PluginInvokeModerationApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@ -58,23 +73,27 @@ class PluginInvokeModerationApi(Resource):
|
|||||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeToolApi(Resource):
|
class PluginInvokeToolApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@plugin_inner_api_only
|
@plugin_inner_api_only
|
||||||
@get_tenant
|
@get_tenant
|
||||||
@plugin_data(payload_type=RequestInvokeTool)
|
@plugin_data(payload_type=RequestInvokeTool)
|
||||||
def post(self, user_id: str, tenant_model: Tenant):
|
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||||
parser = reqparse.RequestParser()
|
def generator():
|
||||||
parser.add_argument('provider', type=dict, required=True, location='json')
|
for i in range(10):
|
||||||
parser.add_argument('tool', type=dict, required=True, location='json')
|
time.sleep(0.1)
|
||||||
parser.add_argument('parameters', type=dict, required=True, location='json')
|
yield (
|
||||||
|
ToolInvokeMessage(
|
||||||
args = parser.parse_args()
|
type=ToolInvokeMessage.MessageType.TEXT,
|
||||||
|
message=ToolInvokeMessage.TextMessage(text='helloworld'),
|
||||||
response = PluginInvokeService.invoke_tool(
|
|
||||||
user_id, tenant_model, args['provider'], args['tool'], args['parameters']
|
|
||||||
)
|
)
|
||||||
return compact_generate_response(response)
|
.model_dump_json()
|
||||||
|
.encode()
|
||||||
|
+ b'\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return compact_generate_response(generator())
|
||||||
|
|
||||||
|
|
||||||
class PluginInvokeNodeApi(Resource):
|
class PluginInvokeNodeApi(Resource):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class RateLimit:
|
|||||||
|
|
||||||
|
|
||||||
class RateLimitGenerator:
|
class RateLimitGenerator:
|
||||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str):
|
||||||
self.rate_limit = rate_limit
|
self.rate_limit = rate_limit
|
||||||
if callable(generator):
|
if callable(generator):
|
||||||
self.generator = generator()
|
self.generator = generator()
|
||||||
|
@ -90,6 +90,12 @@ class ApiProviderAuthType(Enum):
|
|||||||
raise ValueError(f'invalid mode value {value}')
|
raise ValueError(f'invalid mode value {value}')
|
||||||
|
|
||||||
class ToolInvokeMessage(BaseModel):
|
class ToolInvokeMessage(BaseModel):
|
||||||
|
class TextMessage(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
class JsonMessage(BaseModel):
|
||||||
|
json_object: dict
|
||||||
|
|
||||||
class MessageType(Enum):
|
class MessageType(Enum):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
@ -103,7 +109,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
plain text, image url or link url
|
plain text, image url or link url
|
||||||
"""
|
"""
|
||||||
message: Optional[Union[str, bytes, dict]] = None
|
message: JsonMessage | TextMessage
|
||||||
meta: Optional[dict[str, Any]] = None
|
meta: Optional[dict[str, Any]] = None
|
||||||
save_as: str = ''
|
save_as: str = ''
|
||||||
|
|
||||||
|
@ -36,8 +36,7 @@ def email(email):
|
|||||||
if re.match(pattern, email) is not None:
|
if re.match(pattern, email) is not None:
|
||||||
return email
|
return email
|
||||||
|
|
||||||
error = ('{email} is not a valid email.'
|
error = '{email} is not a valid email.'.format(email=email)
|
||||||
.format(email=email))
|
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
@ -49,10 +48,10 @@ def uuid_value(value):
|
|||||||
uuid_obj = uuid.UUID(value)
|
uuid_obj = uuid.UUID(value)
|
||||||
return str(uuid_obj)
|
return str(uuid_obj)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
error = ('{value} is not a valid uuid.'
|
error = '{value} is not a valid uuid.'.format(value=value)
|
||||||
.format(value=value))
|
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
def alphanumeric(value: str):
|
def alphanumeric(value: str):
|
||||||
# check if the value is alphanumeric and underlined
|
# check if the value is alphanumeric and underlined
|
||||||
if re.match(r'^[a-zA-Z0-9_]+$', value):
|
if re.match(r'^[a-zA-Z0-9_]+$', value):
|
||||||
@ -60,6 +59,7 @@ def alphanumeric(value: str):
|
|||||||
|
|
||||||
raise ValueError(f'{value} is not a valid alphanumeric value')
|
raise ValueError(f'{value} is not a valid alphanumeric value')
|
||||||
|
|
||||||
|
|
||||||
def timestamp_value(timestamp):
|
def timestamp_value(timestamp):
|
||||||
try:
|
try:
|
||||||
int_timestamp = int(timestamp)
|
int_timestamp = int(timestamp)
|
||||||
@ -67,8 +67,7 @@ def timestamp_value(timestamp):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
return int_timestamp
|
return int_timestamp
|
||||||
except ValueError:
|
except ValueError:
|
||||||
error = ('{timestamp} is not a valid timestamp.'
|
error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp)
|
||||||
.format(timestamp=timestamp))
|
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
@ -82,8 +81,9 @@ class str_len:
|
|||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
length = len(value)
|
length = len(value)
|
||||||
if length > self.max_length:
|
if length > self.max_length:
|
||||||
error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
|
error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format(
|
||||||
.format(arg=self.argument, val=value, length=self.max_length))
|
arg=self.argument, val=value, length=self.max_length
|
||||||
|
)
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
@ -91,6 +91,7 @@ class str_len:
|
|||||||
|
|
||||||
class float_range:
|
class float_range:
|
||||||
"""Restrict input to an float in a range (inclusive)"""
|
"""Restrict input to an float in a range (inclusive)"""
|
||||||
|
|
||||||
def __init__(self, low, high, argument='argument'):
|
def __init__(self, low, high, argument='argument'):
|
||||||
self.low = low
|
self.low = low
|
||||||
self.high = high
|
self.high = high
|
||||||
@ -99,8 +100,9 @@ class float_range:
|
|||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
value = _get_float(value)
|
value = _get_float(value)
|
||||||
if value < self.low or value > self.high:
|
if value < self.low or value > self.high:
|
||||||
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
|
error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format(
|
||||||
.format(arg=self.argument, val=value, lo=self.low, hi=self.high))
|
arg=self.argument, val=value, lo=self.low, hi=self.high
|
||||||
|
)
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
@ -115,8 +117,9 @@ class datetime_string:
|
|||||||
try:
|
try:
|
||||||
datetime.strptime(value, self.format)
|
datetime.strptime(value, self.format)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
|
error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format(
|
||||||
.format(arg=self.argument, val=value, format=self.format))
|
arg=self.argument, val=value, format=self.format
|
||||||
|
)
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
@ -128,18 +131,18 @@ def _get_float(value):
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
raise ValueError('{} is not a valid float'.format(value))
|
raise ValueError('{} is not a valid float'.format(value))
|
||||||
|
|
||||||
|
|
||||||
def timezone(timezone_string):
|
def timezone(timezone_string):
|
||||||
if timezone_string and timezone_string in available_timezones():
|
if timezone_string and timezone_string in available_timezones():
|
||||||
return timezone_string
|
return timezone_string
|
||||||
|
|
||||||
error = ('{timezone_string} is not a valid timezone.'
|
error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string)
|
||||||
.format(timezone_string=timezone_string))
|
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
def generate_string(n):
|
def generate_string(n):
|
||||||
letters_digits = string.ascii_letters + string.digits
|
letters_digits = string.ascii_letters + string.digits
|
||||||
result = ""
|
result = ''
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
result += random.choice(letters_digits)
|
result += random.choice(letters_digits)
|
||||||
|
|
||||||
@ -149,8 +152,8 @@ def generate_string(n):
|
|||||||
def get_remote_ip(request) -> str:
|
def get_remote_ip(request) -> str:
|
||||||
if request.headers.get('CF-Connecting-IP'):
|
if request.headers.get('CF-Connecting-IP'):
|
||||||
return request.headers.get('Cf-Connecting-Ip')
|
return request.headers.get('Cf-Connecting-Ip')
|
||||||
elif request.headers.getlist("X-Forwarded-For"):
|
elif request.headers.getlist('X-Forwarded-For'):
|
||||||
return request.headers.getlist("X-Forwarded-For")[0]
|
return request.headers.getlist('X-Forwarded-For')[0]
|
||||||
else:
|
else:
|
||||||
return request.remote_addr
|
return request.remote_addr
|
||||||
|
|
||||||
@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str:
|
|||||||
return sha256(hash_text.encode()).hexdigest()
|
return sha256(hash_text.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
|
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
else:
|
else:
|
||||||
def generate() -> Generator:
|
|
||||||
yield from response
|
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200,
|
def generate() -> Generator:
|
||||||
mimetype='text/event-stream')
|
for data in response:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
yield json.dumps(data).encode()
|
||||||
|
if isinstance(data, str):
|
||||||
|
yield data.encode()
|
||||||
|
else:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream')
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
class TokenManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
||||||
old_token = cls._get_current_token_for_account(account.id, token_type)
|
old_token = cls._get_current_token_for_account(account.id, token_type)
|
||||||
@ -182,21 +190,13 @@ class TokenManager:
|
|||||||
cls.revoke_token(old_token, token_type)
|
cls.revoke_token(old_token, token_type)
|
||||||
|
|
||||||
token = str(uuid.uuid4())
|
token = str(uuid.uuid4())
|
||||||
token_data = {
|
token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type}
|
||||||
'account_id': account.id,
|
|
||||||
'email': account.email,
|
|
||||||
'token_type': token_type
|
|
||||||
}
|
|
||||||
if additional_data:
|
if additional_data:
|
||||||
token_data.update(additional_data)
|
token_data.update(additional_data)
|
||||||
|
|
||||||
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
|
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
|
||||||
token_key = cls._get_token_key(token, token_type)
|
token_key = cls._get_token_key(token, token_type)
|
||||||
redis_client.setex(
|
redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
|
||||||
token_key,
|
|
||||||
expiry_hours * 60 * 60,
|
|
||||||
json.dumps(token_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
||||||
return token
|
return token
|
||||||
@ -215,7 +215,7 @@ class TokenManager:
|
|||||||
key = cls._get_token_key(token, token_type)
|
key = cls._get_token_key(token, token_type)
|
||||||
token_data_json = redis_client.get(key)
|
token_data_json = redis_client.get(key)
|
||||||
if token_data_json is None:
|
if token_data_json is None:
|
||||||
logging.warning(f"{token_type} token {token} not found with key {key}")
|
logging.warning(f'{token_type} token {token} not found with key {key}')
|
||||||
return None
|
return None
|
||||||
token_data = json.loads(token_data_json)
|
token_data = json.loads(token_data_json)
|
||||||
return token_data
|
return token_data
|
||||||
@ -243,7 +243,7 @@ class RateLimiter:
|
|||||||
self.time_window = time_window
|
self.time_window = time_window
|
||||||
|
|
||||||
def _get_key(self, email: str) -> str:
|
def _get_key(self, email: str) -> str:
|
||||||
return f"{self.prefix}:{email}"
|
return f'{self.prefix}:{email}'
|
||||||
|
|
||||||
def is_rate_limited(self, email: str) -> bool:
|
def is_rate_limited(self, email: str) -> bool:
|
||||||
key = self._get_key(email)
|
key = self._get_key(email)
|
||||||
|
Loading…
Reference in New Issue
Block a user