diff --git a/api/commands.py b/api/commands.py index b67b4f8676..2e883ada8d 100644 --- a/api/commands.py +++ b/api/commands.py @@ -134,7 +134,7 @@ def generate_upper_string(): @click.command('gen-recommended-apps', help='Number of records to generate') def generate_recommended_apps(): print('Generating recommended app data...') - apps = App.query.all() + apps = App.query.filter(App.is_public == True).all() for app in apps: recommended_app = RecommendedApp( app_id=app.id, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 971e489971..6834d3a0c5 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -5,8 +5,11 @@ from libs.external_api import ExternalApi bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) +# Import other controllers +from . import setup, version, apikey, admin + # Import app controllers -from .app import app, site, explore, completion, model_config, statistic, conversation, message +from .app import app, site, completion, model_config, statistic, conversation, message # Import auth controllers from .auth import login, oauth @@ -14,7 +17,8 @@ from .auth import login, oauth # Import datasets controllers from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing -# Import other controllers -from . import setup, version, apikey - +# Import workspace controllers from .workspace import workspace, members, providers, account + +# Import explore controllers +from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py new file mode 100644 index 0000000000..7a337a4918 --- /dev/null +++ b/api/controllers/console/admin.py @@ -0,0 +1,135 @@ +import os +from functools import wraps + +from flask import request +from flask_restful import Resource, reqparse +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.console import api +from controllers.console.wraps import only_edition_cloud +from extensions.ext_database import db +from models.model import RecommendedApp, App, InstalledApp + + +def admin_required(view): + @wraps(view) + def decorated(*args, **kwargs): + if not os.getenv('ADMIN_API_KEY'): + raise Unauthorized('API key is invalid.') + + auth_header = request.headers.get('Authorization') + if auth_header is None: + raise Unauthorized('Authorization header is missing.') + + if ' ' not in auth_header: + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + + if auth_scheme != 'bearer': + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + + if os.getenv('ADMIN_API_KEY') != auth_token: + raise Unauthorized('API key is invalid.') + + return view(*args, **kwargs) + + return decorated + + +class InsertExploreAppListApi(Resource): + @only_edition_cloud + @admin_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') + parser.add_argument('desc_en', type=str, location='json') + parser.add_argument('desc_zh', type=str, location='json') + parser.add_argument('copyright', type=str, location='json') + parser.add_argument('privacy_policy', type=str, location='json') + parser.add_argument('category', type=str, required=True, nullable=False, location='json') + parser.add_argument('position', type=int, required=True, nullable=False, location='json') + args = parser.parse_args() + + app = App.query.filter(App.id == args['app_id']).first() + if not app: + raise NotFound('App not found') + + site = app.site + if not site: + desc = args['desc_en'] + copy_right = args['copyright'] + privacy_policy = args['privacy_policy'] + else: + desc = site.description if not args['desc_en'] else args['desc_en'] + copy_right = site.copyright if not args['copyright'] else args['copyright'] + privacy_policy = site.privacy_policy if not args['privacy_policy'] else args['privacy_policy'] + + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + + if not recommended_app: + recommended_app = RecommendedApp( + app_id=app.id, + description={ + 'en': desc, + 'zh': desc if not args['desc_zh'] else args['desc_zh'] + }, + copyright=copy_right, + privacy_policy=privacy_policy, + category=args['category'], + position=args['position'] + ) + + db.session.add(recommended_app) + + app.is_public = True + db.session.commit() + + return {'result': 'success'}, 201 + else: + recommended_app.description = { + 'en': args['desc_en'], + 'zh': args['desc_zh'] + } + + recommended_app.copyright = args['copyright'] + recommended_app.privacy_policy = args['privacy_policy'] + recommended_app.category = args['category'] + recommended_app.position = args['position'] + + app.is_public = True + + db.session.commit() + + return {'result': 'success'}, 200 + + +class InsertExploreAppApi(Resource): + @only_edition_cloud + @admin_required + def delete(self, app_id): + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() + if not recommended_app: + return {'result': 'success'}, 204 + + app = App.query.filter(App.id == recommended_app.app_id).first() + if app: + app.is_public = False + + installed_apps = InstalledApp.query.filter( + InstalledApp.app_id == recommended_app.app_id, + InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + ).all() + + for installed_app in installed_apps: + db.session.delete(installed_app) + + db.session.delete(recommended_app) + db.session.commit() + + return {'result': 'success'}, 204 + + +api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') +api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') diff --git a/api/controllers/console/app/explore.py b/api/controllers/console/app/explore.py deleted file mode 100644 index eeec2ddc24..0000000000 --- a/api/controllers/console/app/explore.py +++ /dev/null @@ -1,209 +0,0 @@ -# -*- coding:utf-8 -*- -from datetime import datetime - -from flask_login import login_required, current_user -from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs -from sqlalchemy import and_ - -from controllers.console import api -from extensions.ext_database import db -from models.model import Tenant, App, InstalledApp, RecommendedApp -from services.account_service import TenantService - -app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String -} - -installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields, attribute='app'), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': fields.DateTime, - 'editable': fields.Boolean -} - -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} - -recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean, - 'install_count': fields.Integer, - 'installed': fields.Boolean, - 'editable': fields.Boolean -} - -recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) -} - - -class InstalledAppsListResource(Resource): - @login_required - @marshal_with(installed_app_list_fields) - def get(self): - current_tenant_id = Tenant.query.first().id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() - - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ - { - **installed_app, - "editable": current_user.role in ["owner", "admin"], - } - for installed_app in installed_apps - ] - installed_apps.sort(key=lambda app: (-app.is_pinned, app.last_used_at)) - - return {'installed_apps': installed_apps} - - @login_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') - args = parser.parse_args() - - current_tenant_id = Tenant.query.first().id - app = App.query.get(args['app_id']) - if app is None: - abort(404, message='App not found') - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() - if recommended_app is None: - abort(404, message='App not found') - if not app.is_public: - abort(403, message="You can't install a non-public app") - - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() - - if installed_app is None: - # todo: position - recommended_app.install_count += 1 - - new_installed_app = InstalledApp( - app_id=args['app_id'], - tenant_id=current_tenant_id, - is_pinned=False, - last_used_at=datetime.utcnow() - ) - db.session.add(new_installed_app) - db.session.commit() - - return {'message': 'App installed successfully'} - - -class InstalledAppResource(Resource): - - @login_required - def delete(self, installed_app_id): - - installed_app = InstalledApp.query.filter(and_( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - )).first() - - if installed_app is None: - abort(404, message='App not found') - - if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - abort(400, message="You can't uninstall an app owned by the current tenant") - - db.session.delete(installed_app) - db.session.commit() - - return {'result': 'success', 'message': 'App uninstalled successfully'} - - @login_required - def patch(self, installed_app_id): - parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) - args = parser.parse_args() - - current_tenant_id = Tenant.query.first().id - installed_app = InstalledApp.query.filter(and_( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_tenant_id - )).first() - - if installed_app is None: - abort(404, message='Installed app not found') - - commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] - commit_args = True - - if commit_args: - db.session.commit() - - return {'result': 'success', 'message': 'App info updated successfully'} - - -class RecommendedAppsResource(Resource): - @login_required - @marshal_with(recommended_app_list_fields) - def get(self): - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True - ).all() - - categories = set() - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - recommended_apps_result = [] - for recommended_app in recommended_apps: - installed = db.session.query(InstalledApp).filter( - and_( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id == current_user.current_tenant_id - ) - ).first() is not None - - language_prefix = current_user.interface_language.split('-')[0] - desc = None - if recommended_app.description: - if language_prefix in recommended_app.description: - desc = recommended_app.description[language_prefix] - elif 'en' in recommended_app.description: - desc = recommended_app.description['en'] - - recommended_app_result = { - 'id': recommended_app.id, - 'app': recommended_app.app, - 'app_id': recommended_app.app_id, - 'description': desc, - 'copyright': recommended_app.copyright, - 'privacy_policy': recommended_app.privacy_policy, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed, - 'install_count': recommended_app.install_count, - 'installed': installed, - 'editable': current_user.role in ['owner', 'admin'], - } - recommended_apps_result.append(recommended_app_result) - - categories.add(recommended_app.category) # add category to categories - - return {'recommended_apps': recommended_apps_result, 'categories': list(categories)} - - -api.add_resource(InstalledAppsListResource, '/installed-apps') -api.add_resource(InstalledAppResource, '/installed-apps/') -api.add_resource(RecommendedAppsResource, '/explore/apps') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py new file mode 100644 index 0000000000..f2a1acd6d5 --- /dev/null +++ b/api/controllers/console/explore/completion.py @@ -0,0 +1,180 @@ +# -*- coding:utf-8 -*- +import json +import logging +from typing import Generator, Union + +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound + +import services +from controllers.console import api +from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ + ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError +from controllers.console.explore.error import NotCompletionAppError, NotChatAppError +from controllers.console.explore.wraps import InstalledAppResource +from core.conversation_message_task import PubHandler +from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value +from services.completion_service import CompletionService + + +# define completion api for user +class CompletionApi(InstalledAppResource): + + def post(self, installed_app): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.completion( + app_model=app_model, + user=current_user, + args=args, + from_source='console', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class CompletionStopApi(InstalledAppResource): + def post(self, installed_app, task_id): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + PubHandler.stop(current_user, task_id) + + return {'result': 'success'}, 200 + + +class ChatApi(InstalledAppResource): + def post(self, installed_app): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.completion( + app_model=app_model, + user=current_user, + args=args, + from_source='console', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class ChatStopApi(InstalledAppResource): + def post(self, installed_app, task_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + PubHandler.stop(current_user, task_id) + + return {'result': 'success'}, 200 + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except services.errors.conversation.ConversationNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n" + except services.errors.conversation.ConversationCompletedError: + yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n" + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') +api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') +api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') +api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py new file mode 100644 index 0000000000..1e25acc14e --- /dev/null +++ b/api/controllers/console/explore/conversation.py @@ -0,0 +1,127 @@ +# -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import fields, reqparse, marshal_with +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.explore.error import NotChatAppError +from controllers.console.explore.wraps import InstalledAppResource +from libs.helper import TimestampField, uuid_value +from services.conversation_service import ConversationService +from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError +from services.web_conversation_service import WebConversationService + +conversation_fields = { + 'id': fields.String, + 'name': fields.String, + 'inputs': fields.Raw, + 'status': fields.String, + 'introduction': fields.String, + 'created_at': TimestampField +} + +conversation_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(conversation_fields)) +} + + +class ConversationListApi(InstalledAppResource): + + @marshal_with(conversation_infinite_scroll_pagination_fields) + def get(self, installed_app): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + args = parser.parse_args() + + pinned = None + if 'pinned' in args and args['pinned'] is not None: + pinned = True if args['pinned'] == 'true' else False + + try: + return WebConversationService.pagination_by_last_id( + app_model=app_model, + user=current_user, + last_id=args['last_id'], + limit=args['limit'], + pinned=pinned + ) + except LastConversationNotExistsError: + raise NotFound("Last Conversation Not Exists.") + + +class ConversationApi(InstalledAppResource): + def delete(self, installed_app, c_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + ConversationService.delete(app_model, conversation_id, current_user) + WebConversationService.unpin(app_model, conversation_id, current_user) + + return {"result": "success"}, 204 + + +class ConversationRenameApi(InstalledAppResource): + + @marshal_with(conversation_fields) + def post(self, installed_app, c_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + args = parser.parse_args() + + try: + return ConversationService.rename(app_model, conversation_id, current_user, args['name']) + except ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + + +class ConversationPinApi(InstalledAppResource): + + def patch(self, installed_app, c_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + + try: + WebConversationService.pin(app_model, conversation_id, current_user) + except ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + + return {"result": "success"} + + +class ConversationUnPinApi(InstalledAppResource): + def patch(self, installed_app, c_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + WebConversationService.unpin(app_model, conversation_id, current_user) + + return {"result": "success"} + + +api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') +api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') +api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') +api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') +api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py new file mode 100644 index 0000000000..e3180bf987 --- /dev/null +++ b/api/controllers/console/explore/error.py @@ -0,0 +1,20 @@ +# -*- coding:utf-8 -*- +from libs.exception import BaseHTTPException + + +class NotCompletionAppError(BaseHTTPException): + error_code = 'not_completion_app' + description = "Not Completion App" + code = 400 + + +class NotChatAppError(BaseHTTPException): + error_code = 'not_chat_app' + description = "Not Chat App" + code = 400 + + +class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): + error_code = 'app_suggested_questions_after_answer_disabled' + description = "Function Suggested questions after answer disabled." + code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py new file mode 100644 index 0000000000..3a2a1dbee9 --- /dev/null +++ b/api/controllers/console/explore/installed_app.py @@ -0,0 +1,143 @@ +# -*- coding:utf-8 -*- +from datetime import datetime + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with, inputs +from sqlalchemy import and_ +from werkzeug.exceptions import NotFound, Forbidden, BadRequest + +from controllers.console import api +from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from libs.helper import TimestampField +from models.model import App, InstalledApp, RecommendedApp +from services.account_service import TenantService + +app_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String +} + +installed_app_fields = { + 'id': fields.String, + 'app': fields.Nested(app_fields), + 'app_owner_tenant_id': fields.String, + 'is_pinned': fields.Boolean, + 'last_used_at': TimestampField, + 'editable': fields.Boolean, + 'uninstallable': fields.Boolean, +} + +installed_app_list_fields = { + 'installed_apps': fields.List(fields.Nested(installed_app_fields)) +} + + +class InstalledAppsListApi(Resource): + @login_required + @account_initialization_required + @marshal_with(installed_app_list_fields) + def get(self): + current_tenant_id = current_user.current_tenant_id + installed_apps = db.session.query(InstalledApp).filter( + InstalledApp.tenant_id == current_tenant_id + ).all() + + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + installed_apps = [ + { + 'id': installed_app.id, + 'app': installed_app.app, + 'app_owner_tenant_id': installed_app.app_owner_tenant_id, + 'is_pinned': installed_app.is_pinned, + 'last_used_at': installed_app.last_used_at, + "editable": current_user.role in ["owner", "admin"], + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id + } + for installed_app in installed_apps + ] + installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at'] + if app['last_used_at'] is not None else datetime.min)) + + return {'installed_apps': installed_apps} + + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + args = parser.parse_args() + + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + if recommended_app is None: + raise NotFound('App not found') + + current_tenant_id = current_user.current_tenant_id + app = db.session.query(App).filter( + App.id == args['app_id'] + ).first() + + if app is None: + raise NotFound('App not found') + + if not app.is_public: + raise Forbidden('You can\'t install a non-public app') + + installed_app = InstalledApp.query.filter(and_( + InstalledApp.app_id == args['app_id'], + InstalledApp.tenant_id == current_tenant_id + )).first() + + if installed_app is None: + # todo: position + recommended_app.install_count += 1 + + new_installed_app = InstalledApp( + app_id=args['app_id'], + tenant_id=current_tenant_id, + app_owner_tenant_id=app.tenant_id, + is_pinned=False, + last_used_at=datetime.utcnow() + ) + db.session.add(new_installed_app) + db.session.commit() + + return {'message': 'App installed successfully'} + + +class InstalledAppApi(InstalledAppResource): + """ + update and delete an installed app + use InstalledAppResource to apply default decorators and get installed_app + """ + def delete(self, installed_app): + if installed_app.app_owner_tenant_id == current_user.current_tenant_id: + raise BadRequest('You can\'t uninstall an app owned by the current tenant') + + db.session.delete(installed_app) + db.session.commit() + + return {'result': 'success', 'message': 'App uninstalled successfully'} + + def patch(self, installed_app): + parser = reqparse.RequestParser() + parser.add_argument('is_pinned', type=inputs.boolean) + args = parser.parse_args() + + commit_args = False + if 'is_pinned' in args: + installed_app.is_pinned = args['is_pinned'] + commit_args = True + + if commit_args: + db.session.commit() + + return {'result': 'success', 'message': 'App info updated successfully'} + + +api.add_resource(InstalledAppsListApi, '/installed-apps') +api.add_resource(InstalledAppApi, '/installed-apps/') diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py new file mode 100644 index 0000000000..b5b9547ff7 --- /dev/null +++ b/api/controllers/console/explore/message.py @@ -0,0 +1,196 @@ +# -*- coding:utf-8 -*- +import json +import logging +from typing import Generator, Union + +from flask import stream_with_context, Response +from flask_login import current_user +from flask_restful import reqparse, fields, marshal_with +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound, InternalServerError + +import services +from controllers.console import api +from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \ + ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError +from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError +from controllers.console.explore.wraps import InstalledAppResource +from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ + ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value, TimestampField +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from services.message_service import MessageService + + +class MessageListApi(InstalledAppResource): + feedback_fields = { + 'rating': fields.String + } + + message_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'created_at': TimestampField + } + + message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) + } + + @marshal_with(message_infinite_scroll_pagination_fields) + def get(self, installed_app): + app_model = installed_app.app + + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') + parser.add_argument('first_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + try: + return MessageService.pagination_by_first_id(app_model, current_user, + args['conversation_id'], args['first_id'], args['limit']) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.message.FirstMessageNotExistsError: + raise NotFound("First Message Not Exists.") + + +class MessageFeedbackApi(InstalledAppResource): + def post(self, installed_app, message_id): + app_model = installed_app.app + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + args = parser.parse_args() + + try: + MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + except services.errors.message.MessageNotExistsError: + raise NotFound("Message Not Exists.") + + return {'result': 'success'} + + +class MessageMoreLikeThisApi(InstalledAppResource): + def get(self, installed_app, message_id): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except MessageNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" + except MoreLikeThisDisabledError: + yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +class MessageSuggestedQuestionApi(InstalledAppResource): + def get(self, installed_app, message_id): + app_model = installed_app.app + if app_model.mode != 'chat': + raise NotCompletionAppError() + + message_id = str(message_id) + + try: + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, + user=current_user, + message_id=message_id + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except SuggestedQuestionsAfterAnswerDisabledError: + raise AppSuggestedQuestionsAfterAnswerDisabledError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + return {'data': questions} + + +api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') +api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') +api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') +api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py new file mode 100644 index 0000000000..2d0459ed40 --- /dev/null +++ b/api/controllers/console/explore/parameter.py @@ -0,0 +1,43 @@ +# -*- coding:utf-8 -*- +from flask_restful import marshal_with, fields + +from controllers.console import api +from controllers.console.explore.wraps import InstalledAppResource + + +class AppParameterApi(InstalledAppResource): + """Resource for app variables.""" + variable_fields = { + 'key': fields.String, + 'name': fields.String, + 'description': fields.String, + 'type': fields.String, + 'default': fields.String, + 'max_length': fields.Integer, + 'options': fields.List(fields.String) + } + + parameters_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw, + 'suggested_questions_after_answer': fields.Raw, + 'more_like_this': fields.Raw, + 'user_input_form': fields.Raw, + } + + @marshal_with(parameters_fields) + def get(self, installed_app): + """Retrieve app parameters.""" + app_model = installed_app.app + app_model_config = app_model.app_model_config + + return { + 'opening_statement': app_model_config.opening_statement, + 'suggested_questions': app_model_config.suggested_questions_list, + 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, + 'more_like_this': app_model_config.more_like_this_dict, + 'user_input_form': app_model_config.user_input_form_list + } + + +api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py new file mode 100644 index 0000000000..d3942c90bf --- /dev/null +++ b/api/controllers/console/explore/recommended_app.py @@ -0,0 +1,139 @@ +# -*- coding:utf-8 -*- +from flask_login import login_required, current_user +from flask_restful import Resource, fields, marshal_with +from sqlalchemy import and_ + +from controllers.console import api +from controllers.console.app.error import AppNotFoundError +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from models.model import App, InstalledApp, RecommendedApp +from services.account_service import TenantService + +app_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String +} + +recommended_app_fields = { + 'app': fields.Nested(app_fields, attribute='app'), + 'app_id': fields.String, + 'description': fields.String(attribute='description'), + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'category': fields.String, + 'position': fields.Integer, + 'is_listed': fields.Boolean, + 'install_count': fields.Integer, + 'installed': fields.Boolean, + 'editable': fields.Boolean +} + +recommended_app_list_fields = { + 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), + 'categories': fields.List(fields.String) +} + + +class RecommendedAppListApi(Resource): + @login_required + @account_initialization_required + @marshal_with(recommended_app_list_fields) + def get(self): + recommended_apps = db.session.query(RecommendedApp).filter( + RecommendedApp.is_listed == True + ).all() + + categories = set() + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + recommended_apps_result = [] + for recommended_app in recommended_apps: + installed = db.session.query(InstalledApp).filter( + and_( + InstalledApp.app_id == recommended_app.app_id, + InstalledApp.tenant_id == current_user.current_tenant_id + ) + ).first() is not None + + app = recommended_app.app + if not app or not app.is_public: + continue + + language_prefix = current_user.interface_language.split('-')[0] + desc = None + if recommended_app.description: + if language_prefix in recommended_app.description: + desc = recommended_app.description[language_prefix] + elif 'en' in recommended_app.description: + desc = recommended_app.description['en'] + + recommended_app_result = { + 'id': recommended_app.id, + 'app': app, + 'app_id': recommended_app.app_id, + 'description': desc, + 'copyright': recommended_app.copyright, + 'privacy_policy': recommended_app.privacy_policy, + 'category': recommended_app.category, + 'position': recommended_app.position, + 'is_listed': recommended_app.is_listed, + 'install_count': recommended_app.install_count, + 'installed': installed, + 'editable': current_user.role in ['owner', 'admin'], + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) # add category to categories + + return {'recommended_apps': recommended_apps_result, 'categories': list(categories)} + + +class RecommendedAppApi(Resource): + model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), + 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), + 'more_like_this': fields.Raw(attribute='more_like_this_dict'), + 'model': fields.Raw(attribute='model_dict'), + 'user_input_form': fields.Raw(attribute='user_input_form_list'), + 'pre_prompt': fields.String, + 'agent_mode': fields.Raw(attribute='agent_mode_dict'), + } + + app_simple_detail_fields = { + 'id': fields.String, + 'name': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'mode': fields.String, + 'app_model_config': fields.Nested(model_config_fields), + } + + @login_required + @account_initialization_required + @marshal_with(app_simple_detail_fields) + def get(self, app_id): + app_id = str(app_id) + + # is in public recommended list + recommended_app = db.session.query(RecommendedApp).filter( + RecommendedApp.is_listed == True, + RecommendedApp.app_id == app_id + ).first() + + if not recommended_app: + raise AppNotFoundError + + # get app detail + app = db.session.query(App).filter(App.id == app_id).first() + if not app or not app.is_public: + raise AppNotFoundError + + return app + + +api.add_resource(RecommendedAppListApi, '/explore/apps') +api.add_resource(RecommendedAppApi, '/explore/apps/') diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py new file mode 100644 index 0000000000..3f9bc63096 --- /dev/null +++ b/api/controllers/console/explore/saved_message.py @@ -0,0 +1,79 @@ +from flask_login import current_user +from flask_restful import reqparse, marshal_with, fields +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.explore.error import NotCompletionAppError +from controllers.console.explore.wraps import InstalledAppResource +from libs.helper import uuid_value, TimestampField +from services.errors.message import MessageNotExistsError +from services.saved_message_service import SavedMessageService + +feedback_fields = { + 'rating': fields.String +} + +message_fields = { + 'id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'created_at': TimestampField +} + + +class SavedMessageListApi(InstalledAppResource): + saved_message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) + } + + @marshal_with(saved_message_infinite_scroll_pagination_fields) + def get(self, installed_app): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + + def post(self, installed_app): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=uuid_value, required=True, location='json') + args = parser.parse_args() + + try: + SavedMessageService.save(app_model, current_user, args['message_id']) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + + return {'result': 'success'} + + +class SavedMessageApi(InstalledAppResource): + def delete(self, installed_app, message_id): + app_model = installed_app.app + + message_id = str(message_id) + + if app_model.mode != 'completion': + raise NotCompletionAppError() + + SavedMessageService.delete(app_model, current_user, message_id) + + return {'result': 'success'} + + +api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') +api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py new file mode 100644 index 0000000000..601e9352ea --- /dev/null +++ b/api/controllers/console/explore/wraps.py @@ -0,0 +1,48 @@ +from flask_login import login_required, current_user +from flask_restful import Resource +from functools import wraps + +from werkzeug.exceptions import NotFound + +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from models.model import InstalledApp + + +def installed_app_required(view=None): + def decorator(view): + @wraps(view) + def decorated(*args, **kwargs): + if not kwargs.get('installed_app_id'): + raise ValueError('missing installed_app_id in path parameters') + + installed_app_id = kwargs.get('installed_app_id') + installed_app_id = str(installed_app_id) + + del kwargs['installed_app_id'] + + installed_app = db.session.query(InstalledApp).filter( + InstalledApp.id == str(installed_app_id), + InstalledApp.tenant_id == current_user.current_tenant_id + ).first() + + if installed_app is None: + raise NotFound('Installed app not found') + + if not installed_app.app: + db.session.delete(installed_app) + db.session.commit() + + raise NotFound('Installed app not found') + + return view(installed_app, *args, **kwargs) + return decorated + + if view: + return decorator(view) + return decorator + + +class InstalledAppResource(Resource): + # must be reversed if there are multiple decorators + method_decorators = [installed_app_required, account_initialization_required, login_required] diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 53ba382051..06c500857d 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -47,7 +47,7 @@ class ConversationListApi(WebApiResource): try: return WebConversationService.pagination_by_last_id( app_model=app_model, - end_user=end_user, + user=end_user, last_id=args['last_id'], limit=args['limit'], pinned=pinned diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index d227a9659e..c68b8f1cf2 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -42,13 +42,16 @@ def validate_and_get_site(): """ auth_header = request.headers.get('Authorization') if auth_header is None: - raise Unauthorized() + raise Unauthorized('Authorization header is missing.') + + if ' ' not in auth_header: + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != 'bearer': - raise Unauthorized() + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') site = db.session.query(Site).filter( Site.code == auth_token, diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py new file mode 100644 index 0000000000..c7e3e801ec --- /dev/null +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -0,0 +1,46 @@ +"""add created by role + +Revision ID: 9f4e3427ea84 +Revises: 64b051264f32 +Create Date: 2023-05-17 17:29:01.060435 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9f4e3427ea84' +down_revision = '64b051264f32' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) + batch_op.drop_index('pinned_conversation_conversation_idx') + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by_role', 'created_by'], unique=False) + + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) + batch_op.drop_index('saved_message_message_idx') + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.drop_index('saved_message_message_idx') + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by'], unique=False) + batch_op.drop_column('created_by_role') + + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.drop_index('pinned_conversation_conversation_idx') + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by'], unique=False) + batch_op.drop_column('created_by_role') + + # ### end Alembic commands ### diff --git a/api/models/web.py b/api/models/web.py index 1580ce74c9..b2466430b9 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -8,12 +8,13 @@ class SavedMessage(db.Model): __tablename__ = 'saved_messages' __table_args__ = ( db.PrimaryKeyConstraint('id', name='saved_message_pkey'), - db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by'), + db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) message_id = db.Column(UUID, nullable=False) + created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -26,11 +27,12 @@ class PinnedConversation(db.Model): __tablename__ = 'pinned_conversations' __table_args__ = ( db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), - db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by'), + db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) conversation_id = db.Column(UUID, nullable=False) + created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/services/message_service.py b/api/services/message_service.py index b59fb0f10c..5c60017a97 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -127,7 +127,7 @@ class MessageService: message_id=message_id ) - feedback = message.user_feedback + feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback if not rating and feedback: db.session.delete(feedback) diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 1a68a1ba34..e363f65fa9 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,7 +1,8 @@ -from typing import Optional +from typing import Optional, Union from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from models.account import Account from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService @@ -9,27 +10,29 @@ from services.message_service import MessageService class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser], + def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], last_id: Optional[str], limit: int) -> InfiniteScrollPagination: saved_messages = db.session.query(SavedMessage).filter( SavedMessage.app_id == app_model.id, - SavedMessage.created_by == end_user.id + SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), + SavedMessage.created_by == user.id ).order_by(SavedMessage.created_at.desc()).all() message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( app_model=app_model, - user=end_user, + user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod - def save(cls, app_model: App, user: Optional[EndUser], message_id: str): + def save(cls, app_model: App, user: Optional[Union[Account | EndUser]], message_id: str): saved_message = db.session.query(SavedMessage).filter( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), SavedMessage.created_by == user.id ).first() @@ -45,6 +48,7 @@ class SavedMessageService: saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, + created_by_role='account' if isinstance(user, Account) else 'end_user', created_by=user.id ) @@ -52,10 +56,11 @@ class SavedMessageService: db.session.commit() @classmethod - def delete(cls, app_model: App, user: Optional[EndUser], message_id: str): + def delete(cls, app_model: App, user: Optional[Union[Account | EndUser]], message_id: str): saved_message = db.session.query(SavedMessage).filter( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), SavedMessage.created_by == user.id ).first() diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 5cfab25006..231083db19 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -2,6 +2,7 @@ from typing import Optional, Union from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db +from models.account import Account from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService @@ -9,14 +10,15 @@ from services.conversation_service import ConversationService class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser], + def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: pinned_conversations = db.session.query(PinnedConversation).filter( PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by == end_user.id + PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), + PinnedConversation.created_by == user.id ).order_by(PinnedConversation.created_at.desc()).all() pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: @@ -26,7 +28,7 @@ class WebConversationService: return ConversationService.pagination_by_last_id( app_model=app_model, - user=end_user, + user=user, last_id=last_id, limit=limit, include_ids=include_ids, @@ -34,10 +36,11 @@ class WebConversationService: ) @classmethod - def pin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]): + def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]): pinned_conversation = db.session.query(PinnedConversation).filter( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), PinnedConversation.created_by == user.id ).first() @@ -53,6 +56,7 @@ class WebConversationService: pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, + created_by_role='account' if isinstance(user, Account) else 'end_user', created_by=user.id ) @@ -60,10 +64,11 @@ class WebConversationService: db.session.commit() @classmethod - def unpin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]): + def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]): pinned_conversation = db.session.query(PinnedConversation).filter( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), PinnedConversation.created_by == user.id ).first()