Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly 2024-12-04 19:02:50 +08:00
commit 3d3a42945f
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
5 changed files with 79 additions and 24 deletions

View File

@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -21,6 +21,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version

View File

@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@ -207,5 +207,8 @@ app_import_fields = {
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,
}
app_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}

View File

@ -9,7 +9,7 @@ import uuid
from collections.abc import Generator
from datetime import datetime
from hashlib import sha256
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from flask import Response, stream_with_context
from flask_restful import fields
@ -19,7 +19,9 @@ from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
from extensions.ext_redis import redis_client
from models.account import Account
if TYPE_CHECKING:
from models.account import Account
def run(script):
@ -196,7 +198,7 @@ class TokenManager:
def generate_token(
cls,
token_type: str,
account: Optional[Account] = None,
account: Optional["Account"] = None,
email: Optional[str] = None,
additional_data: Optional[dict] = None,
) -> str:

View File

@ -31,7 +31,8 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
IMPORT_INFO_REDIS_EXPIRY = 2 * 60 * 60 # 2 hours
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
CURRENT_DSL_VERSION = "0.1.4"
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
@ -54,10 +55,13 @@ class Import(BaseModel):
app_id: Optional[str] = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
@ -87,6 +91,11 @@ class PendingData(BaseModel):
app_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None
class AppDslService:
def __init__(self, session: Session):
self._session = session
@ -243,23 +252,11 @@ class AppDslService:
imported_dsl_version=imported_version,
)
try:
dependencies = self.get_leaked_dependencies(account.current_tenant_id, data.get("dependencies", []))
except Exception as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
if len(dependencies) > 0:
return Import(
id=import_id,
status=ImportStatus.PENDING,
app_id=app_id,
imported_dsl_version=imported_version,
leaked_dependencies=dependencies,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update app
app = self._create_or_update_app(
@ -271,6 +268,7 @@ class AppDslService:
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
dependencies=check_dependencies_pending_data,
)
return Import(
@ -355,6 +353,29 @@ class AppDslService:
error=str(e),
)
def check_dependencies(
self,
*,
app_model: App,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_app(
self,
*,
@ -366,6 +387,7 @@ class AppDslService:
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
@ -408,6 +430,14 @@ class AppDslService:
self._session.commit()
app_was_created.send(app, account=account)
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")