From 67f2c766bc01cccd5b74bdcbc1a72aed3e5090e2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 26 Feb 2025 19:56:19 +0800 Subject: [PATCH] dataset metadata update --- .../console/datasets/datasets_segments.py | 4 +- api/controllers/console/datasets/metadata.py | 143 ++++++++++ .../rag/datasource/keyword/jieba/jieba.py | 11 +- api/core/rag/datasource/retrieval_service.py | 19 +- .../vdb/analyticdb/analyticdb_vector.py | 2 +- .../vdb/analyticdb/analyticdb_vector_sql.py | 14 +- .../rag/datasource/vdb/baidu/baidu_vector.py | 20 +- .../datasource/vdb/chroma/chroma_vector.py | 10 +- .../vdb/elasticsearch/elasticsearch_vector.py | 6 + .../datasource/vdb/lindorm/lindorm_vector.py | 12 +- .../datasource/vdb/milvus/milvus_vector.py | 12 + .../datasource/vdb/myscale/myscale_vector.py | 4 + .../vdb/oceanbase/oceanbase_vector.py | 6 + .../vdb/opensearch/opensearch_vector.py | 6 + .../rag/datasource/vdb/oracle/oraclevector.py | 15 +- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 3 + .../rag/datasource/vdb/pgvector/pgvector.py | 12 + .../datasource/vdb/qdrant/qdrant_vector.py | 79 +++--- .../rag/datasource/vdb/relyt/relyt_vector.py | 10 +- .../datasource/vdb/tencent/tencent_vector.py | 7 +- .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 16 ++ .../datasource/vdb/tidb_vector/tidb_vector.py | 6 + .../datasource/vdb/upstash/upstash_vector.py | 14 +- api/core/rag/datasource/vdb/vector_base.py | 2 +- .../vdb/vikingdb/vikingdb_vector.py | 6 +- .../vdb/weaviate/weaviate_vector.py | 32 ++- .../constant/built_in_field.py | 9 + api/core/rag/retrieval/dataset_retrieval.py | 7 + api/core/workflow/nodes/code/code_node.py | 8 +- .../nodes/knowledge_retrieval/entities.py | 45 ++- .../workflow/nodes/knowledge_retrieval/exc.py | 4 + .../knowledge_retrieval_node.py | 265 +++++++++++++++++- api/fields/dataset_fields.py | 3 + api/fields/document_fields.py | 9 + api/models/dataset.py | 38 ++- api/services/dataset_service.py | 49 ++-- .../knowledge_entities/knowledge_entities.py | 25 +- api/services/metadata_service.py | 182 ++++++++++++ api/tasks/update_documents_metadata_task.py | 121 ++++++++ 39 files changed, 1112 insertions(+), 124 deletions(-) create mode 100644 api/controllers/console/datasets/metadata.py create mode 100644 api/core/rag/index_processor/constant/built_in_field.py create mode 100644 api/services/metadata_service.py create mode 100644 api/tasks/update_documents_metadata_task.py diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 19255c618a..d2c94045ad 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -88,9 +88,9 @@ class DatasetDocumentSegmentListApi(Resource): if args["enabled"].lower() != "all": if args["enabled"].lower() == "true": - query = query.filter(DocumentSegment.enabled == True) # noqa: E712 + query = query.filter(DocumentSegment.enabled == True) elif args["enabled"].lower() == "false": - query = query.filter(DocumentSegment.enabled == False) # noqa: E712 + query = query.filter(DocumentSegment.enabled == False) segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py new file mode 100644 index 0000000000..c6f1768ec8 --- /dev/null +++ b/api/controllers/console/datasets/metadata.py @@ -0,0 +1,143 @@ +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from fields.dataset_fields import dataset_metadata_fields +from libs.login import login_required +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import ( + MetadataArgs, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class DatasetListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + @marshal_with(dataset_metadata_fields) + def post(self, dataset_id): + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + args = parser.parse_args() + metadata_args = MetadataArgs(**args) + + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_permission(dataset, current_user) + + metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) + return metadata, 201 + + +class DatasetMetadataApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def patch(self, dataset_id, metadata_id): + parser = reqparse.RequestParser() + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + args = parser.parse_args() + + dataset_id_str = str(dataset_id) + metadata_id_str = str(metadata_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_permission(dataset, current_user) + + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + return metadata, 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def delete(self, dataset_id, metadata_id): + dataset_id_str = str(dataset_id) + metadata_id_str = str(metadata_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_permission(dataset, current_user) + + MetadataService.delete_metadata(dataset_id_str, metadata_id_str) + return 200 + + +class DatasetMetadataBuiltInFieldApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + built_in_fields = MetadataService.get_built_in_fields() + return built_in_fields, 200 + + +class DatasetMetadataBuiltInFieldActionApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, dataset_id, action): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_permission(dataset, current_user) + + if action == "enable": + MetadataService.enable_built_in_field(dataset) + elif action == "disable": + MetadataService.disable_built_in_field(dataset) + return 200 + + +class DocumentMetadataApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_permission(dataset, current_user) + + parser = reqparse.RequestParser() + parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") + args = parser.parse_args() + metadata_args = MetadataOperationData(**args) + + MetadataService.update_documents_metadata(dataset, metadata_args) + + return 200 + + +api.add_resource(DatasetListApi, "/datasets//metadata") +api.add_resource(DatasetMetadataApi, "/datasets//metadata/") +api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") +api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/") +api.add_resource(DocumentMetadataApi, "/datasets//documents/metadata") diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 95a2316f1d..d6d0bd88b2 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -88,16 +88,17 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() k = kwargs.get("top_k", 4) - + document_ids_filter = kwargs.get("document_ids_filter") sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = ( - db.session.query(DocumentSegment) - .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) - .first() + segment_query = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index ) + if document_ids_filter: + segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) + segment = segment_query.first() if segment: documents.append( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 927df0efc4..8bb2ed21e5 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -38,6 +38,7 @@ class RetrievalService: reranking_model: Optional[dict] = None, reranking_mode: str = "reranking_model", weights: Optional[dict] = None, + document_ids_filter: Optional[list[str]] = None, ): if not query: return [] @@ -61,6 +62,7 @@ class RetrievalService: "top_k": top_k, "all_documents": all_documents, "exceptions": exceptions, + "document_ids_filter": document_ids_filter, }, ) threads.append(keyword_thread) @@ -79,6 +81,7 @@ class RetrievalService: "all_documents": all_documents, "retrieval_method": retrieval_method, "exceptions": exceptions, + "document_ids_filter": document_ids_filter, }, ) threads.append(embedding_thread) @@ -98,6 +101,7 @@ class RetrievalService: "reranking_model": reranking_model, "all_documents": all_documents, "exceptions": exceptions, + "document_ids_filter": document_ids_filter, }, ) threads.append(full_text_index_thread) @@ -135,7 +139,14 @@ class RetrievalService: @classmethod def keyword_search( - cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + all_documents: list, + exceptions: list, + document_ids_filter: Optional[list[str]] = None, ): with flask_app.app_context(): try: @@ -145,7 +156,9 @@ class RetrievalService: keyword = Keyword(dataset=dataset) - documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) + documents = keyword.search( + cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter + ) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @@ -162,6 +175,7 @@ class RetrievalService: all_documents: list, retrieval_method: str, exceptions: list, + document_ids_filter: Optional[list[str]] = None, ): with flask_app.app_context(): try: @@ -177,6 +191,7 @@ class RetrievalService: top_k=top_k, score_threshold=score_threshold, filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, ) if documents: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 603d3fdbcd..b9e488362e 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector): self.analyticdb_vector.delete_by_metadata_field(key, value) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - return self.analyticdb_vector.search_by_vector(query_vector) + return self.analyticdb_vector.search_by_vector(query_vector, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self.analyticdb_vector.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 4d8f792941..884fc0e3eb 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -194,6 +194,11 @@ class AnalyticdbVectorBySql: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "WHERE 1=1" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"AND metadata_->>'doc_id' IN ({doc_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) with self._get_cursor() as cur: query_vector_str = json.dumps(query_vector) @@ -202,7 +207,7 @@ class AnalyticdbVectorBySql: f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " f"t.page_content as page_content, t.metadata_ AS metadata_ " f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " - f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", + f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t", (query_vector_str,), ) documents = [] @@ -220,12 +225,17 @@ class AnalyticdbVectorBySql: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"AND metadata_->>'doc_id' IN ({doc_ids})" with self._get_cursor() as cur: cur.execute( f"""SELECT id, vector, page_content, metadata_, ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score FROM {self.table_name} - WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') + WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} ORDER BY score DESC LIMIT {top_k}""", (f"'{query}'", f"'{query}'"), diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index a658495af7..fd29166b1a 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -123,11 +123,21 @@ class BaiduVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] - anns = AnnSearch( - vector_field=self.field_vector, - vector_floats=query_vector, - params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), - ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + filter=f"doc_id IN ({doc_ids})", + ) + else: + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + ) res = self._db.table(self._collection_name).search( anns=anns, projections=[self.field_id, self.field_text, self.field_metadata], diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 907c4d2285..0cf08363df 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -95,7 +95,15 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) - results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + results: QueryResult = collection.query( + query_embeddings=query_vector, + n_results=kwargs.get("top_k", 4), + where={"doc_id": {"$in": document_ids_filter}}, + ) + else: + results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) # Check if results contain data diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index cca696baee..93f5d8f547 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector): top_k = kwargs.get("top_k", 4) num_candidates = math.ceil(top_k * 1.5) knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + knn["filter"] = {"terms": {"metadata.doc_id": document_ids_filter}} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) @@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str = {"match": {Field.CONTENT_KEY.value: query}} + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + query_str["filter"] = {"terms": {"metadata.doc_id": document_ids_filter}} results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 66fba763e7..ace4f56cd0 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector): raise ValueError("All elements in query_vector should be floats") top_k = kwargs.get("top_k", 10) - query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) + document_ids_filter = kwargs.get("document_ids_filter") + filters = [] + if document_ids_filter: + filters.append({"terms": {"metadata.doc_id": document_ids_filter}}) + query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs) + try: params = {} if self._using_ugc: @@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector): should = kwargs.get("should") minimum_should_match = kwargs.get("minimum_should_match", 0) top_k = kwargs.get("top_k", 10) - filters = kwargs.get("filter") + filters = kwargs.get("filter", []) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filters.append({"terms": {"metadata.doc_id": document_ids_filter}}) routing = self._routing full_text_query = default_text_search_query( query_text=query, diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 9a184f7dd9..479f0fa279 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -218,12 +218,18 @@ class MilvusVector(BaseVector): """ Search for documents by vector similarity. """ + document_ids_filter = kwargs.get("document_ids_filter") + filter = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + filter = f'metadata["doc_id"] in ({doc_ids})' results = self._client.search( collection_name=self._collection_name, data=[query_vector], anns_field=Field.VECTOR.value, limit=kwargs.get("top_k", 4), output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + filter=filter, ) return self._process_search_results( @@ -239,6 +245,11 @@ class MilvusVector(BaseVector): if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") return [] + document_ids_filter = kwargs.get("document_ids_filter") + filter = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + filter = f'metadata["doc_id"] in ({doc_ids})' results = self._client.search( collection_name=self._collection_name, @@ -246,6 +257,7 @@ class MilvusVector(BaseVector): anns_field=Field.SPARSE_VECTOR.value, limit=kwargs.get("top_k", 4), output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + filter=filter, ) return self._process_search_results( diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 556b952ec2..bb4bed4f40 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -131,6 +131,10 @@ class MyScaleVector(BaseVector): if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_str = f"{where_str} AND metadata['doc_id'] in ({doc_ids})" sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 3c2d53ce78..055eff252c 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector): return [] def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = None + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f"metadata->>'$.doc_id' in ({doc_ids})" ef_search = kwargs.get("ef_search", self._hnsw_ef_search) if ef_search != self._hnsw_ef_search: self._client.set_ob_hnsw_ef_search(ef_search) @@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector): distance_func=func.l2_distance, output_column_names=["text", "metadata"], with_dist=True, + where_clause=where_clause, ) docs = [] for text, metadata, distance in cur: diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 72a1502205..7fe8d126af 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector): "size": kwargs.get("top_k", 4), "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + query["query"] = {"terms": {"metadata.doc_id": document_ids_filter}} try: response = self._client.search(index=self._collection_name.lower(), body=query) @@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + full_text_query["query"]["terms"] = {"metadata.doc_id": document_ids_filter} response = self._client.search(index=self._collection_name.lower(), body=full_text_query) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index a58df7eb9f..e7ffa38668 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -185,10 +185,15 @@ class OracleVector(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f"WHERE metadata->>'doc_id' in ({doc_ids})" with self._get_cursor() as cur: cur.execute( f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" - f" ORDER BY distance fetch first {top_k} rows only", + f" {where_clause} ORDER BY distance fetch first {top_k} rows only", [numpy.array(query_vector)], ) docs = [] @@ -241,9 +246,15 @@ class OracleVector(BaseVector): if token not in stop_words: entities.append(token) with self._get_cursor() as cur: + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" AND metadata->>'doc_id' in ({doc_ids}) " cur.execute( f"select meta, text, embedding FROM {self.table_name}" - f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", + f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " + f"order by score(1) desc fetch first {top_k} rows only", [" ACCUM ".join(entities)], ) docs = [] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 221bc68d68..2e520a9efb 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -189,6 +189,9 @@ class PGVectoRS(BaseVector): .limit(kwargs.get("top_k", 4)) .order_by("distance") ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + stmt = stmt.where(self._table.meta["doc_id"].in_(document_ids_filter)) res = session.execute(stmt) results = [(row[0], row[1]) for row in res] diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c8a1e4f90c..c51e800862 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -155,10 +155,16 @@ class PGVector(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" WHERE metadata->>'doc_id' in ({doc_ids}) " with self._get_cursor() as cur: cur.execute( f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" + f" {where_clause}" f" ORDER BY distance LIMIT {top_k}", (json.dumps(query_vector),), ) @@ -176,10 +182,16 @@ class PGVector(BaseVector): top_k = kwargs.get("top_k", 5) with self._get_cursor() as cur: + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" AND metadata->>'doc_id' in ({doc_ids}) " cur.execute( f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score FROM {self.table_name} WHERE to_tsvector(text) @@ plainto_tsquery(%s) + {where_clause} ORDER BY score DESC LIMIT {top_k}""", # f"'{query}'" is required in order to account for whitespace in query diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 93dcb280ed..9a9e110b6c 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -286,27 +286,26 @@ class QdrantVector(BaseVector): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse - for node_id in ids: - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=ids), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] @@ -331,6 +330,14 @@ class QdrantVector(BaseVector): ), ], ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filter.must.append( + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=document_ids_filter), + ) + ) results = self._client.search( collection_name=self._collection_name, query_vector=query_vector, @@ -377,6 +384,14 @@ class QdrantVector(BaseVector): ), ] ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + scroll_filter.must.append( + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=document_ids_filter), + ) + ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, @@ -393,28 +408,6 @@ class QdrantVector(BaseVector): return documents - def update_metadata(self, document_id: str, metadata: dict) -> None: - from qdrant_client.http import models - scroll_filter = models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self._group_id), - ), - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=document_id), - ), - ] - ) - self._client.set_payload( - collection_name=self._collection_name, - filter=scroll_filter, - payload={ - Field.METADATA_KEY.value: metadata, - }, - ) - def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): self._client = cast(QdrantLocal, self._client) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index a3a20448ff..1643abdc71 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -223,8 +223,12 @@ class RelytVector(BaseVector): return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = kwargs.get("filter", {}) + if document_ids_filter: + filter["doc_id"] = document_ids_filter results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter") + k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter ) # Organize results. @@ -246,9 +250,9 @@ class RelytVector(BaseVector): filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})" if len(value) > 1 - else f"metadata->>{key!r} = {value[0]!r}" + else f"metadata->>'{key!r}' = {value[0]!r}" for key, value in filter.items() ] filter_condition = f"WHERE {' AND '.join(conditions)}" diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 1a4fa7b87e..b08dd50fe8 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -145,11 +145,16 @@ class TencentVector(BaseVector): self._db.collection(self._collection_name).delete(document_ids=ids) def delete_by_metadata_field(self, key: str, value: str) -> None: - self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) + self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = Filter(Filter.In("metadata.doc_id", document_ids_filter)) res = self._db.collection(self._collection_name).search( vectors=[query_vector], + filter=filter, params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), retrieve_vector=False, limit=kwargs.get("top_k", 4), diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 549f0175eb..f46ce2b1c7 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -326,6 +326,14 @@ class TidbOnQdrantVector(BaseVector): ), ], ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filter.must.append( + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=document_ids_filter), + ) + ) results = self._client.search( collection_name=self._collection_name, query_vector=query_vector, @@ -368,6 +376,14 @@ class TidbOnQdrantVector(BaseVector): ) ] ) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + scroll_filter.must.append( + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=document_ids_filter), + ) + ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 6dd4be65c8..e54de902d8 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -196,6 +196,11 @@ class TiDBVector(BaseVector): docs = [] tidb_dist_func = self._get_distance_func() + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" WHERE meta->>'$.doc_id' in ({doc_ids}) " with Session(self._engine) as session: select_statement = sql_text(f""" @@ -206,6 +211,7 @@ class TiDBVector(BaseVector): text, {tidb_dist_func}(vector, :query_vector_str) AS distance FROM {self._collection_name} + {where_clause} ORDER BY distance ASC LIMIT :top_k ) t diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py index 5c3fee98a9..0a4bef9f5a 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -88,7 +88,19 @@ class UpstashVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) - result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filter = f"doc_id in ({', '.join(f"'{id}'" for id in document_ids_filter)})" + else: + filter = "" + result = self.index.query( + vector=query_vector, + top_k=top_k, + include_metadata=True, + include_data=True, + include_vectors=False, + filter=filter, + ) docs = [] score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in result: diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 2b10504630..8e5a646a09 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -48,7 +48,7 @@ class BaseVector(ABC): @abstractmethod def delete(self) -> None: raise NotImplementedError - + @abstractmethod def update_metadata(self, document_id: str, metadata: dict) -> None: raise NotImplementedError diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 9de8761a91..7f4c32b9c4 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -177,7 +177,11 @@ class VikingDBVector(BaseVector): query_vector, limit=kwargs.get("top_k", 4) ) score_threshold = float(kwargs.get("score_threshold") or 0.0) - return self._get_search_res(results, score_threshold) + docs = self._get_search_res(results, score_threshold) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + docs = [doc for doc in docs if doc.metadata.get("doc_id") in document_ids_filter] + return docs def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 68d043a19f..7038e431d6 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -168,16 +168,16 @@ class WeaviateVector(BaseVector): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - for uuid in ids: - try: - self._client.data_object.delete( - class_name=self._collection_name, - uuid=uuid, - ) - except weaviate.UnexpectedStatusCodeException as e: - # tolerate not found error - if e.status_code != 404: - raise e + try: + self._client.batch.delete_objects( + class_name=self._collection_name, + where={"operator": "ContainsAny", "path": ["id"], "valueTextArray": ids}, + output="minimal", + ) + except weaviate.UnexpectedStatusCodeException as e: + # tolerate not found error + if e.status_code != 404: + raise e def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Look up similar documents by embedding vector in Weaviate.""" @@ -187,8 +187,10 @@ class WeaviateVector(BaseVector): query_obj = self._client.query.get(collection_name, properties) vector = {"vector": query_vector} - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + where_filter = {"operator": "ContainsAny", "path": ["doc_id"], "valueTextArray": document_ids_filter} + query_obj = query_obj.with_where(where_filter) result = ( query_obj.with_near_vector(vector) .with_limit(kwargs.get("top_k", 4)) @@ -233,8 +235,10 @@ class WeaviateVector(BaseVector): if kwargs.get("search_distance"): content["certainty"] = kwargs.get("search_distance") query_obj = self._client.query.get(collection_name, properties) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + where_filter = {"operator": "ContainsAny", "path": ["doc_id"], "valueTextArray": document_ids_filter} + query_obj = query_obj.with_where(where_filter) query_obj = query_obj.with_additional(["vector"]) properties = ["text"] result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py new file mode 100644 index 0000000000..bcd8ed6370 --- /dev/null +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class BuiltInField(str, Enum): + document_name = "document_name" + uploader = "uploader" + upload_date = "upload_date" + last_update_date = "last_update_date" + source = "source" diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index e1d36aad1f..1f1d291df8 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -239,6 +239,7 @@ class DatasetRetrieval: model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, message_id: Optional[str] = None, + metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, ): tools = [] for dataset in available_datasets: @@ -293,6 +294,11 @@ class DatasetRetrieval: document.metadata["dataset_name"] = dataset.name results.append(document) else: + document_ids_filter = None + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k @@ -324,6 +330,7 @@ class DatasetRetrieval: reranking_model=reranking_model, reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), weights=retrieval_model_config.get("weights", None), + document_ids_filter=document_ids_filter, ) self._on_query(query, [dataset_id], app_id, user_from, user_id) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2f82bf8c38..f52835e835 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -195,7 +195,7 @@ class CodeNode(BaseNode[CodeNodeData]): if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): - if isinstance(result.get(output_name), type(None)): + if result.get(output_name) is None: transformed_result[output_name] = None else: raise OutputValidationError( @@ -223,7 +223,7 @@ class CodeNode(BaseNode[CodeNodeData]): elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): - if isinstance(result[output_name], type(None)): + if result[output_name] is None: transformed_result[output_name] = None else: raise OutputValidationError( @@ -244,7 +244,7 @@ class CodeNode(BaseNode[CodeNodeData]): elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): - if isinstance(result[output_name], type(None)): + if result[output_name] is None: transformed_result[output_name] = None else: raise OutputValidationError( @@ -265,7 +265,7 @@ class CodeNode(BaseNode[CodeNodeData]): elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): - if isinstance(result[output_name], type(None)): + if result[output_name] is None: transformed_result[output_name] = None else: raise OutputValidationError( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 133af9c838..6f255c5165 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,8 +1,10 @@ +from collections.abc import Sequence from typing import Any, Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm.entities import VisionConfig class RerankingModelConfig(BaseModel): @@ -73,11 +75,44 @@ class SingleRetrievalConfig(BaseModel): model: ModelConfig +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "starts with", + "ends with", + "is", + "is not", + "empty", + "is not empty", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + # for time + "before", + "after", +] + + +class Condition(BaseModel): + """ + Conditon detail + """ + + metadata_name: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + + class MetadataFilteringCondition(BaseModel): """ Metadata Filtering Condition. """ - + logical_operator: Optional[Literal["and", "or"]] = "and" conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) @@ -93,5 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_filtering_conditions: Optional[dict[str, Any]] = None \ No newline at end of file + metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" + metadata_model_config: Optional[ModelConfig] = None + metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py index 0c3b6e86fa..6bcdc32790 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/exc.py +++ b/api/core/workflow/nodes/knowledge_retrieval/exc.py @@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError): class ModelQuotaExceededError(KnowledgeRetrievalNodeError): """Raised when the model provider quota is exceeded.""" + + +class InvalidModelTypeError(KnowledgeRetrievalNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5153abf0b0..31693d4834 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,6 +1,8 @@ +import json import logging +from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import func @@ -9,21 +11,38 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.knowledge_retrieval.template_prompts import ( + METADATA_FILTER_ASSISTANT_PROMPT_1, + METADATA_FILTER_ASSISTANT_PROMPT_2, + METADATA_FILTER_COMPLETION_PROMPT, + METADATA_FILTER_SYSTEM_PROMPT, + METADATA_FILTER_USER_PROMPT_1, + METADATA_FILTER_USER_PROMPT_3, +) +from core.workflow.nodes.list_operator.exc import InvalidConditionError +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 from extensions.ext_database import db -from models.dataset import Dataset, Document +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.dataset import Dataset, DatasetMetadata, Document from models.workflow import WorkflowNodeExecutionStatus from .entities import KnowledgeRetrievalNodeData from .exc import ( + InvalidModelTypeError, KnowledgeRetrievalNodeError, ModelCredentialsNotInitializedError, ModelNotExistError, @@ -42,13 +61,14 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): +class KnowledgeRetrievalNode(LLMNode): _node_data_cls = KnowledgeRetrievalNodeData _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self) -> NodeRunResult: + node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) if not isinstance(variable, StringSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -63,7 +83,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): ) # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results = self._fetch_dataset_retriever(node_data=node_data, query=query) outputs = {"result": results} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs @@ -95,8 +115,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) .filter( Document.indexing_status == "completed", - Document.enabled == True, # noqa: E712 - Document.archived == False, # noqa: E712 + Document.enabled == True, + Document.archived == False, Document.dataset_id.in_(dataset_ids), ) .group_by(Document.dataset_id) @@ -117,6 +137,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): if not dataset: continue available_datasets.append(dataset) + metadata_filter_document_ids = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) all_documents = [] dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: @@ -146,6 +169,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): model_config=model_config, model_instance=model_instance, planning_strategy=planning_strategy, + metadata_filter_document_ids=metadata_filter_document_ids, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: if node_data.multiple_retrieval_config is None: @@ -221,8 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, - Document.enabled == True, # noqa: E712 - Document.archived == False, # noqa: E712 + Document.enabled == True, + Document.archived == False, ).first() if dataset and document: source = { @@ -258,6 +282,134 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): item["metadata"]["position"] = position return retrieval_resource_list + def _get_metadata_filter_condition( + self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData + ) -> dict[str, list[str]]: + document_query = db.session.query(Document.id).filter( + Document.dataset_id.in_(dataset_ids), + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + if node_data.metadata_filtering_mode == "disabled": + return None + elif node_data.metadata_filtering_mode == "automatic": + automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) + if automatic_metadata_filters: + for filter in automatic_metadata_filters: + self._process_metadata_filter_func( + filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query + ) + elif node_data.metadata_filtering_mode == "manual": + for condition in node_data.metadata_filtering_conditions.conditions: + metadata_name = condition.metadata_name + expected_value = condition.value + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).text + self._process_metadata_filter_func( + condition.comparison_operator, metadata_name, expected_value, document_query + ) + else: + raise ValueError("Invalid metadata filtering mode") + documnents = document_query.all() + # group by dataset_id + metadata_filter_document_ids = defaultdict(list) + for document in documnents: + metadata_filter_document_ids[document.dataset_id].append(document.id) + return metadata_filter_document_ids + + def _automatic_metadata_filter_func( + self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData + ) -> list[dict[str, Any]]: + # get all metadata field + metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields] + # get metadata model config + metadata_model_config = node_data.metadata_model_config + if metadata_model_config is None: + raise ValueError("metadata_model_config is required") + # get metadata model instance + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) + # fetch prompt messages + prompt_template = self._get_prompt_template( + node_data=node_data, + query=query or "", + metadata_fields=all_metadata_fields, + ) + prompt_messages, stop = self._fetch_prompt_messages( + prompt_template=prompt_template, + sys_query=query, + memory=None, + model_config=model_config, + sys_files=[], + vision_enabled=node_data.vision.enabled, + vision_detail=node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + result_text = "" + try: + # handle invoke result + generator = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + for event in generator: + if isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + break + + result_text_json = parse_and_check_json_markdown(result_text, []) + automatic_metadata_filters = [] + if "metadata_map" in result_text_json: + metadata_map = result_text_json["metadata_map"] + for item in metadata_map: + if item.get("metadata_field_name") in all_metadata_fields: + automatic_metadata_filters.append( + { + "metadata_name": item.get("metadata_field_name"), + "value": item.get("metadata_field_value"), + "condition": item.get("comparison_operator"), + } + ) + except Exception as e: + return None + return automatic_metadata_filters + + def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query): + match condition: + case "contains": + query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%")) + case "not contains": + query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%")) + case "start with": + query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%")) + case "end with": + query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}")) + case "is", "=": + query = query.filter(Document.doc_metadata[metadata_name] == value) + case "is not", "≠": + query = query.filter(Document.doc_metadata[metadata_name] != value) + case "is empty": + query = query.filter(Document.doc_metadata[metadata_name].is_(None)) + case "is not empty": + query = query.filter(Document.doc_metadata[metadata_name].isnot(None)) + case "before", "<": + query = query.filter(Document.doc_metadata[metadata_name] < value) + case "after", ">": + query = query.filter(Document.doc_metadata[metadata_name] > value) + case "≤", ">=": + query = query.filter(Document.doc_metadata[metadata_name] <= value) + case "≥", ">=": + query = query.filter(Document.doc_metadata[metadata_name] >= value) + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -343,3 +495,94 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): parameters=completion_params, stop=stop, ) + + def _calculate_rest_token( + self, + node_data: KnowledgeRetrievalNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_template = self._get_prompt_template(node_data, query, None, 2000) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query="", + files=[], + context=context, + memory_config=node_data.memory, + memory=None, + model_config=model_config, + ) + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): + model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) + input_text = query + memory_str = "" + + prompt_messages: list[LLMNodeChatModelMessage] = [] + if model_mode == ModelMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( + role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT + ) + prompt_messages.append(system_prompt_messages) + user_prompt_message_1 = LLMNodeChatModelMessage( + role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1 + ) + prompt_messages.append(user_prompt_message_1) + assistant_prompt_message_1 = LLMNodeChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 + ) + prompt_messages.append(assistant_prompt_message_1) + user_prompt_message_2 = LLMNodeChatModelMessage( + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 + ) + prompt_messages.append(user_prompt_message_2) + assistant_prompt_message_2 = LLMNodeChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 + ) + prompt_messages.append(assistant_prompt_message_2) + user_prompt_message_3 = LLMNodeChatModelMessage( + role=PromptMessageRole.USER, + text=METADATA_FILTER_USER_PROMPT_3.format( + input_text=input_text, + metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), + ), + ) + prompt_messages.append(user_prompt_message_3) + return prompt_messages + elif model_mode == ModelMode.COMPLETION: + return LLMNodeCompletionModelPromptTemplate( + text=METADATA_FILTER_COMPLETION_PROMPT.format( + input_text=input_text, + metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), + ) + ) + + else: + raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index b96074dc0d..90da34e22a 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -53,6 +53,8 @@ external_knowledge_info_fields = { "external_knowledge_api_endpoint": fields.String, } +doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -76,6 +78,7 @@ dataset_detail_fields = { "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), + "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), } dataset_query_detail_fields = { diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index f2250d964a..e052b300d9 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -3,6 +3,13 @@ from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField +document_metadata_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "value": fields.String, +} + document_fields = { "id": fields.String, "position": fields.Integer, @@ -25,6 +32,7 @@ document_fields = { "word_count": fields.Integer, "hit_count": fields.Integer, "doc_form": fields.String, + "doc_metadata_details": fields.List(fields.Nested(document_metadata_fields)), } document_with_segments_fields = { @@ -51,6 +59,7 @@ document_with_segments_fields = { "hit_count": fields.Integer, "completed_segments": fields.Integer, "total_segments": fields.Integer, + "doc_metadata_details": fields.List(fields.Nested(document_metadata_fields)), } dataset_and_document_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index 7e712f5da8..a84cb87576 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -198,6 +198,19 @@ class Dataset(db.Model): # type: ignore[name-defined] "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property + def doc_metadata(self): + dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() + + return [ + { + "id": dataset_metadata.id, + "name": dataset_metadata.name, + "type": dataset_metadata.type, + } + for dataset_metadata in dataset_metadatas + ] + @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") @@ -251,6 +264,7 @@ class Document(db.Model): # type: ignore[name-defined] db.Index("document_dataset_id_idx", "dataset_id"), db.Index("document_is_paused_idx", "is_paused"), db.Index("document_tenant_idx", "tenant_id"), + db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), ) # initial fields @@ -307,7 +321,7 @@ class Document(db.Model): # type: ignore[name-defined] archived_at = db.Column(db.DateTime, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) - doc_metadata = db.Column(db.JSON, nullable=True) + doc_metadata = db.Column(JSONB, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) @@ -410,6 +424,28 @@ class Document(db.Model): # type: ignore[name-defined] def last_update_date(self): return self.updated_at + @property + def doc_metadata_details(self): + if self.doc_metadata: + document_metadatas = ( + db.session.query(DatasetMetadata) + .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) + .filter( + DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id + ) + .all() + ) + metadata_list = [] + for metadata in document_metadatas: + metadata_dict = { + "id": metadata.id, + "name": metadata.name, + "type": metadata.type, + "value": self.doc_metadata.get(metadata.type), + } + metadata_list.append(metadata_dict) + return metadata_list + return None def process_rule_dict(self): if self.dataset_process_rule_id: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2c38937594..d4d0df954a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -585,28 +585,43 @@ class DocumentService: @staticmethod def get_document_by_ids(document_ids: list[str]) -> list[Document]: - documents = db.session.query(Document).filter(Document.id.in_(document_ids), - Document.enabled == True, - Document.indexing_status == "completed", - Document.archived == False, - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.id.in_(document_ids), + Document.enabled == True, + Document.indexing_status == "completed", + Document.archived == False, + ) + .all() + ) return documents @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, - Document.enabled == True, - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset_id, + Document.enabled == True, + ) + .all() + ) return documents - + @staticmethod def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, - Document.enabled == True, - Document.indexing_status == "completed", - Document.archived == False, - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset_id, + Document.enabled == True, + Document.indexing_status == "completed", + Document.archived == False, + ) + .all() + ) return documents @@ -688,7 +703,7 @@ class DocumentService: if document.tenant_id != current_user.current_tenant_id: raise ValueError("No permission.") - + if dataset.built_in_field_enabled: if document.doc_metadata: document.doc_metadata[BuiltInField.document_name] = name @@ -1097,7 +1112,9 @@ class DocumentService: BuiltInField.document_name: name, BuiltInField.uploader: account.name, BuiltInField.upload_date: datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - BuiltInField.last_update_date: datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.last_update_date: datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ), BuiltInField.source: data_source_type, } if metadata is not None: diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index f23f578a15..7d0f545f9e 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -130,9 +130,30 @@ class MetadataArgs(BaseModel): type: Literal["string", "number", "time"] name: str -class MetadataValue(BaseModel): + +class MetadataUpdateArgs(BaseModel): name: str value: str + class MetadataValueUpdateArgs(BaseModel): - fields: list[MetadataValue] \ No newline at end of file + fields: list[MetadataUpdateArgs] + + +class MetadataDetail(BaseModel): + id: str + name: str + value: str + + +class DocumentMetadataOperation(BaseModel): + document_id: str + metadata_list: list[MetadataDetail] + + +class MetadataOperationData(BaseModel): + """ + Metadata operation data + """ + + operation_data: list[DocumentMetadataOperation] diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py new file mode 100644 index 0000000000..9877e09fdd --- /dev/null +++ b/api/services/metadata_service.py @@ -0,0 +1,182 @@ +import datetime +from typing import Optional + +from flask_login import current_user # type: ignore + +from core.rag.index_processor.constant.built_in_field import BuiltInField +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding +from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import ( + MetadataArgs, + MetadataOperationData, +) +from tasks.update_documents_metadata_task import update_documents_metadata_task + + +class MetadataService: + @staticmethod + def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + metadata = DatasetMetadata( + dataset_id=dataset_id, + type=metadata_args.type, + name=metadata_args.name, + created_by=current_user.id, + ) + db.session.add(metadata) + db.session.commit() + return metadata + + @staticmethod + def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: + lock_key = f"dataset_metadata_lock_{dataset_id}" + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + if metadata is None: + raise ValueError("Metadata not found.") + old_name = metadata.name + metadata.name = name + metadata.updated_by = current_user.id + metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + # update related documents + documents = [] + dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + if dataset_metadata_bindings: + document_ids = [binding.document_id for binding in dataset_metadata_bindings] + documents = DocumentService.get_document_by_ids(document_ids) + for document in documents: + document.doc_metadata[name] = document.doc_metadata.pop(old_name) + db.session.add(document) + db.session.commit() + if document_ids: + update_documents_metadata_task.delay(dataset_id, document_ids, lock_key) + return metadata + + @staticmethod + def delete_metadata(dataset_id: str, metadata_id: str): + lock_key = f"dataset_metadata_lock_{dataset_id}" + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + if metadata is None: + raise ValueError("Metadata not found.") + db.session.delete(metadata) + + # delete related documents + dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + if dataset_metadata_bindings: + document_ids = [binding.document_id for binding in dataset_metadata_bindings] + documents = DocumentService.get_document_by_ids(document_ids) + for document in documents: + document.doc_metadata.pop(metadata.name) + db.session.add(document) + db.session.commit() + if document_ids: + update_documents_metadata_task.delay(dataset_id, document_ids, lock_key) + + @staticmethod + def get_built_in_fields(): + return [ + {"name": BuiltInField.document_name, "type": "string"}, + {"name": BuiltInField.uploader, "type": "string"}, + {"name": BuiltInField.upload_date, "type": "date"}, + {"name": BuiltInField.last_update_date, "type": "date"}, + {"name": BuiltInField.source, "type": "string"}, + ] + + @staticmethod + def enable_built_in_field(dataset: Dataset): + if dataset.built_in_fields: + return + lock_key = f"dataset_metadata_lock_{dataset.id}" + MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) + dataset.built_in_fields = True + db.session.add(dataset) + documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) + document_ids = [] + if documents: + for document in documents: + document.doc_metadata[BuiltInField.document_name] = document.name + document.doc_metadata[BuiltInField.uploader] = document.uploader + document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S") + document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime( + "%Y-%m-%d %H:%M:%S" + ) + document.doc_metadata[BuiltInField.source] = document.data_source_type + db.session.add(document) + document_ids.append(document.id) + db.session.commit() + if document_ids: + update_documents_metadata_task.delay(dataset.id, document_ids, lock_key) + + @staticmethod + def disable_built_in_field(dataset: Dataset): + if not dataset.built_in_fields: + return + lock_key = f"dataset_metadata_lock_{dataset.id}" + MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) + dataset.built_in_fields = False + db.session.add(dataset) + documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) + document_ids = [] + if documents: + for document in documents: + document.doc_metadata.pop(BuiltInField.document_name) + document.doc_metadata.pop(BuiltInField.uploader) + document.doc_metadata.pop(BuiltInField.upload_date) + document.doc_metadata.pop(BuiltInField.last_update_date) + document.doc_metadata.pop(BuiltInField.source) + db.session.add(document) + document_ids.append(document.id) + db.session.commit() + if document_ids: + update_documents_metadata_task.delay(dataset.id, document_ids, lock_key) + + @staticmethod + def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData): + for operation in metadata_args.operation_data: + lock_key = f"document_metadata_lock_{operation.document_id}" + MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id) + document = DocumentService.get_document(operation.document_id) + if document is None: + raise ValueError("Document not found.") + document.doc_metadata = {} + for metadata_value in metadata_args.fields: + document.doc_metadata[metadata_value.name] = metadata_value.value + if dataset.built_in_fields: + document.doc_metadata[BuiltInField.document_name] = document.name + document.doc_metadata[BuiltInField.uploader] = document.uploader + document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S") + document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime( + "%Y-%m-%d %H:%M:%S" + ) + document.doc_metadata[BuiltInField.source] = document.data_source_type + # deal metadata bindding + DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() + for metadata_value in operation.metadata_list: + dataset_metadata_binding = DatasetMetadataBinding( + tenant_id=current_user.tenant_id, + dataset_id=dataset.id, + document_id=operation.document_id, + metadata_id=metadata_value.id, + created_by=current_user.id, + ) + db.session.add(dataset_metadata_binding) + db.session.add(document) + db.session.commit() + + update_documents_metadata_task.delay(dataset.id, [document.id], lock_key) + + @staticmethod + def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): + if dataset_id: + lock_key = f"dataset_metadata_lock_{dataset_id}" + if redis_client.get(lock_key): + raise ValueError("Another knowledge base metadata operation is running, please wait a moment.") + redis_client.set(lock_key, 1, ex=3600) + if document_id: + lock_key = f"document_metadata_lock_{document_id}" + if redis_client.get(lock_key): + raise ValueError("Another document metadata operation is running, please wait a moment.") + redis_client.set(lock_key, 1, ex=3600) diff --git a/api/tasks/update_documents_metadata_task.py b/api/tasks/update_documents_metadata_task.py new file mode 100644 index 0000000000..6f1bbf6a8b --- /dev/null +++ b/api/tasks/update_documents_metadata_task.py @@ -0,0 +1,121 @@ +import logging +import time +from typing import Optional + +import click +from celery import shared_task # type: ignoreq + +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import ( + Document as DatasetDocument, +) +from models.dataset import ( + DocumentSegment, +) +from services.dataset_service import DatasetService + + +@shared_task(queue="dataset") +def update_documents_metadata_task( + dataset_id: str, + document_ids: list[str], + lock_key: Optional[str] = None, +): + """ + Update documents metadata. + :param dataset_id: dataset id + :param document_ids: document ids + + Usage: update_documents_metadata_task.delay(dataset_id, document_ids) + """ + logging.info(click.style("Start update documents metadata: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise ValueError("Dataset not found.") + documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.id.in_(document_ids), + DatasetDocument.enabled == True, + DatasetDocument.indexing_status == "completed", + DatasetDocument.archived == False, + ) + .all() + ) + if not documents: + raise ValueError("Documents not found.") + for dataset_document in documents: + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + continue + # delete all documents in vector index + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=False, delete_child_chunks=True) + # update documents metadata + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": dataset_document.id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": dataset_document.id, + "dataset_id": dataset_id, + }, + ) + if dataset.built_in_field_enabled: + child_document.metadata[BuiltInField.uploader] = dataset_document.created_by + child_document.metadata[BuiltInField.upload_date] = dataset_document.created_at + child_document.metadata[BuiltInField.last_update_date] = dataset_document.updated_at + child_document.metadata[BuiltInField.source] = dataset_document.data_source_type + child_document.metadata[BuiltInField.original_filename] = dataset_document.name + if dataset_document.doc_metadata: + child_document.metadata.update(dataset_document.doc_metadata) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) # noqa: B909 + # save vector index + index_processor.load(dataset, documents) + end_at = time.perf_counter() + logging.info( + click.style("Updated documents metadata: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Updated documents metadata failed") + finally: + if lock_key: + redis_client.delete(lock_key)