diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py new file mode 100644 index 0000000000..412c56374a --- /dev/null +++ b/api/configs/middleware/vdb/upstash_config.py @@ -0,0 +1,20 @@ +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class UpstashConfig(BaseSettings): + """ + Configuration settings for Upstash vector database + """ + + UPSTASH_VECTOR_URL: Optional[str] = Field( + description="URL of the upstash server (e.g., 'https://vector.upstash.io')", + default=None, + ) + + UPSTASH_VECTOR_TOKEN: Optional[str] = Field( + description="Token for authenticating with the upstash server", + default=None, + ) diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py deleted file mode 100644 index 3c4d7046f4..0000000000 --- a/api/core/app/segments/parser.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from core.workflow.entities.variable_pool import VariablePool - -from . import SegmentGroup, factory - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def convert_template(*, template: str, variable_pool: VariablePool): - parts = re.split(VARIABLE_PATTERN, template) - segments = [] - for part in filter(lambda x: x, parts): - if "." in part and (value := variable_pool.get(part.split("."))): - segments.append(value) - else: - segments.append(factory.build_segment(part)) - return SegmentGroup(value=segments) diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py deleted file mode 100644 index 10bc9f6ed7..0000000000 --- a/api/core/entities/message_entities.py +++ /dev/null @@ -1,29 +0,0 @@ -import enum -from typing import Any - -from pydantic import BaseModel - - -class PromptMessageFileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in PromptMessageFileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class PromptMessageFile(BaseModel): - type: PromptMessageFileType - data: Any = None - - -class ImagePromptMessageFile(PromptMessageFile): - class DETAIL(enum.Enum): - LOW = "low" - HIGH = "high" - - type: PromptMessageFileType = PromptMessageFileType.IMAGE - detail: DETAIL = DETAIL.LOW diff --git a/api/core/file/constants.py b/api/core/file/constants.py new file mode 100644 index 0000000000..ce1d238e93 --- /dev/null +++ b/api/core/file/constants.py @@ -0,0 +1 @@ +FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/core/file/enums.py b/api/core/file/enums.py new file mode 100644 index 0000000000..f4153f1676 --- /dev/null +++ b/api/core/file/enums.py @@ -0,0 +1,55 @@ +from enum import Enum + + +class FileType(str, Enum): + IMAGE = "image" + DOCUMENT = "document" + AUDIO = "audio" + VIDEO = "video" + CUSTOM = "custom" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(str, Enum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(str, Enum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIME_TYPE = "mime_type" + TRANSFER_METHOD = "transfer_method" + URL = "url" + EXTENSION = "extension" + + +class ArrayFileAttribute(str, Enum): + LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py new file mode 100644 index 0000000000..0c6ce8ce75 --- /dev/null +++ b/api/core/file/file_manager.py @@ -0,0 +1,156 @@ +import base64 + +from configs import dify_config +from core.file import file_repository +from core.helper import ssrf_proxy +from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent +from extensions.ext_database import db +from extensions.ext_storage import storage + +from . import helpers +from .enums import FileAttribute +from .models import File, FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +def get_attr(*, file: File, attr: FileAttribute): + match attr: + case FileAttribute.TYPE: + return file.type.value + case FileAttribute.SIZE: + return file.size + case FileAttribute.NAME: + return file.filename + case FileAttribute.MIME_TYPE: + return file.mime_type + case FileAttribute.TRANSFER_METHOD: + return file.transfer_method.value + case FileAttribute.URL: + return file.remote_url + case FileAttribute.EXTENSION: + return file.extension + case _: + raise ValueError(f"Invalid file attribute: {attr}") + + +def to_prompt_message_content(f: File, /): + """ + Convert a File object to an ImagePromptMessageContent object. + + This function takes a File object and converts it to an ImagePromptMessageContent + object, which can be used as a prompt for image-based AI models. + + Args: + file (File): The File object to convert. Must be of type FileType.IMAGE. + + Returns: + ImagePromptMessageContent: An object containing the image data and detail level. + + Raises: + ValueError: If the file is not an image or if the file data is missing. + + Note: + The detail level of the image prompt is determined by the file's extra_config. + If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + """ + match f.type: + case FileType.IMAGE: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + + if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: + detail = f._extra_config.image_config.detail + else: + detail = ImagePromptMessageContent.DETAIL.LOW + + return ImagePromptMessageContent(data=data, detail=detail) + case FileType.AUDIO: + encoded_string = _file_to_encoded_string(f) + if f.extension is None: + raise ValueError("Missing file extension") + return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def download(f: File, /): + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + return _download_file_content(upload_file.key) + + +def _download_file_content(path: str, /): + """ + Download and return the contents of a file as bytes. + + This function loads the file from storage and ensures it's in bytes format. + + Args: + path (str): The path to the file in storage. + + Returns: + bytes: The contents of the file as a bytes object. + + Raises: + ValueError: If the loaded file is not a bytes object. + """ + data = storage.load(path, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {path} is not a bytes object") + return data + + +def _get_encoded_string(f: File, /): + match f.transfer_method: + case FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url) + response.raise_for_status() + content = response.content + encoded_string = base64.b64encode(content).decode("utf-8") + return encoded_string + case FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + data = _download_file_content(upload_file.key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + data = _download_file_content(tool_file.file_key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case _: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +def _to_base64_data_string(f: File, /): + encoded_string = _get_encoded_string(f) + return f"data:{f.mime_type};base64,{encoded_string}" + + +def _file_to_encoded_string(f: File, /): + match f.type: + case FileType.IMAGE: + return _to_base64_data_string(f) + case FileType.AUDIO: + return _get_encoded_string(f) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def _to_url(f: File, /): + if f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") + return f.remote_url + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + if f.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=f.related_id) + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + # add sign url + if f.related_id is None or f.extension is None: + raise ValueError("Missing file related_id or extension") + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + else: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py deleted file mode 100644 index 5c4e694025..0000000000 --- a/api/core/file/file_obj.py +++ /dev/null @@ -1,145 +0,0 @@ -import enum -from typing import Any, Optional - -from pydantic import BaseModel - -from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from extensions.ext_database import db - - -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: Optional[dict[str, Any]] = None - - -class FileType(enum.Enum): - IMAGE = "image" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(enum.Enum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(enum.Enum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileVar(BaseModel): - id: Optional[str] = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - url: Optional[str] = None # remote url - related_id: Optional[str] = None - extra_config: Optional[FileExtraConfig] = None - filename: Optional[str] = None - extension: Optional[str] = None - mime_type: Optional[str] = None - - def to_dict(self) -> dict: - return { - "__variant": self.__class__.__name__, - "tenant_id": self.tenant_id, - "type": self.type.value, - "transfer_method": self.transfer_method.value, - "url": self.preview_url, - "remote_url": self.url, - "related_id": self.related_id, - "filename": self.filename, - "extension": self.extension, - "mime_type": self.mime_type, - } - - def to_markdown(self) -> str: - """ - Convert file to markdown - :return: - """ - preview_url = self.preview_url - if self.type == FileType.IMAGE: - text = f'![{self.filename or ""}]({preview_url})' - else: - text = f"[{self.filename or preview_url}]({preview_url})" - - return text - - @property - def data(self) -> Optional[str]: - """ - Get image data, file signed url or base64 data - depending on config MULTIMODAL_SEND_IMAGE_FORMAT - :return: - """ - return self._get_data() - - @property - def preview_url(self) -> Optional[str]: - """ - Get signed preview url - :return: - """ - return self._get_data(force_url=True) - - @property - def prompt_message_content(self) -> ImagePromptMessageContent: - if self.type == FileType.IMAGE: - image_config = self.extra_config.image_config - - return ImagePromptMessageContent( - data=self.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" - else ImagePromptMessageContent.DETAIL.LOW, - ) - - def _get_data(self, force_url: bool = False) -> Optional[str]: - from models.model import UploadFile - - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = ( - db.session.query(UploadFile) - .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) - .first() - ) - - return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - extension = self.extension - # add sign url - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=extension - ) - - return None diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py new file mode 100644 index 0000000000..975e1e72db --- /dev/null +++ b/api/core/file/file_repository.py @@ -0,0 +1,32 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import ToolFile, UploadFile + +from .models import File + + +def get_upload_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(UploadFile).filter( + UploadFile.id == file.related_id, + UploadFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"upload file {file.related_id} not found") + return record + + +def get_tool_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(ToolFile).filter( + ToolFile.id == file.related_id, + ToolFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"tool file {file.related_id} not found") + return record diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py new file mode 100644 index 0000000000..12123cf3f7 --- /dev/null +++ b/api/core/file/helpers.py @@ -0,0 +1,48 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def get_signed_file_url(upload_file_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py deleted file mode 100644 index 641686bd7c..0000000000 --- a/api/core/file/message_file_parser.py +++ /dev/null @@ -1,243 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any, Union -from urllib.parse import parse_qs, urlparse - -import requests - -from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, MessageFile, UploadFile -from services.file_service import IMAGE_EXTENSIONS - - -class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - - def validate_and_transform_files_arg( - self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] - ) -> list[FileVar]: - """ - validate and transform files arg - - :param files: - :param file_extra_config: - :param user: - :return: - """ - for file in files: - if not isinstance(file, dict): - raise ValueError("Invalid file format, must be dict") - if not file.get("type"): - raise ValueError("Missing file type") - FileType.value_of(file.get("type")) - if not file.get("transfer_method"): - raise ValueError("Missing file transfer method") - FileTransferMethod.value_of(file.get("transfer_method")) - if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: - if not file.get("url"): - raise ValueError("Missing file url") - if not file.get("url").startswith("http"): - raise ValueError("Invalid file url") - if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): - raise ValueError("Missing file upload_file_id") - if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): - raise ValueError("Missing file tool_file_id") - - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # validate files - new_files = [] - for file_type, file_objs in type_file_objs.items(): - if file_type == FileType.IMAGE: - # parse and validate files - image_config = file_extra_config.image_config - - # check if image file feature is enabled - if not image_config: - continue - - # Validate number of files - if len(files) > image_config["number_limits"]: - raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") - - for file_obj in file_objs: - # Validate transfer method - if file_obj.transfer_method.value not in image_config["transfer_methods"]: - raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") - - # Validate file type - if file_obj.type != FileType.IMAGE: - raise ValueError(f"Invalid file type: {file_obj.type}") - - if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: - # check remote url valid and is image - result, error = self._check_image_remote_url(file_obj.url) - if result is False: - raise ValueError(error) - elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: - # get upload file from upload_file_id - upload_file = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), - UploadFile.extension.in_(IMAGE_EXTENSIONS), - ) - .first() - ) - - # check upload file is belong to tenant and user - if not upload_file: - raise ValueError("Invalid upload file") - - new_files.append(file_obj) - - # return all file objs - return new_files - - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): - """ - transform message files - - :param files: - :param file_extra_config: - :return: - """ - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # return all file objs - return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - - def _to_file_objs( - self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig - ) -> dict[FileType, list[FileVar]]: - """ - transform files to file objs - - :param files: - :param file_extra_config: - :return: - """ - type_file_objs: dict[FileType, list[FileVar]] = { - # Currently only support image - FileType.IMAGE: [] - } - - if not files: - return type_file_objs - - # group by file type and convert file args or message files to FileObj - for file in files: - if isinstance(file, MessageFile): - if file.belongs_to == FileBelongsTo.ASSISTANT.value: - continue - - file_obj = self._to_file_obj(file, file_extra_config) - if file_obj.type not in type_file_objs: - continue - - type_file_objs[file_obj.type].append(file_obj) - - return type_file_objs - - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): - """ - transform file to file obj - - :param file: - :return: - """ - if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) - if transfer_method != FileTransferMethod.TOOL_FILE: - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config, - ) - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get("type")), - transfer_method=transfer_method, - url=None, - related_id=file.get("tool_file_id"), - extra_config=file_extra_config, - ) - else: - return FileVar( - id=file.id, - tenant_id=self.tenant_id, - type=FileType.value_of(file.type), - transfer_method=FileTransferMethod.value_of(file.transfer_method), - url=file.url, - related_id=file.upload_file_id or None, - extra_config=file_extra_config, - ) - - def _check_image_remote_url(self, url): - try: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" - " Chrome/91.0.4472.124 Safari/537.36" - } - - def is_s3_presigned_url(url): - try: - parsed_url = urlparse(url) - if "amazonaws.com" not in parsed_url.netloc: - return False - query_params = parse_qs(parsed_url.query) - - def check_presign_v2(query_params): - required_params = ["Signature", "Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["Expires"][0].isdigit(): - return False - signature = query_params["Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - def check_presign_v4(query_params): - required_params = ["X-Amz-Signature", "X-Amz-Expires"] - for param in required_params: - if param not in query_params: - return False - if not query_params["X-Amz-Expires"][0].isdigit(): - return False - signature = query_params["X-Amz-Signature"][0] - if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): - return False - - return True - - return check_presign_v4(query_params) or check_presign_v2(query_params) - except Exception: - return False - - if is_s3_presigned_url(url): - response = requests.get(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - - response = requests.head(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - else: - return False, "URL does not exist." - except requests.RequestException as e: - return False, f"Error checking URL: {e}" diff --git a/api/core/file/models.py b/api/core/file/models.py new file mode 100644 index 0000000000..866ff3155b --- /dev/null +++ b/api/core/file/models.py @@ -0,0 +1,140 @@ +from collections.abc import Mapping, Sequence +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + +from . import helpers +from .constants import FILE_MODEL_IDENTITY +from .enums import FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +class ImageConfig(BaseModel): + """ + NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + """ + + number_limits: int = 0 + transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + detail: ImagePromptMessageContent.DETAIL | None = None + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[ImageConfig] = None + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = 0 + + +class File(BaseModel): + dify_model_identity: str = FILE_MODEL_IDENTITY + + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + remote_url: Optional[str] = None # remote url + related_id: Optional[str] = None + filename: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + mime_type: Optional[str] = None + size: int = -1 + _extra_config: FileExtraConfig | None = None + + def to_dict(self) -> Mapping[str, str | int | None]: + data = self.model_dump(mode="json") + return { + **data, + "url": self.generate_url(), + } + + @property + def markdown(self) -> str: + url = self.generate_url() + if self.type == FileType.IMAGE: + text = f'![{self.filename or ""}]({url})' + else: + text = f"[{self.filename or url}]({url})" + + return text + + def generate_url(self) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + else: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + + @model_validator(mode="after") + def validate_after(self): + match self.transfer_method: + case FileTransferMethod.REMOTE_URL: + if not self.remote_url: + raise ValueError("Missing file url") + if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): + raise ValueError("Invalid file url") + case FileTransferMethod.LOCAL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + case FileTransferMethod.TOOL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + + # Validate the extra config. + if not self._extra_config: + return self + + if self._extra_config.allowed_file_types: + if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: + raise ValueError(f"Invalid file type: {self.type}") + + if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: + raise ValueError(f"Invalid file extension: {self.extension}") + + if ( + self._extra_config.allowed_upload_methods + and self.transfer_method not in self._extra_config.allowed_upload_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + match self.type: + case FileType.IMAGE: + # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + if not self._extra_config.image_config: + return self + # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field + if ( + self._extra_config.image_config.transfer_methods + and self.transfer_method not in self._extra_config.image_config.transfer_methods + ): + raise ValueError(f"Invalid transfer method: {self.transfer_method}") + + return self diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py deleted file mode 100644 index a8c1fd4d02..0000000000 --- a/api/core/file/upload_file_parser.py +++ /dev/null @@ -1,79 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from typing import Optional - -from configs import dify_config -from extensions.ext_storage import storage - -IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: - return cls.get_signed_temp_image_url(upload_file.id) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.error(f"File not found: {upload_file.key}") - return None - - encoded_string = base64.b64encode(data).decode("utf-8") - return f"data:{upload_file.mime_type};base64,{encoded_string}" - - @classmethod - def get_signed_temp_image_url(cls, upload_file_id) -> str: - """ - get signed url from upload file - - :param upload_file: UploadFile object - :return: - """ - base_url = dify_config.FILES_URL - image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml new file mode 100644 index 0000000000..e20b8c4960 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -0,0 +1,39 @@ +model: claude-3-5-sonnet-20241022 +label: + en_US: claude-3-5-sonnet-20241022 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '3.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 0000000000..b1e5698375 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 0000000000..8d831e6fcb --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 0000000000..31a403289b --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml new file mode 100644 index 0000000000..5632218797 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-11b-vision-preview +label: + zh_Hans: Llama 3.2 11B Vision (Preview) + en_US: Llama 3.2 11B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml new file mode 100644 index 0000000000..e7b93101e8 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-90b-vision-preview +label: + zh_Hans: Llama 3.2 90B Vision (Preview) + en_US: Llama 3.2 90B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/speech2text/__init__.py b/api/core/model_runtime/model_providers/groq/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml new file mode 100644 index 0000000000..202d006a66 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml @@ -0,0 +1,5 @@ +model: distil-whisper-large-v3-en +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py new file mode 100644 index 0000000000..75feeb9cb9 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py @@ -0,0 +1,30 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class GroqSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for Groq Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + return super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml new file mode 100644 index 0000000000..3882a3f4f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3-turbo +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml new file mode 100644 index 0000000000..ed02477d70 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3 +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml new file mode 100644 index 0000000000..256e87edbe --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-audio-preview +label: + zh_Hans: gpt-4o-audio-preview + en_US: gpt-4o-audio-preview +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '5.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py new file mode 100644 index 0000000000..508da4bf20 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py @@ -0,0 +1,159 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class OAICompatRerankModel(RerankModel): + """ + rerank model API is compatible with Jina rerank model API. So copy the JinaRerankModel class code here. + we need enhance for llama.cpp , which return raw score, not normalize score 0~1. It seems Dify need it + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + server_url = credentials["endpoint_url"] + model_name = model + + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + # TODO: Do we need truncate docs to avoid llama.cpp return error? + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + scores = [result["relevance_score"] for result in results["results"]] + + # Min-Max Normalization: Normalize scores to 0 ~ 1.0 range + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score if max_score != min_score else 1.0 # Avoid division by zero + + for result in results["results"]: + index = result["index"] + + # Retrieve document text (fallback if llama.cpp rerank doesn't return it) + text = result.get("document", {}).get("text", docs[index]) + + # Normalize the score + normalized_score = (result["relevance_score"] - min_score) / score_range + + # Create RerankDocument object with normalized score + rerank_document = RerankDocument( + index=index, + text=text, + score=normalized_score, + ) + + # Apply threshold (if defined) + if score_threshold is None or normalized_score >= score_threshold: + rerank_documents.append(rerank_document) + + # Sort rerank_documents by normalized score in descending order + rerank_documents.sort(key=lambda doc: doc.score, reverse=True) + + return RerankResult(model=model, docs=rerank_documents) + + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml new file mode 100644 index 0000000000..0be3e26e7a --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml @@ -0,0 +1,55 @@ +model: claude-3-5-sonnet-v2@20241022 +label: + en_US: Claude 3.5 Sonnet v2 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/rag/datasource/vdb/upstash/__init__.py b/api/core/rag/datasource/vdb/upstash/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py new file mode 100644 index 0000000000..df1b550b40 --- /dev/null +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -0,0 +1,129 @@ +import json +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, model_validator +from upstash_vector import Index, Vector + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from models.dataset import Dataset + + +class UpstashVectorConfig(BaseModel): + url: str + token: str + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["url"]: + raise ValueError("Upstash URL is required") + if not values["token"]: + raise ValueError("Upstash Token is required") + return values + + +class UpstashVector(BaseVector): + def __init__(self, collection_name: str, config: UpstashVectorConfig): + super().__init__(collection_name) + self._table_name = collection_name + self.index = Index(url=config.url, token=config.token) + + def _get_index_dimension(self) -> int: + index_info = self.index.info() + if index_info and index_info.dimension: + return index_info.dimension + else: + return 1536 + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + vectors = [ + Vector( + id=str(uuid4()), + vector=embedding, + metadata=doc.metadata, + data=doc.page_content, + ) + for doc, embedding in zip(documents, embeddings) + ] + self.index.upsert(vectors=vectors) + + def text_exists(self, id: str) -> bool: + response = self.get_ids_by_metadata_field("doc_id", id) + return len(response) > 0 + + def delete_by_ids(self, ids: list[str]) -> None: + item_ids = [] + for doc_id in ids: + ids = self.get_ids_by_metadata_field("doc_id", doc_id) + if id: + item_ids += ids + self._delete_by_ids(ids=item_ids) + + def _delete_by_ids(self, ids: list[str]) -> None: + if ids: + self.index.delete(ids=ids) + + def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: + query_result = self.index.query( + vector=[1.001 * i for i in range(self._get_index_dimension())], + include_metadata=True, + top_k=1000, + filter=f"{key} = '{value}'", + ) + return [result.id for result in query_result] + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 4) + result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True) + docs = [] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + for record in result: + metadata = record.metadata + text = record.data + score = record.score + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + return [] + + def delete(self) -> None: + self.index.reset() + + def get_type(self) -> str: + return VectorType.UPSTASH + + +class UpstashVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name)) + + return UpstashVector( + collection_name=collection_name, + config=UpstashVectorConfig( + url=dify_config.UPSTASH_VECTOR_URL, + token=dify_config.UPSTASH_VECTOR_TOKEN, + ), + ) diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py new file mode 100644 index 0000000000..dd8a979e70 --- /dev/null +++ b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py @@ -0,0 +1,47 @@ +import logging + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class UnstructuredPDFExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + + api_url: Unstructured API URL + + api_key: Unstructured API Key + """ + + def __init__(self, file_path: str, api_url: str, api_key: str): + """Initialize with file path.""" + self._file_path = file_path + self._api_url = api_url + self._api_key = api_key + + def extract(self) -> list[Document]: + if self._api_url: + from unstructured.partition.api import partition_via_api + + elements = partition_via_api( + filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" + ) + else: + from unstructured.partition.pdf import partition_pdf + + elements = partition_pdf(filename=self._file_path, strategy="auto") + + from unstructured.chunking.title import chunk_by_title + + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) + documents = [] + for chunk in chunks: + text = chunk.text.strip() + documents.append(Document(page_content=text)) + + return documents diff --git a/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg new file mode 100644 index 0000000000..82b23ebbc6 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg @@ -0,0 +1,32 @@ + + 绿 lgo + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.py b/api/core/tools/provider/builtin/aliyuque/aliyuque.py new file mode 100644 index 0000000000..56eac1a4b5 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.py @@ -0,0 +1,19 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AliYuqueProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + token = credentials.get("token") + if not token: + raise ToolProviderCredentialValidationError("token is required") + + try: + resp = AliYuqueTool.auth(token) + if resp and resp.get("data", {}).get("id"): + return + + raise ToolProviderCredentialValidationError(resp) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml new file mode 100644 index 0000000000..73d39aa96c --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml @@ -0,0 +1,29 @@ +identity: + author: 佐井 + name: aliyuque + label: + en_US: yuque + zh_Hans: 语雀 + pt_BR: yuque + description: + en_US: Yuque, https://www.yuque.com. + zh_Hans: 语雀,https://www.yuque.com。 + pt_BR: Yuque, https://www.yuque.com. + icon: icon.svg + tags: + - productivity + - search +credentials_for_provider: + token: + type: secret-input + required: true + label: + en_US: Yuque Team Token + zh_Hans: 语雀团队Token + placeholder: + en_US: Please input your Yuque team token + zh_Hans: 请输入你的语雀团队Token + help: + en_US: Get Alibaba Yuque team token + zh_Hans: 先获取语雀团队Token + url: https://www.yuque.com/settings/tokens diff --git a/api/core/tools/provider/builtin/aliyuque/tools/base.py b/api/core/tools/provider/builtin/aliyuque/tools/base.py new file mode 100644 index 0000000000..fb7e219bff --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/base.py @@ -0,0 +1,50 @@ +""" +语雀客户端 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 09:45:20" + +from typing import Any + +import requests + + +class AliYuqueTool: + # yuque service url + server_url = "https://www.yuque.com" + + @staticmethod + def auth(token): + session = requests.Session() + session.headers.update({"Accept": "application/json", "X-Auth-Token": token}) + login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user") + login.raise_for_status() + resp = login.json() + return resp + + def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str: + if not token: + raise Exception("token is required") + session = requests.Session() + session.headers.update({"accept": "application/json", "X-Auth-Token": token}) + new_params = {**tool_parameters} + # 找出需要替换的变量 + replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path} + + # 替换 path 中的变量 + for key, value in replacements.items(): + path = path.replace(f"{{{key}}}", str(value)) + del new_params[key] # 从 kwargs 中删除已经替换的变量 + # 请求接口 + if method.upper() in {"POST", "PUT"}: + session.headers.update( + { + "Content-Type": "application/json", + } + ) + response = session.request(method.upper(), self.server_url + path, json=new_params) + else: + response = session.request(method, self.server_url + path, params=new_params) + response.raise_for_status() + return response.text diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py new file mode 100644 index 0000000000..feadc29258 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py @@ -0,0 +1,22 @@ +""" +创建文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 10:45:20" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml new file mode 100644 index 0000000000..b9d1c60327 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml @@ -0,0 +1,99 @@ +identity: + name: aliyuque_create_document + author: 佐井 + label: + en_US: Create Document + zh_Hans: 创建文档 + icon: icon.svg +description: + human: + en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述 + zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。 + llm: Creates docs in a KB. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: public + type: select + required: false + form: llm + options: + - value: 0 + label: + en_US: Private + zh_Hans: 私密 + - value: 1 + label: + en_US: Public + zh_Hans: 公开 + - value: 2 + label: + en_US: Enterprise-only + zh_Hans: 企业内公开 + label: + en_US: Visibility + zh_Hans: 公开性 + human_description: + en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only). + zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。 + llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + - value: html + label: + en_US: html + zh_Hans: html + - value: lake + label: + en_US: lake + zh_Hans: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py new file mode 100644 index 0000000000..74c731a944 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +删除文档 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 22:04" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml new file mode 100644 index 0000000000..87372c5350 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml @@ -0,0 +1,37 @@ +identity: + name: aliyuque_delete_document + author: 佐井 + label: + en_US: Delete Document + zh_Hans: 删除文档 + icon: icon.svg +description: + human: + en_US: Delete Document + zh_Hans: 根据id删除文档 + llm: Delete document. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID or 路径 + human_description: + en_US: Document ID or path. + zh_Hans: 文档 ID or 路径。 + llm_description: Document ID or path. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py new file mode 100644 index 0000000000..02bf603a24 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py @@ -0,0 +1,24 @@ +""" +获取知识库首页 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 22:57:14" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml new file mode 100644 index 0000000000..5e490725d1 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_book_index_page + author: 佐井 + label: + en_US: Get Repo Index Page + zh_Hans: 获取知识库首页 + icon: icon.svg + +description: + human: + en_US: Retrieves the homepage of a knowledge base within a group, supporting both book ID and group login with book slug access. + zh_Hans: 获取团队中知识库的首页信息,可通过书籍ID或团队登录名与书籍路径访问。 + llm: Fetches the knowledge base homepage using group and book identifiers with support for alternate access paths. + +parameters: + - name: group_login + type: string + required: true + form: llm + label: + en_US: Group Login + zh_Hans: 团队登录名 + human_description: + en_US: The login name of the group that owns the knowledge base. + zh_Hans: 拥有该知识库的团队登录名。 + llm_description: Team login identifier for the knowledge base owner. + + - name: book_slug + type: string + required: true + form: llm + label: + en_US: Book Slug + zh_Hans: 知识库路径 + human_description: + en_US: The unique slug representing the path of the knowledge base. + zh_Hans: 知识库的唯一路径标识。 + llm_description: Unique path identifier for the knowledge base. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py new file mode 100644 index 0000000000..fcfe449c6d --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +""" +获取知识库目录 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 15:17:11" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml new file mode 100644 index 0000000000..0c2bd22132 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml @@ -0,0 +1,25 @@ +identity: + name: aliyuque_describe_book_table_of_contents + author: 佐井 + label: + en_US: Get Book's Table of Contents + zh_Hans: 获取知识库的目录 + icon: icon.svg +description: + human: + en_US: Get Book's Table of Contents. + zh_Hans: 获取知识库的目录。 + llm: Get Book's Table of Contents. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py new file mode 100644 index 0000000000..1e70593879 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py @@ -0,0 +1,61 @@ +""" +获取文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-02 07:11:45" + +import json +from typing import Any, Union +from urllib.parse import urlparse + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + new_params = {**tool_parameters} + token = new_params.pop("token") + if not token or token.lower() == "none": + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + new_params = {**tool_parameters} + url = new_params.pop("url") + if not url or not url.startswith("http"): + raise Exception("url is not valid") + + parsed_url = urlparse(url) + path_parts = parsed_url.path.strip("/").split("/") + if len(path_parts) < 3: + raise Exception("url is not correct") + doc_id = path_parts[-1] + book_slug = path_parts[-2] + group_id = path_parts[-3] + + # 1. 请求首页信息,获取book_id + new_params["group_login"] = group_id + new_params["book_slug"] = book_slug + index_page = json.loads( + self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) + book_id = index_page.get("data", {}).get("book", {}).get("id") + if not book_id: + raise Exception(f"can not parse book_id from {index_page}") + # 2. 获取文档内容 + new_params["book_id"] = book_id + new_params["id"] = doc_id + data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}") + data = json.loads(data) + body_only = tool_parameters.get("body_only") or "" + if body_only.lower() == "true": + return self.create_text_message(data.get("data").get("body")) + else: + raw = data.get("data") + del raw["body_lake"] + del raw["body_html"] + return self.create_text_message(json.dumps(data)) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml new file mode 100644 index 0000000000..6116886a96 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml @@ -0,0 +1,50 @@ +identity: + name: aliyuque_describe_document_content + author: 佐井 + label: + en_US: Fetch Document Content + zh_Hans: 获取文档内容 + icon: icon.svg + +description: + human: + en_US: Retrieves document content from Yuque based on the provided document URL, which can be a normal or shared link. + zh_Hans: 根据提供的语雀文档地址(支持正常链接或分享链接)获取文档内容。 + llm: Fetches Yuque document content given a URL. + +parameters: + - name: url + type: string + required: true + form: llm + label: + en_US: Document URL + zh_Hans: 文档地址 + human_description: + en_US: The URL of the document to retrieve content from, can be normal or shared. + zh_Hans: 需要获取内容的文档地址,可以是正常链接或分享链接。 + llm_description: URL of the Yuque document to fetch content. + + - name: body_only + type: string + required: false + form: llm + label: + en_US: return body content only + zh_Hans: 仅返回body内容 + human_description: + en_US: true:Body content only, false:Full response with metadata. + zh_Hans: true:仅返回body内容,不返回其他元数据,false:返回所有元数据。 + llm_description: true:Body content only, false:Full response with metadata. + + - name: token + type: secret-input + required: false + form: llm + label: + en_US: Yuque API Token + zh_Hans: 语雀接口Token + human_description: + en_US: The token for calling the Yuque API defaults to the Yuque token bound to the current tool if not provided. + zh_Hans: 调用语雀接口的token,如果不传则默认为当前工具绑定的语雀Token。 + llm_description: If the token for calling the Yuque API is not provided, it will default to the Yuque token bound to the current tool. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py new file mode 100644 index 0000000000..ed1b2a8643 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py @@ -0,0 +1,24 @@ +""" +获取文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 10:45:20" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml new file mode 100644 index 0000000000..5156345d71 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_documents + author: 佐井 + label: + en_US: Get Doc Detail + zh_Hans: 获取文档详情 + icon: icon.svg + +description: + human: + en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base. + zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。 + llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: Identifier for the knowledge base where the document resides. + zh_Hans: 文档所属知识库的唯一标识。 + llm_description: ID of the knowledge base holding the document. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或路径 + human_description: + en_US: The unique identifier or path of the document to retrieve. + zh_Hans: 需要获取的文档的ID或其在知识库中的路径。 + llm_description: Unique doc ID or its path for retrieval. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py new file mode 100644 index 0000000000..932559445e --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +""" +获取知识库目录 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 15:17:11" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + + doc_ids = tool_parameters.get("doc_ids") + if doc_ids: + doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")] + tool_parameters["doc_ids"] = doc_ids + + return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml new file mode 100644 index 0000000000..f0c0024f17 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml @@ -0,0 +1,222 @@ +identity: + name: aliyuque_update_book_table_of_contents + author: 佐井 + label: + en_US: Update Book's Table of Contents + zh_Hans: 更新知识库目录 + icon: icon.svg +description: + human: + en_US: Update Book's Table of Contents. + zh_Hans: 更新知识库目录。 + llm: Update Book's Table of Contents. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. + + - name: action + type: select + required: true + form: llm + options: + - value: appendNode + label: + en_US: appendNode + zh_Hans: appendNode + pt_BR: appendNode + - value: prependNode + label: + en_US: prependNode + zh_Hans: prependNode + pt_BR: prependNode + - value: editNode + label: + en_US: editNode + zh_Hans: editNode + pt_BR: editNode + - value: editNode + label: + en_US: removeNode + zh_Hans: removeNode + pt_BR: removeNode + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + zh_Hans: 操作,创建场景下不支持同级头插 prependNode,删除节点不会删除关联文档,删除节点时action_mode=sibling (删除当前节点), action_mode=child (删除当前节点及子节点) + llm_description: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + + + - name: action_mode + type: select + required: false + form: llm + options: + - value: sibling + label: + en_US: sibling + zh_Hans: 同级 + pt_BR: sibling + - value: child + label: + en_US: child + zh_Hans: 子集 + pt_BR: child + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: Operation mode (sibling:same level, child:child level). + zh_Hans: 操作模式 (sibling:同级, child:子级)。 + llm_description: Operation mode (sibling:same level, child:child level). + + - name: target_uuid + type: string + required: false + form: llm + label: + en_US: Target node UUID + zh_Hans: 目标节点 UUID + human_description: + en_US: Target node UUID, defaults to root node if left empty. + zh_Hans: 目标节点 UUID, 不填默认为根节点。 + llm_description: Target node UUID, defaults to root node if left empty. + + - name: node_uuid + type: string + required: false + form: llm + label: + en_US: Node UUID + zh_Hans: 操作节点 UUID + human_description: + en_US: Operation node UUID [required for move/update/delete]. + zh_Hans: 操作节点 UUID [移动/更新/删除必填]。 + llm_description: Operation node UUID [required for move/update/delete]. + + - name: doc_ids + type: string + required: false + form: llm + label: + en_US: Document IDs + zh_Hans: 文档id列表 + human_description: + en_US: Document IDs [required for creating documents], separate multiple IDs with ','. + zh_Hans: 文档 IDs [创建文档必填],多个用','分隔。 + llm_description: Document IDs [required for creating documents], separate multiple IDs with ','. + + + - name: type + type: select + required: false + form: llm + default: DOC + options: + - value: DOC + label: + en_US: DOC + zh_Hans: 文档 + pt_BR: DOC + - value: LINK + label: + en_US: LINK + zh_Hans: 链接 + pt_BR: LINK + - value: TITLE + label: + en_US: TITLE + zh_Hans: 分组 + pt_BR: TITLE + label: + en_US: Node type + zh_Hans: 操节点类型 + human_description: + en_US: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + zh_Hans: 操节点类型 [创建必填] (DOC:文档, LINK:外链, TITLE:分组)。 + llm_description: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + + - name: title + type: string + required: false + form: llm + label: + en_US: Node Name + zh_Hans: 节点名称 + human_description: + en_US: Node name [required for creating groups/external links]. + zh_Hans: 节点名称 [创建分组/外链必填]。 + llm_description: Node name [required for creating groups/external links]. + + - name: url + type: string + required: false + form: llm + label: + en_US: Node URL + zh_Hans: 节点URL + human_description: + en_US: Node URL [required for creating external links]. + zh_Hans: 节点 URL [创建外链必填]。 + llm_description: Node URL [required for creating external links]. + + + - name: open_window + type: select + required: false + form: llm + default: 0 + options: + - value: 0 + label: + en_US: DOC + zh_Hans: Current Page + pt_BR: DOC + - value: 1 + label: + en_US: LINK + zh_Hans: New Page + pt_BR: LINK + label: + en_US: Open in new window + zh_Hans: 是否新窗口打开 + human_description: + en_US: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + zh_Hans: 是否新窗口打开 [外链选填] (0:当前页打开, 1:新窗口打开)。 + llm_description: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + + + - name: visible + type: select + required: false + form: llm + default: 1 + options: + - value: 0 + label: + en_US: Invisible + zh_Hans: 隐藏 + pt_BR: Invisible + - value: 1 + label: + en_US: Visible + zh_Hans: 可见 + pt_BR: Visible + label: + en_US: Visibility + zh_Hans: 是否可见 + human_description: + en_US: Visibility (0:invisible, 1:visible). + zh_Hans: 是否可见 (0:不可见, 1:可见)。 + llm_description: Visibility (0:invisible, 1:visible). diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py new file mode 100644 index 0000000000..0c6e0205e1 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py @@ -0,0 +1,24 @@ +""" +更新文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-19 16:50:07" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueUpdateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml new file mode 100644 index 0000000000..87f88c9b1b --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml @@ -0,0 +1,87 @@ +identity: + name: aliyuque_update_document + author: 佐井 + label: + en_US: Update Document + zh_Hans: 更新文档 + icon: icon.svg +description: + human: + en_US: Update an existing document within a specified knowledge base by providing the document ID or path. + zh_Hans: 通过提供文档ID或路径,更新指定知识库中的现有文档。 + llm: Update doc in a knowledge base via ID/path. +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: The unique identifier of the knowledge base where the document resides. + zh_Hans: 文档所属知识库的ID。 + llm_description: ID of the knowledge base holding the doc. + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或 路径 + human_description: + en_US: The unique identifier or the path of the document to be updated. + zh_Hans: 要更新的文档的唯一ID或路径。 + llm_description: Doc's ID or path for update. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + pt_BR: markdown + - value: html + label: + en_US: html + zh_Hans: html + pt_BR: html + - value: lake + label: + en_US: lake + zh_Hans: lake + pt_BR: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.png b/api/core/tools/provider/builtin/feishu_base/_assets/icon.png new file mode 100644 index 0000000000..787427e721 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_base/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg deleted file mode 100644 index 2663a0f59e..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg +++ /dev/null @@ -1,47 +0,0 @@ - - - - diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py deleted file mode 100644 index 4a605fbffe..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class AddBaseRecordTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_id = tool_parameters.get("table_id", "") - if not table_id: - return self.create_text_message("Invalid parameter table_id") - - fields = tool_parameters.get("fields", "") - if not fields: - return self.create_text_message("Invalid parameter fields") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = {} - payload = {"fields": json.loads(fields)} - - try: - res = httpx.post( - url.format(app_token=app_token, table_id=table_id), - headers=headers, - params=params, - json=payload, - timeout=30, - ) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml deleted file mode 100644 index 3ce0154efd..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml +++ /dev/null @@ -1,66 +0,0 @@ -identity: - name: add_base_record - author: Doug Lea - label: - en_US: Add Base Record - zh_Hans: 在多维表格数据表中新增一条记录 -description: - human: - en_US: Add Base Record - zh_Hans: | - 在多维表格数据表中新增一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/create - llm: Add a new record in the multidimensional table data table. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: 数据表的列字段内容 - human_description: - en_US: The fields of the Base data table are the columns of the data table. - zh_Hans: | - 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - llm_description: | - 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.py b/api/core/tools/provider/builtin/feishu_base/tools/add_records.py new file mode 100644 index 0000000000..905f8b7880 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.add_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml new file mode 100644 index 0000000000..f2a93490dc --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml @@ -0,0 +1,91 @@ +identity: + name: add_records + author: Doug Lea + label: + en_US: Add Records + zh_Hans: 新增多条记录 +description: + human: + en_US: Add Multiple Records to Multidimensional Table + zh_Hans: 在多维表格数据表中新增多条记录 + llm: A tool for adding multiple records to a multidimensional table. (在多维表格数据表中新增多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be added in this request. Example value: [{"multi-line-text":"text content","single_select":"option 1","date":1674206443000}] + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py deleted file mode 100644 index b05d700113..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class CreateBaseTableTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - name = tool_parameters.get("name", "") - - fields = tool_parameters.get("fields", "") - if not fields: - return self.create_text_message("Invalid parameter fields") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = {} - payload = {"table": {"name": name, "fields": json.loads(fields)}} - - try: - res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml deleted file mode 100644 index 48c46bec14..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml +++ /dev/null @@ -1,106 +0,0 @@ -identity: - name: create_base_table - author: Doug Lea - label: - en_US: Create Base Table - zh_Hans: 多维表格新增一个数据表 -description: - human: - en_US: Create base table - zh_Hans: | - 多维表格新增一个数据表,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table/create - llm: A tool for add a new data table to the multidimensional table. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: name - type: string - required: false - label: - en_US: name - zh_Hans: name - human_description: - en_US: Multidimensional table data table name - zh_Hans: 多维表格数据表名称 - llm_description: Multidimensional table data table name - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: fields - human_description: - en_US: Initial fields of the data table - zh_Hans: | - 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 - field_name:字段名; - type: 字段类型;可选值有 - 1:多行文本 - 2:数字 - 3:单选 - 4:多选 - 5:日期 - 7:复选框 - 11:人员 - 13:电话号码 - 15:超链接 - 17:附件 - 18:单向关联 - 20:公式 - 21:双向关联 - 22:地理位置 - 23:群组 - 1001:创建时间 - 1002:最后更新时间 - 1003:创建人 - 1004:修改人 - 1005:自动编号 - llm_description: | - 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 - field_name:字段名; - type: 字段类型;可选值有 - 1:多行文本 - 2:数字 - 3:单选 - 4:多选 - 5:日期 - 7:复选框 - 11:人员 - 13:电话号码 - 15:超链接 - 17:附件 - 18:单向关联 - 20:公式 - 21:双向关联 - 22:地理位置 - 23:群组 - 1001:创建时间 - 1002:最后更新时间 - 1003:创建人 - 1004:修改人 - 1005:自动编号 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_table.py new file mode 100644 index 0000000000..81f2617545 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_table.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateTableTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_name = tool_parameters.get("table_name") + default_view_name = tool_parameters.get("default_view_name") + fields = tool_parameters.get("fields") + + res = client.create_table(app_token, table_name, default_view_name, fields) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml new file mode 100644 index 0000000000..8b1007b9a5 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml @@ -0,0 +1,61 @@ +identity: + name: create_table + author: Doug Lea + label: + en_US: Create Table + zh_Hans: 新增数据表 +description: + human: + en_US: Add a Data Table to Multidimensional Table + zh_Hans: 在多维表格中新增一个数据表 + llm: A tool for adding a data table to a multidimensional table. (在多维表格中新增一个数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_name + type: string + required: true + label: + en_US: Table Name + zh_Hans: 数据表名称 + human_description: + en_US: | + The name of the data table, length range: 1 character to 100 characters. + zh_Hans: 数据表名称,长度范围:1 字符 ~ 100 字符。 + llm_description: 数据表名称,长度范围:1 字符 ~ 100 字符。 + form: llm + + - name: default_view_name + type: string + required: false + label: + en_US: Default View Name + zh_Hans: 默认表格视图的名称 + human_description: + en_US: The name of the default table view, defaults to "Table" if not filled. + zh_Hans: 默认表格视图的名称,不填则默认为"表格"。 + llm_description: 默认表格视图的名称,不填则默认为"表格"。 + form: llm + + - name: fields + type: string + required: true + label: + en_US: Initial Fields + zh_Hans: 初始字段 + human_description: + en_US: | + Initial fields of the data table, format: [ { "field_name": "Multi-line Text","type": 1 },{ "field_name": "Number","type": 2 },{ "field_name": "Single Select","type": 3 },{ "field_name": "Multiple Select","type": 4 },{ "field_name": "Date","type": 5 } ]. For field details, refer to: https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + zh_Hans: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + llm_description: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py deleted file mode 100644 index 862eb2171b..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class DeleteBaseRecordsTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_id = tool_parameters.get("table_id", "") - if not table_id: - return self.create_text_message("Invalid parameter table_id") - - record_ids = tool_parameters.get("record_ids", "") - if not record_ids: - return self.create_text_message("Invalid parameter record_ids") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = {} - payload = {"records": json.loads(record_ids)} - - try: - res = httpx.post( - url.format(app_token=app_token, table_id=table_id), - headers=headers, - params=params, - json=payload, - timeout=30, - ) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml deleted file mode 100644 index 595b287029..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml +++ /dev/null @@ -1,60 +0,0 @@ -identity: - name: delete_base_records - author: Doug Lea - label: - en_US: Delete Base Records - zh_Hans: 在多维表格数据表中删除多条记录 -description: - human: - en_US: Delete base records - zh_Hans: | - 该接口用于删除多维表格数据表中的多条记录,单次调用中最多删除 500 条记录。 - llm: A tool for delete multiple records in a multidimensional table data table, up to 500 records can be deleted in a single call. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_ids - type: string - required: true - label: - en_US: record_ids - zh_Hans: record_ids - human_description: - en_US: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] - zh_Hans: 待删除的多条记录id列表,示例为 ["recwNXzPQv","recpCsf4ME"] - llm_description: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py deleted file mode 100644 index f512186303..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ /dev/null @@ -1,46 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class DeleteBaseTablesTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_ids = tool_parameters.get("table_ids", "") - if not table_ids: - return self.create_text_message("Invalid parameter table_ids") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = {} - payload = {"table_ids": json.loads(table_ids)} - - try: - res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml deleted file mode 100644 index 5d72814363..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml +++ /dev/null @@ -1,48 +0,0 @@ -identity: - name: delete_base_tables - author: Doug Lea - label: - en_US: Delete Base Tables - zh_Hans: 删除多维表格中的数据表 -description: - human: - en_US: Delete base tables - zh_Hans: | - 删除多维表格中的数据表 - llm: A tool for deleting a data table in a multidimensional table -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_ids - type: string - required: true - label: - en_US: table_ids - zh_Hans: table_ids - human_description: - en_US: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - zh_Hans: 待删除数据表的id列表,当前一次操作最多支持50个数据表,示例为 ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - llm_description: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py new file mode 100644 index 0000000000..c896a2c81b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class DeleteRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + + res = client.delete_records(app_token, table_id, table_name, record_ids) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml new file mode 100644 index 0000000000..c30ebd630c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml @@ -0,0 +1,86 @@ +identity: + name: delete_records + author: Doug Lea + label: + en_US: Delete Records + zh_Hans: 删除多条记录 +description: + human: + en_US: Delete Multiple Records from Multidimensional Table + zh_Hans: 删除多维表格数据表中的多条记录 + llm: A tool for deleting multiple records from a multidimensional table. (删除多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: Record IDs + zh_Hans: 记录 ID 列表 + human_description: + en_US: | + List of IDs for the records to be deleted, example value: ["recwNXzPQv"]. + zh_Hans: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + llm_description: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py new file mode 100644 index 0000000000..f732a16da6 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class DeleteTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_ids = tool_parameters.get("table_ids") + table_names = tool_parameters.get("table_names") + + res = client.delete_tables(app_token, table_ids, table_names) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml new file mode 100644 index 0000000000..498126eae5 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml @@ -0,0 +1,49 @@ +identity: + name: delete_tables + author: Doug Lea + label: + en_US: Delete Tables + zh_Hans: 删除数据表 +description: + human: + en_US: Batch Delete Data Tables from Multidimensional Table + zh_Hans: 批量删除多维表格中的数据表 + llm: A tool for batch deleting data tables from a multidimensional table. (批量删除多维表格中的数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_ids + type: string + required: false + label: + en_US: Table IDs + zh_Hans: 数据表 ID + human_description: + en_US: | + IDs of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["tbl1TkhyTWDkSoZ3"]. Ensure that either table_ids or table_names is not empty. + zh_Hans: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + llm_description: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + form: llm + + - name: table_names + type: string + required: false + label: + en_US: Table Names + zh_Hans: 数据表名称 + human_description: + en_US: | + Names of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["Table1", "Table2"]. Ensure that either table_names or table_ids is not empty. + zh_Hans: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + llm_description: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py deleted file mode 100644 index 2ea61d0068..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class GetTenantAccessTokenTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - - app_id = tool_parameters.get("app_id", "") - if not app_id: - return self.create_text_message("Invalid parameter app_id") - - app_secret = tool_parameters.get("app_secret", "") - if not app_secret: - return self.create_text_message("Invalid parameter app_secret") - - headers = { - "Content-Type": "application/json", - } - params = {} - payload = {"app_id": app_id, "app_secret": app_secret} - - """ - { - "code": 0, - "msg": "ok", - "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", - "expire": 7200 - } - """ - try: - res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml deleted file mode 100644 index 88acc27e06..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml +++ /dev/null @@ -1,39 +0,0 @@ -identity: - name: get_tenant_access_token - author: Doug Lea - label: - en_US: Get Tenant Access Token - zh_Hans: 获取飞书自建应用的 tenant_access_token -description: - human: - en_US: Get tenant access token - zh_Hans: | - 获取飞书自建应用的 tenant_access_token,响应体示例: - {"code":0,"msg":"ok","tenant_access_token":"t-caecc734c2e3328a62489fe0648c4b98779515d3","expire":7200} - tenant_access_token: 租户访问凭证; - expire: tenant_access_token 的过期时间,单位为秒; - llm: A tool for obtaining a tenant access token. The input parameters must include app_id and app_secret. -parameters: - - name: app_id - type: string - required: true - label: - en_US: app_id - zh_Hans: 应用唯一标识 - human_description: - en_US: app_id is the unique identifier of the Lark Open Platform application - zh_Hans: app_id 是飞书开放平台应用的唯一标识 - llm_description: app_id is the unique identifier of the Lark Open Platform application - form: llm - - - name: app_secret - type: secret-input - required: true - label: - en_US: app_secret - zh_Hans: 应用秘钥 - human_description: - en_US: app_secret is the secret key of the application - zh_Hans: app_secret 是应用的秘钥 - llm_description: app_secret is the secret key of the application - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py deleted file mode 100644 index e579d02f69..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ListBaseRecordsTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_id = tool_parameters.get("table_id", "") - if not table_id: - return self.create_text_message("Invalid parameter table_id") - - page_token = tool_parameters.get("page_token", "") - page_size = tool_parameters.get("page_size", "") - sort_condition = tool_parameters.get("sort_condition", "") - filter_condition = tool_parameters.get("filter_condition", "") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = { - "page_token": page_token, - "page_size": page_size, - } - - payload = {"automatic_fields": True} - if sort_condition: - payload["sort"] = json.loads(sort_condition) - if filter_condition: - payload["filter"] = json.loads(filter_condition) - - try: - res = httpx.post( - url.format(app_token=app_token, table_id=table_id), - headers=headers, - params=params, - json=payload, - timeout=30, - ) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml deleted file mode 100644 index 8647c880a6..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml +++ /dev/null @@ -1,108 +0,0 @@ -identity: - name: list_base_records - author: Doug Lea - label: - en_US: List Base Records - zh_Hans: 查询多维表格数据表中的现有记录 -description: - human: - en_US: List base records - zh_Hans: | - 查询多维表格数据表中的现有记录,单次最多查询 500 行记录,支持分页获取。 - llm: Query existing records in a multidimensional table data table. A maximum of 500 rows of records can be queried at a time, and paging retrieval is supported. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: paging size - zh_Hans: 分页大小,默认值为 20,最大值为 100。 - llm_description: The default value of paging size is 20 and the maximum value is 100. - form: llm - - - name: sort_condition - type: string - required: false - label: - en_US: sort_condition - zh_Hans: 排序条件 - human_description: - en_US: sort condition - zh_Hans: | - 排序条件,格式为:[{"field_name":"多行文本","desc":true}]。 - field_name: 字段名称; - desc: 是否倒序排序; - llm_description: | - Sorting conditions, the format is: [{"field_name":"multi-line text","desc":true}]. - form: llm - - - name: filter_condition - type: string - required: false - label: - en_US: filter_condition - zh_Hans: 筛选条件 - human_description: - en_US: filter condition - zh_Hans: | - 筛选条件,格式为:{"conjunction":"and","conditions":[{"field_name":"字段1","operator":"is","value":["文本内容"]}]}。 - conjunction:条件逻辑连接词; - conditions:筛选条件集合; - field_name:筛选条件的左值,值为字段的名称; - operator:条件运算符; - value:目标值; - llm_description: | - The format of the filter condition is: {"conjunction":"and","conditions":[{"field_name":"Field 1","operator":"is","value":["text content"]}]}. - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py deleted file mode 100644 index 4ec9a476bc..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ListBaseTablesTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - page_token = tool_parameters.get("page_token", "") - page_size = tool_parameters.get("page_size", "") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = { - "page_token": page_token, - "page_size": page_size, - } - - try: - res = httpx.get(url.format(app_token=app_token), headers=headers, params=params, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml deleted file mode 100644 index 9887124a28..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml +++ /dev/null @@ -1,65 +0,0 @@ -identity: - name: list_base_tables - author: Doug Lea - label: - en_US: List Base Tables - zh_Hans: 根据 app_token 获取多维表格下的所有数据表 -description: - human: - en_US: List base tables - zh_Hans: | - 根据 app_token 获取多维表格下的所有数据表 - llm: A tool for getting all data tables under a multidimensional table based on app_token. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 - llm_description: | - Pagination token. If it is not filled in the first request, it means to start traversal from the beginning. - If there are more items in the pagination query result, a new page_token will be returned at the same time. - The page_token can be used to obtain the query result in the next traversal. - 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: paging size - zh_Hans: 分页大小,默认值为 20,最大值为 100。 - llm_description: The default value of paging size is 20 and the maximum value is 100. - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py new file mode 100644 index 0000000000..c7768a496d --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size", 20) + + res = client.list_tables(app_token, page_token, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml new file mode 100644 index 0000000000..5a3891bd45 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml @@ -0,0 +1,50 @@ +identity: + name: list_tables + author: Doug Lea + label: + en_US: List Tables + zh_Hans: 列出数据表 +description: + human: + en_US: Get All Data Tables under Multidimensional Table + zh_Hans: 获取多维表格下的所有数据表 + llm: A tool for getting all data tables under a multidimensional table. (获取多维表格下的所有数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 100. + zh_Hans: 分页大小,默认值:20,最大值:100。 + llm_description: 分页大小,默认值:20,最大值:100。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py deleted file mode 100644 index fb818f8380..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ /dev/null @@ -1,49 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ReadBaseRecordTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_id = tool_parameters.get("table_id", "") - if not table_id: - return self.create_text_message("Invalid parameter table_id") - - record_id = tool_parameters.get("record_id", "") - if not record_id: - return self.create_text_message("Invalid parameter record_id") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - try: - res = httpx.get( - url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 - ) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml deleted file mode 100644 index 400e9a1021..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml +++ /dev/null @@ -1,60 +0,0 @@ -identity: - name: read_base_record - author: Doug Lea - label: - en_US: Read Base Record - zh_Hans: 根据 record_id 的值检索多维表格数据表的记录 -description: - human: - en_US: Read base record - zh_Hans: | - 根据 record_id 的值检索多维表格数据表的记录 - llm: Retrieve records from a multidimensional table based on the value of record_id -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_id - type: string - required: true - label: - en_US: record_id - zh_Hans: 单条记录的 id - human_description: - en_US: The id of a single record - zh_Hans: 单条记录的 id - llm_description: The id of a single record - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.py b/api/core/tools/provider/builtin/feishu_base/tools/read_records.py new file mode 100644 index 0000000000..46f3df4ff0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ReadRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_records(app_token, table_id, table_name, record_ids, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml new file mode 100644 index 0000000000..911e667cfc --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml @@ -0,0 +1,86 @@ +identity: + name: read_records + author: Doug Lea + label: + en_US: Read Records + zh_Hans: 批量获取记录 +description: + human: + en_US: Batch Retrieve Records from Multidimensional Table + zh_Hans: 批量获取多维表格数据表中的记录信息 + llm: A tool for batch retrieving records from a multidimensional table, supporting up to 100 records per call. (批量获取多维表格数据表中的记录信息,单次调用最多支持查询 100 条记录) + +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: record_ids + zh_Hans: 记录 ID 列表 + human_description: + en_US: List of record IDs, which can be obtained by calling the "Query Records API". + zh_Hans: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + llm_description: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py new file mode 100644 index 0000000000..c959496735 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -0,0 +1,39 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SearchRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + view_id = tool_parameters.get("view_id") + field_names = tool_parameters.get("field_names") + sort = tool_parameters.get("sort") + filters = tool_parameters.get("filter") + page_token = tool_parameters.get("page_token") + automatic_fields = tool_parameters.get("automatic_fields", False) + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 20) + + res = client.search_record( + app_token, + table_id, + table_name, + view_id, + field_names, + sort, + filters, + page_token, + automatic_fields, + user_id_type, + page_size, + ) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml new file mode 100644 index 0000000000..6cac4b0524 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml @@ -0,0 +1,163 @@ +identity: + name: search_records + author: Doug Lea + label: + en_US: Search Records + zh_Hans: 查询记录 +description: + human: + en_US: Query records in a multidimensional table, up to 500 rows per query. + zh_Hans: 查询多维表格数据表中的记录,单次最多查询 500 行记录。 + llm: A tool for querying records in a multidimensional table, up to 500 rows per query. (查询多维表格数据表中的记录,单次最多查询 500 行记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: view_id + type: string + required: false + label: + en_US: view_id + zh_Hans: 视图唯一标识 + human_description: + en_US: | + Unique identifier for a view in a multidimensional table. It can be found in the URL's query parameter with the key 'view'. For example: https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx. + zh_Hans: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 + llm_description: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 + form: llm + + - name: field_names + type: string + required: false + label: + en_US: field_names + zh_Hans: 字段名称 + human_description: + en_US: | + Field names to specify which fields to include in the returned records. Example value: ["Field1", "Field2"]. + zh_Hans: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + llm_description: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + form: llm + + - name: sort + type: string + required: false + label: + en_US: sort + zh_Hans: 排序条件 + human_description: + en_US: | + Sorting conditions, for example: [{"field_name":"Multiline Text","desc":true}]. + zh_Hans: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + llm_description: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + form: llm + + - name: filter + type: string + required: false + label: + en_US: filter + zh_Hans: 筛选条件 + human_description: + en_US: Object containing filter information. For details on how to fill in the filter, refer to the record filter parameter guide (https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide). + zh_Hans: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + llm_description: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + form: llm + + - name: automatic_fields + type: boolean + required: false + label: + en_US: automatic_fields + zh_Hans: automatic_fields + human_description: + en_US: Whether to return automatically calculated fields. Default is false, meaning they are not returned. + zh_Hans: 是否返回自动计算的字段。默认为 false,表示不返回。 + llm_description: 是否返回自动计算的字段。默认为 false,表示不返回。 + form: form + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 500. + zh_Hans: 分页大小,默认值:20,最大值:500。 + llm_description: 分页大小,默认值:20,最大值:500。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py deleted file mode 100644 index 6d7e33f3ff..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class UpdateBaseRecordTool(BuiltinTool): - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - - access_token = tool_parameters.get("Authorization", "") - if not access_token: - return self.create_text_message("Invalid parameter access_token") - - app_token = tool_parameters.get("app_token", "") - if not app_token: - return self.create_text_message("Invalid parameter app_token") - - table_id = tool_parameters.get("table_id", "") - if not table_id: - return self.create_text_message("Invalid parameter table_id") - - record_id = tool_parameters.get("record_id", "") - if not record_id: - return self.create_text_message("Invalid parameter record_id") - - fields = tool_parameters.get("fields", "") - if not fields: - return self.create_text_message("Invalid parameter fields") - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - } - - params = {} - payload = {"fields": json.loads(fields)} - - try: - res = httpx.put( - url.format(app_token=app_token, table_id=table_id, record_id=record_id), - headers=headers, - params=params, - json=payload, - timeout=30, - ) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}" - ) - except Exception as e: - return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml deleted file mode 100644 index 788798c4b3..0000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml +++ /dev/null @@ -1,78 +0,0 @@ -identity: - name: update_base_record - author: Doug Lea - label: - en_US: Update Base Record - zh_Hans: 更新多维表格数据表中的一条记录 -description: - human: - en_US: Update base record - zh_Hans: | - 更新多维表格数据表中的一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/update - llm: Update a record in a multidimensional table data table -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_id - type: string - required: true - label: - en_US: record_id - zh_Hans: 单条记录的 id - human_description: - en_US: The id of a single record - zh_Hans: 单条记录的 id - llm_description: The id of a single record - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: 数据表的列字段内容 - human_description: - en_US: The fields of a multidimensional table data table, that is, the columns of the data table. - zh_Hans: | - 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - llm_description: | - 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py new file mode 100644 index 0000000000..a7b0363875 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class UpdateRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.update_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml new file mode 100644 index 0000000000..68117e7136 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml @@ -0,0 +1,91 @@ +identity: + name: update_records + author: Doug Lea + label: + en_US: Update Records + zh_Hans: 更新多条记录 +description: + human: + en_US: Update Multiple Records in Multidimensional Table + zh_Hans: 更新多维表格数据表中的多条记录 + llm: A tool for updating multiple records in a multidimensional table. (更新多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be updated in this request. Example value: [{"fields":{"multi-line-text":"text content","single_select":"option 1","date":1674206443000},"record_id":"recupK4f4RM5RX"}]. + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg new file mode 100644 index 0000000000..01743c9cd3 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py new file mode 100644 index 0000000000..0b9c025834 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py @@ -0,0 +1,33 @@ +from typing import Any + +import openai + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PodcastGeneratorProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + tts_service = credentials.get("tts_service") + api_key = credentials.get("api_key") + + if not tts_service: + raise ToolProviderCredentialValidationError("TTS service is not specified") + + if not api_key: + raise ToolProviderCredentialValidationError("API key is missing") + + if tts_service == "openai": + self._validate_openai_credentials(api_key) + else: + raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}") + + def _validate_openai_credentials(self, api_key: str) -> None: + client = openai.OpenAI(api_key=api_key) + try: + # We're using a simple API call to validate the credentials + client.models.list() + except openai.AuthenticationError: + raise ToolProviderCredentialValidationError("Invalid OpenAI API key") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}") diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml new file mode 100644 index 0000000000..bd02b32020 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml @@ -0,0 +1,34 @@ +identity: + author: Dify + name: podcast_generator + label: + en_US: Podcast Generator + zh_Hans: 播客生成器 + description: + en_US: Generate podcast audio using Text-to-Speech services + zh_Hans: 使用文字转语音服务生成播客音频 + icon: icon.svg +credentials_for_provider: + tts_service: + type: select + required: true + label: + en_US: TTS Service + zh_Hans: TTS 服务 + placeholder: + en_US: Select a TTS service + zh_Hans: 选择一个 TTS 服务 + options: + - label: + en_US: OpenAI TTS + zh_Hans: OpenAI TTS + value: openai + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API 密钥 + placeholder: + en_US: Enter your TTS service API key + zh_Hans: 输入您的 TTS 服务 API 密钥 diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py new file mode 100644 index 0000000000..8c8dd9bf68 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -0,0 +1,100 @@ +import concurrent.futures +import io +import random +from typing import Any, Literal, Optional, Union + +import openai +from pydub import AudioSegment + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class PodcastAudioGeneratorTool(BuiltinTool): + @staticmethod + def _generate_silence(duration: float): + # Generate silent WAV data using pydub + silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds + return silence + + @staticmethod + def _generate_audio_segment( + client: openai.OpenAI, + line: str, + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + index: int, + ) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]: + try: + response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav") + audio = AudioSegment.from_wav(io.BytesIO(response.content)) + silence_duration = random.uniform(0.1, 1.5) + silence = PodcastAudioGeneratorTool._generate_silence(silence_duration) + return index, audio, silence + except Exception as e: + return index, f"Error generating audio: {str(e)}", None + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # Extract parameters + script = tool_parameters.get("script", "") + host1_voice = tool_parameters.get("host1_voice") + host2_voice = tool_parameters.get("host2_voice") + + # Split the script into lines + script_lines = [line for line in script.split("\n") if line.strip()] + + # Ensure voices are provided + if not host1_voice or not host2_voice: + raise ToolParameterValidationError("Host voices are required") + + # Get OpenAI API key from credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") + api_key = self.runtime.credentials.get("api_key") + if not api_key: + raise ToolProviderCredentialValidationError("OpenAI API key is missing") + + # Initialize OpenAI client + client = openai.OpenAI(api_key=api_key) + + # Create a thread pool + max_workers = 5 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i, line in enumerate(script_lines): + voice = host1_voice if i % 2 == 0 else host2_voice + future = executor.submit(self._generate_audio_segment, client, line, voice, i) + futures.append(future) + + # Collect results + audio_segments: list[Any] = [None] * len(script_lines) + for future in concurrent.futures.as_completed(futures): + index, audio, silence = future.result() + if isinstance(audio, str): # Error occurred + return self.create_text_message(audio) + audio_segments[index] = (audio, silence) + + # Combine audio segments in the correct order + combined_audio = AudioSegment.empty() + for i, (audio, silence) in enumerate(audio_segments): + if audio: + combined_audio += audio + if i < len(audio_segments) - 1 and silence: + combined_audio += silence + + # Export the combined audio to a WAV file in memory + buffer = io.BytesIO() + combined_audio.export(buffer, format="wav") + wav_bytes = buffer.getvalue() + + # Create a blob message with the combined audio + return [ + self.create_text_message("Audio generated successfully"), + self.create_blob_message( + blob=wav_bytes, + meta={"mime_type": "audio/x-wav"}, + save_as=self.VariableKey.AUDIO, + ), + ] diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml new file mode 100644 index 0000000000..d6ae98f595 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml @@ -0,0 +1,95 @@ +identity: + name: podcast_audio_generator + author: Dify + label: + en_US: Podcast Audio Generator + zh_Hans: 播客音频生成器 +description: + human: + en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service. + zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。 + llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts. +parameters: + - name: script + type: string + required: true + label: + en_US: Podcast Script + zh_Hans: 播客脚本 + human_description: + en_US: A string containing alternating lines for two hosts, separated by newline characters. + zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。 + llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters. + form: llm + - name: host1_voice + type: select + required: true + label: + en_US: Host 1 Voice + zh_Hans: 主持人1 音色 + human_description: + en_US: The voice for the first host. + zh_Hans: 第一位主持人的音色。 + llm_description: The voice identifier for the first host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form + - name: host2_voice + type: select + required: true + label: + en_US: Host 2 Voice + zh_Hans: 主持人2 音色 + human_description: + en_US: The voice for the second host. + zh_Hans: 第二位主持人的音色。 + llm_description: The voice identifier for the second host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py deleted file mode 100644 index 6f7610651c..0000000000 --- a/api/core/tools/utils/tool_parameter_converter.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolParameter - - -class ToolParameterConverter: - @staticmethod - def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - return "string" - - case ToolParameter.ToolParameterType.BOOLEAN: - return "boolean" - - case ToolParameter.ToolParameterType.NUMBER: - return "number" - - case _: - raise ValueError(f"Unsupported parameter type {parameter_type}") - - @staticmethod - def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: - # convert tool parameter config to correct type - try: - match parameter_type: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - if value is None: - return "" - else: - return value if isinstance(value, str) else str(value) - - case ToolParameter.ToolParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case ToolParameter.ToolParameterType.NUMBER: - if isinstance(value, int) | isinstance(value, float): - return value - elif isinstance(value, str) and value != "": - if "." in value: - return float(value) - else: - return int(value) - case ToolParameter.ToolParameterType.FILE: - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/app/segments/__init__.py b/api/core/variables/__init__.py similarity index 78% rename from api/core/app/segments/__init__.py rename to api/core/variables/__init__.py index 652ef243b4..87f9e3ed45 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/variables/__init__.py @@ -1,7 +1,12 @@ from .segment_group import SegmentGroup from .segments import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, ArraySegment, + ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -15,6 +20,7 @@ from .variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -46,4 +52,10 @@ __all__ = [ "ArrayNumberVariable", "ArrayObjectVariable", "ArraySegment", + "ArrayFileSegment", + "ArrayNumberSegment", + "ArrayObjectSegment", + "ArrayStringSegment", + "FileSegment", + "FileVariable", ] diff --git a/api/core/app/segments/exc.py b/api/core/variables/exc.py similarity index 100% rename from api/core/app/segments/exc.py rename to api/core/variables/exc.py diff --git a/api/core/app/segments/segment_group.py b/api/core/variables/segment_group.py similarity index 100% rename from api/core/app/segments/segment_group.py rename to api/core/variables/segment_group.py diff --git a/api/core/app/segments/segments.py b/api/core/variables/segments.py similarity index 77% rename from api/core/app/segments/segments.py rename to api/core/variables/segments.py index b26b3c8291..b71882b043 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/variables/segments.py @@ -5,6 +5,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, field_validator +from core.file import File + from .types import SegmentType @@ -39,6 +41,9 @@ class Segment(BaseModel): @property def size(self) -> int: + """ + Return the size of the value in bytes. + """ return sys.getsizeof(self.value) def to_object(self) -> Any: @@ -51,15 +56,15 @@ class NoneSegment(Segment): @property def text(self) -> str: - return "null" + return "" @property def log(self) -> str: - return "null" + return "" @property def markdown(self) -> str: - return "null" + return "" class StringSegment(Segment): @@ -99,13 +104,27 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, "to_markdown"): - items.append(item.to_markdown()) - else: - items.append(str(item)) + items.append(str(item)) return "\n".join(items) +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + value: File + + @property + def markdown(self) -> str: + return self.value.markdown + + @property + def log(self) -> str: + return str(self.value) + + @property + def text(self) -> str: + return str(self.value) + + class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY value: Sequence[Any] @@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] + + +class ArrayFileSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_FILE + value: Sequence[File] + + @property + def markdown(self) -> str: + items = [] + for item in self.value: + items.append(item.markdown) + return "\n".join(items) diff --git a/api/core/app/segments/types.py b/api/core/variables/types.py similarity index 86% rename from api/core/app/segments/types.py rename to api/core/variables/types.py index 9cf0856df5..53c2e8a3aa 100644 --- a/api/core/app/segments/types.py +++ b/api/core/variables/types.py @@ -11,5 +11,7 @@ class SegmentType(str, Enum): ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" OBJECT = "object" + FILE = "file" + ARRAY_FILE = "array[file]" GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/variables/variables.py similarity index 95% rename from api/core/app/segments/variables.py rename to api/core/variables/variables.py index f0e403ab8d..ddc6914192 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/variables/variables.py @@ -7,6 +7,7 @@ from .segments import ( ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -73,3 +74,7 @@ class SecretVariable(StringVariable): class NoneVariable(NoneSegment, Variable): value_type: SegmentType = SegmentType.NONE value: None = None + + +class FileVariable(FileSegment, Variable): + pass diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py similarity index 99% rename from api/core/app/apps/workflow_logging_callback.py rename to api/core/workflow/callbacks/workflow_logging_callback.py index 60683b0f21..17913de7b0 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -1,7 +1,6 @@ from typing import Optional from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, @@ -20,6 +19,8 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunSucceededEvent, ) +from .base_workflow_callback import WorkflowCallback + _TEXT_COLOR_MAPPING = { "blue": "36;1", "yellow": "33;1", diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py new file mode 100644 index 0000000000..e3fe17c284 --- /dev/null +++ b/api/core/workflow/constants.py @@ -0,0 +1,3 @@ +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py new file mode 100644 index 0000000000..61f727740c --- /dev/null +++ b/api/core/workflow/nodes/base/__init__.py @@ -0,0 +1,4 @@ +from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from .node import BaseNode + +__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"] diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/nodes/base/entities.py similarity index 100% rename from api/core/workflow/entities/base_node_data_entities.py rename to api/core/workflow/nodes/base/entities.py diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base/node.py similarity index 60% rename from api/core/workflow/nodes/base_node.py rename to api/core/workflow/nodes/base/node.py index 7bfe45a13c..053a339ba7 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,17 +1,27 @@ -from abc import ABC, abstractmethod +import logging +from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import BaseNodeData + +if TYPE_CHECKING: + from core.workflow.graph_engine.entities.event import InNodeEvent + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + +logger = logging.getLogger(__name__) + +GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) -class BaseNode(ABC): +class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -19,9 +29,9 @@ class BaseNode(ABC): self, id: str, config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, ) -> None: @@ -45,22 +55,25 @@ class BaseNode(ABC): raise ValueError("Node ID is required.") self.node_id = node_id - self.node_data = self._node_data_cls(**config.get("data", {})) + self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {}))) @abstractmethod - def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: """ Run node :return: """ raise NotImplementedError - def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node entry - :return: - """ - result = self._run() + def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: + try: + result = self._run() + except Exception as e: + logger.error(f"Node {self.node_id} failed to run: {e}") + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) if isinstance(result, NodeRunResult): yield RunCompletedEvent(run_result=result) @@ -69,7 +82,10 @@ class BaseNode(ABC): @classmethod def extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], config: dict + cls, + *, + graph_config: Mapping[str, Any], + config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -83,12 +99,16 @@ class BaseNode(ABC): node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: GenericNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 0000000000..3cc5fae187 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .entities import DocumentExtractorNodeData +from .node import DocumentExtractorNode + +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py new file mode 100644 index 0000000000..7e9ffaa889 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -0,0 +1,7 @@ +from collections.abc import Sequence + +from core.workflow.nodes.base import BaseNodeData + + +class DocumentExtractorNodeData(BaseNodeData): + variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py new file mode 100644 index 0000000000..c9d4bb8ef6 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -0,0 +1,14 @@ +class DocumentExtractorError(Exception): + """Base exception for errors related to the DocumentExtractorNode.""" + + +class FileDownloadError(DocumentExtractorError): + """Exception raised when there's an error downloading a file.""" + + +class UnsupportedFileTypeError(DocumentExtractorError): + """Exception raised when trying to extract text from an unsupported file type.""" + + +class TextExtractionError(DocumentExtractorError): + """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 0000000000..b4ffee1f13 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,274 @@ +import csv +import io + +import docx +import pandas as pd +import pypdfium2 +from unstructured.partition.email import partition_email +from unstructured.partition.epub import partition_epub +from unstructured.partition.msg import partition_msg +from unstructured.partition.ppt import partition_ppt +from unstructured.partition.pptx import partition_pptx + +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy +from core.variables import ArrayFileSegment +from core.variables.segments import FileSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import DocumentExtractorNodeData +from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError + + +class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): + """ + Extracts text content from various file types. + Supports plain text, PDF, and DOC/DOCX files. + """ + + _node_data_cls = DocumentExtractorNodeData + _node_type = NodeType.DOCUMENT_EXTRACTOR + + def _run(self): + variable_selector = self.node_data.variable_selector + variable = self.graph_runtime_state.variable_pool.get(variable_selector) + + if variable is None: + error_message = f"File variable not found for selector: {variable_selector}" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): + error_message = f"Variable {variable_selector} is not an ArrayFileSegment" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + + value = variable.value + inputs = {"variable_selector": variable_selector} + process_data = {"documents": value if isinstance(value, list) else [value]} + + try: + if isinstance(value, list): + extracted_text_list = list(map(_extract_text_from_file, value)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text_list}, + ) + elif isinstance(value, File): + extracted_text = _extract_text_from_file(value) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text}, + ) + else: + raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") + except DocumentExtractorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + ) + + +def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: + """Extract text from a file based on its MIME type.""" + if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}: + return _extract_text_from_plain_text(file_content) + elif mime_type == "application/pdf": + return _extract_text_from_pdf(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", + }: + return _extract_text_from_doc(file_content) + elif mime_type == "text/csv": + return _extract_text_from_csv(file_content) + elif mime_type in { + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-excel", + }: + return _extract_text_from_excel(file_content) + elif mime_type == "application/vnd.ms-powerpoint": + return _extract_text_from_ppt(file_content) + elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return _extract_text_from_pptx(file_content) + elif mime_type == "application/epub+zip": + return _extract_text_from_epub(file_content) + elif mime_type == "message/rfc822": + return _extract_text_from_eml(file_content) + elif mime_type == "application/vnd.ms-outlook": + return _extract_text_from_msg(file_content) + else: + raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + + +def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: + """Extract text from a file based on its file extension.""" + match file_extension: + case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + return _extract_text_from_plain_text(file_content) + case ".pdf": + return _extract_text_from_pdf(file_content) + case ".doc" | ".docx": + return _extract_text_from_doc(file_content) + case ".csv": + return _extract_text_from_csv(file_content) + case ".xls" | ".xlsx": + return _extract_text_from_excel(file_content) + case ".ppt": + return _extract_text_from_ppt(file_content) + case ".pptx": + return _extract_text_from_pptx(file_content) + case ".epub": + return _extract_text_from_epub(file_content) + case ".eml": + return _extract_text_from_eml(file_content) + case ".msg": + return _extract_text_from_msg(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") + + +def _extract_text_from_plain_text(file_content: bytes) -> str: + try: + return file_content.decode("utf-8") + except UnicodeDecodeError as e: + raise TextExtractionError("Failed to decode plain text file") from e + + +def _extract_text_from_pdf(file_content: bytes) -> str: + try: + pdf_file = io.BytesIO(file_content) + pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) + text = "" + for page in pdf_document: + text_page = page.get_textpage() + text += text_page.get_text_range() + text_page.close() + page.close() + return text + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e + + +def _extract_text_from_doc(file_content: bytes) -> str: + try: + doc_file = io.BytesIO(file_content) + doc = docx.Document(doc_file) + return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e + + +def _download_file_content(file: File) -> bytes: + """Download the content of a file based on its transfer method.""" + try: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + if file.remote_url is None: + raise FileDownloadError("Missing URL for remote file") + response = ssrf_proxy.get(file.remote_url) + response.raise_for_status() + return response.content + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + return file_manager.download(file) + else: + raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + except Exception as e: + raise FileDownloadError(f"Error downloading file: {str(e)}") from e + + +def _extract_text_from_file(file: File): + if file.mime_type is None: + raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing") + file_content = _download_file_content(file) + if file.transfer_method == FileTransferMethod.REMOTE_URL: + extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + else: + extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) + return extracted_text + + +def _extract_text_from_csv(file_content: bytes) -> str: + try: + csv_file = io.StringIO(file_content.decode("utf-8")) + csv_reader = csv.reader(csv_file) + rows = list(csv_reader) + + if not rows: + return "" + + # Create Markdown table + markdown_table = "| " + " | ".join(rows[0]) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" + for row in rows[1:]: + markdown_table += "| " + " | ".join(row) + " |\n" + + return markdown_table.strip() + except Exception as e: + raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e + + +def _extract_text_from_excel(file_content: bytes) -> str: + """Extract text from an Excel file using pandas.""" + + try: + df = pd.read_excel(io.BytesIO(file_content)) + + # Drop rows where all elements are NaN + df.dropna(how="all", inplace=True) + + # Convert DataFrame to Markdown table + markdown_table = df.to_markdown(index=False) + return markdown_table + except Exception as e: + raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e + + +def _extract_text_from_ppt(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_ppt(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e + + +def _extract_text_from_pptx(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_pptx(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e + + +def _extract_text_from_epub(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_epub(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e + + +def _extract_text_from_eml(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_email(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e + + +def _extract_text_from_msg(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_msg(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py new file mode 100644 index 0000000000..208144655b --- /dev/null +++ b/api/core/workflow/nodes/enums.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class NodeType(str, Enum): + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" + VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # Fake start node for iteration. + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" + DOCUMENT_EXTRACTOR = "document-extractor" + LIST_OPERATOR = "list-operator" diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py new file mode 100644 index 0000000000..581def9553 --- /dev/null +++ b/api/core/workflow/nodes/event/__init__.py @@ -0,0 +1,10 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .types import NodeEvent + +__all__ = [ + "RunCompletedEvent", + "RunRetrieverResourceEvent", + "RunStreamChunkEvent", + "NodeEvent", + "ModelInvokeCompletedEvent", +] diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event/event.py similarity index 72% rename from api/core/workflow/nodes/event.py rename to api/core/workflow/nodes/event/event.py index 276c13a6d4..b7034561bf 100644 --- a/api/core/workflow/nodes/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field +from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult @@ -17,4 +18,11 @@ class RunRetrieverResourceEvent(BaseModel): context: str = Field(..., description="context") -RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent +class ModelInvokeCompletedEvent(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: str | None = None diff --git a/api/core/workflow/nodes/event/types.py b/api/core/workflow/nodes/event/types.py new file mode 100644 index 0000000000..b19a91022d --- /dev/null +++ b/api/core/workflow/nodes/event/types.py @@ -0,0 +1,3 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent + +NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py new file mode 100644 index 0000000000..0270d7e0fd --- /dev/null +++ b/api/core/workflow/nodes/http_request/executor.py @@ -0,0 +1,321 @@ +import json +from collections.abc import Mapping +from copy import deepcopy +from random import randint +from typing import Any, Literal +from urllib.parse import urlencode, urlparse + +import httpx + +from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy +from core.workflow.entities.variable_pool import VariablePool + +from .entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +BODY_TYPE_TO_CONTENT_TYPE = { + "json": "application/json", + "x-www-form-urlencoded": "application/x-www-form-urlencoded", + "form-data": "multipart/form-data", + "raw-text": "text/plain", +} + + +class Executor: + method: Literal["get", "head", "post", "put", "delete", "patch"] + url: str + params: Mapping[str, str] | None + content: str | bytes | None + data: Mapping[str, Any] | None + files: Mapping[str, bytes] | None + json: Any + headers: dict[str, str] + auth: HttpRequestNodeAuthorization + timeout: HttpRequestNodeTimeout + + boundary: str + + def __init__( + self, + *, + node_data: HttpRequestNodeData, + timeout: HttpRequestNodeTimeout, + variable_pool: VariablePool, + ): + # If authorization API key is present, convert the API key using the variable pool + if node_data.authorization.type == "api-key": + if node_data.authorization.config is None: + raise ValueError("authorization config is required") + node_data.authorization.config.api_key = variable_pool.convert_template( + node_data.authorization.config.api_key + ).text + + self.url: str = node_data.url + self.method = node_data.method + self.auth = node_data.authorization + self.timeout = timeout + self.params = {} + self.headers = {} + self.content = None + self.files = None + self.data = None + self.json = None + + # init template + self.variable_pool = variable_pool + self.node_data = node_data + self._initialize() + + def _initialize(self): + self._init_url() + self._init_params() + self._init_headers() + self._init_body() + + def _init_url(self): + self.url = self.variable_pool.convert_template(self.node_data.url).text + + def _init_params(self): + params = self.variable_pool.convert_template(self.node_data.params).text + self.params = _plain_text_to_dict(params) + + def _init_headers(self): + headers = self.variable_pool.convert_template(self.node_data.headers).text + self.headers = _plain_text_to_dict(headers) + + body = self.node_data.body + if body is None: + return + if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: + self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + if body.type == "form-data": + self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" + + def _init_body(self): + body = self.node_data.body + if body is not None: + data = body.data + match body.type: + case "none": + self.content = "" + case "raw-text": + self.content = self.variable_pool.convert_template(data[0].value).text + case "json": + json_string = self.variable_pool.convert_template(data[0].value).text + json_object = json.loads(json_string) + self.json = json_object + # self.json = self._parse_object_contains_variables(json_object) + case "binary": + file_selector = data[0].file + file_variable = self.variable_pool.get_file(file_selector) + if file_variable is None: + raise ValueError(f"cannot fetch file with selector {file_selector}") + file = file_variable.value + self.content = file_manager.download(file) + case "x-www-form-urlencoded": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in data + } + self.data = form_data + case "form-data": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in filter(lambda item: item.type == "text", data) + } + file_selectors = { + self.variable_pool.convert_template(item.key).text: item.file + for item in filter(lambda item: item.type == "file", data) + } + files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} + files = {k: v for k, v in files.items() if v is not None} + files = {k: variable.value for k, variable in files.items()} + files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None} + + self.data = form_data + self.files = files + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.auth) + headers = deepcopy(self.headers) or {} + if self.auth.type == "api-key": + if self.auth.config is None: + raise ValueError("self.authorization config is required") + if authorization.config is None: + raise ValueError("authorization config is required") + + if self.auth.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if self.auth.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.auth.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.auth.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key or "" + + return headers + + def _validate_and_parse_response(self, response: httpx.Response) -> Response: + executor_response = Response(response) + + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) + + return executor_response + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + if self.method not in {"get", "head", "post", "put", "delete", "patch"}: + raise ValueError(f"Invalid http method {self.method}") + + request_args = { + "url": self.url, + "data": self.data, + "files": self.files, + "json": self.json, + "content": self.content, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, + } + + response = getattr(ssrf_proxy, self.method)(**request_args) + return response + + def invoke(self) -> Response: + # assemble headers + headers = self._assembling_headers() + # do http request + response = self._do_http_request(headers) + # validate response + return self._validate_and_parse_response(response) + + def to_log(self): + url_parts = urlparse(self.url) + path = url_parts.path or "/" + + # Add query parameters + if self.params: + query_string = urlencode(self.params) + path += f"?{query_string}" + elif url_parts.query: + path += f"?{url_parts.query}" + + raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" + raw += f"Host: {url_parts.netloc}\r\n" + + headers = self._assembling_headers() + for k, v in headers.items(): + if self.auth.type == "api-key": + authorization_header = "Authorization" + if self.auth.config and self.auth.config.header: + authorization_header = self.auth.config.header + if k.lower() == authorization_header.lower(): + raw += f'{k}: {"*" * len(v)}\r\n' + continue + raw += f"{k}: {v}\r\n" + + body = "" + if self.files: + boundary = self.boundary + for k, v in self.files.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body += f"{v[1]}\r\n" + body += f"--{boundary}--\r\n" + elif self.node_data.body: + if self.content: + if isinstance(self.content, str): + body = self.content + elif isinstance(self.content, bytes): + body = self.content.decode("utf-8", errors="replace") + elif self.data and self.node_data.body.type == "x-www-form-urlencoded": + body = urlencode(self.data) + elif self.data and self.node_data.body.type == "form-data": + boundary = self.boundary + for key, value in self.data.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body += f"{value}\r\n" + body += f"--{boundary}--\r\n" + elif self.json: + body = json.dumps(self.json) + elif self.node_data.body.type == "raw-text": + body = self.node_data.body.data[0].value + if body: + raw += f"Content-Length: {len(body)}\r\n" + raw += "\r\n" # Empty line between headers and body + raw += body + + return raw + + +def _plain_text_to_dict(text: str, /) -> dict[str, str]: + """ + Convert a string of key-value pairs to a dictionary. + + Each line in the input string represents a key-value pair. + Keys and values are separated by ':'. + Empty values are allowed. + + Examples: + 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} + 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} + 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} + + Args: + convert_text (str): The input string to convert. + + Returns: + dict[str, str]: A dictionary of key-value pairs. + """ + return { + key.strip(): (value[0].strip() if value else "") + for line in text.splitlines() + if line.strip() + for key, *value in [line.split(":", 1)] + } + + +def _generate_random_string(n: int) -> str: + """ + Generate a random string of lowercase ASCII letters. + + Args: + n (int): The length of the random string to generate. + + Returns: + str: A random string of lowercase ASCII letters with length n. + + Example: + >>> _generate_random_string(5) + 'abcde' + """ + return "".join([chr(randint(97, 122)) for _ in range(n)]) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py deleted file mode 100644 index f8ab4e3132..0000000000 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ /dev/null @@ -1,343 +0,0 @@ -import json -from copy import deepcopy -from random import randint -from typing import Any, Optional, Union -from urllib.parse import urlencode - -import httpx - -from configs import dify_config -from core.helper import ssrf_proxy -from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser - - -class HttpExecutorResponse: - headers: dict[str, str] - response: httpx.Response - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} - - @property - def is_file(self) -> bool: - """ - check if response is file - """ - content_type = self.get_content_type() - file_content_types = ["image", "audio", "video"] - - return any(v in content_type for v in file_content_types) - - def get_content_type(self) -> str: - return self.headers.get("content-type", "") - - def extract_file(self) -> tuple[str, bytes]: - """ - extract file from response if content type is file related - """ - if self.is_file: - return self.get_content_type(), self.body - - return "", b"" - - @property - def content(self) -> str: - if isinstance(self.response, httpx.Response): - return self.response.text - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def body(self) -> bytes: - if isinstance(self.response, httpx.Response): - return self.response.content - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def status_code(self) -> int: - if isinstance(self.response, httpx.Response): - return self.response.status_code - else: - raise ValueError(f"Invalid response type {type(self.response)}") - - @property - def size(self) -> int: - return len(self.body) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - -class HttpExecutor: - server_url: str - method: str - authorization: HttpRequestNodeAuthorization - params: dict[str, Any] - headers: dict[str, Any] - body: Union[None, str] - files: Union[None, dict[str, Any]] - boundary: str - variable_selectors: list[VariableSelector] - timeout: HttpRequestNodeTimeout - - def __init__( - self, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: Optional[VariablePool] = None, - ): - self.server_url = node_data.url - self.method = node_data.method - self.authorization = node_data.authorization - self.timeout = timeout - self.params = {} - self.headers = {} - self.body = None - self.files = None - - # init template - self.variable_selectors = [] - self._init_template(node_data, variable_pool) - - @staticmethod - def _is_json_body(body: HttpRequestNodeBody): - """ - check if body is json - """ - if body and body.type == "json" and body.data: - try: - json.loads(body.data) - return True - except: - return False - - return False - - @staticmethod - def _to_dict(convert_text: str): - """ - Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` - """ - kv_paris = convert_text.split("\n") - result = {} - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(":", maxsplit=1) - if len(kv) == 1: - k, v = kv[0], "" - else: - k, v = kv - result[k.strip()] = v - return result - - def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - # extract all template in url - self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) - - # extract all template in params - params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - self.params = self._to_dict(params) - - # extract all template in headers - headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - self.headers = self._to_dict(headers) - - # extract all template in body - body_data_variable_selectors = [] - if node_data.body: - # check if it's a valid JSON - is_valid_json = self._is_json_body(node_data.body) - - body_data = node_data.body.data or "" - if body_data: - body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - - content_type_is_set = any(key.lower() == "content-type" for key in self.headers) - if node_data.body.type == "json" and not content_type_is_set: - self.headers["Content-Type"] = "application/json" - elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: - self.headers["Content-Type"] = "application/x-www-form-urlencoded" - - if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: - body = self._to_dict(body_data) - - if node_data.body.type == "form-data": - self.files = {k: ("", v) for k, v in body.items()} - random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f"----WebKitFormBoundary{random_str(16)}" - - self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" - else: - self.body = urlencode(body) - elif node_data.body.type in {"json", "raw-text"}: - self.body = body_data - elif node_data.body.type == "none": - self.body = "" - - self.variable_selectors = ( - server_url_variable_selectors - + params_variable_selectors - + headers_variable_selectors - + body_data_variable_selectors - ) - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or {} - if self.authorization.type == "api-key": - if self.authorization.config is None: - raise ValueError("self.authorization config is required") - if authorization.config is None: - raise ValueError("authorization config is required") - - if self.authorization.config.api_key is None: - raise ValueError("api_key is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.authorization.config.type == "bearer": - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.authorization.config.type == "basic": - headers[authorization.config.header] = f"Basic {authorization.config.api_key}" - elif self.authorization.config.type == "custom": - headers[authorization.config.header] = authorization.config.api_key - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse: - """ - validate the response - """ - if isinstance(response, httpx.Response): - executor_response = HttpExecutorResponse(response) - else: - raise ValueError(f"Invalid response type {type(response)}") - - threshold_size = ( - dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE - if executor_response.is_file - else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE - ) - if executor_response.size > threshold_size: - raise ValueError( - f'{"File" if executor_response.is_file else "Text"} size is too large,' - f' max size is {threshold_size / 1024 / 1024:.2f} MB,' - f' but current size is {executor_response.readable_size}.' - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - kwargs = { - "url": self.server_url, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "follow_redirects": True, - } - - if self.method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) - else: - raise ValueError(f"Invalid http method {self.method}") - return response - - def invoke(self) -> HttpExecutorResponse: - """ - invoke http request - """ - # assemble headers - headers = self._assembling_headers() - - # do http request - response = self._do_http_request(headers) - - # validate response - return self._validate_and_parse_response(response) - - def to_raw_request(self) -> str: - """ - convert to raw request - """ - server_url = self.server_url - if self.params: - server_url += f"?{urlencode(self.params)}" - - raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" - - headers = self._assembling_headers() - for k, v in headers.items(): - # get authorization header - if self.authorization.type == "api-key": - authorization_header = "Authorization" - if self.authorization.config and self.authorization.config.header: - authorization_header = self.authorization.config.header - - if k.lower() == authorization_header.lower(): - raw_request += f'{k}: {"*" * len(v)}\n' - continue - - raw_request += f"{k}: {v}\n" - - raw_request += "\n" - - # if files, use multipart/form-data with boundary - if self.files: - boundary = self.boundary - raw_request += f"--{boundary}" - for k, v in self.files.items(): - raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f"{v[1]}\n" - raw_request += f"--{boundary}" - raw_request += "--" - else: - raw_request += self.body or "" - - return raw_request - - def _format_template( - self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False - ) -> tuple[str, list[VariableSelector]]: - """ - format template - """ - variable_template_parser = VariableTemplateParser(template=template) - variable_selectors = variable_template_parser.extract_variable_selectors() - - if variable_pool: - variable_value_mapping = {} - for variable_selector in variable_selectors: - variable = variable_pool.get_any(variable_selector.value_selector) - if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") - if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace("\n", "\\n") - else: - value = variable - variable_value_mapping[variable_selector.variable] = value - - return variable_template_parser.format(variable_value_mapping), variable_selectors - else: - return template, variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py deleted file mode 100644 index cd40819126..0000000000 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -from collections.abc import Mapping, Sequence -from mimetypes import guess_extension -from os import path -from typing import Any, cast - -from configs import dify_config -from core.app.segments import parser -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse -from models.workflow import WorkflowNodeExecutionStatus - -HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, -) - - -class HttpRequestNode(BaseNode): - _node_data_cls = HttpRequestNodeData - _node_type = NodeType.HTTP_REQUEST - - @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, - }, - }, - } - - def _run(self) -> NodeRunResult: - node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) - # TODO: Switch to use segment directly - if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool - ).text - - # init http executor - http_executor = None - try: - http_executor = HttpExecutor( - node_data=node_data, - timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool, - ) - - # invoke http executor - response = http_executor.invoke() - except Exception as e: - process_data = {} - if http_executor: - process_data = { - "request": http_executor.to_raw_request(), - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - ) - - files = self.extract_files(http_executor.server_url, response) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.content if not files else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_raw_request(), - }, - ) - - @staticmethod - def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - timeout = node_data.timeout - if timeout is None: - return HTTP_REQUEST_DEFAULT_TIMEOUT - - timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - return timeout - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - try: - http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) - - variable_selectors = http_executor.variable_selectors - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - except Exception as e: - logging.exception(f"Failed to extract variable selector to variable mapping: {e}") - return {} - - def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: - """ - Extract files from response - """ - files = [] - mimetype, file_binary = response.extract_file() - - if mimetype: - # extract filename from url - filename = path.basename(url) - # extract extension if possible - extension = guess_extension(mimetype) or ".bin" - - tool_file = ToolFileManager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mimetype, - ) - - files.append( - FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=mimetype, - ) - ) - - return files diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py new file mode 100644 index 0000000000..483d0e2b7e --- /dev/null +++ b/api/core/workflow/nodes/http_request/node.py @@ -0,0 +1,174 @@ +import logging +from collections.abc import Mapping, Sequence +from mimetypes import guess_extension +from os import path +from typing import Any + +from configs import dify_config +from core.file import File, FileTransferMethod, FileType +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request.executor import Executor +from core.workflow.utils import variable_template_parser +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ( + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) + +HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, +) + +logger = logging.getLogger(__name__) + + +class HttpRequestNode(BaseNode[HttpRequestNodeData]): + _node_data_cls = HttpRequestNodeData + _node_type = NodeType.HTTP_REQUEST + + @classmethod + def get_default_config(cls, filters: dict | None = None) -> dict: + return { + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", + }, + "body": {"type": "none"}, + "timeout": { + **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + }, + }, + } + + def _run(self) -> NodeRunResult: + process_data = {} + try: + http_executor = Executor( + node_data=self.node_data, + timeout=self._get_request_timeout(self.node_data), + variable_pool=self.graph_runtime_state.variable_pool, + ) + process_data["request"] = http_executor.to_log() + + response = http_executor.invoke() + files = self.extract_files(url=http_executor.url, response=response) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + ) + except Exception as e: + logger.warning(f"http request node {self.node_id} failed to run: {e}") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data, + ) + + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + timeout = node_data.timeout + if timeout is None: + return HTTP_REQUEST_DEFAULT_TIMEOUT + + timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect + timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read + timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write + return timeout + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData, + ) -> Mapping[str, Sequence[str]]: + selectors: list[VariableSelector] = [] + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data + match body_type: + case "binary": + selector = data[0].file + selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) + case "json" | "raw-text": + selectors += variable_template_parser.extract_selectors_from_template(data[0].key) + selectors += variable_template_parser.extract_selectors_from_template(data[0].value) + case "x-www-form-urlencoded": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + selectors += variable_template_parser.extract_selectors_from_template(item.value) + case "form-data": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + if item.type == "text": + selectors += variable_template_parser.extract_selectors_from_template(item.value) + elif item.type == "file": + selectors.append( + VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) + ) + + mapping = {} + for selector in selectors: + mapping[node_id + "." + selector.variable] = selector.value_selector + + return mapping + + def extract_files(self, url: str, response: Response) -> list[File]: + """ + Extract files from response + """ + files = [] + content_type = response.content_type + content = response.content + + if content_type: + # extract filename from url + filename = path.basename(url) + # extract extension if possible + extension = guess_extension(content_type) or ".bin" + + tool_file = ToolFileManager.create_file_by_raw( + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + file_binary=content, + mimetype=content_type, + ) + + files.append( + File( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file.id, + filename=filename, + extension=extension, + mime_type=content_type, + ) + ) + + return files diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/core/workflow/nodes/list_operator/__init__.py new file mode 100644 index 0000000000..1877586ef4 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/__init__.py @@ -0,0 +1,3 @@ +from .node import ListOperatorNode + +__all__ = ["ListOperatorNode"] diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py new file mode 100644 index 0000000000..79cef1c27a --- /dev/null +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -0,0 +1,56 @@ +from collections.abc import Sequence +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData + +_Condition = Literal[ + # string conditions + "contains", + "start with", + "end with", + "is", + "in", + "empty", + "not contains", + "is not", + "not in", + "not empty", + # number conditions + "=", + "≠", + "<", + ">", + "≥", + "≤", +] + + +class FilterCondition(BaseModel): + key: str = "" + comparison_operator: _Condition = "contains" + value: str | Sequence[str] = "" + + +class FilterBy(BaseModel): + enabled: bool = False + conditions: Sequence[FilterCondition] = Field(default_factory=list) + + +class OrderBy(BaseModel): + enabled: bool = False + key: str = "" + value: Literal["asc", "desc"] = "asc" + + +class Limit(BaseModel): + enabled: bool = False + size: int = -1 + + +class ListOperatorNodeData(BaseNodeData): + variable: Sequence[str] = Field(default_factory=list) + filter_by: FilterBy + order_by: OrderBy + limit: Limit diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py new file mode 100644 index 0000000000..d7e4c64313 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/node.py @@ -0,0 +1,259 @@ +from collections.abc import Callable, Sequence +from typing import Literal + +from core.file import File +from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ListOperatorNodeData + + +class ListOperatorNode(BaseNode[ListOperatorNodeData]): + _node_data_cls = ListOperatorNodeData + _node_type = NodeType.LIST_OPERATOR + + def _run(self): + inputs = {} + process_data = {} + outputs = {} + + variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + if variable is None: + error_message = f"Variable not found for selector: {self.node_data.variable}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + error_message = ( + f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + "or ArrayStringSegment" + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + + if isinstance(variable, ArrayFileSegment): + process_data["variable"] = [item.to_dict() for item in variable.value] + else: + process_data["variable"] = variable.value + + # Filter + if self.node_data.filter_by.enabled: + for condition in self.node_data.filter_by.conditions: + if isinstance(variable, ArrayStringSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + if not isinstance(condition.value, str): + raise ValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, + ) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + + # Order + if self.node_data.order_by.enabled: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + + # Slice + if self.node_data.limit.enabled: + result = variable.value[: self.node_data.limit.size] + variable = variable.model_copy(update={"value": result}) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + +def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: + match key: + case "size": + return lambda x: x.size + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: + match key: + case "name": + return lambda x: x.filename or "" + case "type": + return lambda x: x.type + case "extension": + return lambda x: x.extension or "" + case "mimetype": + return lambda x: x.mime_type or "" + case "transfer_method": + return lambda x: x.transfer_method + case "url": + return lambda x: x.remote_url or "" + case _: + raise ValueError(f"Invalid key: {key}") + + +def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: + match condition: + case "contains": + return _contains(value) + case "start with": + return _startswith(value) + case "end with": + return _endswith(value) + case "is": + return _is(value) + case "in": + return _in(value) + case "empty": + return lambda x: x == "" + case "not contains": + return lambda x: not _contains(value)(x) + case "is not": + return lambda x: not _is(value)(x) + case "not in": + return lambda x: not _in(value)(x) + case "not empty": + return lambda x: x != "" + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: + match condition: + case "in": + return _in(value) + case "not in": + return lambda x: not _in(value)(x) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: + match condition: + case "=": + return _eq(value) + case "≠": + return _ne(value) + case "<": + return _lt(value) + case "≤": + return _le(value) + case ">": + return _gt(value) + case "≥": + return _ge(value) + case _: + raise ValueError(f"Invalid condition: {condition}") + + +def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) + if key in {"type", "transfer_method"} and isinstance(value, Sequence): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) + elif key == "size" and isinstance(value, str): + extract_func = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + else: + raise ValueError(f"Invalid key: {key}") + + +def _contains(value: str): + return lambda x: value in x + + +def _startswith(value: str): + return lambda x: x.startswith(value) + + +def _endswith(value: str): + return lambda x: x.endswith(value) + + +def _is(value: str): + return lambda x: x is value + + +def _in(value: str | Sequence[str]): + return lambda x: x in value + + +def _eq(value: int | float): + return lambda x: x == value + + +def _ne(value: int | float): + return lambda x: x != value + + +def _lt(value: int | float): + return lambda x: x < value + + +def _le(value: int | float): + return lambda x: x <= value + + +def _gt(value: int | float): + return lambda x: x > value + + +def _ge(value: int | float): + return lambda x: x >= value + + +def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: + extract_func = _get_file_extract_string_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + elif order_by == "size": + extract_func = _get_file_extract_number_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + else: + raise ValueError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/node.py similarity index 71% rename from api/core/workflow/nodes/llm/llm_node.py rename to api/core/workflow/nodes/llm/node.py index 3d336b0b0b..abf77f3339 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,39 +1,48 @@ import json from collections.abc import Generator, Mapping, Sequence -from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, cast -from pydantic import BaseModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, + TextPromptMessageContent, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.variables import ( + ArrayAnySegment, + ArrayFileSegment, + ArraySegment, + FileSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.llm.entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, + NodeEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunStreamChunkEvent, ) from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -41,44 +50,34 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) + if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelInvokeCompleted(BaseModel): - """ - Model invoke completed - """ - - text: str - usage: LLMUsage - finish_reason: Optional[str] = None - - -class LLMNode(BaseNode): +class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: - """ - Run node - :return: - """ - node_data = cast(LLMNodeData, deepcopy(self.node_data)) - variable_pool = self.graph_runtime_state.variable_pool - + def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: node_inputs = None process_data = None try: # init messages template - node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data, variable_pool) + inputs = self._fetch_inputs(node_data=self.node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) # merge inputs inputs.update(jinja_inputs) @@ -86,13 +85,17 @@ class LLMNode(BaseNode): node_inputs = {} # fetch files - files = self._fetch_files(node_data, variable_pool) + files = ( + self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + if self.node_data.vision.enabled + else [] + ) if files: node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data=self.node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -103,21 +106,31 @@ class LLMNode(BaseNode): node_inputs["#context#"] = context # type: ignore # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = self._fetch_model_config(self.node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) # fetch prompt messages + if self.node_data.memory: + query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if not query: + raise ValueError("Query not found") + query = query.text + else: + query = None + prompt_messages, stop = self._fetch_prompt_messages( - node_data=node_data, - query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, - query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, + system_query=query, inputs=inputs, files=files, context=context, memory=memory, model_config=model_config, + prompt_template=self.node_data.prompt_template, + memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, ) process_data = { @@ -131,7 +144,7 @@ class LLMNode(BaseNode): # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, + node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -143,7 +156,7 @@ class LLMNode(BaseNode): for event in generator: if isinstance(event, RunStreamChunkEvent): yield event - elif isinstance(event, ModelInvokeCompleted): + elif isinstance(event, ModelInvokeCompletedEvent): result_text = event.text usage = event.usage finish_reason = event.finish_reason @@ -182,15 +195,7 @@ class LLMNode(BaseNode): model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None, - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ + ) -> Generator[NodeEvent, None, None]: db.session.close() invoke_result = model_instance.invoke_llm( @@ -207,20 +212,13 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() for event in generator: yield event - if isinstance(event, ModelInvokeCompleted): + if isinstance(event, ModelInvokeCompletedEvent): usage = event.usage # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator - ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: - """ - Handle invoke result - :param invoke_result: invoke result - :return: - """ + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): return @@ -250,18 +248,11 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) + yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) def _transform_chat_messages( - self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - """ - Transform chat messages - - :param messages: chat messages - :return: - """ - + self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / + ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text @@ -274,69 +265,51 @@ class LLMNode(BaseNode): return messages - def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch jinja inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables = {} if not node_data.prompt_config: return variables for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + variable_name = variable_selector.variable + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + raise ValueError(f"Variable {variable_selector.variable} not found") - def parse_dict(d: dict) -> str: + def parse_dict(input_dict: Mapping[str, Any]) -> str: """ Parse dict into string """ # check if it's a context structure - if "metadata" in d and "_source" in d["metadata"] and "content" in d: - return d["content"] + if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: + return input_dict["content"] # else, parse the dict try: - return json.dumps(d, ensure_ascii=False) + return json.dumps(input_dict, ensure_ascii=False) except Exception: - return str(d) + return str(input_dict) - if isinstance(value, str): - value = value - elif isinstance(value, list): + if isinstance(variable, ArraySegment): result = "" - for item in value: + for item in variable.value: if isinstance(item, dict): result += parse_dict(item) - elif isinstance(item, str): - result += item - elif isinstance(item, int | float): - result += str(item) else: result += str(item) result += "\n" value = result.strip() - elif isinstance(value, dict): - value = parse_dict(value) - elif isinstance(value, int | float): - value = str(value) + elif isinstance(variable, ObjectSegment): + value = parse_dict(variable.value) else: - value = str(value) + value = variable.text - variables[variable] = value + variables[variable_name] = value return variables - def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: inputs = {} prompt_template = node_data.prompt_template @@ -350,11 +323,12 @@ class LLMNode(BaseNode): variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) - if variable_value is None: + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: raise ValueError(f"Variable {variable_selector.variable} not found") - - inputs[variable_selector.variable] = variable_value + if isinstance(variable, NoneSegment): + continue + inputs[variable_selector.variable] = variable.to_object() memory = node_data.memory if memory and memory.query_prompt_template: @@ -362,51 +336,44 @@ class LLMNode(BaseNode): template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) - if variable_value is None: + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: raise ValueError(f"Variable {variable_selector.variable} not found") - - inputs[variable_selector.variable] = variable_value + if isinstance(variable, NoneSegment): + continue + inputs[variable_selector.variable] = variable.to_object() return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: - """ - Fetch files - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.vision.enabled: + def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: + variable = self.graph_runtime_state.variable_pool.get(selector) + if variable is None: return [] - - files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) - if not files: + if isinstance(variable, FileSegment): + return [variable.value] + if isinstance(variable, ArrayFileSegment): + return variable.value + # FIXME: Temporary fix for empty array, + # all variables added to variable pool should be a Segment instance. + if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: return [] + raise ValueError(f"Invalid variable type: {type(variable)}") - return files - - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: - """ - Fetch context - :param node_data: node data - :param variable_pool: variable pool - :return: - """ + def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return if not node_data.context.variable_selector: return - context_value = variable_pool.get_any(node_data.context.variable_selector) - if context_value: - if isinstance(context_value, str): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) - elif isinstance(context_value, list): + context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) + if context_value_variable: + if isinstance(context_value_variable, StringSegment): + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource = [] - for item in context_value: + for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" else: @@ -424,11 +391,6 @@ class LLMNode(BaseNode): ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: - """ - Convert to original retriever resource, temp. - :param context_dict: context dict - :return: - """ if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -451,6 +413,7 @@ class LLMNode(BaseNode): "segment_position": metadata.get("segment_position"), "index_node_hash": metadata.get("segment_index_node_hash"), "content": context_dict.get("content"), + "page": metadata.get("page"), } return source @@ -460,11 +423,6 @@ class LLMNode(BaseNode): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config - :param node_data_model: node data model - :return: - """ model_name = node_data_model.name provider_name = node_data_model.provider @@ -523,21 +481,18 @@ class LLMNode(BaseNode): ) def _fetch_memory( - self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance ) -> Optional[TokenBufferMemory]: - """ - Fetch memory - :param node_data_memory: node data memory - :param variable_pool: variable pool - :return: - """ if not node_data_memory: return None # get conversation id - conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) - if conversation_id is None: + conversation_id_variable = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.CONVERSATION_ID.value] + ) + if not isinstance(conversation_id_variable, StringSegment): return None + conversation_id = conversation_id_variable.value # get conversation conversation = ( @@ -555,43 +510,32 @@ class LLMNode(BaseNode): def _fetch_prompt_messages( self, - node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + *, + system_query: str | None = None, + inputs: dict[str, str] | None = None, + files: Sequence["File"], + context: str | None = None, + memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, + vision_detail: ImagePromptMessageContent.DETAIL, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt messages - :param node_data: node data - :param query: query - :param query_prompt_template: query prompt template - :param inputs: inputs - :param files: files - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ + inputs = inputs or {} + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_messages = prompt_transform.get_prompt( - prompt_template=node_data.prompt_template, + prompt_template=prompt_template, inputs=inputs, - query=query or "", + query=system_query or "", files=files, context=context, - memory_config=node_data.memory, + memory_config=memory_config, memory=memory, model_config=model_config, - query_prompt_template=query_prompt_template, ) stop = model_config.stop - - vision_enabled = node_data.vision.enabled - vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None filtered_prompt_messages = [] for prompt_message in prompt_messages: if prompt_message.is_empty(): @@ -599,17 +543,17 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] - for content_item in prompt_message.content: - if ( - vision_enabled - and content_item.type == PromptMessageContentType.IMAGE - and isinstance(content_item, ImagePromptMessageContent) - ): - # Override vision config if LLM node has vision config - if vision_detail: - content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) + for content_item in prompt_message.content or []: + # Skip image if vision is disabled + if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + continue + + if isinstance(content_item, ImagePromptMessageContent): + # Override vision config if LLM node has vision config, + # cuz vision detail is related to the configuration from FileUpload feature. + content_item.detail = vision_detail prompt_message_content.append(content_item) - elif content_item.type == PromptMessageContentType.TEXT: + elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): prompt_message_content.append(content_item) if len(prompt_message_content) > 1: @@ -631,13 +575,6 @@ class LLMNode(BaseNode): @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - """ - Deduct LLM quota - :param tenant_id: tenant id - :param model_instance: model instance - :param usage: usage - :return: - """ provider_model_bundle = model_instance.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -668,7 +605,7 @@ class LLMNode(BaseNode): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == tenant_id, Provider.provider_name == model_instance.provider, @@ -680,27 +617,28 @@ class LLMNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ prompt_template = node_data.prompt_template variable_selectors = [] - if isinstance(prompt_template, list): + if isinstance(prompt_template, list) and all( + isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template + ): for prompt in prompt_template: if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - else: + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() + else: + raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: @@ -745,11 +683,6 @@ class LLMNode(BaseNode): @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return { "type": "llm", "config": { diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py new file mode 100644 index 0000000000..56b1d6bd28 --- /dev/null +++ b/api/extensions/ext_logging.py @@ -0,0 +1,45 @@ +import logging +import os +import sys +from logging.handlers import RotatingFileHandler + +from flask import Flask + +from configs import dify_config + + +def init_app(app: Flask): + log_handlers = None + log_file = dify_config.LOG_FILE + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024, + backupCount=dify_config.LOG_FILE_BACKUP_COUNT, + ), + logging.StreamHandler(sys.stdout), + ] + + logging.basicConfig( + level=dify_config.LOG_LEVEL, + format=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, + handlers=log_handlers, + force=True, + ) + log_tz = dify_config.LOG_TZ + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py new file mode 100644 index 0000000000..ead7b9a8b3 --- /dev/null +++ b/api/factories/file_factory.py @@ -0,0 +1,251 @@ +import mimetypes +from collections.abc import Mapping, Sequence +from typing import Any + +import httpx +from sqlalchemy import select + +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models import MessageFile, ToolFile, UploadFile +from models.enums import CreatedByRole + + +def build_from_message_files( + *, + message_files: Sequence["MessageFile"], + tenant_id: str, + config: FileExtraConfig, +) -> Sequence[File]: + results = [ + build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) + for file in message_files + if file.belongs_to != FileBelongsTo.ASSISTANT + ] + return results + + +def build_from_message_file( + *, + message_file: "MessageFile", + tenant_id: str, + config: FileExtraConfig, +): + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "id": message_file.id, + "type": message_file.type, + "upload_file_id": message_file.upload_file_id, + } + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=message_file.created_by, + role=CreatedByRole(message_file.created_by_role), + config=config, + ) + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, +): + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) + match transfer_method: + case FileTransferMethod.REMOTE_URL: + file = _build_from_remote_url( + mapping=mapping, + tenant_id=tenant_id, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.LOCAL_FILE: + file = _build_from_local_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + transfer_method=transfer_method, + ) + case FileTransferMethod.TOOL_FILE: + file = _build_from_tool_file( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + config=config, + transfer_method=transfer_method, + ) + case _: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileExtraConfig | None, + tenant_id: str, + user_id: str, + role: "CreatedByRole", +) -> Sequence[File]: + if not config: + return [] + + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + user_id=user_id, + role=role, + config=config, + ) + for mapping in mappings + ] + + if ( + # If image config is set. + config.image_config + # And the number of image files exceeds the maximum limit + and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + role: "CreatedByRole", + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + # check if the upload file exists. + file_type = FileType.value_of(mapping.get("type")) + stmt = select(UploadFile).where( + UploadFile.id == mapping.get("upload_file_id"), + UploadFile.tenant_id == tenant_id, + UploadFile.created_by == user_id, + UploadFile.created_by_role == role, + ) + if file_type == FileType.IMAGE: + stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) + elif file_type == FileType.VIDEO: + stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) + elif file_type == FileType.AUDIO: + stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) + elif file_type == FileType.DOCUMENT: + stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: + raise ValueError("Invalid upload file") + file = File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=None, + related_id=mapping.get("upload_file_id"), + _extra_config=config, + size=row.size, + ) + return file + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + url = mapping.get("url") + if not url: + raise ValueError("Invalid file url") + + mime_type = mimetypes.guess_type(url)[0] or "" + file_size = -1 + filename = url.split("/")[-1].split("?")[0] or "unknown_file" + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + if content_disposition := resp.headers.get("Content-Disposition"): + filename = content_disposition.split("filename=")[-1].strip('"') + file_size = int(resp.headers.get("Content-Length", file_size)) + mime_type = mime_type or str(resp.headers.get("Content-Type", "")) + + # Determine file extension + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + + if not mime_type: + mime_type, _ = mimetypes.guess_type(url) + file = File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + _extra_config=config, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + return file + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + user_id: str, + config: FileExtraConfig, + transfer_method: FileTransferMethod, +): + tool_file = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == mapping.get("tool_file_id"), + ToolFile.tenant_id == tenant_id, + ToolFile.user_id == user_id, + ) + .first() + ) + if tool_file is None: + raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + + path = tool_file.file_key + if "." in path: + extension = "." + path.split("/")[-1].split(".")[-1] + else: + extension = ".bin" + file = File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=tool_file.name, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=tool_file.original_url, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + _extra_config=config, + ) + return file diff --git a/api/core/app/segments/factory.py b/api/factories/variable_factory.py similarity index 73% rename from api/core/app/segments/factory.py rename to api/factories/variable_factory.py index 40a69ed4eb..a758f9981f 100644 --- a/api/core/app/segments/factory.py +++ b/api/factories/variable_factory.py @@ -2,29 +2,32 @@ from collections.abc import Mapping from typing import Any from configs import dify_config - -from .exc import VariableError -from .segments import ( +from core.file import File +from core.variables import ( ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayNumberVariable, + ArrayObjectSegment, + ArrayObjectVariable, + ArrayStringSegment, + ArrayStringVariable, + FileSegment, FloatSegment, + FloatVariable, IntegerSegment, + IntegerVariable, NoneSegment, ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, ObjectVariable, SecretVariable, + Segment, + SegmentType, + StringSegment, StringVariable, Variable, ) +from core.variables.exc import VariableError def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: @@ -71,6 +74,22 @@ def build_segment(value: Any, /) -> Segment: return FloatSegment(value=value) if isinstance(value, dict): return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) if isinstance(value, list): - return ArrayAnySegment(value=value) + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if len(types) != 1: + return ArrayAnySegment(value=value) + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER: + return ArrayNumberSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case _: + raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") diff --git a/api/fields/raws.py b/api/fields/raws.py new file mode 100644 index 0000000000..15ec16ab13 --- /dev/null +++ b/api/fields/raws.py @@ -0,0 +1,17 @@ +from flask_restful import fields + +from core.file import File + + +class FilesContainedField(fields.Raw): + def format(self, value): + return self._format_file_object(value) + + def _format_file_object(self, v): + if isinstance(v, File): + return v.model_dump() + if isinstance(v, dict): + return {k: self._format_file_object(vv) for k, vv in v.items()} + if isinstance(v, list): + return [self._format_file_object(vv) for vv in v] + return v diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py new file mode 100644 index 0000000000..c17d1db77a --- /dev/null +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -0,0 +1,49 @@ +"""add name and size to tool_files + +Revision ID: bbadea11becb +Revises: 33f5fac87f29 +Create Date: 2024-10-10 05:16:14.764268 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bbadea11becb' +down_revision = 'd8e744d88ed6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Get the database connection + conn = op.get_bind() + + # Use SQLAlchemy inspector to get the columns of the 'tool_files' table + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('tool_files')] + + # If 'name' or 'size' columns already exist, exit the upgrade function + if 'name' in columns or 'size' in columns: + return + + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True)) + op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL") + op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL") + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('name', existing_type=sa.String(), nullable=False) + batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.drop_column('size') + batch_op.drop_column('name') + # ### end Alembic commands ### diff --git a/api/models/enums.py b/api/models/enums.py new file mode 100644 index 0000000000..a83d35e042 --- /dev/null +++ b/api/models/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class CreatedByRole(str, Enum): + ACCOUNT = "account" + END_USER = "end_user" + + +class UserFrom(str, Enum): + ACCOUNT = "account" + END_USER = "end-user" + + +class WorkflowRunTriggeredFrom(str, Enum): + DEBUGGING = "debugging" + APP_RUN = "app-run" diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py new file mode 100644 index 0000000000..714064ffdf --- /dev/null +++ b/api/services/errors/workspace.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class WorkSpaceNotAllowedCreateError(BaseServiceError): + pass + + +class WorkSpaceNotFoundError(BaseServiceError): + pass diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py new file mode 100644 index 0000000000..d78fc2b891 --- /dev/null +++ b/api/tasks/mail_email_code_login.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_email_code_login_mail_task(language: str, to: str, code: str): + """ + Async Send email code login mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Email code to be included in the email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send email code login mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="邮箱验证码", html=html_content) + else: + html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Email Code", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send email code login mail to {} failed".format(to)) diff --git a/api/templates/email_code_login_mail_template_en-US.html b/api/templates/email_code_login_mail_template_en-US.html new file mode 100644 index 0000000000..066818d10c --- /dev/null +++ b/api/templates/email_code_login_mail_template_en-US.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Your login code for Dify

+

Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request a login, don't worry. You can safely ignore this email.

+
+ + diff --git a/api/templates/email_code_login_mail_template_zh-CN.html b/api/templates/email_code_login_mail_template_zh-CN.html new file mode 100644 index 0000000000..0c2b63a1f1 --- /dev/null +++ b/api/templates/email_code_login_mail_template_zh-CN.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Dify 的登录验证码

+

复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。

+
+ + diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py new file mode 100644 index 0000000000..c93292bd8a --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py @@ -0,0 +1,75 @@ +import os +from typing import Optional + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from upstash_vector import Index + + +# Mocking the Index class from upstash_vector +class MockIndex: + def __init__(self, url="", token=""): + self.url = url + self.token = token + self.vectors = [] + + def upsert(self, vectors): + for vector in vectors: + vector.score = 0.5 + self.vectors.append(vector) + return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)} + + def fetch(self, ids): + return [vector for vector in self.vectors if vector.id in ids] + + def delete(self, ids): + self.vectors = [vector for vector in self.vectors if vector.id not in ids] + return {"code": 0, "msg": "Success"} + + def query( + self, + vector: None, + top_k: int = 10, + include_vectors: bool = False, + include_metadata: bool = False, + filter: str = "", + data: Optional[str] = None, + namespace: str = "", + include_data: bool = False, + ): + # Simple mock query, in real scenario you would calculate similarity + mock_result = [] + for vector_data in self.vectors: + mock_result.append(vector_data) + return mock_result[:top_k] + + def reset(self): + self.vectors = [] + + def info(self): + return AttrDict({"dimension": 1024}) + + +class AttrDict(dict): + def __getattr__(self, item): + return self.get(item) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Index, "__init__", MockIndex.__init__) + monkeypatch.setattr(Index, "upsert", MockIndex.upsert) + monkeypatch.setattr(Index, "fetch", MockIndex.fetch) + monkeypatch.setattr(Index, "delete", MockIndex.delete) + monkeypatch.setattr(Index, "query", MockIndex.query) + monkeypatch.setattr(Index, "reset", MockIndex.reset) + monkeypatch.setattr(Index, "info", MockIndex.info) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/upstash/__init__.py b/api/tests/integration_tests/vdb/upstash/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py b/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000..23470474ff --- /dev/null +++ b/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,28 @@ +from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text + + +class UpstashVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = UpstashVector( + collection_name="test_collection", + config=UpstashVectorConfig( + url="your-server-url", + token="your-access-token", + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) != 0 + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_upstash_vector(setup_upstashvector_mock): + UpstashVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/test_sync_workflow.py b/api/tests/integration_tests/workflow/test_sync_workflow.py new file mode 100644 index 0000000000..df2ec95ebc --- /dev/null +++ b/api/tests/integration_tests/workflow/test_sync_workflow.py @@ -0,0 +1,57 @@ +""" +This test file is used to verify the compatibility of Workflow before and after supporting multiple file types. +""" + +import json + +from models import Workflow + +OLD_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NEW_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_extensions": [], + "allowed_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def test_workflow_features(): + workflow = Workflow( + tenant_id="", + app_id="", + type="", + version="", + graph="", + features=json.dumps(OLD_VERSION_WORKFLOW_FEATURES), + created_by="", + environment_variables=[], + conversation_variables=[], + ) + + assert workflow.features_dict == NEW_VERSION_WORKFLOW_FEATURES diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py new file mode 100644 index 0000000000..aa61c1c6f7 --- /dev/null +++ b/api/tests/unit_tests/core/test_file.py @@ -0,0 +1,40 @@ +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType + + +def test_file_loads_and_dumps(): + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + ) + + file_dict = file.model_dump() + assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY + assert file_dict["type"] == file.type.value + assert isinstance(file_dict["type"], str) + assert file_dict["transfer_method"] == file.transfer_method.value + assert isinstance(file_dict["transfer_method"], str) + assert "_extra_config" not in file_dict + + file_obj = File.model_validate(file_dict) + assert file_obj.id == file.id + assert file_obj.tenant_id == file.tenant_id + assert file_obj.type == file.type + assert file_obj.transfer_method == file.transfer_method + assert file_obj.remote_url == file.remote_url + + +def test_file_to_dict(): + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + ) + + file_dict = file.to_dict() + assert "_extra_config" not in file_dict + assert "url" in file_dict diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py deleted file mode 100644 index 279a6cdbc3..0000000000 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -from core.tools.entities.tool_entities import ToolParameter -from core.tools.utils.tool_parameter_converter import ToolParameterConverter - - -def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" - with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type("unsupported_type") - - -def test_cast_parameter_by_type(): - # string - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" - - # secret input - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" - - # select - assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" - - # boolean - true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] - for value in true_values: - assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - - false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] - for value in false_values: - assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False - - # number - assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None - - # unknown - assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" - assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py new file mode 100644 index 0000000000..8a41678267 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py @@ -0,0 +1,49 @@ +from core.tools.entities.tool_entities import ToolParameter + + +def test_get_parameter_type(): + assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean" + assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number" + assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file" + assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files" + + +def test_cast_parameter_by_type(): + # string + assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.STRING.cast_value(None) == "" + + # secret input + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == "" + + # select + assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == "" + + # boolean + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] + for value in true_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True + + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] + for value in false_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False + + # number + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py new file mode 100644 index 0000000000..a141fa9a13 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -0,0 +1,167 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.file import File, FileTransferMethod +from core.variables import ArrayFileSegment +from core.variables.variables import StringVariable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from core.workflow.nodes.document_extractor.node import ( + _extract_text_from_doc, + _extract_text_from_pdf, + _extract_text_from_plain_text, +) +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def document_extractor_node(): + node_data = DocumentExtractorNodeData( + title="Test Document Extractor", + variable_selector=["node_id", "variable_name"], + ) + return DocumentExtractorNode( + id="test_node_id", + config={"id": "test_node_id", "data": node_data.model_dump()}, + graph_init_params=Mock(), + graph=Mock(), + graph_runtime_state=Mock(), + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + return Mock() + + +def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = None + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "File variable not found" in result.error + + +def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = StringVariable( + value="Not an ArrayFileSegment", name="test" + ) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "is not an ArrayFileSegment" in result.error + + +@pytest.mark.parametrize( + ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), + [ + ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "application/pdf", + b"%PDF-1.5\n%Test PDF content", + ["Mocked PDF content"], + FileTransferMethod.LOCAL_FILE, + ".pdf", + ), + ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + b"PK\x03\x04", + ["Mocked DOCX content"], + FileTransferMethod.REMOTE_URL, + "", + ), + ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), + ], +) +def test_run_extract_text( + document_extractor_node, + mock_graph_runtime_state, + mime_type, + file_content, + expected_text, + transfer_method, + extension, + monkeypatch, +): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + mock_file = Mock(spec=File) + mock_file.mime_type = mime_type + mock_file.transfer_method = transfer_method + mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None + mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None + mock_file.extension = extension + + mock_array_file_segment = Mock(spec=ArrayFileSegment) + mock_array_file_segment.value = [mock_file] + + mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment + + mock_download = Mock(return_value=file_content) + mock_ssrf_proxy_get = Mock() + mock_ssrf_proxy_get.return_value.content = file_content + mock_ssrf_proxy_get.return_value.raise_for_status = Mock() + + monkeypatch.setattr("core.file.file_manager.download", mock_download) + monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) + + if mime_type == "application/pdf": + mock_pdf_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + elif mime_type.startswith("application/vnd.openxmlformats"): + mock_docx_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["text"] == expected_text + + if transfer_method == FileTransferMethod.REMOTE_URL: + mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + elif transfer_method == FileTransferMethod.LOCAL_FILE: + mock_download.assert_called_once_with(mock_file) + + +def test_extract_text_from_plain_text(): + text = _extract_text_from_plain_text(b"Hello, world!") + assert text == "Hello, world!" + + +@patch("pypdfium2.PdfDocument") +def test_extract_text_from_pdf(mock_pdf_document): + mock_page = Mock() + mock_text_page = Mock() + mock_text_page.get_text_range.return_value = "PDF content" + mock_page.get_textpage.return_value = mock_text_page + mock_pdf_document.return_value = [mock_page] + text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content") + assert text == "PDF content" + + +@patch("docx.Document") +def test_extract_text_from_doc(mock_document): + mock_paragraph1 = Mock() + mock_paragraph1.text = "Paragraph 1" + mock_paragraph2 = Mock() + mock_paragraph2.text = "Paragraph 2" + mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2] + + text = _extract_text_from_doc(b"PK\x03\x04") + assert text == "Paragraph 1\nParagraph 2" + + +def test_node_type(document_extractor_node): + assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py new file mode 100644 index 0000000000..2a5fda48b1 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py @@ -0,0 +1,369 @@ +import json + +import httpx + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.variables import FileVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNode, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_plain_text_to_dict(): + assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""} + assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"} + assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"} + assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"} + + +def test_http_request_node_binary_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="binary", + data=[ + BodyData( + key="file", + type="file", + value="", + file=["1111", "file"], + ) + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]), + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "test" + + +def test_http_request_node_form_with_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="file", + type="file", + file=["1111", "file"], + ), + BodyData( + key="name", + type="text", + value="test", + ), + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + + def attr_checker(*args, **kwargs): + assert kwargs["data"] == {"name": "test"} + assert kwargs["files"] == {"file": b"test"} + return httpx.Response(200, content=b"") + + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + attr_checker, + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "" + + +def test_executor_with_json_body_and_number_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "number"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Number Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"number": {{#pre_node_id.number#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"number": 42} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '{"number": 42}' in raw_request + + +def test_executor_with_json_body_and_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value="{{#pre_node_id.object#}}", + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_executor_with_json_body_and_nested_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"object": {{#pre_node_id.object#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"object": {' in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py new file mode 100644 index 0000000000..53e3c93fcc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -0,0 +1,111 @@ +from unittest.mock import MagicMock + +import pytest + +from core.file import File +from core.file.models import FileTransferMethod, FileType +from core.variables import ArrayFileSegment +from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy +from core.workflow.nodes.list_operator.node import ListOperatorNode +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def list_operator_node(): + config = { + "variable": ["test_variable"], + "filter_by": FilterBy( + enabled=True, + conditions=[ + FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT]) + ], + ), + "order_by": OrderBy(enabled=False, value="asc"), + "limit": Limit(enabled=False, size=0), + "title": "Test Title", + } + node_data = ListOperatorNodeData(**config) + node = ListOperatorNode( + id="test_node_id", + config={ + "id": "test_node_id", + "data": node_data.model_dump(), + }, + graph_init_params=MagicMock(), + graph=MagicMock(), + graph_runtime_state=MagicMock(), + ) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.variable_pool = MagicMock() + return node + + +def test_filter_files_by_type(list_operator_node): + # Setup test data + files = [ + File( + filename="image1.jpg", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related1", + ), + File( + filename="document1.pdf", + type=FileType.DOCUMENT, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related2", + ), + File( + filename="image2.png", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related3", + ), + File( + filename="audio1.mp3", + type=FileType.AUDIO, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related4", + ), + ] + variable = ArrayFileSegment(value=files) + list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable + + # Run the node + result = list_operator_node._run() + + # Verify the result + expected_files = [ + { + "filename": "image1.jpg", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related1", + }, + { + "filename": "document1.pdf", + "type": FileType.DOCUMENT, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related2", + }, + { + "filename": "image2.png", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related3", + }, + ] + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + for expected_file, result_file in zip(expected_files, result.outputs["result"]): + assert expected_file["filename"] == result_file.filename + assert expected_file["type"] == result_file.type + assert expected_file["tenant_id"] == result_file.tenant_id + assert expected_file["transfer_method"] == result_file.transfer_method + assert expected_file["related_id"] == result_file.related_id diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py new file mode 100644 index 0000000000..f990280c5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -0,0 +1,67 @@ +from core.model_runtime.entities import ImagePromptMessageContent +from core.workflow.nodes.question_classifier import QuestionClassifierNodeData + + +def test_init_question_classifier_node_data(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == True + assert node_data.vision.configs.variable_selector == ["image"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW + + +def test_init_question_classifier_node_data_without_vision_config(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == False + assert node_data.vision.configs.variable_selector == ["sys", "files"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py new file mode 100644 index 0000000000..9ea6acac17 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -0,0 +1,45 @@ +import pytest + +from core.file import File, FileTransferMethod, FileType +from core.variables import FileSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool + + +@pytest.fixture +def pool(): + return VariablePool(system_variables={}, user_inputs={}) + + +@pytest.fixture +def file(): + return File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + ) + + +def test_get_file_attribute(pool, file): + # Add a FileSegment to the pool + pool.add(("node_1", "file_var"), FileSegment(value=file)) + + # Test getting the 'name' attribute of the file + result = pool.get(("node_1", "file_var", "name")) + + assert result is not None + assert result.value == file.filename + + # Test getting a non-existent attribute + result = pool.get(("node_1", "file_var", "non_existent_attr")) + assert result is None + + +def test_use_long_selector(pool): + pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) + + result = pool.get(("node_1", "part_1", "part_2")) + assert result is not None + assert result.value == "test_value" diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py new file mode 100644 index 0000000000..2f90afcf89 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -0,0 +1,28 @@ +from core.variables import SecretVariable +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.utils import variable_template_parser + + +def test_extract_selectors_from_template(): + variable_pool = VariablePool( + system_variables={ + SystemVariableKey("user_id"): "fake-user-id", + }, + user_inputs={}, + environment_variables=[ + SecretVariable(name="secret_key", value="fake-secret-key"), + ], + conversation_variables=[], + ) + variable_pool.add(("node_id", "custom_query"), "fake-user-query") + template = ( + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." + ) + selectors = variable_template_parser.extract_selectors_from_template(template) + assert selectors == [ + VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), + VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), + VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), + ] diff --git a/api/tests/unit_tests/oss/__mock/__init__.py b/api/tests/unit_tests/oss/__mock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py new file mode 100644 index 0000000000..241764c521 --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -0,0 +1,100 @@ +import os +from typing import Union +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from tos import TosClientV2 +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput + + +class AttrDict(dict): + def __getattr__(self, item): + return self.get(item) + + +def get_example_bucket() -> str: + return "dify" + + +def get_example_filename() -> str: + return "test.txt" + + +def get_example_data() -> bytes: + return b"test" + + +def get_example_filepath() -> str: + return "/test" + + +class MockVolcengineTosClass: + def __init__(self, ak="", sk="", endpoint="", region=""): + self.bucket_name = get_example_bucket() + self.key = get_example_filename() + self.content = get_example_data() + self.filepath = get_example_filepath() + self.resp = AttrDict( + { + "x-tos-server-side-encryption": "kms", + "x-tos-server-side-encryption-kms-key-id": "trn:kms:cn-beijing:****:keyrings/ring-test/keys/key-test", + "x-tos-server-side-encryption-customer-algorithm": "AES256", + "x-tos-version-id": "test", + "x-tos-hash-crc64ecma": 123456, + "request_id": "test", + "headers": { + "x-tos-id-2": "test", + "ETag": "123456", + }, + "status": 200, + } + ) + + def put_object(self, bucket: str, key: str, content=None) -> PutObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + assert content == self.content + return PutObjectOutput(self.resp) + + def get_object(self, bucket: str, key: str) -> GetObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + + get_object_output = MagicMock(GetObjectOutput) + get_object_output.read.return_value = self.content + return get_object_output + + def get_object_to_file(self, bucket: str, key: str, file_path: str): + assert bucket == self.bucket_name + assert key == self.key + assert file_path == self.filepath + + def head_object(self, bucket: str, key: str) -> HeadObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + return HeadObjectOutput(self.resp) + + def delete_object(self, bucket: str, key: str): + assert bucket == self.bucket_name + assert key == self.key + return DeleteObjectOutput(self.resp) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_volcengine_tos_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(TosClientV2, "__init__", MockVolcengineTosClass.__init__) + monkeypatch.setattr(TosClientV2, "put_object", MockVolcengineTosClass.put_object) + monkeypatch.setattr(TosClientV2, "get_object", MockVolcengineTosClass.get_object) + monkeypatch.setattr(TosClientV2, "get_object_to_file", MockVolcengineTosClass.get_object_to_file) + monkeypatch.setattr(TosClientV2, "head_object", MockVolcengineTosClass.head_object) + monkeypatch.setattr(TosClientV2, "delete_object", MockVolcengineTosClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/volcengine_tos/__init__.py b/api/tests/unit_tests/oss/volcengine_tos/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py new file mode 100644 index 0000000000..545d18044d --- /dev/null +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -0,0 +1,67 @@ +from collections.abc import Generator + +from flask import Flask +from tos import TosClientV2 +from tos.clientv2 import GetObjectOutput, HeadObjectOutput, PutObjectOutput + +from extensions.storage.volcengine_tos_storage import VolcengineTosStorage +from tests.unit_tests.oss.__mock.volcengine_tos import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, + setup_volcengine_tos_mock, +) + + +class VolcengineTosTest: + _instance = None + + def __new__(cls): + if cls._instance == None: + cls._instance = object.__new__(cls) + return cls._instance + else: + return cls._instance + + def __init__(self): + self.storage = VolcengineTosStorage() + self.storage.bucket_name = get_example_bucket() + self.storage.client = TosClientV2( + ak="dify", + sk="dify", + endpoint="https://xxx.volces.com", + region="cn-beijing", + ) + + +def test_save(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + volc_tos.storage.save(get_example_filename(), get_example_data()) + + +def test_load_once(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + assert volc_tos.storage.load_once(get_example_filename()) == get_example_data() + + +def test_load_stream(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + generator = volc_tos.storage.load_stream(get_example_filename()) + assert isinstance(generator, Generator) + assert next(generator) == get_example_data() + + +def test_download(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + volc_tos.storage.download(get_example_filename(), get_example_filepath()) + + +def test_exists(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + assert volc_tos.storage.exists(get_example_filename()) + + +def test_delete(setup_volcengine_tos_mock): + volc_tos = VolcengineTosTest() + volc_tos.storage.delete(get_example_filename()) diff --git a/web/__mocks__/mime.js b/web/__mocks__/mime.js new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/app/components/app/app-publisher/features-wrapper.tsx b/web/app/components/app/app-publisher/features-wrapper.tsx new file mode 100644 index 0000000000..dadd112135 --- /dev/null +++ b/web/app/components/app/app-publisher/features-wrapper.tsx @@ -0,0 +1,86 @@ +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import type { AppPublisherProps } from '@/app/components/app/app-publisher' +import Confirm from '@/app/components/base/confirm' +import AppPublisher from '@/app/components/app/app-publisher' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { ModelAndParameter } from '@/app/components/app/configuration/debug/types' +import type { FileUpload } from '@/app/components/base/features/types' +import { Resolution } from '@/types/app' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' + +type Props = Omit & { + onPublish?: (modelAndParameter?: ModelAndParameter, features?: any) => Promise | any + publishedConfig?: any + resetAppConfig?: () => void +} + +const FeaturesWrappedAppPublisher = (props: Props) => { + const { t } = useTranslation() + const features = useFeatures(s => s.features) + const featuresStore = useFeaturesStore() + const [restoreConfirmOpen, setRestoreConfirmOpen] = useState(false) + const handleConfirm = useCallback(() => { + props.resetAppConfig?.() + const { + features, + setFeatures, + } = featuresStore!.getState() + const newFeatures = produce(features, (draft) => { + draft.moreLikeThis = props.publishedConfig.modelConfig.more_like_this || { enabled: false } + draft.opening = { + enabled: !!props.publishedConfig.modelConfig.opening_statement, + opening_statement: props.publishedConfig.modelConfig.opening_statement || '', + suggested_questions: props.publishedConfig.modelConfig.suggested_questions || [], + } + draft.moderation = props.publishedConfig.modelConfig.sensitive_word_avoidance || { enabled: false } + draft.speech2text = props.publishedConfig.modelConfig.speech_to_text || { enabled: false } + draft.text2speech = props.publishedConfig.modelConfig.text_to_speech || { enabled: false } + draft.suggested = props.publishedConfig.modelConfig.suggested_questions_after_answer || { enabled: false } + draft.citation = props.publishedConfig.modelConfig.retriever_resource || { enabled: false } + draft.annotationReply = props.publishedConfig.modelConfig.annotation_reply || { enabled: false } + draft.file = { + image: { + detail: props.publishedConfig.modelConfig.file_upload?.image?.detail || Resolution.high, + enabled: !!props.publishedConfig.modelConfig.file_upload?.image?.enabled, + number_limits: props.publishedConfig.modelConfig.file_upload?.image?.number_limits || 3, + transfer_methods: props.publishedConfig.modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(props.publishedConfig.modelConfig.file_upload?.enabled || props.publishedConfig.modelConfig.file_upload?.image?.enabled), + allowed_file_types: props.publishedConfig.modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: props.publishedConfig.modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: props.publishedConfig.modelConfig.file_upload?.allowed_file_upload_methods || props.publishedConfig.modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: props.publishedConfig.modelConfig.file_upload?.number_limits || props.publishedConfig.modelConfig.file_upload?.image?.number_limits || 3, + } as FileUpload + }) + setFeatures(newFeatures) + setRestoreConfirmOpen(false) + }, [featuresStore, props]) + + const handlePublish = useCallback((modelAndParameter?: ModelAndParameter) => { + return props.onPublish?.(modelAndParameter, features) + }, [features, props]) + + return ( + <> + setRestoreConfirmOpen(true), + }}/> + {restoreConfirmOpen && ( + setRestoreConfirmOpen(false)} + /> + )} + + ) +} + +export default FeaturesWrappedAppPublisher diff --git a/web/app/components/app/configuration/config-var/select-type-item/style.module.css b/web/app/components/app/configuration/config-var/select-type-item/style.module.css deleted file mode 100644 index 8ff716d58b..0000000000 --- a/web/app/components/app/configuration/config-var/select-type-item/style.module.css +++ /dev/null @@ -1,40 +0,0 @@ -.item { - display: flex; - flex-direction: column; - justify-content: center; - align-items: center; - height: 58px; - width: 98px; - border-radius: 8px; - border: 1px solid #EAECF0; - box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); - background-color: #fff; - cursor: pointer; -} - -.item:not(.selected):hover { - border-color: #B2CCFF; - background-color: #F5F8FF; - box-shadow: 0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06); -} - -.item.selected { - color: #155EEF; - border-color: #528BFF; - background-color: #F5F8FF; - box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); -} - -.text { - font-size: 13px; - color: #667085; - font-weight: 500; -} - -.item.selected .text { - color: #155EEF; -} - -.item:not(.selected):hover { - color: #344054; -} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-vision/radio-group/index.tsx b/web/app/components/app/configuration/config-vision/radio-group/index.tsx deleted file mode 100644 index a1cfb06e6a..0000000000 --- a/web/app/components/app/configuration/config-vision/radio-group/index.tsx +++ /dev/null @@ -1,40 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import s from './style.module.css' -import cn from '@/utils/classnames' - -type OPTION = { - label: string - value: any -} - -type Props = { - className?: string - options: OPTION[] - value: any - onChange: (value: any) => void -} - -const RadioGroup: FC = ({ - className = '', - options, - value, - onChange, -}) => { - return ( -
- {options.map(item => ( -
onChange(item.value)} - > -
-
{item.label}
-
- ))} -
- ) -} -export default React.memo(RadioGroup) diff --git a/web/app/components/app/configuration/config-vision/radio-group/style.module.css b/web/app/components/app/configuration/config-vision/radio-group/style.module.css deleted file mode 100644 index 22c29c6a42..0000000000 --- a/web/app/components/app/configuration/config-vision/radio-group/style.module.css +++ /dev/null @@ -1,24 +0,0 @@ -.item { - @apply grow flex items-center h-8 px-2.5 rounded-lg bg-gray-25 border border-gray-100 cursor-pointer space-x-2; -} - -.item:hover { - background-color: #ffffff; - border-color: #B2CCFF; - box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); -} - -.item.checked { - background-color: #ffffff; - border-color: #528BFF; - box-shadow: 0px 1px 2px 0px rgba(16, 24, 40, 0.06), 0px 1px 3px 0px rgba(16, 24, 40, 0.10); -} - -.radio { - @apply w-4 h-4 border-[2px] border-gray-200 rounded-full; -} - -.item.checked .radio { - border-width: 5px; - border-color: #155eef; -} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-voice/param-config-content.tsx b/web/app/components/app/configuration/config-voice/param-config-content.tsx deleted file mode 100644 index 4e70bdda21..0000000000 --- a/web/app/components/app/configuration/config-voice/param-config-content.tsx +++ /dev/null @@ -1,220 +0,0 @@ -'use client' -import useSWR from 'swr' -import type { FC } from 'react' -import { useContext } from 'use-context-selector' -import React, { Fragment } from 'react' -import { usePathname } from 'next/navigation' -import { useTranslation } from 'react-i18next' -import { Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' -import classNames from '@/utils/classnames' -import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' -import type { Item } from '@/app/components/base/select' -import ConfigContext from '@/context/debug-configuration' -import { fetchAppVoices } from '@/service/apps' -import Tooltip from '@/app/components/base/tooltip' -import { languages } from '@/i18n/language' -import { TtsAutoPlay } from '@/types/app' -const VoiceParamConfig: FC = () => { - const { t } = useTranslation() - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - - const { - textToSpeechConfig, - setTextToSpeechConfig, - } = useContext(ConfigContext) - - let languageItem = languages.find(item => item.value === textToSpeechConfig.language) - const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - if (languages && !languageItem && languages.length > 0) - languageItem = languages[0] - const language = languageItem?.value - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - let voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) - if (voiceItems && !voiceItem && voiceItems.length > 0) - voiceItem = voiceItems[0] - - const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') - - return ( -
-
-
{t('appDebug.voice.voiceSettings.title')}
-
-
-
-
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
- } - /> -
- { - setTextToSpeechConfig({ - ...textToSpeechConfig, - language: String(value.value), - }) - }} - > -
- - - {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} - - - - - - - - {languages.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} - {(selected || item.value === textToSpeechConfig.language) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.voice')}
- { - if (!value.value) - return - setTextToSpeechConfig({ - ...textToSpeechConfig, - voice: String(value.value), - }) - }} - > -
- - {voiceItem?.name ?? localVoicePlaceholder} - - - - - - - {voiceItems?.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {item.name} - {(selected || item.value === textToSpeechConfig.voice) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.autoPlay')}
- { - setTextToSpeechConfig({ - ...textToSpeechConfig, - autoPlay: value, - }) - }} - /> -
-
-
- - ) -} - -export default React.memo(VoiceParamConfig) diff --git a/web/app/components/app/configuration/config-voice/param-config.tsx b/web/app/components/app/configuration/config-voice/param-config.tsx deleted file mode 100644 index f1e2475495..0000000000 --- a/web/app/components/app/configuration/config-voice/param-config.tsx +++ /dev/null @@ -1,41 +0,0 @@ -'use client' -import type { FC } from 'react' -import { memo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import VoiceParamConfig from './param-config-content' -import cn from '@/utils/classnames' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' - -const ParamsConfig: FC = () => { - const { t } = useTranslation() - const [open, setOpen] = useState(false) - - return ( - - setOpen(v => !v)}> -
- -
{t('appDebug.voice.settings')}
-
-
- -
- -
-
-
- ) -} -export default memo(ParamsConfig) diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx new file mode 100644 index 0000000000..b63e3e2693 --- /dev/null +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -0,0 +1,220 @@ +import type { FC } from 'react' +import React from 'react' +import cn from 'classnames' +import useBoolean from 'ahooks/lib/useBoolean' +import { useTranslation } from 'react-i18next' +import ConfigPrompt from '../../config-prompt' +import { languageMap } from '../../../../workflow/nodes/_base/components/editor/code-editor/index' +import { generateRuleCode } from '@/service/debug' +import type { CodeGenRes } from '@/service/debug' +import { type AppType, type Model, ModelModeType } from '@/types/app' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import { Generator } from '@/app/components/base/icons/src/vender/other' +import Toast from '@/app/components/base/toast' +import Loading from '@/app/components/base/loading' +import Confirm from '@/app/components/base/confirm' +import type { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' +export type IGetCodeGeneratorResProps = { + mode: AppType + isShow: boolean + codeLanguages: CodeLanguage + onClose: () => void + onFinished: (res: CodeGenRes) => void +} + +export const GetCodeGeneratorResModal: FC = ( + { + mode, + isShow, + codeLanguages, + onClose, + onFinished, + }, +) => { + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const { t } = useTranslation() + const [instruction, setInstruction] = React.useState('') + const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) + const [res, setRes] = React.useState(null) + const isValid = () => { + if (instruction.trim() === '') { + Toast.notify({ + type: 'error', + message: t('common.errorMsg.fieldRequired', { + field: t('appDebug.code.instruction'), + }), + }) + return false + } + return true + } + const model: Model = { + provider: currentProvider?.provider || '', + name: currentModel?.model || '', + mode: ModelModeType.chat, + // This is a fixed parameter + completion_params: { + temperature: 0.7, + max_tokens: 0, + top_p: 0, + echo: false, + stop: [], + presence_penalty: 0, + frequency_penalty: 0, + }, + } + const isInLLMNode = true + const onGenerate = async () => { + if (!isValid()) + return + if (isLoading) + return + setLoadingTrue() + try { + const { error, ...res } = await generateRuleCode({ + instruction, + model_config: model, + no_variable: !!isInLLMNode, + code_language: languageMap[codeLanguages] || 'javascript', + }) + setRes(res) + if (error) { + Toast.notify({ + type: 'error', + message: error, + }) + } + } + finally { + setLoadingFalse() + } + } + const [showConfirmOverwrite, setShowConfirmOverwrite] = React.useState(false) + + const renderLoading = ( +
+ +
{t('appDebug.codegen.loading')}
+
+ ) + + return ( + +
+
+
+
{t('appDebug.codegen.title')}
+
{t('appDebug.codegen.description')}
+
+
+ + +
+
+
+
{t('appDebug.codegen.instruction')}
+ -
- ) - : ( -
- )} - {renderQuestions()} - ) : ( -
{t('appDebug.openingStatement.noDataPlaceHolder')}
- )} - - {isShowConfirmAddVar && ( - - )} - -
- - ) -} -export default React.memo(OpeningStatement) diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx b/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx deleted file mode 100644 index 2e08a99122..0000000000 --- a/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx +++ /dev/null @@ -1,38 +0,0 @@ -import ReactSlider from 'react-slider' -import s from './style.module.css' -import cn from '@/utils/classnames' - -type ISliderProps = { - className?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ className, max, min, step, value, disabled, onChange }) => { - return ( -
-
-
- {(state.valueNow / 100).toFixed(2)} -
-
-
- )} - /> -} - -export default Slider diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css b/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css deleted file mode 100644 index 4e93b39563..0000000000 --- a/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css +++ /dev/null @@ -1,20 +0,0 @@ -.slider { - position: relative; -} - -.slider.disabled { - opacity: 0.6; -} - -.slider-thumb:focus { - outline: none; -} - -.slider-track { - background-color: #528BFF; - height: 2px; -} - -.slider-track-1 { - background-color: #E5E7EB; -} \ No newline at end of file diff --git a/web/app/components/base/features/feature-panel/speech-to-text/index.tsx b/web/app/components/base/features/feature-panel/speech-to-text/index.tsx deleted file mode 100644 index 2e5e3de439..0000000000 --- a/web/app/components/base/features/feature-panel/speech-to-text/index.tsx +++ /dev/null @@ -1,22 +0,0 @@ -'use client' -import React, { type FC } from 'react' -import { useTranslation } from 'react-i18next' -import { Microphone01 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' - -const SpeechToTextConfig: FC = () => { - const { t } = useTranslation() - - return ( -
-
- -
-
-
{t('appDebug.feature.speechToText.title')}
-
-
-
{t('appDebug.feature.speechToText.resDes')}
-
- ) -} -export default React.memo(SpeechToTextConfig) diff --git a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx deleted file mode 100644 index e6d0b6e7e0..0000000000 --- a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx +++ /dev/null @@ -1,25 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { useTranslation } from 'react-i18next' -import { MessageSmileSquare } from '@/app/components/base/icons/src/vender/solid/communication' -import Tooltip from '@/app/components/base/tooltip' - -const SuggestedQuestionsAfterAnswer: FC = () => { - const { t } = useTranslation() - - return ( -
-
- -
-
-
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- -
-
-
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
-
- ) -} -export default React.memo(SuggestedQuestionsAfterAnswer) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/index.tsx b/web/app/components/base/features/feature-panel/text-to-speech/index.tsx deleted file mode 100644 index 2480a19077..0000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/index.tsx +++ /dev/null @@ -1,62 +0,0 @@ -'use client' -import useSWR from 'swr' -import React from 'react' -import { useTranslation } from 'react-i18next' -import { usePathname } from 'next/navigation' -import { useFeatures } from '../../hooks' -import type { OnFeaturesChange } from '../../types' -import ParamsConfig from './params-config' -import { Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import { languages } from '@/i18n/language' -import { fetchAppVoices } from '@/service/apps' -import AudioBtn from '@/app/components/base/audio-btn' - -type TextToSpeechProps = { - onChange?: OnFeaturesChange - disabled?: boolean -} -const TextToSpeech = ({ - onChange, - disabled, -}: TextToSpeechProps) => { - const { t } = useTranslation() - const textToSpeech = useFeatures(s => s.features.text2speech) - - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - const language = textToSpeech?.language - const languageInfo = languages.find(i => i.value === textToSpeech?.language) - - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - const voiceItem = voiceItems?.find(item => item.value === textToSpeech?.voice) - - return ( -
-
- -
-
- {t('appDebug.feature.textToSpeech.title')} -
-
-
-
- {languageInfo && (`${languageInfo?.name} - `)}{voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} - { languageInfo?.example && ( - - )} -
-
- -
-
- ) -} -export default React.memo(TextToSpeech) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx deleted file mode 100644 index e923d9a333..0000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx +++ /dev/null @@ -1,241 +0,0 @@ -'use client' -import useSWR from 'swr' -import produce from 'immer' -import React, { Fragment } from 'react' -import { usePathname } from 'next/navigation' -import { useTranslation } from 'react-i18next' -import { Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' -import { - useFeatures, - useFeaturesStore, -} from '../../hooks' -import type { OnFeaturesChange } from '../../types' -import classNames from '@/utils/classnames' -import type { Item } from '@/app/components/base/select' -import { fetchAppVoices } from '@/service/apps' -import Tooltip from '@/app/components/base/tooltip' -import { languages } from '@/i18n/language' -import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' -import { TtsAutoPlay } from '@/types/app' - -type VoiceParamConfigProps = { - onChange?: OnFeaturesChange -} -const VoiceParamConfig = ({ - onChange, -}: VoiceParamConfigProps) => { - const { t } = useTranslation() - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - const text2speech = useFeatures(state => state.features.text2speech) - const featuresStore = useFeaturesStore() - - let languageItem = languages.find(item => item.value === text2speech?.language) - if (languages && !languageItem) - languageItem = languages[0] - const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - - const language = languageItem?.value - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - let voiceItem = voiceItems?.find(item => item.value === text2speech?.voice) - if (voiceItems && !voiceItem) - voiceItem = voiceItems[0] - const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') - - const handleChange = (value: Record) => { - const { - features, - setFeatures, - } = featuresStore!.getState() - - const newFeatures = produce(features, (draft) => { - draft.text2speech = { - ...draft.text2speech, - ...value, - } - }) - - setFeatures(newFeatures) - if (onChange) - onChange(newFeatures) - } - - return ( -
-
-
{t('appDebug.voice.voiceSettings.title')}
-
-
-
-
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item} -
- ))} -
- } - /> -
- { - handleChange({ - language: String(value.value), - }) - }} - > -
- - - {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} - - - - - - - - {languages.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} - {(selected || item.value === text2speech?.language) && ( - - - )} - - )} - - ))} - - -
-
-
- -
-
{t('appDebug.voice.voiceSettings.voice')}
- { - handleChange({ - voice: String(value.value), - }) - }} - > -
- - {voiceItem?.name ?? localVoicePlaceholder} - - - - - - - {voiceItems?.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {item.name} - {(selected || item.value === text2speech?.voice) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.autoPlay')}
- { - handleChange({ - autoPlay: value, - }) - }} - /> -
-
-
-
- ) -} - -export default React.memo(VoiceParamConfig) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx b/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx deleted file mode 100644 index 095fd6cce8..0000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx +++ /dev/null @@ -1,48 +0,0 @@ -'use client' -import { memo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import type { OnFeaturesChange } from '../../types' -import ParamConfigContent from './param-config-content' -import cn from '@/utils/classnames' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' - -type ParamsConfigProps = { - onChange?: OnFeaturesChange - disabled?: boolean -} -const ParamsConfig = ({ - onChange, - disabled, -}: ParamsConfigProps) => { - const { t } = useTranslation() - const [open, setOpen] = useState(false) - - return ( - - !disabled && setOpen(v => !v)}> -
- -
{t('appDebug.voice.settings')}
-
-
- -
- -
-
-
- ) -} -export default memo(ParamsConfig) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-btn/index.tsx similarity index 100% rename from web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-btn/index.tsx diff --git a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx similarity index 99% rename from web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx index b660977d08..801f1348ee 100644 --- a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import ScoreSlider from '../score-slider' +import ScoreSlider from './score-slider' import { Item } from './config-param' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx new file mode 100644 index 0000000000..8b3a0af240 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx @@ -0,0 +1,24 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Tooltip from '@/app/components/base/tooltip' + +export const Item: FC<{ title: string; tooltip: string; children: JSX.Element }> = ({ + title, + tooltip, + children, +}) => { + return ( +
+
+
{title}
+ {tooltip}
+ } + /> +
+
{children}
+
+ ) +} diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx new file mode 100644 index 0000000000..f44aab5b9c --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx @@ -0,0 +1,152 @@ +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { usePathname, useRouter } from 'next/navigation' +import produce from 'immer' +import { RiEqualizer2Line, RiExternalLinkLine } from '@remixicon/react' +import { MessageFast } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import Button from '@/app/components/base/button' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import useAnnotationConfig from '@/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config' +import ConfigParamModal from '@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal' +import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' +import { ANNOTATION_DEFAULT } from '@/config' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange +} + +const AnnotationReply = ({ + disabled, + onChange, +}: Props) => { + const { t } = useTranslation() + const router = useRouter() + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const featuresStore = useFeaturesStore() + const annotationReply = useFeatures(s => s.features.annotationReply) + + const updateAnnotationReply = useCallback((newConfig: any) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + const newFeatures = produce(features, (draft) => { + draft.annotationReply = newConfig + }) + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + }, [featuresStore, onChange]) + + const { + handleEnableAnnotation, + handleDisableAnnotation, + isShowAnnotationConfigInit, + setIsShowAnnotationConfigInit, + isShowAnnotationFullModal, + setIsShowAnnotationFullModal, + } = useAnnotationConfig({ + appId, + annotationConfig: annotationReply as any || { + id: '', + enabled: false, + score_threshold: ANNOTATION_DEFAULT.score_threshold, + embedding_model: { + embedding_provider_name: '', + embedding_model_name: '', + }, + }, + setAnnotationConfig: updateAnnotationReply, + }) + + const handleSwitch = useCallback((enabled: boolean) => { + if (enabled) + setIsShowAnnotationConfigInit(true) + else + handleDisableAnnotation(annotationReply?.embedding_model as any) + }, [annotationReply?.embedding_model, handleDisableAnnotation, setIsShowAnnotationConfigInit]) + + const [isHovering, setIsHovering] = useState(false) + + return ( + <> + + + + } + title={t('appDebug.feature.annotation.title')} + value={!!annotationReply?.enabled} + onChange={state => handleSwitch(state)} + onMouseEnter={() => setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + disabled={disabled} + > + <> + {!annotationReply?.enabled && ( +
{t('appDebug.feature.annotation.description')}
+ )} + {!!annotationReply?.enabled && ( + <> + {!isHovering && ( +
+
+
{t('appDebug.feature.annotation.scoreThreshold.title')}
+
{annotationReply.score_threshold || '-'}
+
+
+
+
{t('common.modelProvider.embeddingModel.key')}
+
{annotationReply.embedding_model?.embedding_model_name}
+
+
+ )} + {isHovering && ( +
+ + +
+ )} + + )} + +
+ { + setIsShowAnnotationConfigInit(false) + // showChooseFeatureTrue() + }} + onSave={async (embeddingModel, score) => { + await handleEnableAnnotation(embeddingModel, score) + setIsShowAnnotationConfigInit(false) + }} + annotationConfig={annotationReply as any} + /> + {isShowAnnotationFullModal && ( + setIsShowAnnotationFullModal(false)} + /> + )} + + ) +} + +export default AnnotationReply diff --git a/web/app/components/app/configuration/toolbox/score-slider/base-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx similarity index 100% rename from web/app/components/app/configuration/toolbox/score-slider/base-slider/index.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx diff --git a/web/app/components/app/configuration/toolbox/score-slider/base-slider/style.module.css b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css similarity index 100% rename from web/app/components/app/configuration/toolbox/score-slider/base-slider/style.module.css rename to web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css diff --git a/web/app/components/base/features/feature-panel/score-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx similarity index 90% rename from web/app/components/base/features/feature-panel/score-slider/index.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx index 9826cbadcf..d68db9be73 100644 --- a/web/app/components/base/features/feature-panel/score-slider/index.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/app/configuration/toolbox/score-slider/base-slider' +import Slider from '@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider' type Props = { className?: string diff --git a/web/app/components/app/configuration/toolbox/annotation/type.ts b/web/app/components/base/features/new-feature-panel/annotation-reply/type.ts similarity index 100% rename from web/app/components/app/configuration/toolbox/annotation/type.ts rename to web/app/components/base/features/new-feature-panel/annotation-reply/type.ts diff --git a/web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts b/web/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config.ts similarity index 100% rename from web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts rename to web/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config.ts diff --git a/web/app/components/base/features/new-feature-panel/citation.tsx b/web/app/components/base/features/new-feature-panel/citation.tsx new file mode 100644 index 0000000000..a0b702e9f9 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/citation.tsx @@ -0,0 +1,56 @@ +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { Citations } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import { FeatureEnum } from '@/app/components/base/features/types' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange +} + +const Citation = ({ + disabled, + onChange, +}: Props) => { + const { t } = useTranslation() + const features = useFeatures(s => s.features) + const featuresStore = useFeaturesStore() + + const handleChange = useCallback((type: FeatureEnum, enabled: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft[type] = { + ...draft[type], + enabled, + } + }) + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + }, [featuresStore, onChange]) + + return ( + + + + } + title={t('appDebug.feature.citation.title')} + value={!!features.citation?.enabled} + description={t('appDebug.feature.citation.description')!} + onChange={state => handleChange(FeatureEnum.citation, state)} + disabled={disabled} + /> + ) +} + +export default Citation diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx new file mode 100644 index 0000000000..ab6b3ec6db --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx @@ -0,0 +1,119 @@ +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { RiEditLine } from '@remixicon/react' +import { LoveMessage } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import Button from '@/app/components/base/button' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import { FeatureEnum } from '@/app/components/base/features/types' +import { useModalContext } from '@/context/modal-context' +import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange + promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] + onAutoAddPromptVariable?: (variable: PromptVariable[]) => void +} + +const ConversationOpener = ({ + disabled, + onChange, + promptVariables, + workflowVariables, + onAutoAddPromptVariable, +}: Props) => { + const { t } = useTranslation() + const { setShowOpeningModal } = useModalContext() + const opening = useFeatures(s => s.features.opening) + const featuresStore = useFeaturesStore() + const [isHovering, setIsHovering] = useState(false) + const handleOpenOpeningModal = useCallback(() => { + if (disabled) + return + const { + features, + setFeatures, + } = featuresStore!.getState() + setShowOpeningModal({ + payload: { + ...opening, + promptVariables, + workflowVariables, + onAutoAddPromptVariable, + }, + onSaveCallback: (newOpening) => { + const newFeatures = produce(features, (draft) => { + draft.opening = newOpening + }) + setFeatures(newFeatures) + if (onChange) + onChange() + }, + onCancelCallback: () => { + if (onChange) + onChange() + }, + }) + }, [disabled, featuresStore, onAutoAddPromptVariable, onChange, opening, promptVariables, setShowOpeningModal]) + + const handleChange = useCallback((type: FeatureEnum, enabled: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft[type] = { + ...draft[type], + enabled, + } + }) + setFeatures(newFeatures) + if (onChange) + onChange() + }, [featuresStore, onChange]) + + return ( + + + + } + title={t('appDebug.feature.conversationOpener.title')} + value={!!opening?.enabled} + onChange={state => handleChange(FeatureEnum.opening, state)} + onMouseEnter={() => setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + disabled={disabled} + > + <> + {!opening?.enabled && ( +
{t('appDebug.feature.conversationOpener.description')}
+ )} + {!!opening?.enabled && ( + <> + {!isHovering && ( +
+ {opening.opening_statement || t('appDebug.openingStatement.placeholder')} +
+ )} + {isHovering && ( + + )} + + )} + +
+ ) +} + +export default ConversationOpener diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx new file mode 100644 index 0000000000..9f25d0fa11 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -0,0 +1,206 @@ +import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import produce from 'immer' +import { ReactSortable } from 'react-sortablejs' +import { RiAddLine, RiAsterisk, RiCloseLine, RiDeleteBinLine, RiDraggable } from '@remixicon/react' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/confirm-add-var' +import type { OpeningStatement } from '@/app/components/base/features/types' +import { getInputKeys } from '@/app/components/base/block-input' +import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' +import { getNewVar } from '@/utils/var' + +type OpeningSettingModalProps = { + data: OpeningStatement + onSave: (newState: OpeningStatement) => void + onCancel: () => void + promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] + onAutoAddPromptVariable?: (variable: PromptVariable[]) => void +} + +const MAX_QUESTION_NUM = 5 + +const OpeningSettingModal = ({ + data, + onSave, + onCancel, + promptVariables = [], + workflowVariables = [], + onAutoAddPromptVariable, +}: OpeningSettingModalProps) => { + const { t } = useTranslation() + const [tempValue, setTempValue] = useState(data?.opening_statement || '') + useEffect(() => { + setTempValue(data.opening_statement || '') + }, [data.opening_statement]) + const [tempSuggestedQuestions, setTempSuggestedQuestions] = useState(data.suggested_questions || []) + const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) + const [notIncludeKeys, setNotIncludeKeys] = useState([]) + + const handleSave = useCallback((ignoreVariablesCheck?: boolean) => { + if (!ignoreVariablesCheck) { + const keys = getInputKeys(tempValue) + const promptKeys = promptVariables.map(item => item.key) + const workflowVariableKeys = workflowVariables.map(item => item.variable) + let notIncludeKeys: string[] = [] + + if (promptKeys.length === 0 && workflowVariables.length === 0) { + if (keys.length > 0) + notIncludeKeys = keys + } + else { + if (workflowVariables.length > 0) + notIncludeKeys = keys.filter(key => !workflowVariableKeys.includes(key)) + else notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) + } + + if (notIncludeKeys.length > 0) { + setNotIncludeKeys(notIncludeKeys) + showConfirmAddVar() + return + } + } + const newOpening = produce(data, (draft) => { + if (draft) { + draft.opening_statement = tempValue + draft.suggested_questions = tempSuggestedQuestions + } + }) + onSave(newOpening) + }, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue]) + + const cancelAutoAddVar = useCallback(() => { + hideConfirmAddVar() + handleSave(true) + }, [handleSave, hideConfirmAddVar]) + + const autoAddVar = useCallback(() => { + onAutoAddPromptVariable?.([ + ...notIncludeKeys.map(key => getNewVar(key, 'string')), + ]) + hideConfirmAddVar() + handleSave(true) + }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) + + const renderQuestions = () => { + return ( +
+
+
+
{t('appDebug.openingStatement.openingQuestion')}
+
·
+
{tempSuggestedQuestions.length}/{MAX_QUESTION_NUM}
+
+
+
+ { + return { + id: index, + name, + } + })} + setList={list => setTempSuggestedQuestions(list.map(item => item.name))} + handle='.handle' + ghostClass="opacity-50" + animation={150} + > + {tempSuggestedQuestions.map((question, index) => { + return ( +
+ + { + const value = e.target.value + setTempSuggestedQuestions(tempSuggestedQuestions.map((item, i) => { + if (index === i) + return value + + return item + })) + }} + className={'w-full overflow-x-auto pl-1.5 pr-8 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer rounded-lg'} + /> + +
{ + setTempSuggestedQuestions(tempSuggestedQuestions.filter((_, i) => index !== i)) + }} + > + +
+
+ ) + })}
+ {tempSuggestedQuestions.length < MAX_QUESTION_NUM && ( +
{ setTempSuggestedQuestions([...tempSuggestedQuestions, '']) }} + className='mt-1 flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100 hover:bg-gray-200'> + +
{t('appDebug.variableConfig.addOption')}
+
+ )} +
+ ) + } + + return ( + { }} + className='!p-6 !mt-14 !max-w-none !w-[640px] !bg-components-panel-bg-blur' + > +
+
{t('appDebug.feature.conversationOpener.title')}
+
+
+
+
+ +
+
+ + ) + }, +) +Textarea.displayName = 'Textarea' + +export default Textarea +export { Textarea, textareaVariants } diff --git a/web/app/components/signin/countdown.tsx b/web/app/components/signin/countdown.tsx new file mode 100644 index 0000000000..6282480d10 --- /dev/null +++ b/web/app/components/signin/countdown.tsx @@ -0,0 +1,41 @@ +'use client' +import { useCountDown } from 'ahooks' +import { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' + +export const COUNT_DOWN_TIME_MS = 59000 +export const COUNT_DOWN_KEY = 'leftTime' + +type CountdownProps = { + onResend?: () => void +} + +export default function Countdown({ onResend }: CountdownProps) { + const { t } = useTranslation() + const [leftTime, setLeftTime] = useState(Number(localStorage.getItem(COUNT_DOWN_KEY) || COUNT_DOWN_TIME_MS)) + const [time] = useCountDown({ + leftTime, + onEnd: () => { + setLeftTime(0) + localStorage.removeItem(COUNT_DOWN_KEY) + }, + }) + + const resend = async function () { + setLeftTime(COUNT_DOWN_TIME_MS) + localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`) + onResend?.() + } + + useEffect(() => { + localStorage.setItem(COUNT_DOWN_KEY, `${time}`) + }, [time]) + + return

+ {t('login.checkCode.didNotReceiveCode')} + {time > 0 && {Math.round(time / 1000)}s} + { + time <= 0 && {t('login.checkCode.resend')} + } +

+} diff --git a/web/app/components/workflow/header/global-variable-button.tsx b/web/app/components/workflow/header/global-variable-button.tsx new file mode 100644 index 0000000000..ff02604b26 --- /dev/null +++ b/web/app/components/workflow/header/global-variable-button.tsx @@ -0,0 +1,20 @@ +import { memo } from 'react' +import Button from '@/app/components/base/button' +import { GlobalVariable } from '@/app/components/base/icons/src/vender/line/others' +import { useStore } from '@/app/components/workflow/store' + +const GlobalVariableButton = ({ disabled }: { disabled: boolean }) => { + const setShowPanel = useStore(s => s.setShowGlobalVariablePanel) + + const handleClick = () => { + setShowPanel(true) + } + + return ( + + ) +} + +export default memo(GlobalVariableButton) diff --git a/web/app/components/workflow/hooks/use-config-vision.ts b/web/app/components/workflow/hooks/use-config-vision.ts new file mode 100644 index 0000000000..a3cddbc47c --- /dev/null +++ b/web/app/components/workflow/hooks/use-config-vision.ts @@ -0,0 +1,88 @@ +import produce from 'immer' +import { useCallback } from 'react' +import { useIsChatMode } from './use-workflow' +import type { ModelConfig, VisionSetting } from '@/app/components/workflow/types' +import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { + ModelFeatureEnum, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { Resolution } from '@/types/app' + +type Payload = { + enabled: boolean + configs?: VisionSetting +} + +type Params = { + payload: Payload + onChange: (payload: Payload) => void +} +const useConfigVision = (model: ModelConfig, { + payload = { + enabled: false, + }, + onChange, +}: Params) => { + const { + currentModel: currModel, + } = useTextGenerationCurrentProviderAndModelAndModelList( + { + provider: model.provider, + model: model.name, + }, + ) + + const isChatMode = useIsChatMode() + + const getIsVisionModel = useCallback(() => { + return !!currModel?.features?.includes(ModelFeatureEnum.vision) + }, [currModel]) + + const isVisionModel = getIsVisionModel() + + const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { + const newPayload = produce(payload, (draft) => { + draft.enabled = enabled + if (enabled && isChatMode) { + draft.configs = { + detail: Resolution.high, + variable_selector: ['sys', 'files'], + } + } + }) + onChange(newPayload) + }, [isChatMode, onChange, payload]) + + const handleVisionResolutionChange = useCallback((config: VisionSetting) => { + const newPayload = produce(payload, (draft) => { + draft.configs = config + }) + onChange(newPayload) + }, [onChange, payload]) + + const handleModelChanged = useCallback(() => { + const isVisionModel = getIsVisionModel() + if (!isVisionModel) { + handleVisionResolutionEnabledChange(false) + return + } + if (payload.enabled) { + onChange({ + enabled: true, + configs: { + detail: Resolution.high, + variable_selector: [], + }, + }) + } + }, [getIsVisionModel, handleVisionResolutionEnabledChange, onChange, payload.enabled]) + + return { + isVisionModel, + handleVisionResolutionEnabledChange, + handleVisionResolutionChange, + handleModelChanged, + } +} + +export default useConfigVision diff --git a/web/app/components/workflow/nodes/_base/components/code-generator-button.tsx b/web/app/components/workflow/nodes/_base/components/code-generator-button.tsx new file mode 100644 index 0000000000..7f3a71dc09 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/code-generator-button.tsx @@ -0,0 +1,48 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useBoolean } from 'ahooks' +import cn from 'classnames' +import type { CodeLanguage } from '../../code/types' +import { Generator } from '@/app/components/base/icons/src/vender/other' +import { ActionButton } from '@/app/components/base/action-button' +import { AppType } from '@/types/app' +import type { CodeGenRes } from '@/service/debug' +import { GetCodeGeneratorResModal } from '@/app/components/app/configuration/config/code-generator/get-code-generator-res' + +type Props = { + className?: string + onGenerated?: (prompt: string) => void + codeLanguages: CodeLanguage +} + +const CodeGenerateBtn: FC = ({ + className, + codeLanguages, + onGenerated, +}) => { + const [showAutomatic, { setTrue: showAutomaticTrue, setFalse: showAutomaticFalse }] = useBoolean(false) + const handleAutomaticRes = useCallback((res: CodeGenRes) => { + onGenerated?.(res.code) + showAutomaticFalse() + }, [onGenerated, showAutomaticFalse]) + return ( +
+ + + + {showAutomatic && ( + + )} +
+ ) +} +export default React.memo(CodeGenerateBtn) diff --git a/web/app/components/workflow/nodes/_base/components/config-vision.tsx b/web/app/components/workflow/nodes/_base/components/config-vision.tsx new file mode 100644 index 0000000000..56cd1a5dbb --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/config-vision.tsx @@ -0,0 +1,91 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import VarReferencePicker from './variable/var-reference-picker' +import ResolutionPicker from '@/app/components/workflow/nodes/llm/components/resolution-picker' +import Field from '@/app/components/workflow/nodes/_base/components/field' +import Switch from '@/app/components/base/switch' +import { type ValueSelector, type Var, VarType, type VisionSetting } from '@/app/components/workflow/types' +import { Resolution } from '@/types/app' +import Tooltip from '@/app/components/base/tooltip' +const i18nPrefix = 'workflow.nodes.llm' + +type Props = { + isVisionModel: boolean + readOnly: boolean + enabled: boolean + onEnabledChange: (enabled: boolean) => void + nodeId: string + config?: VisionSetting + onConfigChange: (config: VisionSetting) => void +} + +const ConfigVision: FC = ({ + isVisionModel, + readOnly, + enabled, + onEnabledChange, + nodeId, + config = { + detail: Resolution.high, + variable_selector: [], + }, + onConfigChange, +}) => { + const { t } = useTranslation() + + const filterVar = useCallback((payload: Var) => { + return [VarType.file, VarType.arrayFile].includes(payload.type) + }, []) + const handleVisionResolutionChange = useCallback((resolution: Resolution) => { + const newConfig = produce(config, (draft) => { + draft.detail = resolution + }) + onConfigChange(newConfig) + }, [config, onConfigChange]) + + const handleVarSelectorChange = useCallback((valueSelector: ValueSelector | string) => { + const newConfig = produce(config, (draft) => { + draft.variable_selector = valueSelector as ValueSelector + }) + onConfigChange(newConfig) + }, [config, onConfigChange]) + + return ( + + + + } + > + {(enabled && isVisionModel) + ? ( +
+ + +
+ ) + : null} + +
+ ) +} +export default React.memo(ConfigVision) diff --git a/web/app/components/workflow/nodes/_base/components/file-type-item.tsx b/web/app/components/workflow/nodes/_base/components/file-type-item.tsx new file mode 100644 index 0000000000..c3d52f265b --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/file-type-item.tsx @@ -0,0 +1,77 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import { SupportUploadFileTypes } from '../../../types' +import cn from '@/utils/classnames' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import TagInput from '@/app/components/base/tag-input' +import Checkbox from '@/app/components/base/checkbox' +import { FileTypeIcon } from '@/app/components/base/file-uploader' + +type Props = { + type: SupportUploadFileTypes.image | SupportUploadFileTypes.document | SupportUploadFileTypes.audio | SupportUploadFileTypes.video | SupportUploadFileTypes.custom + selected: boolean + onToggle: (type: SupportUploadFileTypes) => void + onCustomFileTypesChange?: (customFileTypes: string[]) => void + customFileTypes?: string[] +} + +const FileTypeItem: FC = ({ + type, + selected, + onToggle, + customFileTypes = [], + onCustomFileTypesChange = () => { }, +}) => { + const { t } = useTranslation() + + const handleOnSelect = useCallback(() => { + onToggle(type) + }, [onToggle, type]) + + const isCustomSelected = type === SupportUploadFileTypes.custom && selected + + return ( +
+ {isCustomSelected + ? ( +
+
+ +
{t(`appDebug.variableConfig.file.${type}.name`)}
+ +
+
e.stopPropagation()}> + +
+
+ ) + : ( +
+ +
+
{t(`appDebug.variableConfig.file.${type}.name`)}
+
{type !== SupportUploadFileTypes.custom ? FILE_EXTS[type].join(', ') : t('appDebug.variableConfig.file.custom.description')}
+
+ +
+ )} + +
+ ) +} + +export default React.memo(FileTypeItem) diff --git a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx new file mode 100644 index 0000000000..82a3a906cf --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx @@ -0,0 +1,195 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import useSWR from 'swr' +import produce from 'immer' +import { useTranslation } from 'react-i18next' +import type { UploadFileSetting } from '../../../types' +import { SupportUploadFileTypes } from '../../../types' +import OptionCard from './option-card' +import FileTypeItem from './file-type-item' +import InputNumberWithSlider from './input-number-with-slider' +import Field from '@/app/components/app/configuration/config-var/config-modal/field' +import { TransferMethod } from '@/types/app' +import { fetchFileUploadConfig } from '@/service/common' +import { useFileSizeLimit } from '@/app/components/base/file-uploader/hooks' +import { formatFileSize } from '@/utils/format' + +type Props = { + payload: UploadFileSetting + isMultiple: boolean + inFeaturePanel?: boolean + hideSupportFileType?: boolean + onChange: (payload: UploadFileSetting) => void +} + +const FileUploadSetting: FC = ({ + payload, + isMultiple, + inFeaturePanel = false, + hideSupportFileType = false, + onChange, +}) => { + const { t } = useTranslation() + + const { + allowed_file_upload_methods, + max_length, + allowed_file_types, + allowed_file_extensions, + } = payload + const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileUploadConfigResponse) + + const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => { + const newPayload = produce(payload, (draft) => { + if (type === SupportUploadFileTypes.custom) { + if (!draft.allowed_file_types.includes(SupportUploadFileTypes.custom)) + draft.allowed_file_types = [SupportUploadFileTypes.custom] + + else + draft.allowed_file_types = draft.allowed_file_types.filter(v => v !== type) + } + else { + draft.allowed_file_types = draft.allowed_file_types.filter(v => v !== SupportUploadFileTypes.custom) + if (draft.allowed_file_types.includes(type)) + draft.allowed_file_types = draft.allowed_file_types.filter(v => v !== type) + else + draft.allowed_file_types.push(type) + } + }) + onChange(newPayload) + }, [onChange, payload]) + + const handleUploadMethodChange = useCallback((method: TransferMethod) => { + return () => { + const newPayload = produce(payload, (draft) => { + if (method === TransferMethod.all) + draft.allowed_file_upload_methods = [TransferMethod.local_file, TransferMethod.remote_url] + else + draft.allowed_file_upload_methods = [method] + }) + onChange(newPayload) + } + }, [onChange, payload]) + + const handleCustomFileTypesChange = useCallback((customFileTypes: string[]) => { + const newPayload = produce(payload, (draft) => { + draft.allowed_file_extensions = customFileTypes.map((v) => { + if (v.startsWith('.')) // Not start with dot + return v.slice(1) + return v + }) + }) + onChange(newPayload) + }, [onChange, payload]) + + const handleMaxUploadNumLimitChange = useCallback((value: number) => { + const newPayload = produce(payload, (draft) => { + draft.max_length = value + }) + onChange(newPayload) + }, [onChange, payload]) + + return ( +
+ {!inFeaturePanel && ( + +
+ { + [SupportUploadFileTypes.document, SupportUploadFileTypes.image, SupportUploadFileTypes.audio, SupportUploadFileTypes.video].map((type: SupportUploadFileTypes) => ( + + )) + } + `.${item}`)} + onCustomFileTypesChange={handleCustomFileTypesChange} + /> +
+
+ )} + +
+ + + +
+
+ {isMultiple && ( + +
+
{t('appDebug.variableConfig.maxNumberTip', { + imgLimit: formatFileSize(imgSizeLimit), + docLimit: formatFileSize(docSizeLimit), + audioLimit: formatFileSize(audioSizeLimit), + videoLimit: formatFileSize(videoSizeLimit), + })}
+ + +
+
+ )} + {inFeaturePanel && !hideSupportFileType && ( + +
+ { + [SupportUploadFileTypes.document, SupportUploadFileTypes.image, SupportUploadFileTypes.audio, SupportUploadFileTypes.video].map((type: SupportUploadFileTypes) => ( + + )) + } + +
+
+ )} + +
+ ) +} +export default React.memo(FileUploadSetting) diff --git a/web/app/components/workflow/nodes/_base/components/input-number-with-slider.tsx b/web/app/components/workflow/nodes/_base/components/input-number-with-slider.tsx new file mode 100644 index 0000000000..0210db2f8e --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/input-number-with-slider.tsx @@ -0,0 +1,65 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import Slider from '@/app/components/base/slider' + +type Props = { + value: number + defaultValue?: number + min?: number + max?: number + readonly?: boolean + onChange: (value: number) => void +} + +const InputNumberWithSlider: FC = ({ + value, + defaultValue = 0, + min, + max, + readonly, + onChange, +}) => { + const handleBlur = useCallback(() => { + if (value === undefined || value === null) { + onChange(defaultValue) + return + } + if (max !== undefined && value > max) { + onChange(max) + return + } + if (min !== undefined && value < min) + onChange(min) + }, [defaultValue, max, min, onChange, value]) + + const handleChange = useCallback((e: React.ChangeEvent) => { + onChange(Number.parseFloat(e.target.value)) + }, [onChange]) + + return ( +
+ + +
+ ) +} +export default React.memo(InputNumberWithSlider) diff --git a/web/app/components/workflow/nodes/code/dependency-picker.tsx b/web/app/components/workflow/nodes/code/dependency-picker.tsx new file mode 100644 index 0000000000..43e8523e17 --- /dev/null +++ b/web/app/components/workflow/nodes/code/dependency-picker.tsx @@ -0,0 +1,85 @@ +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { t } from 'i18next' +import { + RiArrowDownSLine, +} from '@remixicon/react' +import type { CodeDependency } from './types' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import Input from '@/app/components/base/input' +import { Check } from '@/app/components/base/icons/src/vender/line/general' + +type Props = { + value: CodeDependency + available_dependencies: CodeDependency[] + onChange: (dependency: CodeDependency) => void +} + +const DependencyPicker: FC = ({ + available_dependencies, + value, + onChange, +}) => { + const [open, setOpen] = useState(false) + const [searchText, setSearchText] = useState('') + + const handleChange = useCallback((dependency: CodeDependency) => { + return () => { + setOpen(false) + onChange(dependency) + } + }, [onChange]) + + return ( + + setOpen(!open)} className='flex-grow cursor-pointer'> +
+
{value.name}
+ +
+
+ +
+
+ setSearchText(e.target.value)} + onClear={() => setSearchText('')} + autoFocus + /> +
+
+ {available_dependencies.filter((v) => { + if (!searchText) + return true + return v.name.toLowerCase().includes(searchText.toLowerCase()) + }).map(dependency => ( +
+
{dependency.name}
+ {dependency.name === value.name && } +
+ ))} +
+
+
+
+ ) +} + +export default React.memo(DependencyPicker) diff --git a/web/app/components/workflow/nodes/document-extractor/default.ts b/web/app/components/workflow/nodes/document-extractor/default.ts new file mode 100644 index 0000000000..26eddff62b --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/default.ts @@ -0,0 +1,36 @@ +import { BlockEnum } from '../../types' +import type { NodeDefault } from '../../types' +import type { DocExtractorNodeType } from './types' +import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' +const i18nPrefix = 'workflow.errorMsg' + +const nodeDefault: NodeDefault = { + defaultValue: { + variable_selector: [], + is_array_file: false, + }, + getAvailablePrevNodes(isChatMode: boolean) { + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + return nodes + }, + getAvailableNextNodes(isChatMode: boolean) { + const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + return nodes + }, + checkValid(payload: DocExtractorNodeType, t: any) { + let errorMessages = '' + const { variable_selector: variable } = payload + + if (!errorMessages && !variable?.length) + errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.assigner.assignedVariable') }) + + return { + isValid: !errorMessages, + errorMessage: errorMessages, + } + }, +} + +export default nodeDefault diff --git a/web/app/components/workflow/nodes/document-extractor/node.tsx b/web/app/components/workflow/nodes/document-extractor/node.tsx new file mode 100644 index 0000000000..becf9fda95 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/node.tsx @@ -0,0 +1,42 @@ +import type { FC } from 'react' +import React from 'react' +import { useNodes } from 'reactflow' +import { useTranslation } from 'react-i18next' +import NodeVariableItem from '../variable-assigner/components/node-variable-item' +import type { DocExtractorNodeType } from './types' +import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { BlockEnum, type Node, type NodeProps } from '@/app/components/workflow/types' + +const i18nPrefix = 'workflow.nodes.docExtractor' + +const NodeComponent: FC> = ({ + data, +}) => { + const { t } = useTranslation() + + const nodes: Node[] = useNodes() + const { variable_selector: variable } = data + + if (!variable || variable.length === 0) + return null + + const isSystem = isSystemVar(variable) + const isEnv = isENV(variable) + const isChatVar = isConversationVar(variable) + const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === variable[0]) + const varName = isSystem ? `sys.${variable[variable.length - 1]}` : variable.slice(1).join('.') + return ( +
+
{t(`${i18nPrefix}.inputVar`)}
+ +
+ ) +} + +export default React.memo(NodeComponent) diff --git a/web/app/components/workflow/nodes/document-extractor/panel.tsx b/web/app/components/workflow/nodes/document-extractor/panel.tsx new file mode 100644 index 0000000000..52491875cd --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/panel.tsx @@ -0,0 +1,88 @@ +import type { FC } from 'react' +import React from 'react' +import useSWR from 'swr' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import VarReferencePicker from '../_base/components/variable/var-reference-picker' +import OutputVars, { VarItem } from '../_base/components/output-vars' +import Split from '../_base/components/split' +import { useNodeHelpLink } from '../_base/hooks/use-node-help-link' +import useConfig from './use-config' +import type { DocExtractorNodeType } from './types' +import { fetchSupportFileTypes } from '@/service/datasets' +import Field from '@/app/components/workflow/nodes/_base/components/field' +import { BlockEnum, type NodePanelProps } from '@/app/components/workflow/types' +import I18n from '@/context/i18n' +import { LanguagesSupported } from '@/i18n/language' + +const i18nPrefix = 'workflow.nodes.docExtractor' + +const Panel: FC> = ({ + id, + data, +}) => { + const { t } = useTranslation() + const { locale } = useContext(I18n) + const link = useNodeHelpLink(BlockEnum.DocExtractor) + const { data: supportFileTypesResponse } = useSWR({ url: '/files/support-type' }, fetchSupportFileTypes) + const supportTypes = supportFileTypesResponse?.allowed_extensions || [] + const supportTypesShowNames = (() => { + const extensionMap: { [key: string]: string } = { + md: 'markdown', + pptx: 'pptx', + htm: 'html', + xlsx: 'xlsx', + docx: 'docx', + } + + return [...supportTypes] + .map(item => extensionMap[item] || item) // map to standardized extension + .map(item => item.toLowerCase()) // convert to lower case + .filter((item, index, self) => self.indexOf(item) === index) // remove duplicates + .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') + })() + const { + readOnly, + inputs, + handleVarChanges, + filterVar, + } = useConfig(id, data) + + return ( +
+
+ + <> + +
+ {t(`${i18nPrefix}.supportFileTypes`, { types: supportTypesShowNames })} + {t(`${i18nPrefix}.learnMore`)} +
+ +
+
+ +
+ + + +
+
+ ) +} + +export default React.memo(Panel) diff --git a/web/app/components/workflow/nodes/document-extractor/types.ts b/web/app/components/workflow/nodes/document-extractor/types.ts new file mode 100644 index 0000000000..8ab7592109 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/types.ts @@ -0,0 +1,6 @@ +import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types' + +export type DocExtractorNodeType = CommonNodeType & { + variable_selector: ValueSelector + is_array_file: boolean +} diff --git a/web/app/components/workflow/nodes/document-extractor/use-config.ts b/web/app/components/workflow/nodes/document-extractor/use-config.ts new file mode 100644 index 0000000000..1654bee02a --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/use-config.ts @@ -0,0 +1,66 @@ +import { useCallback, useMemo } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' + +import type { ValueSelector, Var } from '../../types' +import { VarType } from '../../types' +import type { DocExtractorNodeType } from './types' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { + useIsChatMode, + useNodesReadOnly, + useWorkflow, + useWorkflowVariables, +} from '@/app/components/workflow/hooks' + +const useConfig = (id: string, payload: DocExtractorNodeType) => { + const { nodesReadOnly: readOnly } = useNodesReadOnly() + const { inputs, setInputs } = useNodeCrud(id, payload) + + const filterVar = useCallback((varPayload: Var) => { + return varPayload.type === VarType.file || varPayload.type === VarType.arrayFile + }, []) + + const isChatMode = useIsChatMode() + + const store = useStoreApi() + const { getBeforeNodesInSameBranch } = useWorkflow() + const { + getNodes, + } = store.getState() + const currentNode = getNodes().find(n => n.id === id) + const isInIteration = payload.isInIteration + const iterationNode = isInIteration ? getNodes().find(n => n.id === currentNode!.parentId) : null + const availableNodes = useMemo(() => { + return getBeforeNodesInSameBranch(id) + }, [getBeforeNodesInSameBranch, id]) + + const { getCurrentVariableType } = useWorkflowVariables() + const getType = useCallback((variable?: ValueSelector) => { + const varType = getCurrentVariableType({ + parentNode: iterationNode, + valueSelector: variable || [], + availableNodes, + isChatMode, + isConstant: false, + }) + return varType + }, [getCurrentVariableType, availableNodes, isChatMode, iterationNode]) + + const handleVarChanges = useCallback((variable: ValueSelector | string) => { + const newInputs = produce(inputs, (draft) => { + draft.variable_selector = variable as ValueSelector + draft.is_array_file = getType(draft.variable_selector) === VarType.arrayFile + }) + setInputs(newInputs) + }, [getType, inputs, setInputs]) + + return { + readOnly, + inputs, + filterVar, + handleVarChanges, + } +} + +export default useConfig diff --git a/web/app/components/workflow/nodes/if-else/components/condition-files-list-value.tsx b/web/app/components/workflow/nodes/if-else/components/condition-files-list-value.tsx new file mode 100644 index 0000000000..f21a3fac10 --- /dev/null +++ b/web/app/components/workflow/nodes/if-else/components/condition-files-list-value.tsx @@ -0,0 +1,115 @@ +import { + memo, + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { ComparisonOperator, type Condition } from '../types' +import { + comparisonOperatorNotRequireValue, + isComparisonOperatorNeedTranslate, + isEmptyRelatedOperator, +} from '../utils' +import { FILE_TYPE_OPTIONS, TRANSFER_METHOD } from '../default' +import type { ValueSelector } from '../../../types' +import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' +import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others' +import cn from '@/utils/classnames' +import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' +const i18nPrefix = 'workflow.nodes.ifElse' + +type ConditionValueProps = { + condition: Condition +} +const ConditionValue = ({ + condition, +}: ConditionValueProps) => { + const { t } = useTranslation() + const { + variable_selector, + comparison_operator: operator, + sub_variable_condition, + } = condition + + const variableSelector = variable_selector as ValueSelector + + const variableName = (isSystemVar(variableSelector) ? variableSelector.slice(0).join('.') : variableSelector.slice(1).join('.')) + const operatorName = isComparisonOperatorNeedTranslate(operator) ? t(`workflow.nodes.ifElse.comparisonOperator.${operator}`) : operator + const notHasValue = comparisonOperatorNotRequireValue(operator) + const isEnvVar = isENV(variableSelector) + const isChatVar = isConversationVar(variableSelector) + const formatValue = useCallback((c: Condition) => { + const notHasValue = comparisonOperatorNotRequireValue(c.comparison_operator) + if (notHasValue) + return '' + + const value = c.value as string + return value.replace(/{{#([^#]*)#}}/g, (a, b) => { + const arr: string[] = b.split('.') + if (isSystemVar(arr)) + return `{{${b}}}` + + return `{{${arr.slice(1).join('.')}}}` + }) + }, []) + + const isSelect = useCallback((c: Condition) => { + return c.comparison_operator === ComparisonOperator.in || c.comparison_operator === ComparisonOperator.notIn + }, []) + + const selectName = useCallback((c: Condition) => { + const isSelect = c.comparison_operator === ComparisonOperator.in || c.comparison_operator === ComparisonOperator.notIn + if (isSelect) { + const name = [...FILE_TYPE_OPTIONS, ...TRANSFER_METHOD].filter(item => item.value === (Array.isArray(c.value) ? c.value[0] : c.value))[0] + return name + ? t(`workflow.nodes.ifElse.optionName.${name.i18nKey}`).replace(/{{#([^#]*)#}}/g, (a, b) => { + const arr: string[] = b.split('.') + if (isSystemVar(arr)) + return `{{${b}}}` + + return `{{${arr.slice(1).join('.')}}}` + }) + : '' + } + return '' + }, []) + + return ( +
+
+ {!isEnvVar && !isChatVar && } + {isEnvVar && } + {isChatVar && } + +
+ {variableName} +
+
+ {operatorName} +
+
+
+ { + sub_variable_condition?.conditions.map((c: Condition, index) => ( +
+
{c.key}
+
{isComparisonOperatorNeedTranslate(c.comparison_operator) ? t(`workflow.nodes.ifElse.comparisonOperator.${c.comparison_operator}`) : c.comparison_operator}
+ {c.comparison_operator && !isEmptyRelatedOperator(c.comparison_operator) &&
{isSelect(c) ? selectName(c) : formatValue(c)}
} + {index !== sub_variable_condition.conditions.length - 1 && (
{t(`${i18nPrefix}.${sub_variable_condition.logical_operator}`)}
)} +
+ )) + } +
+
+ ) +} + +export default memo(ConditionValue) diff --git a/web/app/components/workflow/nodes/if-else/components/condition-wrap.tsx b/web/app/components/workflow/nodes/if-else/components/condition-wrap.tsx new file mode 100644 index 0000000000..39c03c9b38 --- /dev/null +++ b/web/app/components/workflow/nodes/if-else/components/condition-wrap.tsx @@ -0,0 +1,225 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { ReactSortable } from 'react-sortablejs' +import { + RiAddLine, + RiDeleteBinLine, + RiDraggable, +} from '@remixicon/react' +import type { CaseItem, HandleAddCondition, HandleAddSubVariableCondition, HandleRemoveCondition, HandleToggleConditionLogicalOperator, HandleToggleSubVariableConditionLogicalOperator, HandleUpdateCondition, HandleUpdateSubVariableCondition, handleRemoveSubVariableCondition } from '../types' +import type { Node, NodeOutPutVar, Var } from '../../../types' +import { VarType } from '../../../types' +import { useGetAvailableVars } from '../../variable-assigner/hooks' +import { SUB_VARIABLES } from '../default' +import ConditionList from './condition-list' +import ConditionAdd from './condition-add' +import cn from '@/utils/classnames' +import Button from '@/app/components/base/button' +import { PortalSelect as Select } from '@/app/components/base/select' + +type Props = { + isSubVariable?: boolean + caseId?: string + conditionId?: string + cases: CaseItem[] + readOnly: boolean + handleSortCase?: (sortedCases: (CaseItem & { id: string })[]) => void + handleRemoveCase?: (caseId: string) => void + handleAddCondition?: HandleAddCondition + handleRemoveCondition?: HandleRemoveCondition + handleUpdateCondition?: HandleUpdateCondition + handleToggleConditionLogicalOperator?: HandleToggleConditionLogicalOperator + handleAddSubVariableCondition?: HandleAddSubVariableCondition + handleRemoveSubVariableCondition?: handleRemoveSubVariableCondition + handleUpdateSubVariableCondition?: HandleUpdateSubVariableCondition + handleToggleSubVariableConditionLogicalOperator?: HandleToggleSubVariableConditionLogicalOperator + nodeId: string + nodesOutputVars: NodeOutPutVar[] + availableNodes: Node[] + varsIsVarFileAttribute?: Record + filterVar: (varPayload: Var) => boolean +} + +const ConditionWrap: FC = ({ + isSubVariable, + caseId, + conditionId, + nodeId: id = '', + cases = [], + readOnly, + handleSortCase = () => { }, + handleRemoveCase, + handleUpdateCondition, + handleAddCondition, + handleRemoveCondition, + handleToggleConditionLogicalOperator, + handleAddSubVariableCondition, + handleRemoveSubVariableCondition, + handleUpdateSubVariableCondition, + handleToggleSubVariableConditionLogicalOperator, + nodesOutputVars = [], + availableNodes = [], + varsIsVarFileAttribute = {}, + filterVar = () => true, +}) => { + const { t } = useTranslation() + + const getAvailableVars = useGetAvailableVars() + + const [willDeleteCaseId, setWillDeleteCaseId] = useState('') + const casesLength = cases.length + + const filterNumberVar = useCallback((varPayload: Var) => { + return varPayload.type === VarType.number + }, []) + + const subVarOptions = SUB_VARIABLES.map(item => ({ + name: item, + value: item, + })) + + return ( + <> + ({ ...caseItem, id: caseItem.case_id }))} + setList={handleSortCase} + handle='.handle' + ghostClass='bg-components-panel-bg' + animation={150} + disabled={readOnly || isSubVariable} + > + { + cases.map((item, index) => ( +
+
+ {!isSubVariable && ( + <> + 1 && 'group-hover:block', + )} /> +
+ { + index === 0 ? 'IF' : 'ELIF' + } + { + casesLength > 1 && ( +
CASE {index + 1}
+ ) + } +
+ + )} + + { + !!item.conditions.length && ( +
+ +
+ ) + } + +
+ {isSubVariable + ? ( + handleChange('value')(item.value)} + className='!text-[13px]' + wrapperClassName='grow h-8' + placeholder='Select value' + /> + )} + {!isSelect && ( + handleChange('value')(e.target.value)} + /> + )} + + )} +
+
+ ) +} +export default React.memo(FilterCondition) diff --git a/web/app/components/workflow/nodes/list-operator/components/limit-config.tsx b/web/app/components/workflow/nodes/list-operator/components/limit-config.tsx new file mode 100644 index 0000000000..b8812d3473 --- /dev/null +++ b/web/app/components/workflow/nodes/list-operator/components/limit-config.tsx @@ -0,0 +1,80 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import type { Limit } from '../types' +import InputNumberWithSlider from '../../_base/components/input-number-with-slider' +import cn from '@/utils/classnames' +import Field from '@/app/components/workflow/nodes/_base/components/field' +import Switch from '@/app/components/base/switch' + +const i18nPrefix = 'workflow.nodes.listFilter' +const LIMIT_SIZE_MIN = 1 +const LIMIT_SIZE_MAX = 20 +const LIMIT_SIZE_DEFAULT = 10 + +type Props = { + className?: string + readonly: boolean + config: Limit + onChange: (limit: Limit) => void + canSetRoleName?: boolean +} + +const LIMIT_DEFAULT: Limit = { + enabled: false, + size: LIMIT_SIZE_DEFAULT, +} + +const LimitConfig: FC = ({ + className, + readonly, + config = LIMIT_DEFAULT, + onChange, +}) => { + const { t } = useTranslation() + const payload = config + + const handleLimitEnabledChange = useCallback((enabled: boolean) => { + onChange({ + ...config, + enabled, + }) + }, [config, onChange]) + + const handleLimitSizeChange = useCallback((size: number | string) => { + onChange({ + ...config, + size: Number.parseInt(size as string), + }) + }, [onChange, config]) + + return ( +
+ + } + > + {payload?.enabled + ? ( + + ) + : null} + +
+ ) +} +export default React.memo(LimitConfig) diff --git a/web/app/components/workflow/nodes/list-operator/components/sub-variable-picker.tsx b/web/app/components/workflow/nodes/list-operator/components/sub-variable-picker.tsx new file mode 100644 index 0000000000..0a210504cf --- /dev/null +++ b/web/app/components/workflow/nodes/list-operator/components/sub-variable-picker.tsx @@ -0,0 +1,73 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import { SUB_VARIABLES } from '../../if-else/default' +import type { Item } from '@/app/components/base/select' +import { SimpleSelect as Select } from '@/app/components/base/select' +import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' +import cn from '@/utils/classnames' + +type Props = { + value: string + onChange: (value: string) => void + className?: string +} + +const SubVariablePicker: FC = ({ + value, + onChange, + className, +}) => { + const { t } = useTranslation() + const subVarOptions = SUB_VARIABLES.map(item => ({ + value: item, + name: item, + })) + + const renderOption = ({ item }: { item: Record }) => { + return ( +
+
+ + {item.name} +
+ {item.type} +
+ ) + } + + const handleChange = useCallback(({ value }: Item) => { + onChange(value as string) + }, [onChange]) + + return ( +
+ + + setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} /> + + + +
+
+
+
router.back()} className='flex items-center justify-center h-9 text-text-tertiary cursor-pointer'> +
+ +
+ {t('login.back')} +
+
+} diff --git a/web/app/reset-password/layout.tsx b/web/app/reset-password/layout.tsx new file mode 100644 index 0000000000..16d8642ed2 --- /dev/null +++ b/web/app/reset-password/layout.tsx @@ -0,0 +1,39 @@ +import Header from '../signin/_header' +import style from '../signin/page.module.css' + +import cn from '@/utils/classnames' + +export default async function SignInLayout({ children }: any) { + return <> +
+
+
+
+
+ {children} +
+
+
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. +
+
+
+ +} diff --git a/web/app/reset-password/page.tsx b/web/app/reset-password/page.tsx new file mode 100644 index 0000000000..65f1db3fb5 --- /dev/null +++ b/web/app/reset-password/page.tsx @@ -0,0 +1,101 @@ +'use client' +import Link from 'next/link' +import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useState } from 'react' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '../components/signin/countdown' +import { emailRegex } from '@/config' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Toast from '@/app/components/base/toast' +import { sendResetPasswordCode } from '@/service/common' +import I18NContext from '@/context/i18n' + +export default function CheckCode() { + const { t } = useTranslation() + const searchParams = useSearchParams() + const router = useRouter() + const [email, setEmail] = useState('') + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const handleGetEMailVerificationCode = async () => { + try { + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + setIsLoading(true) + const res = await sendResetPasswordCode(email, locale) + if (res.result === 'success') { + localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`) + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(res.data)) + params.set('email', encodeURIComponent(email)) + router.push(`/reset-password/check-code?${params.toString()}`) + } + else if (res.code === 'account_not_found') { + Toast.notify({ + type: 'error', + message: t('login.error.registrationNotAllowed'), + }) + } + else { + Toast.notify({ + type: 'error', + message: res.data, + }) + } + } + catch (error) { + console.error(error) + } + finally { + setIsLoading(false) + } + } + + return
+
+ +
+
+

{t('login.resetPassword')}

+

+ {t('login.resetPasswordDesc')} +

+
+ +
{ }}> + +
+ +
+ setEmail(e.target.value)} /> +
+
+ +
+
+
+
+
+
+ +
+ +
+ {t('login.backToLogin')} + +
+} diff --git a/web/app/reset-password/set-password/page.tsx b/web/app/reset-password/set-password/page.tsx new file mode 100644 index 0000000000..7948c59a9a --- /dev/null +++ b/web/app/reset-password/set-password/page.tsx @@ -0,0 +1,193 @@ +'use client' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import cn from 'classnames' +import { RiCheckboxCircleFill } from '@remixicon/react' +import { useCountDown } from 'ahooks' +import Button from '@/app/components/base/button' +import { changePasswordWithToken } from '@/service/common' +import Toast from '@/app/components/base/toast' +import Input from '@/app/components/base/input' + +const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +const ChangePasswordForm = () => { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const token = decodeURIComponent(searchParams.get('token') || '') + + const [password, setPassword] = useState('') + const [confirmPassword, setConfirmPassword] = useState('') + const [showSuccess, setShowSuccess] = useState(false) + const [showPassword, setShowPassword] = useState(false) + const [showConfirmPassword, setShowConfirmPassword] = useState(false) + + const showErrorMessage = useCallback((message: string) => { + Toast.notify({ + type: 'error', + message, + }) + }, []) + + const getSignInUrl = () => { + if (searchParams.has('invite_token')) { + const params = new URLSearchParams() + params.set('token', searchParams.get('invite_token') as string) + return `/activate?${params.toString()}` + } + return '/signin' + } + + const AUTO_REDIRECT_TIME = 5000 + const [leftTime, setLeftTime] = useState(undefined) + const [countdown] = useCountDown({ + leftTime, + onEnd: () => { + router.replace(getSignInUrl()) + }, + }) + + const valid = useCallback(() => { + if (!password.trim()) { + showErrorMessage(t('login.error.passwordEmpty')) + return false + } + if (!validPassword.test(password)) { + showErrorMessage(t('login.error.passwordInvalid')) + return false + } + if (password !== confirmPassword) { + showErrorMessage(t('common.account.notEqual')) + return false + } + return true + }, [password, confirmPassword, showErrorMessage, t]) + + const handleChangePassword = useCallback(async () => { + if (!valid()) + return + try { + await changePasswordWithToken({ + url: '/forgot-password/resets', + body: { + token, + new_password: password, + password_confirm: confirmPassword, + }, + }) + setShowSuccess(true) + setLeftTime(AUTO_REDIRECT_TIME) + } + catch (error) { + console.error(error) + } + }, [password, token, valid, confirmPassword]) + + return ( +
+ {!showSuccess && ( +
+
+

+ {t('login.changePassword')} +

+

+ {t('login.changePasswordTip')} +

+
+ +
+
+ {/* Password */} +
+ +
+ setPassword(e.target.value)} + placeholder={t('login.passwordPlaceholder') || ''} + /> + +
+ +
+
+
{t('login.error.passwordInvalid')}
+
+ {/* Confirm Password */} +
+ +
+ setConfirmPassword(e.target.value)} + placeholder={t('login.confirmPasswordPlaceholder') || ''} + /> +
+ +
+
+
+
+ +
+
+
+
+ )} + {showSuccess && ( +
+
+
+ +
+

+ {t('login.passwordChangedTip')} +

+
+
+ +
+
+ )} +
+ ) +} + +export default ChangePasswordForm diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx new file mode 100644 index 0000000000..4767308f72 --- /dev/null +++ b/web/app/signin/check-code/page.tsx @@ -0,0 +1,96 @@ +'use client' +import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useState } from 'react' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Countdown from '@/app/components/signin/countdown' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Toast from '@/app/components/base/toast' +import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common' +import I18NContext from '@/context/i18n' + +export default function CheckCode() { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const email = decodeURIComponent(searchParams.get('email') as string) + const token = decodeURIComponent(searchParams.get('token') as string) + const invite_token = decodeURIComponent(searchParams.get('invite_token') || '') + const [code, setVerifyCode] = useState('') + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const verify = async () => { + try { + if (!code.trim()) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.emptyCode'), + }) + return + } + if (!/\d{6}/.test(code)) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.invalidCode'), + }) + return + } + setIsLoading(true) + const ret = await emailLoginWithCode({ email, code, token }) + if (ret.result === 'success') { + localStorage.setItem('console_token', ret.data.access_token) + localStorage.setItem('refresh_token', ret.data.refresh_token) + router.replace(invite_token ? `/signin/invite-settings?${searchParams.toString()}` : '/apps') + } + } + catch (error) { console.error(error) } + finally { + setIsLoading(false) + } + } + + const resendCode = async () => { + try { + const ret = await sendEMailLoginCode(email, locale) + if (ret.result === 'success') { + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(ret.data)) + router.replace(`/signin/check-code?${params.toString()}`) + } + } + catch (error) { console.error(error) } + } + + return
+
+ +
+
+

{t('login.checkCode.checkYourEmail')}

+

+ +
+ {t('login.checkCode.validTime')} +

+
+ +
+ + setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} /> + + + +
+
+
+
router.back()} className='flex items-center justify-center h-9 text-text-tertiary cursor-pointer'> +
+ +
+ {t('login.back')} +
+
+} diff --git a/web/app/signin/components/mail-and-code-auth.tsx b/web/app/signin/components/mail-and-code-auth.tsx new file mode 100644 index 0000000000..7225b094d4 --- /dev/null +++ b/web/app/signin/components/mail-and-code-auth.tsx @@ -0,0 +1,71 @@ +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Input from '@/app/components/base/input' +import Button from '@/app/components/base/button' +import { emailRegex } from '@/config' +import Toast from '@/app/components/base/toast' +import { sendEMailLoginCode } from '@/service/common' +import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' +import I18NContext from '@/context/i18n' + +type MailAndCodeAuthProps = { + isInvite: boolean +} + +export default function MailAndCodeAuth({ isInvite }: MailAndCodeAuthProps) { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const emailFromLink = decodeURIComponent(searchParams.get('email') || '') + const [email, setEmail] = useState(emailFromLink) + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const handleGetEMailVerificationCode = async () => { + try { + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + setIsLoading(true) + const ret = await sendEMailLoginCode(email, locale) + if (ret.result === 'success') { + localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`) + const params = new URLSearchParams(searchParams) + params.set('email', encodeURIComponent(email)) + params.set('token', encodeURIComponent(ret.data)) + router.push(`/signin/check-code?${params.toString()}`) + } + } + catch (error) { + console.error(error) + } + finally { + setIsLoading(false) + } + } + + return (
{ }}> + +
+ +
+ setEmail(e.target.value)} /> +
+
+ +
+
+
+ ) +} diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx new file mode 100644 index 0000000000..210c877bb7 --- /dev/null +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -0,0 +1,167 @@ +import Link from 'next/link' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Button from '@/app/components/base/button' +import Toast from '@/app/components/base/toast' +import { emailRegex } from '@/config' +import { login } from '@/service/common' +import Input from '@/app/components/base/input' +import I18NContext from '@/context/i18n' + +type MailAndPasswordAuthProps = { + isInvite: boolean + allowRegistration: boolean +} + +const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +export default function MailAndPasswordAuth({ isInvite, allowRegistration }: MailAndPasswordAuthProps) { + const { t } = useTranslation() + const { locale } = useContext(I18NContext) + const router = useRouter() + const searchParams = useSearchParams() + const [showPassword, setShowPassword] = useState(false) + const emailFromLink = decodeURIComponent(searchParams.get('email') || '') + const [email, setEmail] = useState(emailFromLink) + const [password, setPassword] = useState('') + + const [isLoading, setIsLoading] = useState(false) + const handleEmailPasswordLogin = async () => { + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + if (!password?.trim()) { + Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) + return + } + if (!passwordRegex.test(password)) { + Toast.notify({ + type: 'error', + message: t('login.error.passwordInvalid'), + }) + return + } + try { + setIsLoading(true) + const loginData: Record = { + email, + password, + language: locale, + remember_me: true, + } + if (isInvite) + loginData.invite_token = decodeURIComponent(searchParams.get('invite_token') as string) + const res = await login({ + url: '/login', + body: loginData, + }) + if (res.result === 'success') { + if (isInvite) { + router.replace(`/signin/invite-settings?${searchParams.toString()}`) + } + else { + localStorage.setItem('console_token', res.data.access_token) + localStorage.setItem('refresh_token', res.data.refresh_token) + router.replace('/apps') + } + } + else if (res.code === 'account_not_found') { + if (allowRegistration) { + const params = new URLSearchParams() + params.append('email', encodeURIComponent(email)) + params.append('token', encodeURIComponent(res.data)) + router.replace(`/reset-password/check-code?${params.toString()}`) + } + else { + Toast.notify({ + type: 'error', + message: t('login.error.registrationNotAllowed'), + }) + } + } + else { + Toast.notify({ + type: 'error', + message: res.data, + }) + } + } + + finally { + setIsLoading(false) + } + } + + return
{ }}> +
+ +
+ setEmail(e.target.value)} + disabled={isInvite} + id="email" + type="email" + autoComplete="email" + placeholder={t('login.emailPlaceholder') || ''} + tabIndex={1} + /> +
+
+ +
+ +
+ setPassword(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') + handleEmailPasswordLogin() + }} + type={showPassword ? 'text' : 'password'} + autoComplete="current-password" + placeholder={t('login.passwordPlaceholder') || ''} + tabIndex={2} + /> +
+ +
+
+
+ +
+ +
+
+} diff --git a/web/app/signin/components/social-auth.tsx b/web/app/signin/components/social-auth.tsx new file mode 100644 index 0000000000..39d7ceaa40 --- /dev/null +++ b/web/app/signin/components/social-auth.tsx @@ -0,0 +1,62 @@ +import { useTranslation } from 'react-i18next' +import { useSearchParams } from 'next/navigation' +import style from '../page.module.css' +import Button from '@/app/components/base/button' +import { apiPrefix } from '@/config' +import classNames from '@/utils/classnames' +import { getPurifyHref } from '@/utils' + +type SocialAuthProps = { + disabled?: boolean +} + +export default function SocialAuth(props: SocialAuthProps) { + const { t } = useTranslation() + const searchParams = useSearchParams() + + const getOAuthLink = (href: string) => { + const url = getPurifyHref(`${apiPrefix}${href}`) + if (searchParams.has('invite_token')) + return `${url}?${searchParams.toString()}` + + return url + } + return <> + + + +} diff --git a/web/app/signin/components/sso-auth.tsx b/web/app/signin/components/sso-auth.tsx new file mode 100644 index 0000000000..fb303b93e2 --- /dev/null +++ b/web/app/signin/components/sso-auth.tsx @@ -0,0 +1,73 @@ +'use client' +import { useRouter, useSearchParams } from 'next/navigation' +import type { FC } from 'react' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' +import Toast from '@/app/components/base/toast' +import { getUserOAuth2SSOUrl, getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' +import Button from '@/app/components/base/button' +import { SSOProtocol } from '@/types/feature' + +type SSOAuthProps = { + protocol: SSOProtocol | '' +} + +const SSOAuth: FC = ({ + protocol, +}) => { + const router = useRouter() + const { t } = useTranslation() + const searchParams = useSearchParams() + const invite_token = decodeURIComponent(searchParams.get('invite_token') || '') + + const [isLoading, setIsLoading] = useState(false) + + const handleSSOLogin = () => { + setIsLoading(true) + if (protocol === SSOProtocol.SAML) { + getUserSAMLSSOUrl(invite_token).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === SSOProtocol.OIDC) { + getUserOIDCSSOUrl(invite_token).then((res) => { + document.cookie = `user-oidc-state=${res.state}` + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === SSOProtocol.OAuth2) { + getUserOAuth2SSOUrl(invite_token).then((res) => { + document.cookie = `user-oauth2-state=${res.state}` + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else { + Toast.notify({ + type: 'error', + message: 'invalid SSO protocol', + }) + setIsLoading(false) + } + } + + return ( + + ) +} + +export default SSOAuth diff --git a/web/app/signin/forms.tsx b/web/app/signin/forms.tsx deleted file mode 100644 index 70a34c26fa..0000000000 --- a/web/app/signin/forms.tsx +++ /dev/null @@ -1,34 +0,0 @@ -'use client' -import React from 'react' -import { useSearchParams } from 'next/navigation' - -import NormalForm from './normalForm' -import OneMoreStep from './oneMoreStep' -import cn from '@/utils/classnames' - -const Forms = () => { - const searchParams = useSearchParams() - const step = searchParams.get('step') - - const getForm = () => { - switch (step) { - case 'next': - return - default: - return - } - } - return
-
- {getForm()} -
-
-} - -export default Forms diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx new file mode 100644 index 0000000000..2138399ec3 --- /dev/null +++ b/web/app/signin/invite-settings/page.tsx @@ -0,0 +1,154 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { useCallback, useState } from 'react' +import Link from 'next/link' +import { useContext } from 'use-context-selector' +import { useRouter, useSearchParams } from 'next/navigation' +import useSWR from 'swr' +import { RiAccountCircleLine } from '@remixicon/react' +import Input from '@/app/components/base/input' +import { SimpleSelect } from '@/app/components/base/select' +import Button from '@/app/components/base/button' +import { timezones } from '@/utils/timezone' +import { LanguagesSupported, languages } from '@/i18n/language' +import I18n from '@/context/i18n' +import { activateMember, invitationCheck } from '@/service/common' +import Loading from '@/app/components/base/loading' +import Toast from '@/app/components/base/toast' + +export default function InviteSettingsPage() { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const token = decodeURIComponent(searchParams.get('invite_token') as string) + const { locale, setLocaleOnClient } = useContext(I18n) + const [name, setName] = useState('') + const [language, setLanguage] = useState(LanguagesSupported[0]) + const [timezone, setTimezone] = useState(Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles') + + const checkParams = { + url: '/activate/check', + params: { + token, + }, + } + const { data: checkRes, mutate: recheck } = useSWR(checkParams, invitationCheck, { + revalidateOnFocus: false, + }) + + const handleActivate = useCallback(async () => { + try { + if (!name) { + Toast.notify({ type: 'error', message: t('login.enterYourName') }) + return + } + const res = await activateMember({ + url: '/activate', + body: { + token, + name, + interface_language: language, + timezone, + }, + }) + if (res.result === 'success') { + localStorage.setItem('console_token', res.data.access_token) + localStorage.setItem('refresh_token', res.data.refresh_token) + setLocaleOnClient(language, false) + router.replace('/apps') + } + } + catch { + recheck() + } + }, [language, name, recheck, setLocaleOnClient, timezone, token, router, t]) + + if (!checkRes) + return + if (!checkRes.is_valid) { + return
+
+
🤷‍♂️
+

{t('login.invalid')}

+
+ +
+ } + + return
+
+ +
+
+

{t('login.setYourAccount')}

+
+
+ +
+ +
+ setName(e.target.value)} + placeholder={t('login.namePlaceholder') || ''} + /> +
+
+
+ +
+ item.supported)} + onSelect={(item) => { + setLanguage(item.value as string) + }} + /> +
+
+ {/* timezone */} +
+ +
+ { + setTimezone(item.value as string) + }} + /> +
+
+
+ +
+
+
+ {t('login.license.tip')} +   + {t('login.license.link')} +
+
+} diff --git a/web/app/signin/layout.tsx b/web/app/signin/layout.tsx new file mode 100644 index 0000000000..342876bc53 --- /dev/null +++ b/web/app/signin/layout.tsx @@ -0,0 +1,54 @@ +import Script from 'next/script' +import Header from './_header' +import style from './page.module.css' + +import cn from '@/utils/classnames' +import { IS_CE_EDITION } from '@/config' + +export default async function SignInLayout({ children }: any) { + return <> + {!IS_CE_EDITION && ( + <> + + + + )} + +
+
+
+
+
+ {children} +
+
+
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. +
+
+
+ +} diff --git a/web/app/signin/userSSOForm.tsx b/web/app/signin/userSSOForm.tsx deleted file mode 100644 index f01afa9eaf..0000000000 --- a/web/app/signin/userSSOForm.tsx +++ /dev/null @@ -1,107 +0,0 @@ -'use client' -import { useRouter, useSearchParams } from 'next/navigation' -import type { FC } from 'react' -import { useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import cn from '@/utils/classnames' -import Toast from '@/app/components/base/toast' -import { getUserOAuth2SSOUrl, getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' -import Button from '@/app/components/base/button' -import useRefreshToken from '@/hooks/use-refresh-token' - -type UserSSOFormProps = { - protocol: string -} - -const UserSSOForm: FC = ({ - protocol, -}) => { - const { getNewAccessToken } = useRefreshToken() - const searchParams = useSearchParams() - const consoleToken = searchParams.get('access_token') - const refreshToken = searchParams.get('refresh_token') - const message = searchParams.get('message') - - const router = useRouter() - const { t } = useTranslation() - - const [isLoading, setIsLoading] = useState(false) - - useEffect(() => { - if (refreshToken && consoleToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) - getNewAccessToken() - router.replace('/apps') - } - - if (message) { - Toast.notify({ - type: 'error', - message, - }) - } - }, [consoleToken, refreshToken, message, router]) - - const handleSSOLogin = () => { - setIsLoading(true) - if (protocol === 'saml') { - getUserSAMLSSOUrl().then((res) => { - router.push(res.url) - }).finally(() => { - setIsLoading(false) - }) - } - else if (protocol === 'oidc') { - getUserOIDCSSOUrl().then((res) => { - document.cookie = `user-oidc-state=${res.state}` - router.push(res.url) - }).finally(() => { - setIsLoading(false) - }) - } - else if (protocol === 'oauth2') { - getUserOAuth2SSOUrl().then((res) => { - document.cookie = `user-oauth2-state=${res.state}` - router.push(res.url) - }).finally(() => { - setIsLoading(false) - }) - } - else { - Toast.notify({ - type: 'error', - message: 'invalid SSO protocol', - }) - setIsLoading(false) - } - } - - return ( -
-
-
-

{t('login.pageTitle')}

-
-
- -
-
-
- ) -} - -export default UserSSOForm diff --git a/web/tailwind-common-config.ts b/web/tailwind-common-config.ts index 9e800750a3..35fd22e0a4 100644 --- a/web/tailwind-common-config.ts +++ b/web/tailwind-common-config.ts @@ -83,6 +83,11 @@ const config = { fontSize: { '2xs': '0.625rem', }, + backgroundImage: { + 'chatbot-bg': 'var(--color-chatbot-bg)', + 'chat-bubble-bg': 'var(--color-chat-bubble-bg)', + 'workflow-process-bg': 'var(--color-workflow-process-bg)', + }, animation: { 'spin-slow': 'spin 2s linear infinite', },