From 8d8a8fe2959e53e8fec3ec8546fa376eb9e34e59 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 21 Oct 2024 17:46:24 +0800 Subject: [PATCH] feat(file-upload): add support for optional file source parameter (#9554) --- api/controllers/console/datasets/file.py | 8 ++++++-- api/controllers/web/file.py | 8 ++++++-- api/services/file_service.py | 17 +++++++++++------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 5ed9a61545..51be7e7a7d 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -2,7 +2,7 @@ import urllib.parse from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with, reqparse import services from configs import dify_config @@ -48,6 +48,10 @@ class FileApi(Resource): # get file from request file = request.files["file"] + parser = reqparse.RequestParser() + parser.add_argument("source", type=str, required=False, location="args") + source = parser.parse_args().get("source") + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -55,7 +59,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file=file, user=current_user) + upload_file = FileService.upload_file(file=file, user=current_user, source=source) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index 6b9c267003..c029a07707 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,7 +1,7 @@ import urllib.parse from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with, reqparse import services from controllers.web import api @@ -18,6 +18,10 @@ class FileApi(WebApiResource): # get file from request file = request.files["file"] + parser = reqparse.RequestParser() + parser.add_argument("source", type=str, required=False, location="args") + source = parser.parse_args().get("source") + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -25,7 +29,7 @@ class FileApi(WebApiResource): if len(request.files) > 1: raise TooManyFilesError() try: - upload_file = FileService.upload_file(file, end_user) + upload_file = FileService.upload_file(file=file, user=end_user, source=source) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/services/file_service.py b/api/services/file_service.py index 0b35561600..84ccc4e882 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,7 +2,7 @@ import datetime import hashlib import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union from flask_login import current_user from werkzeug.datastructures import FileStorage @@ -28,7 +28,9 @@ PREVIEW_WORDS_LIMIT = 3000 class FileService: @staticmethod - def upload_file(file: FileStorage, user: Union[Account, EndUser]) -> UploadFile: + def upload_file( + file: FileStorage, user: Union[Account, EndUser], source: Literal["datasets"] | None = None + ) -> UploadFile: # get file name filename = file.filename if not filename: @@ -36,11 +38,9 @@ class FileService: extension = filename.split(".")[-1] if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension - # read file content - file_content = file.read() - # get file size - file_size = len(file_content) + if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: + raise UnsupportedFileTypeError() # select file size limit if extension in IMAGE_EXTENSIONS: @@ -52,6 +52,11 @@ class FileService: else: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + # read file content + file_content = file.read() + # get file size + file_size = len(file_content) + # check if the file size is exceeded if file_size > file_size_limit: message = f"File size exceeded. {file_size} > {file_size_limit}"