diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index b8246aacb3..c6fe87264d 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -6,6 +6,7 @@ on: - main paths: - api/migrations/** + - .github/workflows/db-migration-test.yml concurrency: group: db-migration-test-${{ github.ref }} diff --git a/api/.env.example b/api/.env.example index afa7d6c799..14d68a56b9 100644 --- a/api/.env.example +++ b/api/.env.example @@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 -# Model Configuration +# Model configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 +MULTIMODAL_SEND_VIDEO_FORMAT=base64 PROMPT_GENERATION_MAX_TOKENS=512 CODE_GENERATION_MAX_TOKENS=1024 @@ -324,10 +325,10 @@ UNSTRUCTURED_API_KEY= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= SSRF_DEFAULT_MAX_RETRIES=3 -SSRF_DEFAULT_TIME_OUT= -SSRF_DEFAULT_CONNECT_TIME_OUT= -SSRF_DEFAULT_READ_TIME_OUT= -SSRF_DEFAULT_WRITE_TIME_OUT= +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database diff --git a/api/app.py b/api/app.py index ed214bde97..60cd622ef4 100644 --- a/api/app.py +++ b/api/app.py @@ -2,7 +2,7 @@ import os from configs import dify_config -if os.environ.get("DEBUG", "false").lower() != "true": +if not dify_config.DEBUG: from gevent import monkey monkey.patch_all() diff --git a/api/app_factory.py b/api/app_factory.py index aba78ccab8..60a584798b 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,8 @@ import os -if os.environ.get("DEBUG", "false").lower() != "true": +from configs import dify_config + +if not dify_config.DEBUG: from gevent import monkey monkey.patch_all() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index c9308f8c11..bd69656c1f 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -329,6 +329,16 @@ class HttpConfig(BaseSettings): default=1 * 1024 * 1024, ) + SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field( + description="Maximum number of retries for network requests (SSRF)", + default=3, + ) + + SSRF_PROXY_ALL_URL: Optional[str] = Field( + description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", + default=None, + ) + SSRF_PROXY_HTTP_URL: Optional[str] = Field( description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", default=None, @@ -677,12 +687,17 @@ class IndexingConfig(BaseSettings): ) -class ImageFormatConfig(BaseSettings): +class VisionFormatConfig(BaseSettings): MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", default="base64", ) + MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field( + description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64", + default="base64", + ) + class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( @@ -787,7 +802,7 @@ class FeatureConfig( FileAccessConfig, FileUploadConfig, HttpConfig, - ImageFormatConfig, + VisionFormatConfig, InnerAPIConfig, IndexingConfig, LoggingConfig, diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 521805a651..3c91a58f8b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -956,7 +956,7 @@ class DocumentRetryApi(DocumentResource): raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: - logging.error(f"Document {document_id} retry failed: {str(e)}") + logging.exception(f"Document {document_id} retry failed: {str(e)}") continue # retry document DocumentService.retry_document(dataset_id, retry_documents) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 527ef4ecd3..815fd6a27a 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -7,7 +7,11 @@ from controllers.service_api import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + conversation_delete_fields, + conversation_infinite_scroll_pagination_fields, + simple_conversation_fields, +) from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -49,7 +53,7 @@ class ConversationApi(Resource): class ConversationDetailApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(simple_conversation_fields) + @marshal_with(conversation_delete_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -58,10 +62,9 @@ class ConversationDetailApi(Resource): conversation_id = str(c_id) try: - ConversationService.delete(app_model, conversation_id, end_user) + return ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 200 class ConversationRenameApi(Resource): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 18c5526c53..ea974acded 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -1,6 +1,5 @@ import contextvars import logging -import os import threading import uuid from collections.abc import Generator @@ -10,6 +9,7 @@ from flask import Flask, current_app from pydantic import ValidationError import contexts +from configs import dify_config from constants import UUID_NIL from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -328,7 +328,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG", "false").lower() == "true": + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 1fc7ffe2c7..1d4c0ea0fa 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_listener_time = time.time() yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception as e: - logger.error(e) + logger.exception(e) break if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b2b161cdca..d439bedcb5 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,5 +1,4 @@ import logging -import os import threading import uuid from collections.abc import Generator @@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -235,7 +235,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 12bcb5a777..a73c351818 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,5 +1,4 @@ import logging -import os import threading import uuid from collections.abc import Generator @@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 7fb05192c7..551df32032 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,5 +1,4 @@ import logging -import os import threading import uuid from collections.abc import Generator @@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom @@ -213,7 +213,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 9e7591545d..09fd033e35 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -1,6 +1,5 @@ import contextvars import logging -import os import threading import uuid from collections.abc import Generator, Mapping, Sequence @@ -10,6 +9,7 @@ from flask import Flask, current_app from pydantic import ValidationError import contexts +from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom @@ -273,7 +273,7 @@ class WorkflowAppGenerator(BaseAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index d119d94a61..aaa4824fe8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa else: yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception as e: - logger.error(e) + logger.exception(e) break if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index b69d7a74c0..ff9220d35f 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -3,7 +3,7 @@ import base64 from configs import dify_config from core.file import file_repository from core.helper import ssrf_proxy -from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent +from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent from extensions.ext_database import db from extensions.ext_storage import storage @@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /): if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + case FileType.VIDEO: + if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) case _: raise ValueError(f"file type {f.type} is not supported") @@ -112,7 +118,7 @@ def _download_file_content(path: str, /): def _get_encoded_string(f: File, /): match f.transfer_method: case FileTransferMethod.REMOTE_URL: - response = ssrf_proxy.get(f.remote_url) + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() content = response.content encoded_string = base64.b64encode(content).decode("utf-8") @@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /): match f.type: case FileType.IMAGE: return _to_base64_data_string(f) + case FileType.VIDEO: + return _to_base64_data_string(f) case FileType.AUDIO: return _get_encoded_string(f) case _: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index df812ca83f..374bd9d57b 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -3,26 +3,20 @@ Proxy requests to avoid SSRF """ import logging -import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") -SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") -SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -SSRF_DEFAULT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_TIME_OUT", "5")) -SSRF_DEFAULT_CONNECT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_CONNECT_TIME_OUT", "5")) -SSRF_DEFAULT_READ_TIME_OUT = float(os.getenv("SSRF_DEFAULT_READ_TIME_OUT", "5")) -SSRF_DEFAULT_WRITE_TIME_OUT = float(os.getenv("SSRF_DEFAULT_WRITE_TIME_OUT", "5")) +from configs import dify_config + +SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES proxy_mounts = ( { - "http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL), - "https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL), + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), } - if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL + if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL else None ) @@ -38,17 +32,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = httpx.Timeout( - SSRF_DEFAULT_TIME_OUT, - connect=SSRF_DEFAULT_CONNECT_TIME_OUT, - read=SSRF_DEFAULT_READ_TIME_OUT, - write=SSRF_DEFAULT_WRITE_TIME_OUT, + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, ) retries = 0 while retries <= max_retries: try: - if SSRF_PROXY_ALL_URL: - with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client: + if dify_config.SSRF_PROXY_ALL_URL: + with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client: response = client.request(method=method, url=url, **kwargs) elif proxy_mounts: with httpx.Client(mounts=proxy_mounts) as client: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index dc95e4b509..29d47fc104 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,8 +1,8 @@ import logging -import os from collections.abc import Callable, Generator, Iterable, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload +from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration @@ -509,7 +509,7 @@ class LBModelManager: continue - if bool(os.environ.get("DEBUG", "False").lower() == "true"): + if dify_config.DEBUG: logger.info( f"Model LB\nid: {config.id}\nname:{config.name}\n" f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index b3eb4d4dfe..f5d4427e3e 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -12,11 +12,13 @@ from .message_entities import ( TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, + VideoPromptMessageContent, ) from .model_entities import ModelPropertyKey __all__ = [ "ImagePromptMessageContent", + "VideoPromptMessageContent", "PromptMessage", "PromptMessageRole", "LLMUsage", diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index cda1639661..3c244d368e 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -56,6 +56,7 @@ class PromptMessageContentType(Enum): TEXT = "text" IMAGE = "image" AUDIO = "audio" + VIDEO = "video" class PromptMessageContent(BaseModel): @@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.TEXT +class VideoPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.VIDEO + data: str = Field(..., description="Base64 encoded video data") + format: str = Field(..., description="Video format") + + class AudioPromptMessageContent(PromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.AUDIO data: str = Field(..., description="Base64 encoded audio data") diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index d8d794be18..83f4d2d57d 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -126,6 +126,6 @@ class OutputModeration(BaseModel): result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) return result except Exception as e: - logger.error("Moderation Output error: %s", e) + logger.exception("Moderation Output error: %s", e) return None diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 764944f799..986749f056 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -708,7 +708,7 @@ class TraceQueueManager: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception as e: - logging.error(f"Error adding trace task: {e}") + logging.exception(f"Error adding trace task: {e}") finally: self.start_timer() @@ -727,7 +727,7 @@ class TraceQueueManager: if tasks: self.send_to_celery(tasks) except Exception as e: - logging.error(f"Error processing trace tasks: {e}") + logging.exception(f"Error processing trace tasks: {e}") def start_timer(self): global trace_manager_timer diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index 3f88d2ca2b..98da5e3d5e 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector): try: self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() except Exception as e: - logger.error(e) + logger.exception(e) def delete_by_document_id(self, document_id: str): query = f""" diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index abd8261a69..30d7f09ec2 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -79,7 +79,7 @@ class LindormVectorStore(BaseVector): existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} except Exception as e: - logger.error(f"Error fetching batch {batch_ids}: {e}") + logger.exception(f"Error fetching batch {batch_ids}: {e}") return set() @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) @@ -96,7 +96,7 @@ class LindormVectorStore(BaseVector): ) return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} except Exception as e: - logger.error(f"Error fetching batch {batch_ids}: {e}") + logger.exception(f"Error fetching batch {batch_ids}: {e}") return set() if ids is None: @@ -177,7 +177,7 @@ class LindormVectorStore(BaseVector): else: logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") except Exception as e: - logger.error(f"Error occurred while deleting the index: {e}") + logger.exception(f"Error occurred while deleting the index: {e}") raise e def text_exists(self, id: str) -> bool: @@ -201,7 +201,7 @@ class LindormVectorStore(BaseVector): try: response = self._client.search(index=self._collection_name, body=query) except Exception as e: - logger.error(f"Error executing search: {e}") + logger.exception(f"Error executing search: {e}") raise docs_and_scores = [] diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 080a1ef567..5a263d6e78 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -86,7 +86,7 @@ class MilvusVector(BaseVector): ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) + logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 1fca926a2d..2610b60a77 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -142,7 +142,7 @@ class MyScaleVector(BaseVector): for r in self._client.query(sql).named_results() ] except Exception as e: - logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") + logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return [] def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 0e0f107268..49eb00f140 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -129,7 +129,7 @@ class OpenSearchVector(BaseVector): if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") else: - logger.error(f"Error deleting document: {error}") + logger.exception(f"Error deleting document: {error}") def delete(self) -> None: self._client.indices.delete(index=self._collection_name.lower()) @@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector): try: response = self._client.search(index=self._collection_name.lower(), body=query) except Exception as e: - logger.error(f"Error executing search: {e}") + logger.exception(f"Error executing search: {e}") raise docs = [] diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index b3e93ce760..3ac65b88bb 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error("Failed to embed documents: %s", ex) + logger.exception("Failed to embed documents: %s", ex) raise ex return text_embeddings diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index d4434ea28f..b59e7f94fd 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -28,7 +28,6 @@ logger = logging.getLogger(__name__) class WordExtractor(BaseExtractor): """Load docx files. - Args: file_path: Path to the file to load. """ @@ -51,9 +50,9 @@ class WordExtractor(BaseExtractor): self.web_path = self.file_path # TODO: use a better way to handle the file - self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 - self.temp_file.write(r.content) - self.file_path = self.temp_file.name + with tempfile.NamedTemporaryFile(delete=False) as self.temp_file: + self.temp_file.write(r.content) + self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") @@ -230,7 +229,7 @@ class WordExtractor(BaseExtractor): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: - logger.error(e) + logger.exception(e) def parse_paragraph(paragraph): paragraph_content = [] diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b1249a0ff5..daa5d0242d 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -98,7 +98,7 @@ class ToolFileManager: response.raise_for_status() blob = response.content except Exception as e: - logger.error(f"Failed to download file from {file_url}: {e}") + logger.exception(f"Failed to download file from {file_url}: {e}") raise mimetype = guess_type(file_url)[0] or "octet/stream" diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 83a2e8ef0f..fc3699a0bc 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -526,7 +526,7 @@ class ToolManager: yield provider except Exception as e: - logger.error(f"load builtin provider error: {e}") + logger.exception(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index 722cf4b538..ea28037df0 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -127,7 +127,9 @@ class FeishuRequest: "folder_token": folder_token, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" @@ -135,7 +137,7 @@ class FeishuRequest: res = self._send_request(url, payload=payload) return res - def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> dict: + def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: """ API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content Example Response: @@ -154,7 +156,9 @@ class FeishuRequest: } url = f"{self.API_BASE_URL}/document/get_document_content" res = self._send_request(url, method="GET", params=params) - return res.get("data").get("content") + if "data" in res: + return res.get("data").get("content") + return "" def list_document_blocks( self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500 @@ -170,7 +174,9 @@ class FeishuRequest: } url = f"{self.API_BASE_URL}/document/list_document_blocks" res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: """ @@ -186,7 +192,9 @@ class FeishuRequest: "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: url = f"{self.API_BASE_URL}/message/send_webhook_message" @@ -220,7 +228,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def get_thread_messages( self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20 @@ -236,7 +246,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: # 创建任务 @@ -249,7 +261,9 @@ class FeishuRequest: "description": description, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def update_task( self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str @@ -265,7 +279,9 @@ class FeishuRequest: "description": description, } res = self._send_request(url, method="PATCH", payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def delete_task(self, task_guid: str) -> dict: # 删除任务 @@ -297,7 +313,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: url = f"{self.API_BASE_URL}/calendar/get_primary_calendar" @@ -305,7 +323,9 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def create_event( self, @@ -328,7 +348,9 @@ class FeishuRequest: "attendee_ability": attendee_ability, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def update_event( self, @@ -374,7 +396,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def search_events( self, @@ -395,7 +419,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: # 参加日程参会人 @@ -406,7 +432,9 @@ class FeishuRequest: "need_notification": need_notification, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def create_spreadsheet( self, @@ -420,7 +448,9 @@ class FeishuRequest: "folder_token": folder_token, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def get_spreadsheet( self, @@ -434,7 +464,9 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def list_spreadsheet_sheets( self, @@ -446,7 +478,9 @@ class FeishuRequest: "spreadsheet_token": spreadsheet_token, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def add_rows( self, @@ -466,7 +500,9 @@ class FeishuRequest: "values": values, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def add_cols( self, @@ -486,7 +522,9 @@ class FeishuRequest: "values": values, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def read_rows( self, @@ -508,7 +546,9 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def read_cols( self, @@ -530,7 +570,9 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def read_table( self, @@ -552,7 +594,9 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def create_base( self, @@ -566,7 +610,9 @@ class FeishuRequest: "folder_token": folder_token, } res = self._send_request(url, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def add_records( self, @@ -588,7 +634,9 @@ class FeishuRequest: "records": convert_add_records(records), } res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def update_records( self, @@ -610,7 +658,9 @@ class FeishuRequest: "records": convert_update_records(records), } res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def delete_records( self, @@ -637,7 +687,9 @@ class FeishuRequest: "records": record_id_list, } res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def search_record( self, @@ -701,7 +753,10 @@ class FeishuRequest: if automatic_fields: payload["automatic_fields"] = automatic_fields res = self._send_request(url, params=params, payload=payload) - return res.get("data") + + if "data" in res: + return res.get("data") + return res def get_base_info( self, @@ -713,7 +768,9 @@ class FeishuRequest: "app_token": app_token, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def create_table( self, @@ -741,7 +798,9 @@ class FeishuRequest: if default_view_name: payload["default_view_name"] = default_view_name res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def delete_tables( self, @@ -774,8 +833,11 @@ class FeishuRequest: "table_ids": table_id_list, "table_names": table_name_list, } + res = self._send_request(url, params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res def list_tables( self, @@ -791,7 +853,9 @@ class FeishuRequest: "page_size": page_size, } res = self._send_request(url, method="GET", params=params) - return res.get("data") + if "data" in res: + return res.get("data") + return res def read_records( self, @@ -819,4 +883,6 @@ class FeishuRequest: "user_id_type": user_id_type, } res = self._send_request(url, method="GET", params=params, payload=payload) - return res.get("data") + if "data" in res: + return res.get("data") + return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py new file mode 100644 index 0000000000..30cb0cb141 --- /dev/null +++ b/api/core/tools/utils/lark_api_utils.py @@ -0,0 +1,820 @@ +import json +from typing import Optional + +import httpx + +from core.tools.errors import ToolProviderCredentialValidationError +from extensions.ext_redis import redis_client + + +def lark_auth(credentials): + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert LarkRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + + +class LarkRequest: + API_BASE_URL = "https://lark-plugin-api.solutionsuite.ai/lark-plugin" + + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + def convert_add_records(self, json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data] + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + def convert_update_records(self, json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + + converted_data = [ + {"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]} + for record in data + if "fields" in record and "record_id" in record + ] + + if len(converted_data) != len(data): + raise ValueError("Each record must contain 'fields' and 'record_id'") + + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + @property + def tenant_access_token(self) -> str: + feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" + if redis_client.exists(feishu_tenant_access_token): + return redis_client.get(feishu_tenant_access_token).decode() + res = self.get_tenant_access_token(self.app_id, self.app_secret) + redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) + if "tenant_access_token" in res: + return res.get("tenant_access_token") + return "" + + def _send_request( + self, + url: str, + method: str = "post", + require_token: bool = True, + payload: Optional[dict] = None, + params: Optional[dict] = None, + ): + headers = { + "Content-Type": "application/json", + "user-agent": "Dify", + } + if require_token: + headers["tenant-access-token"] = f"{self.tenant_access_token}" + res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json() + if res.get("code") != 0: + raise Exception(res) + return res + + def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: + url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" + payload = {"app_id": app_id, "app_secret": app_secret} + res = self._send_request(url, require_token=False, payload=payload) + return res + + def create_document(self, title: str, content: str, folder_token: str) -> dict: + url = f"{self.API_BASE_URL}/document/create_document" + payload = { + "title": title, + "content": content, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def write_document(self, document_id: str, content: str, position: str = "end") -> dict: + url = f"{self.API_BASE_URL}/document/write_document" + payload = {"document_id": document_id, "content": content, "position": position} + res = self._send_request(url, payload=payload) + return res + + def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: + params = { + "document_id": document_id, + "mode": mode, + "lang": lang, + } + url = f"{self.API_BASE_URL}/document/get_document_content" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data").get("content") + return "" + + def list_document_blocks( + self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500 + ) -> dict: + params = { + "user_id_type": user_id_type, + "document_id": document_id, + "page_size": page_size, + "page_token": page_token, + } + url = f"{self.API_BASE_URL}/document/list_document_blocks" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: + url = f"{self.API_BASE_URL}/message/send_bot_message" + params = { + "receive_id_type": receive_id_type, + } + payload = { + "receive_id": receive_id, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: + url = f"{self.API_BASE_URL}/message/send_webhook_message" + payload = { + "webhook": webhook, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, require_token=False, payload=payload) + return res + + def get_chat_messages( + self, + container_id: str, + start_time: str, + end_time: str, + page_token: str, + sort_type: str = "ByCreateTimeAsc", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/message/get_chat_messages" + params = { + "container_id": container_id, + "start_time": start_time, + "end_time": end_time, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def get_thread_messages( + self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20 + ) -> dict: + url = f"{self.API_BASE_URL}/message/get_thread_messages" + params = { + "container_id": container_id, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: + url = f"{self.API_BASE_URL}/task/create_task" + payload = { + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_at": completed_time, + "description": description, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_task( + self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str + ) -> dict: + url = f"{self.API_BASE_URL}/task/update_task" + payload = { + "task_guid": task_guid, + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_time": completed_time, + "description": description, + } + res = self._send_request(url, method="PATCH", payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_task(self, task_guid: str) -> dict: + url = f"{self.API_BASE_URL}/task/delete_task" + payload = { + "task_guid": task_guid, + } + res = self._send_request(url, method="DELETE", payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: + url = f"{self.API_BASE_URL}/task/add_members" + payload = { + "task_guid": task_guid, + "member_phone_or_email": member_phone_or_email, + "member_role": member_role, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: + url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes" + payload = { + "space_id": space_id, + "parent_node_token": parent_node_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: + url = f"{self.API_BASE_URL}/calendar/get_primary_calendar" + params = { + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_event( + self, + summary: str, + description: str, + start_time: str, + end_time: str, + attendee_ability: str, + need_notification: bool = True, + auto_record: bool = False, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/create_event" + payload = { + "summary": summary, + "description": description, + "need_notification": need_notification, + "start_time": start_time, + "end_time": end_time, + "auto_record": auto_record, + "attendee_ability": attendee_ability, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_event( + self, + event_id: str, + summary: str, + description: str, + need_notification: bool, + start_time: str, + end_time: str, + auto_record: bool, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" + payload = {} + if summary: + payload["summary"] = summary + if description: + payload["description"] = description + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + if need_notification: + payload["need_notification"] = need_notification + if auto_record: + payload["auto_record"] = auto_record + res = self._send_request(url, method="PATCH", payload=payload) + return res + + def delete_event(self, event_id: str, need_notification: bool = True) -> dict: + url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}" + params = { + "need_notification": need_notification, + } + res = self._send_request(url, method="DELETE", params=params) + return res + + def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: + url = f"{self.API_BASE_URL}/calendar/list_events" + params = { + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def search_events( + self, + query: str, + start_time: str, + end_time: str, + page_token: str, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/search_events" + payload = { + "query": query, + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "user_id_type": user_id_type, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: + url = f"{self.API_BASE_URL}/calendar/add_event_attendees" + payload = { + "event_id": event_id, + "attendee_phone_or_email": attendee_phone_or_email, + "need_notification": need_notification, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def create_spreadsheet( + self, + title: str, + folder_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet" + payload = { + "title": title, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_spreadsheet( + self, + spreadsheet_token: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet" + params = { + "spreadsheet_token": spreadsheet_token, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def list_spreadsheet_sheets( + self, + spreadsheet_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets" + params = { + "spreadsheet_token": spreadsheet_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def add_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/add_rows" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/add_cols" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def read_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_row: int, + num_rows: int, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_rows" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_row": start_row, + "num_rows": num_rows, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_col: int, + num_cols: int, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_cols" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_col": start_col, + "num_cols": num_cols, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_table( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + num_range: str, + query: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_table" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "range": num_range, + "query": query, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_base( + self, + name: str, + folder_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/create_base" + payload = { + "name": name, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/base/add_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": self.convert_add_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/update_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": self.convert_update_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/delete_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "records": record_id_list, + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def search_record( + self, + app_token: str, + table_id: str, + table_name: str, + view_id: str, + field_names: str, + sort: str, + filters: str, + page_token: str, + automatic_fields: bool = False, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/base/search_record" + + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + "page_token": page_token, + "page_size": page_size, + } + + if not field_names: + field_name_list = [] + else: + try: + field_name_list = json.loads(field_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not sort: + sort_list = [] + else: + try: + sort_list = json.loads(sort) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not filters: + filter_dict = {} + else: + try: + filter_dict = json.loads(filters) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = {} + + if view_id: + payload["view_id"] = view_id + if field_names: + payload["field_names"] = field_name_list + if sort: + payload["sort"] = sort_list + if filters: + payload["filter"] = filter_dict + if automatic_fields: + payload["automatic_fields"] = automatic_fields + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_base_info( + self, + app_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/get_base_info" + params = { + "app_token": app_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_table( + self, + app_token: str, + table_name: str, + default_view_name: str, + fields: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/create_table" + params = { + "app_token": app_token, + } + if not fields: + fields_list = [] + else: + try: + fields_list = json.loads(fields) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "name": table_name, + "fields": fields_list, + } + if default_view_name: + payload["default_view_name"] = default_view_name + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_tables( + self, + app_token: str, + table_ids: str, + table_names: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/delete_tables" + params = { + "app_token": app_token, + } + if not table_ids: + table_id_list = [] + else: + try: + table_id_list = json.loads(table_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not table_names: + table_name_list = [] + else: + try: + table_name_list = json.loads(table_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = { + "table_ids": table_id_list, + "table_names": table_name_list, + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def list_tables( + self, + app_token: str, + page_token: str, + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/base/list_tables" + params = { + "app_token": app_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/base/read_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "record_ids": record_id_list, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="POST", params=params, payload=payload) + if "data" in res: + return res.get("data") + return res diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 053a339ba7..1433c8eaed 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -69,7 +69,7 @@ class BaseNode(Generic[GenericNodeData]): try: result = self._run() except Exception as e: - logger.error(f"Node {self.node_id} failed to run: {e}") + logger.exception(f"Node {self.node_id} failed to run: {e}") result = NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index d90dfcc766..80b322b068 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -97,15 +97,6 @@ class Executor: headers = self.variable_pool.convert_template(self.node_data.headers).text self.headers = _plain_text_to_dict(headers) - body = self.node_data.body - if body is None: - return - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" - def _init_body(self): body = self.node_data.body if body is not None: @@ -154,9 +145,8 @@ class Executor: for k, v in files.items() if v.related_id is not None } - self.data = form_data - self.files = files + self.files = files or None def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.auth) @@ -217,6 +207,7 @@ class Executor: "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "follow_redirects": True, } + # request_args = {k: v for k, v in request_args.items() if v is not None} response = getattr(ssrf_proxy, self.method)(**request_args) return response @@ -244,6 +235,13 @@ class Executor: raw += f"Host: {url_parts.netloc}\r\n" headers = self._assembling_headers() + body = self.node_data.body + boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" + if body: + if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + if body.type == "form-data": + headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" for k, v in headers.items(): if self.auth.type == "api-key": authorization_header = "Authorization" @@ -256,7 +254,6 @@ class Executor: body = "" if self.files: - boundary = self.boundary for k, v in self.files.items(): body += f"--{boundary}\r\n" body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' @@ -271,7 +268,6 @@ class Executor: elif self.data and self.node_data.body.type == "x-www-form-urlencoded": body = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": - boundary = self.boundary for key, value in self.data.items(): body += f"--{boundary}\r\n" body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 47b0e25d9c..eb4d1c9d87 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -14,6 +14,7 @@ from core.model_runtime.entities import ( PromptMessage, PromptMessageContentType, TextPromptMessageContent, + VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.model_entities import ModelType @@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]): # cuz vision detail is related to the configuration from FileUpload feature. content_item.detail = vision_detail prompt_message_content.append(content_item) - elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): + elif isinstance( + content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent + ): prompt_message_content.append(content_item) if len(prompt_message_content) > 1: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 0489020e5e..744dfd3d8d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -127,7 +127,7 @@ class QuestionClassifierNode(LLMNode): category_id = category_id_result except OutputParserError: - logging.error(f"Failed to parse result text: {result_text}") + logging.exception(f"Failed to parse result text: {result_text}") try: process_data = { "model_mode": model_config.mode, diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 67635b129e..58c917dbd3 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,3 +1,4 @@ +import posixpath from collections.abc import Generator import oss2 as aliyun_s3 @@ -50,9 +51,4 @@ class AliyunOssStorage(BaseStorage): self.client.delete_object(self.__wrapper_folder_filename(filename)) def __wrapper_folder_filename(self, filename) -> str: - if self.folder: - if self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - return filename + return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 2eb19c2667..5bd21be807 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -202,6 +202,10 @@ simple_conversation_fields = { "updated_at": TimestampField, } +conversation_delete_fields = { + "result": fields.String, +} + conversation_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, diff --git a/api/libs/smtp.py b/api/libs/smtp.py index bd7de7dd68..d57d99f3b7 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -39,13 +39,13 @@ class SMTPClient: smtp.sendmail(self._from, mail["to"], msg.as_string()) except smtplib.SMTPException as e: - logging.error(f"SMTP error occurred: {str(e)}") + logging.exception(f"SMTP error occurred: {str(e)}") raise except TimeoutError as e: - logging.error(f"Timeout occurred while sending email: {str(e)}") + logging.exception(f"Timeout occurred while sending email: {str(e)}") raise except Exception as e: - logging.error(f"Unexpected error occurred while sending email: {str(e)}") + logging.exception(f"Unexpected error occurred while sending email: {str(e)}") raise finally: if smtp: diff --git a/api/pyproject.toml b/api/pyproject.toml index 1f2758b544..f32add9e43 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -34,6 +34,7 @@ select = [ "RUF101", # redirected-noqa "S506", # unsafe-yaml-load "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence diff --git a/api/services/account_service.py b/api/services/account_service.py index 24472c349a..68687eb3d2 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -821,7 +821,7 @@ class RegisterService: db.session.rollback() except Exception as e: db.session.rollback() - logging.error(f"Register failed: {e}") + logging.exception(f"Register failed: {e}") raise AccountRegisterError(f"Registration failed: {e}") from e return account diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 7bfe59afa0..f9e41988c0 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -160,4 +160,5 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 53e599d5ab..e7903fc4eb 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -195,7 +195,7 @@ class ApiToolManageService: # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: - logger.error(f"parse api schema error: {str(e)}") + logger.exception(f"parse api schema error: {str(e)}") raise ValueError("invalid schema, please check the url you provided") return {"schema": schema} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 8c0302861e..8d5ddeb7a3 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -196,8 +196,7 @@ class ToolTransformService: username = user.name except Exception as e: - logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") - + logger.exception(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = ToolProviderApiEntity( diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 12c469a81a..7c19de6078 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -196,3 +196,72 @@ def test_extract_selectors_from_template_with_newline(): ) assert executor.params == {"test": "line1\nline2"} + + +def test_executor_with_form_data(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") + variable_pool.add(["pre_node_id", "number_field"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test Form Data", + method="post", + url="https://api.example.com/upload", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: multipart/form-data", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="text_field", + type="text", + value="{{#pre_node_id.text_field#}}", + ), + BodyData( + key="number_field", + type="text", + value="{{#pre_node_id.number_field#}}", + ), + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/upload" + assert "Content-Type" in executor.headers + assert "multipart/form-data" in executor.headers["Content-Type"] + assert executor.params == {} + assert executor.json is None + assert executor.files is None + assert executor.content is None + + # Check that the form data is correctly loaded in executor.data + assert isinstance(executor.data, dict) + assert "text_field" in executor.data + assert executor.data["text_field"] == "Hello, World!" + assert "number_field" in executor.data + assert executor.data["number_field"] == "42" + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /upload HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: multipart/form-data" in raw_request + assert "text_field" in raw_request + assert "Hello, World!" in raw_request + assert "number_field" in raw_request + assert "42" in raw_request diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 263230d049..02e23429ce 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1115,7 +1115,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/retrieve" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ "query": "test", "retrieval_model": { "search_method": "keyword_search", diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 9c25d1e7bb..e5d5f56120 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1116,7 +1116,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/retrieve" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ "query": "test", "retrieval_model": { "search_method": "keyword_search", diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 639cb2fad1..2bb11a870c 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -468,8 +468,8 @@ const Configuration: FC = () => { transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], }, enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), - allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], - allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video], + allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`), allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, fileUploadConfig: fileUploadConfigResponse, diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 05aaaa6bc2..32d841148a 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -1,6 +1,5 @@ import { useCallback, - useRef, useState, } from 'react' import Textarea from 'rc-textarea' @@ -63,7 +62,6 @@ const ChatInputArea = ({ isMultipleLine, } = useTextAreaHeight() const [query, setQuery] = useState('') - const isUseInputMethod = useRef(false) const [showVoiceInput, setShowVoiceInput] = useState(false) const filesStore = useFileStore() const { @@ -95,20 +93,11 @@ const ChatInputArea = ({ } } - const handleKeyUp = (e: React.KeyboardEvent) => { - if (e.key === 'Enter') { - e.preventDefault() - // prevent send message when using input method enter - if (!e.shiftKey && !isUseInputMethod.current) - handleSend() - } - } - const handleKeyDown = (e: React.KeyboardEvent) => { - isUseInputMethod.current = e.nativeEvent.isComposing - if (e.key === 'Enter' && !e.shiftKey) { - setQuery(query.replace(/\n$/, '')) + if (e.key === 'Enter' && !e.shiftKey && !e.nativeEvent.isComposing) { e.preventDefault() + setQuery(query.replace(/\n$/, '')) + handleSend() } } @@ -165,7 +154,6 @@ const ChatInputArea = ({ setQuery(e.target.value) handleTextareaResize() }} - onKeyUp={handleKeyUp} onKeyDown={handleKeyDown} onPaste={handleClipboardPasteFile} onDragEnter={handleDragFileEnter} diff --git a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx index d580c00102..4b24bcb931 100644 --- a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx @@ -120,7 +120,7 @@ const ConfigCredential: FC = ({ setTempCredential({ ...tempCredential, api_key_header: e.target.value })} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' + className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs' placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} /> @@ -129,7 +129,7 @@ const ConfigCredential: FC = ({ setTempCredential({ ...tempCredential, api_key_value: e.target.value })} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' + className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs' placeholder={t('tools.createTool.authMethod.types.apiValuePlaceholder')!} /> diff --git a/web/app/components/tools/edit-custom-collection-modal/get-schema.tsx b/web/app/components/tools/edit-custom-collection-modal/get-schema.tsx index 7b0244e3e3..2552c67568 100644 --- a/web/app/components/tools/edit-custom-collection-modal/get-schema.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/get-schema.tsx @@ -70,7 +70,7 @@ const GetSchema: FC = ({
setImportUrl(e.target.value)} @@ -89,7 +89,7 @@ const GetSchema: FC = ({
)} -
+