From 0d2f7dd6884b1d51d364c277c178fccfad602e9f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 10 Mar 2025 19:30:22 +0800 Subject: [PATCH] fix metadata --- api/core/rag/entities/metadata_entities.py | 45 ++++++++ api/core/rag/retrieval/dataset_retrieval.py | 101 ++++++++++++------ .../knowledge_retrieval_node.py | 72 ++++++++----- api/services/external_knowledge_service.py | 4 +- 4 files changed, 165 insertions(+), 57 deletions(-) create mode 100644 api/core/rag/entities/metadata_entities.py diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py new file mode 100644 index 0000000000..9ff32b98c0 --- /dev/null +++ b/api/core/rag/entities/metadata_entities.py @@ -0,0 +1,45 @@ +from collections.abc import Sequence +from typing import Literal, Optional + +from pydantic import BaseModel, Field + +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + # for time + "before", + "after", +] + + +class Condition(BaseModel): + """ + Conditon detail + """ + + name: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None | int | float = None + + +class MetadataCondition(BaseModel): + """ + Metadata Condition. + """ + + logical_operator: Optional[Literal["and", "or"]] = "and" + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index b52a24dee7..45e520d323 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -7,6 +7,8 @@ from collections.abc import Generator, Mapping from typing import Any, Optional, Union, cast from flask import Flask, current_app +from sqlalchemy import Integer, and_, or_ +from sqlalchemy import cast as sqlalchemy_cast from core.app.app_config.entities import ( DatasetEntity, @@ -34,6 +36,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.context_entities import DocumentContext +from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -144,7 +147,7 @@ class DatasetRetrieval: else: inputs = {} available_datasets_ids = [dataset.id for dataset in available_datasets] - metadata_filter_document_ids = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( available_datasets_ids, query, tenant_id, @@ -154,6 +157,7 @@ class DatasetRetrieval: retrieve_config.metadata_filtering_conditions, inputs, ) + all_documents = [] user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: @@ -169,6 +173,7 @@ class DatasetRetrieval: planning_strategy, message_id, metadata_filter_document_ids, + metadata_condition, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( @@ -186,6 +191,7 @@ class DatasetRetrieval: retrieve_config.reranking_enabled or True, message_id, metadata_filter_document_ids, + metadata_condition, ) dify_documents = [item for item in all_documents if item.provider == "dify"] @@ -279,6 +285,7 @@ class DatasetRetrieval: planning_strategy: PlanningStrategy, message_id: Optional[str] = None, metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, + metadata_condition: Optional[MetadataCondition] = None, ): tools = [] for dataset in available_datasets: @@ -319,6 +326,7 @@ class DatasetRetrieval: dataset_id=dataset_id, query=query, external_retrieval_parameters=dataset.retrieval_model, + metadata_condition=metadata_condition, ) for external_document in external_documents: document = Document( @@ -333,11 +341,15 @@ class DatasetRetrieval: document.metadata["dataset_name"] = dataset.name results.append(document) else: + if metadata_condition and not metadata_filter_document_ids: + return [] document_ids_filter = None if metadata_filter_document_ids: document_ids = metadata_filter_document_ids.get(dataset.id, []) if document_ids: document_ids_filter = document_ids + else: + return [] retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k @@ -395,6 +407,7 @@ class DatasetRetrieval: reranking_enable: bool = True, message_id: Optional[str] = None, metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, + metadata_condition: Optional[MetadataCondition] = None, ): if not available_datasets: return [] @@ -435,10 +448,15 @@ class DatasetRetrieval: for dataset in available_datasets: index_type = dataset.indexing_technique document_ids_filter = None - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue retrieval_thread = threading.Thread( target=self._retriever, kwargs={ @@ -448,6 +466,7 @@ class DatasetRetrieval: "top_k": top_k, "all_documents": all_documents, "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, }, ) threads.append(retrieval_thread) @@ -529,7 +548,7 @@ class DatasetRetrieval: db.session.commit() def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, - document_ids_filter: Optional[list[str]] = None): + document_ids_filter: Optional[list[str]] = None, metadata_condition: Optional[MetadataCondition] = None): with flask_app.app_context(): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() @@ -542,6 +561,7 @@ class DatasetRetrieval: dataset_id=dataset_id, query=query, external_retrieval_parameters=dataset.retrieval_model, + metadata_condition=metadata_condition, ) for external_document in external_documents: document = Document( @@ -781,43 +801,61 @@ class DatasetRetrieval: metadata_model_config: ModelConfig, metadata_filtering_conditions: Optional[MetadataFilteringCondition], inputs: dict, - ) -> Optional[dict[str, list[str]]]: + ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: document_query = db.session.query(DatasetDocument).filter( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) + filters = [] + metadata_condition = None if metadata_filtering_mode == "disabled": - return None + return None, None elif metadata_filtering_mode == "automatic": automatic_metadata_filters = self._automatic_metadata_filter_func( dataset_ids, query, tenant_id, user_id, metadata_model_config ) if automatic_metadata_filters: + conditions = [] for filter in automatic_metadata_filters: - document_query = self._process_metadata_filter_func( - filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query + self._process_metadata_filter_func( + filter.get("condition"), filter.get("metadata_name"), filter.get("value"), filters ) + conditions.append(Condition( + name=filter.get("metadata_name"), + comparison_operator=filter.get("condition"), + value=filter.get("value"), + )) + metadata_condition = MetadataCondition( + logical_operator=metadata_filtering_conditions.logical_operator, + conditions=conditions, + ) elif metadata_filtering_mode == "manual": if metadata_filtering_conditions: + metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) for condition in metadata_filtering_conditions.conditions: metadata_name = condition.name expected_value = condition.value if expected_value: if isinstance(expected_value, str): expected_value = self._replace_metadata_filter_value(expected_value, inputs) - document_query = self._process_metadata_filter_func( - condition.comparison_operator, metadata_name, expected_value, document_query + filters = self._process_metadata_filter_func( + condition.comparison_operator, metadata_name, expected_value, filters ) else: raise ValueError("Invalid metadata filtering mode") - documnents = document_query.all() + if filters: + if metadata_filtering_conditions.logical_operator == "or": + document_query = document_query.filter(or_(*filters)) + else: + document_query = document_query.filter(and_(*filters)) + documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) - for document in documnents: + for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) - return metadata_filter_document_ids + return metadata_filter_document_ids, metadata_condition def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: def replacer(match): @@ -882,41 +920,42 @@ class DatasetRetrieval: return None return automatic_metadata_filters - def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, query): + def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, filters: list): match condition: case "contains": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].like(f'"%{value}%"')) + filters.append(DatasetDocument.doc_metadata[metadata_name].like(f'"%{value}%"')) case "not contains": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].notlike(f'"%{value}%"')) + filters.append(DatasetDocument.doc_metadata[metadata_name].notlike(f'"%{value}%"')) case "start with": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].like(f'"{value}%"')) + filters.append(DatasetDocument.doc_metadata[metadata_name].like(f'"{value}%"')) + case "end with": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].like(f'"%{value}"')) + filters.append(DatasetDocument.doc_metadata[metadata_name].like(f'"%{value}"')) case "is" | "=": if isinstance(value, str): - query = query.filter(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"') + filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"') else: - query = query.filter(DatasetDocument.doc_metadata[metadata_name] == value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value) case "is not" | "≠": if isinstance(value, str): - query = query.filter(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"') + filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"') else: - query = query.filter(DatasetDocument.doc_metadata[metadata_name] != value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value) case "is empty": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].is_(None)) + filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None)) case "is not empty": - query = query.filter(DatasetDocument.doc_metadata[metadata_name].isnot(None)) + filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None)) case "before" | "<": - query = query.filter(DatasetDocument.doc_metadata[metadata_name] < value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value) case "after" | ">": - query = query.filter(DatasetDocument.doc_metadata[metadata_name] > value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value) case "≤" | ">=": - query = query.filter(DatasetDocument.doc_metadata[metadata_name] <= value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value) case "≥" | ">=": - query = query.filter(DatasetDocument.doc_metadata[metadata_name] >= value) + filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value) case _: pass - return query + return filters def _fetch_model_config( self, tenant_id: str, model: ModelConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e935cbb15b..fedf00458b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from sqlalchemy import func +from sqlalchemy import and_, func, or_ from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -16,6 +16,7 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment @@ -135,7 +136,7 @@ class KnowledgeRetrievalNode(LLMNode): if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( [dataset.id for dataset in available_datasets], query, node_data ) all_documents = [] @@ -168,6 +169,7 @@ class KnowledgeRetrievalNode(LLMNode): model_instance=model_instance, planning_strategy=planning_strategy, metadata_filter_document_ids=metadata_filter_document_ids, + metadata_condition=metadata_condition, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: if node_data.multiple_retrieval_config is None: @@ -215,6 +217,7 @@ class KnowledgeRetrievalNode(LLMNode): weights=weights, reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, + metadata_condition=metadata_condition, ) dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] @@ -283,7 +286,7 @@ class KnowledgeRetrievalNode(LLMNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> Optional[dict[str, list[str]]]: + ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: document_query = db.session.query(Document).filter( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -291,33 +294,51 @@ class KnowledgeRetrievalNode(LLMNode): Document.archived == False, ) if node_data.metadata_filtering_mode == "disabled": - return None + return None, None elif node_data.metadata_filtering_mode == "automatic": automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) if automatic_metadata_filters: + conditions = [] + filters = [] for filter in automatic_metadata_filters: - document_query = self._process_metadata_filter_func( - filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query + self._process_metadata_filter_func( + filter.get("condition"), filter.get("metadata_name"), filter.get("value"), filters ) + conditions.append(Condition( + name=filter.get("metadata_name"), + comparison_operator=filter.get("condition"), + value=filter.get("value"), + )) + metadata_condition = MetadataCondition( + logical_operator="or", + conditions=conditions, + ) elif node_data.metadata_filtering_mode == "manual": if node_data.metadata_filtering_conditions: for condition in node_data.metadata_filtering_conditions.conditions: + filters = [] metadata_name = condition.name expected_value = condition.value if expected_value: if isinstance(expected_value, str): expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).text - document_query = self._process_metadata_filter_func( - condition.comparison_operator, metadata_name, expected_value, document_query + + filters = self._process_metadata_filter_func( + condition.comparison_operator, metadata_name, expected_value, filters ) else: raise ValueError("Invalid metadata filtering mode") + if filters: + if node_data.metadata_filtering_conditions.logical_operator == "and": + document_query = document_query.filter(and_(*filters)) + else: + document_query = document_query.filter(or_(*filters)) documnents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) for document in documnents: metadata_filter_document_ids[document.dataset_id].append(document.id) - return metadata_filter_document_ids + return metadata_filter_document_ids, metadata_condition def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData @@ -382,41 +403,42 @@ class KnowledgeRetrievalNode(LLMNode): return [] return automatic_metadata_filters - def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, query): + def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, filters: list): match condition: case "contains": - query = query.filter(Document.doc_metadata[metadata_name].like(f'"%{value}%"')) + filters.append(Document.doc_metadata[metadata_name].like(f'"%{value}%"')) case "not contains": - query = query.filter(Document.doc_metadata[metadata_name].notlike(f'"%{value}%"')) + filters.append(Document.doc_metadata[metadata_name].notlike(f'"%{value}%"')) case "start with": - query = query.filter(Document.doc_metadata[metadata_name].like(f'"{value}%"')) + filters.append(Document.doc_metadata[metadata_name].like(f'"{value}%"')) case "end with": - query = query.filter(Document.doc_metadata[metadata_name].like(f'"%{value}"')) + filters.append(Document.doc_metadata[metadata_name].like(f'"%{value}"')) case "=" | "is": if isinstance(value, str): - query = query.filter(Document.doc_metadata[metadata_name] == f'"{value}"') + filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') else: - query = query.filter(Document.doc_metadata[metadata_name] == value) + filters.append(Document.doc_metadata[metadata_name] == value) case "is not" | "≠": if isinstance(value, str): - query = query.filter(Document.doc_metadata[metadata_name] != f'"{value}"') + filters.append(Document.doc_metadata[metadata_name] != f'"{value}"') else: - query = query.filter(Document.doc_metadata[metadata_name] != value) + filters.append(Document.doc_metadata[metadata_name] != value) case "is empty": - query = query.filter(Document.doc_metadata[metadata_name].is_(None)) + filters.append(Document.doc_metadata[metadata_name].is_(None)) case "is not empty": - query = query.filter(Document.doc_metadata[metadata_name].isnot(None)) + filters.append(Document.doc_metadata[metadata_name].isnot(None)) case "before" | "<": - query = query.filter(Document.doc_metadata[metadata_name] < value) + filters.append(Document.doc_metadata[metadata_name] < value) case "after" | ">": - query = query.filter(Document.doc_metadata[metadata_name] > value) + filters.append(Document.doc_metadata[metadata_name] > value) case "≤" | ">=": - query = query.filter(Document.doc_metadata[metadata_name] <= value) + filters.append(Document.doc_metadata[metadata_name] <= value) case "≥" | ">=": - query = query.filter(Document.doc_metadata[metadata_name] >= value) + filters.append(Document.doc_metadata[metadata_name] >= value) case _: pass - return query + return filters + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 8916a951c7..822b458189 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -8,6 +8,7 @@ import validators from constants import HIDDEN_VALUE from core.helper import ssrf_proxy +from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db from models.dataset import ( Dataset, @@ -245,7 +246,7 @@ class ExternalDatasetService: @staticmethod def fetch_external_knowledge_retrieval( - tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict + tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict, metadata_condition: Optional[MetadataCondition] = None ) -> list: external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id @@ -272,6 +273,7 @@ class ExternalDatasetService: }, "query": query, "knowledge_id": external_knowledge_binding.external_knowledge_id, + "metadata_condition": metadata_condition.model_dump() if metadata_condition else None, } response = ExternalDatasetService.process_external_api(