Merge branch 'feat/support-knowledge-metadata' into deploy/dev

This commit is contained in:
jyong 2025-03-10 15:39:03 +08:00
commit 7217e31a7b
2 changed files with 13 additions and 4 deletions

View File

@ -395,6 +395,7 @@ class DatasetRetrieval:
weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
):
if not available_datasets:
return []
@ -434,6 +435,11 @@ class DatasetRetrieval:
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
@ -442,6 +448,7 @@ class DatasetRetrieval:
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
},
)
threads.append(retrieval_thread)
@ -537,7 +544,8 @@ class DatasetRetrieval:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list,
document_ids_filter: Optional[list[str]] = None):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
@ -590,6 +598,7 @@ class DatasetRetrieval:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
all_documents.extend(documents)
@ -834,7 +843,7 @@ class DatasetRetrieval:
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> list[dict[str, Any]]:
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
@ -847,7 +856,6 @@ class DatasetRetrieval:
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
model_instance=model_instance,
model_config=model_config,
mode=metadata_model_config.mode,
metadata_fields=all_metadata_fields,
@ -888,7 +896,7 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, query):
match condition:
case "contains":
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))

View File

@ -214,6 +214,7 @@ class KnowledgeRetrievalNode(LLMNode):
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]