Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly 2024-12-04 19:02:50 +08:00
commit 3d3a42945f
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
5 changed files with 79 additions and 24 deletions

View File

@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -21,6 +21,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App # Import App
api.add_resource(AppImportApi, "/apps/imports") api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version

View File

@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_fields from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import login_required
from models import Account from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@ -207,5 +207,8 @@ app_import_fields = {
"current_dsl_version": fields.String, "current_dsl_version": fields.String,
"imported_dsl_version": fields.String, "imported_dsl_version": fields.String,
"error": fields.String, "error": fields.String,
}
app_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
} }

View File

@ -9,7 +9,7 @@ import uuid
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import fields from flask_restful import fields
@ -19,7 +19,9 @@ from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.account import Account
if TYPE_CHECKING:
from models.account import Account
def run(script): def run(script):
@ -196,7 +198,7 @@ class TokenManager:
def generate_token( def generate_token(
cls, cls,
token_type: str, token_type: str,
account: Optional[Account] = None, account: Optional["Account"] = None,
email: Optional[str] = None, email: Optional[str] = None,
additional_data: Optional[dict] = None, additional_data: Optional[dict] = None,
) -> str: ) -> str:

View File

@ -31,7 +31,8 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
IMPORT_INFO_REDIS_EXPIRY = 2 * 60 * 60 # 2 hours CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
CURRENT_DSL_VERSION = "0.1.4" CURRENT_DSL_VERSION = "0.1.4"
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
@ -54,10 +55,13 @@ class Import(BaseModel):
app_id: Optional[str] = None app_id: Optional[str] = None
current_dsl_version: str = CURRENT_DSL_VERSION current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = "" imported_dsl_version: str = ""
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
error: str = "" error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus: def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison""" """Determine import status based on version comparison"""
try: try:
@ -87,6 +91,11 @@ class PendingData(BaseModel):
app_id: str | None app_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None
class AppDslService: class AppDslService:
def __init__(self, session: Session): def __init__(self, session: Session):
self._session = session self._session = session
@ -243,23 +252,11 @@ class AppDslService:
imported_dsl_version=imported_version, imported_dsl_version=imported_version,
) )
try: # Extract dependencies
dependencies = self.get_leaked_dependencies(account.current_tenant_id, data.get("dependencies", [])) dependencies = data.get("dependencies", [])
except Exception as e: check_dependencies_pending_data = None
return Import( if dependencies:
id=import_id, check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
status=ImportStatus.FAILED,
error=str(e),
)
if len(dependencies) > 0:
return Import(
id=import_id,
status=ImportStatus.PENDING,
app_id=app_id,
imported_dsl_version=imported_version,
leaked_dependencies=dependencies,
)
# Create or update app # Create or update app
app = self._create_or_update_app( app = self._create_or_update_app(
@ -271,6 +268,7 @@ class AppDslService:
icon_type=icon_type, icon_type=icon_type,
icon=icon, icon=icon,
icon_background=icon_background, icon_background=icon_background,
dependencies=check_dependencies_pending_data,
) )
return Import( return Import(
@ -355,6 +353,29 @@ class AppDslService:
error=str(e), error=str(e),
) )
def check_dependencies(
self,
*,
app_model: App,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_app( def _create_or_update_app(
self, self,
*, *,
@ -366,6 +387,7 @@ class AppDslService:
icon_type: Optional[str] = None, icon_type: Optional[str] = None,
icon: Optional[str] = None, icon: Optional[str] = None,
icon_background: Optional[str] = None, icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
) -> App: ) -> App:
"""Create a new app or update an existing one.""" """Create a new app or update an existing one."""
app_data = data.get("app", {}) app_data = data.get("app", {})
@ -408,6 +430,14 @@ class AppDslService:
self._session.commit() self._session.commit()
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
)
# Initialize app based on mode # Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow") workflow_data = data.get("workflow")