diff --git a/api/core/file/models.py b/api/core/file/models.py index 2f0026a203..7eef2d2b33 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -97,32 +97,18 @@ class File(BaseModel): return text def generate_url(self) -> Optional[str]: - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - assert self.related_id is not None - assert self.extension is not None - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=self.extension - ) - else: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - assert self.related_id is not None - assert self.extension is not None - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=self.extension - ) + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) def to_plugin_parameter(self) -> dict[str, Any]: return { diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 24d3c8b906..c4f69f6f6b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -7,7 +7,7 @@ import httpx from sqlalchemy import select from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile @@ -158,6 +158,39 @@ def _build_from_remote_url( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError: + raise ValueError("Invalid upload file id format") + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + + upload_file = db.session.scalar(stmt) + if upload_file is None: + raise ValueError("Invalid upload file") + + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type( + file_type, extension="." + upload_file.extension, mime_type=upload_file.mime_type + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + related_id=mapping.get("upload_file_id"), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: raise ValueError("Invalid file url") diff --git a/api/models/model.py b/api/models/model.py index cb099d5654..87806eb918 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1081,19 +1081,19 @@ class Message(db.Model): # type: ignore[name-defined] files = [] for message_file in message_files: - if message_file.transfer_method == "local_file": + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: if message_file.upload_file_id is None: raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") file = file_factory.build_from_mapping( mapping={ "id": message_file.id, - "upload_file_id": message_file.upload_file_id, - "transfer_method": message_file.transfer_method, "type": message_file.type, + "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == "remote_url": + elif message_file.transfer_method == FileTransferMethod.REMOTE_URL.value: if message_file.url is None: raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") file = file_factory.build_from_mapping( @@ -1101,11 +1101,12 @@ class Message(db.Model): # type: ignore[name-defined] "id": message_file.id, "type": message_file.type, "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, "url": message_file.url, }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == "tool_file": + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE.value: if message_file.upload_file_id is None: assert message_file.url is not None message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]