Merge branch 'feat/support-knowledge-metadata' into deploy/dev
This commit is contained in:
commit
7217e31a7b
@ -395,6 +395,7 @@ class DatasetRetrieval:
|
||||
weights: Optional[dict[str, Any]] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
@ -434,6 +435,11 @@ class DatasetRetrieval:
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
@ -442,6 +448,7 @@ class DatasetRetrieval:
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
@ -537,7 +544,8 @@ class DatasetRetrieval:
|
||||
db.session.add_all(dataset_queries)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list,
|
||||
document_ids_filter: Optional[list[str]] = None):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@ -590,6 +598,7 @@ class DatasetRetrieval:
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@ -834,7 +843,7 @@ class DatasetRetrieval:
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
@ -847,7 +856,6 @@ class DatasetRetrieval:
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
mode=metadata_model_config.mode,
|
||||
metadata_fields=all_metadata_fields,
|
||||
@ -888,7 +896,7 @@ class DatasetRetrieval:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
|
||||
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: str, query):
|
||||
match condition:
|
||||
case "contains":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))
|
||||
|
@ -214,6 +214,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
)
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
|
Loading…
Reference in New Issue
Block a user