From 86d3fff666adb627abf433dd45bd0d673c38263c Mon Sep 17 00:00:00 2001 From: kurokobo Date: Tue, 18 Mar 2025 15:37:07 +0900 Subject: [PATCH 1/4] fix: respect resolution settings for vision for basic chatbot, text generator, and parameter extractor node (#16041) --- api/core/app/apps/base_app_runner.py | 9 ++++++- api/core/app/apps/chat/app_runner.py | 13 ++++++++++ api/core/app/apps/completion/app_runner.py | 13 ++++++++++ api/core/prompt/advanced_prompt_transform.py | 25 +++++++++++++++---- api/core/prompt/simple_prompt_transform.py | 23 +++++++++++++---- .../parameter_extractor_node.py | 12 +++++++++ 6 files changed, 84 insertions(+), 11 deletions(-) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 07a248d77a..8c6b29731e 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -17,7 +17,11 @@ from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, +) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.moderation.input_moderation import InputModeration @@ -141,6 +145,7 @@ class AppRunner: query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages @@ -167,6 +172,7 @@ class AppRunner: context=context, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) @@ -201,6 +207,7 @@ class AppRunner: memory_config=memory_config, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) stop = model_config.stop diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 425f1ab7ef..46c8031633 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db @@ -50,6 +51,16 @@ class ChatAppRunner(AppRunner): query = application_generate_entity.query files = application_generate_entity.files + image_detail_config = ( + application_generate_entity.file_upload_config.image_config.detail + if ( + application_generate_entity.file_upload_config + and application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + # Pre-calculate the number of tokens of the prompt messages, # and return the rest number of tokens by model context token size limit and max token size limit. # If the rest number of tokens is not enough, raise exception. @@ -85,6 +96,7 @@ class ChatAppRunner(AppRunner): files=files, query=query, memory=memory, + image_detail_config=image_detail_config, ) # moderation @@ -182,6 +194,7 @@ class ChatAppRunner(AppRunner): query=query, context=context, memory=memory, + image_detail_config=image_detail_config, ) # check hosting moderation diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 41278b75b4..0ed06c9c98 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db @@ -43,6 +44,16 @@ class CompletionAppRunner(AppRunner): query = application_generate_entity.query files = application_generate_entity.files + image_detail_config = ( + application_generate_entity.file_upload_config.image_config.detail + if ( + application_generate_entity.file_upload_config + and application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + # Pre-calculate the number of tokens of the prompt messages, # and return the rest number of tokens by model context token size limit and max token size limit. # If the rest number of tokens is not enough, raise exception. @@ -66,6 +77,7 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, + image_detail_config=image_detail_config, ) # moderation @@ -140,6 +152,7 @@ class CompletionAppRunner(AppRunner): files=files, query=query, context=context, + image_detail_config=image_detail_config, ) # check hosting moderation diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 87c7a79fb0..c7427f797e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -46,6 +46,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -59,6 +60,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): prompt_messages = self._get_chat_model_prompt_messages( @@ -70,6 +72,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) return prompt_messages @@ -84,6 +87,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: """ Get completion model prompt messages. @@ -124,7 +128,9 @@ class AdvancedPromptTransform(PromptTransform): prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -142,6 +148,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. @@ -197,7 +204,9 @@ class AdvancedPromptTransform(PromptTransform): prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -209,19 +218,25 @@ class AdvancedPromptTransform(PromptTransform): # get last user message content and add files prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) last_message.content = prompt_message_contents else: prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e75877de9b..421b14e0df 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, PromptMessage, PromptMessageContent, SystemPromptMessage, @@ -60,6 +61,7 @@ class SimplePromptTransform(PromptTransform): context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} @@ -74,6 +76,7 @@ class SimplePromptTransform(PromptTransform): context=context, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -85,6 +88,7 @@ class SimplePromptTransform(PromptTransform): context=context, memory=memory, model_config=model_config, + image_detail_config=image_detail_config, ) return prompt_messages, stops @@ -175,6 +179,7 @@ class SimplePromptTransform(PromptTransform): files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages: list[PromptMessage] = [] @@ -204,9 +209,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self.get_last_user_message(query, files)) + prompt_messages.append(self.get_last_user_message(query, files, image_detail_config)) else: - prompt_messages.append(self.get_last_user_message(prompt, files)) + prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config)) return prompt_messages, None @@ -220,6 +225,7 @@ class SimplePromptTransform(PromptTransform): files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( @@ -262,14 +268,21 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self.get_last_user_message(prompt, files)], stops + return [self.get_last_user_message(prompt, files, image_detail_config)], stops - def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: + def get_last_user_message( + self, + prompt: str, + files: Sequence["File"], + image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file)) + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) prompt_message = UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index e147caacf3..7b1b8cf483 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities import ImagePromptMessageContent from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -129,6 +130,7 @@ class ParameterExtractorNode(LLMNode): model_config=model_config, memory=memory, files=files, + vision_detail=node_data.vision.configs.detail, ) else: # use prompt engineering @@ -139,6 +141,7 @@ class ParameterExtractorNode(LLMNode): model_config=model_config, memory=memory, files=files, + vision_detail=node_data.vision.configs.detail, ) prompt_message_tools = [] @@ -267,6 +270,7 @@ class ParameterExtractorNode(LLMNode): model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], files: Sequence[File], + vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -289,6 +293,7 @@ class ParameterExtractorNode(LLMNode): memory_config=node_data.memory, memory=None, model_config=model_config, + image_detail_config=vision_detail, ) # find last user message @@ -347,6 +352,7 @@ class ParameterExtractorNode(LLMNode): model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], files: Sequence[File], + vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -361,6 +367,7 @@ class ParameterExtractorNode(LLMNode): model_config=model_config, memory=memory, files=files, + vision_detail=vision_detail, ) elif model_mode == ModelMode.CHAT: return self._generate_prompt_engineering_chat_prompt( @@ -370,6 +377,7 @@ class ParameterExtractorNode(LLMNode): model_config=model_config, memory=memory, files=files, + vision_detail=vision_detail, ) else: raise InvalidModelModeError(f"Invalid model mode: {model_mode}") @@ -382,6 +390,7 @@ class ParameterExtractorNode(LLMNode): model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], files: Sequence[File], + vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: """ Generate completion prompt. @@ -402,6 +411,7 @@ class ParameterExtractorNode(LLMNode): memory_config=node_data.memory, memory=memory, model_config=model_config, + image_detail_config=vision_detail, ) return prompt_messages @@ -414,6 +424,7 @@ class ParameterExtractorNode(LLMNode): model_config: ModelConfigWithCredentialsEntity, memory: Optional[TokenBufferMemory], files: Sequence[File], + vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> list[PromptMessage]: """ Generate chat prompt. @@ -441,6 +452,7 @@ class ParameterExtractorNode(LLMNode): memory_config=node_data.memory, memory=None, model_config=model_config, + image_detail_config=vision_detail, ) # find last user message From 750ec55646e913bdea41abdeb2d6db3d24e0ef44 Mon Sep 17 00:00:00 2001 From: yihong Date: Tue, 18 Mar 2025 14:57:14 +0800 Subject: [PATCH 2/4] doc: auto correct the doc using autocorrect close #16091 (#16092) Signed-off-by: yihong0618 --- CONTRIBUTING_CN.md | 4 ++-- README_CN.md | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 52606c7896..d7199bd928 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -26,7 +26,7 @@ | [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 | | [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 | | [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 | - | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 | + | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验,综合事项联系人 | | [@takatost](https://github.com/takatost) | 产品整体方向和架构 | 事项优先级: @@ -47,7 +47,7 @@ | ------------------------------------------------------------ | --------------- | | 核心功能的 Bugs(例如无法登录、应用无法工作、安全漏洞) | 紧急 | | 非紧急 bugs, 性能提升 | 中等优先级 | - | 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 | + | 小幅修复 (错别字,能正常工作但存在误导的 UI) | 低优先级 | ## 安装 diff --git a/README_CN.md b/README_CN.md index 6c57b9f59c..a05ef17365 100644 --- a/README_CN.md +++ b/README_CN.md @@ -79,7 +79,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。 **5. Agent 智能体**: - 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。 + 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了 50 多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。 **6. LLMOps**: 随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。 @@ -112,7 +112,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 仅限 OpenAI - RAG引擎 + RAG 引擎 ✅ ✅ ✅ @@ -234,7 +234,7 @@ docker compose up -d 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 -> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 +> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 **Contributors** From 33ba7e659be1bad68dae04f3ee95b2108474d074 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:07:29 +0800 Subject: [PATCH 3/4] fix vector db sql injection (#16096) --- .../rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py | 4 ++++ api/core/rag/datasource/vdb/myscale/myscale_vector.py | 2 ++ api/core/rag/datasource/vdb/opengauss/opengauss.py | 6 ++++-- api/core/rag/datasource/vdb/pgvector/pgvector.py | 5 ++++- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 4d8f792941..0b2f4cf6e2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -194,6 +194,8 @@ class AnalyticdbVectorBySql: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") score_threshold = float(kwargs.get("score_threshold") or 0.0) with self._get_cursor() as cur: query_vector_str = json.dumps(query_vector) @@ -220,6 +222,8 @@ class AnalyticdbVectorBySql: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"""SELECT id, vector, page_content, metadata_, diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 556b952ec2..3223010966 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -125,6 +125,8 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") score_threshold = float(kwargs.get("score_threshold") or 0.0) where_str = ( f"WHERE dist < {1 - score_threshold}" diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py index 7b51eb4bd8..2e5b7a31e4 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py @@ -155,7 +155,8 @@ class OpenGauss(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" @@ -174,7 +175,8 @@ class OpenGauss(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 13c214bfd7..b6153a5b09 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -171,6 +171,8 @@ class PGVector(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( @@ -190,7 +192,8 @@ class PGVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: if self.pg_bigm: cur.execute("SET pg_bigm.similarity_limit TO 0.000001") From 6f6ba2f025a5c2edb759d962d7a20e0117736abf Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:07:53 +0800 Subject: [PATCH 4/4] fix(api): enhance provider model records handling for missing langgenius providers (#16089) --- api/core/provider_manager.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7d8ff18983..099acfd7f4 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -149,6 +149,11 @@ class ProviderManager: provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) + provider_id_entity = ModelProviderID(provider_name) + if provider_id_entity.is_langgenius(): + provider_model_records.extend( + provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) + ) # Convert to custom configuration custom_configuration = self._to_custom_configuration( @@ -190,6 +195,20 @@ class ProviderManager: provider_name ) + provider_id_entity = ModelProviderID(provider_name) + + if provider_id_entity.is_langgenius(): + if provider_model_settings is not None: + provider_model_settings.extend( + provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, []) + ) + if provider_load_balancing_configs is not None: + provider_load_balancing_configs.extend( + provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_id_entity.provider_name, [] + ) + ) + # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, @@ -207,7 +226,7 @@ class ProviderManager: model_settings=model_settings, ) - provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration + provider_configurations[str(provider_id_entity)] = provider_configuration # Return the encapsulated object return provider_configurations