Compare commits
56 Commits
main
...
dev/plugin
Author | SHA1 | Date | |
---|---|---|---|
![]() |
086aeea181 | ||
![]() |
1d7c4a87d0 | ||
![]() |
9042b368e9 | ||
![]() |
f1bcd26c69 | ||
![]() |
3dcd8b6330 | ||
![]() |
10c088029c | ||
![]() |
73b1adf862 | ||
![]() |
ae76dbd92c | ||
![]() |
782df0c383 | ||
![]() |
089207240e | ||
![]() |
53d30d537f | ||
![]() |
53512a4650 | ||
![]() |
1fb7dcda24 | ||
![]() |
3c3e0a35f4 | ||
![]() |
202a246e83 | ||
![]() |
08b968eca5 | ||
![]() |
b1ac71db3e | ||
![]() |
7710d8e83b | ||
![]() |
cf75fcdffc | ||
![]() |
6e8601b52c | ||
![]() |
96cf0ed5af | ||
![]() |
46a798bea8 | ||
![]() |
9e258c495d | ||
![]() |
c53786d229 | ||
![]() |
17f23f4798 | ||
![]() |
67f2c766bc | ||
![]() |
5f995fac32 | ||
![]() |
f88f9d6970 | ||
![]() |
d2cc502c71 | ||
![]() |
b88194d1c6 | ||
![]() |
2b95e54d54 | ||
![]() |
9bff9b5c9e | ||
![]() |
3dd2c170e7 | ||
![]() |
88f41f164f | ||
![]() |
cd932519b3 | ||
![]() |
2ff2b08739 | ||
![]() |
a4a45421cc | ||
![]() |
aafab1b59e | ||
![]() |
7f49f96c3f | ||
![]() |
5673f03db5 | ||
![]() |
278adbc10e | ||
![]() |
5d4e517397 | ||
![]() |
c2671c16a8 | ||
![]() |
10991cbc03 | ||
![]() |
3fcf7e88b0 | ||
![]() |
ffa5af1356 | ||
![]() |
066516b54d | ||
![]() |
49415e5e7f | ||
![]() |
a697bbdfa7 | ||
![]() |
d5c31f8728 | ||
![]() |
508005b741 | ||
![]() |
4f0ecdbb6e | ||
![]() |
ab2e69faef | ||
![]() |
e46a3343b8 | ||
![]() |
47637da734 | ||
![]() |
525bde28f6 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "dev/plugin-deploy"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
|
@ -617,7 +617,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
@ -678,7 +678,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
|
143
api/controllers/console/datasets/metadata.py
Normal file
143
api/controllers/console/datasets/metadata.py
Normal file
@ -0,0 +1,143 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
|
||||
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return built_in_fields, 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return 200
|
||||
|
||||
|
||||
class DocumentMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
@ -180,7 +180,7 @@ class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
|
@ -111,12 +111,6 @@ class ProviderManager:
|
||||
|
||||
# Get all provider model records of the workspace
|
||||
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
|
||||
for provider_name in list(provider_name_to_provider_model_records_dict.keys()):
|
||||
provider_id = ModelProviderID(provider_name)
|
||||
if str(provider_id) not in provider_name_to_provider_model_records_dict:
|
||||
provider_name_to_provider_model_records_dict[str(provider_id)] = (
|
||||
provider_name_to_provider_model_records_dict[provider_name]
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
@ -42,6 +42,7 @@ class RetrievalService:
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
@ -65,6 +66,7 @@ class RetrievalService:
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
@ -80,6 +82,7 @@ class RetrievalService:
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
@ -131,7 +134,14 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -140,7 +150,10 @@ class RetrievalService:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
keyword = Keyword(dataset=dataset)
|
||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
||||
|
||||
documents = keyword.search(
|
||||
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
@ -157,6 +170,7 @@ class RetrievalService:
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -171,6 +185,7 @@ class RetrievalService:
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
if documents:
|
||||
|
@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
@ -194,6 +194,11 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
@ -202,7 +207,7 @@ class AnalyticdbVectorBySql:
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
@ -220,12 +225,17 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
|
@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
|
@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}},
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
|
@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
|
||||
try:
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
filters = kwargs.get("filter", [])
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
|
@ -218,12 +218,18 @@ class MilvusVector(BaseVector):
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
@ -239,6 +245,11 @@ class MilvusVector(BaseVector):
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
@ -246,6 +257,7 @@ class MilvusVector(BaseVector):
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
|
@ -131,6 +131,10 @@ class MyScaleVector(BaseVector):
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
|
@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = None
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=where_clause,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
|
@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
|
@ -185,10 +185,15 @@ class OracleVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance fetch first {top_k} rows only",
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
@ -241,9 +246,15 @@ class OracleVector(BaseVector):
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
|
@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
|
||||
.limit(kwargs.get("top_k", 4))
|
||||
.order_by("distance")
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||
res = session.execute(stmt)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
|
||||
|
@ -155,10 +155,16 @@ class PGVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" {where_clause}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
@ -176,10 +182,16 @@ class PGVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
|
@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@ -331,6 +330,14 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -377,6 +384,14 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
@ -223,8 +223,12 @@ class RelytVector(BaseVector):
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = kwargs.get("filter", {})
|
||||
if document_ids_filter:
|
||||
filter["document_id"] = document_ids_filter
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
else f"metadata->>'{key!r}' = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
@ -145,11 +145,16 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
filter=filter,
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
|
@ -326,6 +326,14 @@ class TidbOnQdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@ -368,6 +376,14 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(f"""
|
||||
@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
FROM {self._collection_name}
|
||||
{where_clause}
|
||||
ORDER BY distance ASC
|
||||
LIMIT :top_k
|
||||
) t
|
||||
|
@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f"document_id in ({document_ids})"
|
||||
else:
|
||||
filter = ""
|
||||
result = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
include_metadata=True,
|
||||
include_data=True,
|
||||
include_vectors=False,
|
||||
filter=filter,
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
|
@ -49,6 +49,10 @@ class BaseVector(ABC):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, document_id: str, metadata: dict) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata and "doc_id" in text.metadata:
|
||||
|
@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
|
||||
query_vector, limit=kwargs.get("top_k", 4)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
docs = self._get_search_res(results, score_threshold)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||
return docs
|
||||
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
|
@ -168,16 +168,16 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
for uuid in ids:
|
||||
try:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
try:
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where={"operator": "ContainsAny", "path": ["id"], "valueTextArray": ids},
|
||||
output="minimal",
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||
|
9
api/core/rag/index_processor/constant/built_in_field.py
Normal file
9
api/core/rag/index_processor/constant/built_in_field.py
Normal file
@ -0,0 +1,9 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BuiltInField(str, Enum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
@ -203,7 +203,6 @@ class DatasetRetrieval:
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
@ -238,6 +237,7 @@ class DatasetRetrieval:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@ -292,6 +292,11 @@ class DatasetRetrieval:
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
@ -323,6 +328,7 @@ class DatasetRetrieval:
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
|
@ -185,7 +185,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
|
@ -123,7 +123,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
@ -172,7 +172,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"doc_metadata": document.doc_metadat, # type: ignore
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
@ -8,12 +8,12 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
@ -156,38 +156,16 @@ class AgentNode(ToolNode):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
tool_parameters=tool.get("parameters", {}),
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
)
|
||||
|
||||
@ -200,26 +178,11 @@ class AgentNode(ToolNode):
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"provider_type": provider_type.value,
|
||||
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
|
@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -17,8 +16,3 @@ class AgentNodeData(BaseNodeData):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
@ -1,8 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import VisionConfig
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"starts with",
|
||||
"ends with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"is not empty",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
"""
|
||||
|
||||
metadata_name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
@ -9,21 +11,38 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidConditionError
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
@ -42,13 +61,14 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
class KnowledgeRetrievalNode(LLMNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -63,7 +83,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
)
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
@ -117,6 +137,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
metadata_filter_document_ids = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
@ -146,6 +169,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
@ -240,7 +264,6 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
}
|
||||
@ -259,6 +282,134 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> dict[str, list[str]]:
|
||||
document_query = db.session.query(Document.id).filter(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
||||
if automatic_metadata_filters:
|
||||
for filter in automatic_metadata_filters:
|
||||
self._process_metadata_filter_func(
|
||||
filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
for condition in node_data.metadata_filtering_conditions.conditions:
|
||||
metadata_name = condition.metadata_name
|
||||
expected_value = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).text
|
||||
self._process_metadata_filter_func(
|
||||
condition.comparison_operator, metadata_name, expected_value, document_query
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
documnents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list)
|
||||
for document in documnents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id)
|
||||
return metadata_filter_document_ids
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
metadata_model_config = node_data.metadata_model_config
|
||||
if metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
metadata_fields=all_metadata_fields,
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
|
||||
match condition:
|
||||
case "contains":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))
|
||||
case "not contains":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%"))
|
||||
case "start with":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%"))
|
||||
case "end with":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}"))
|
||||
case "is", "=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] == value)
|
||||
case "is not", "≠":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] != value)
|
||||
case "is empty":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].is_(None))
|
||||
case "is not empty":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].isnot(None))
|
||||
case "before", "<":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] < value)
|
||||
case "after", ">":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] > value)
|
||||
case "≤", ">=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] <= value)
|
||||
case "≥", ">=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] >= value)
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@ -344,3 +495,94 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
|
@ -459,7 +459,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
"doc_metadata": metadata.get("doc_metadata"),
|
||||
}
|
||||
|
||||
return source
|
||||
|
@ -53,6 +53,8 @@ external_knowledge_info_fields = {
|
||||
"external_knowledge_api_endpoint": fields.String,
|
||||
}
|
||||
|
||||
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
dataset_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
@ -76,6 +78,8 @@ dataset_detail_fields = {
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
||||
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
|
||||
"built_in_field_enabled": fields.Boolean,
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
@ -87,3 +91,9 @@ dataset_query_detail_fields = {
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
dataset_metadata_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
@ -3,6 +3,13 @@ from flask_restful import fields # type: ignore
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
document_metadata_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
document_fields = {
|
||||
"id": fields.String,
|
||||
"position": fields.Integer,
|
||||
@ -25,6 +32,7 @@ document_fields = {
|
||||
"word_count": fields.Integer,
|
||||
"hit_count": fields.Integer,
|
||||
"doc_form": fields.String,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
@ -51,6 +59,7 @@ document_with_segments_fields = {
|
||||
"hit_count": fields.Integer,
|
||||
"completed_segments": fields.Integer,
|
||||
"total_segments": fields.Integer,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
|
@ -7,7 +7,6 @@ document_fields = {
|
||||
"data_source_type": fields.String,
|
||||
"name": fields.String,
|
||||
"doc_type": fields.String,
|
||||
"doc_metadata": fields.Raw,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
|
@ -0,0 +1,90 @@
|
||||
"""add_metadata_function
|
||||
|
||||
Revision ID: d20049ed0af6
|
||||
Revises: 08ec4f75af5e
|
||||
Create Date: 2025-02-27 09:17:48.903213
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd20049ed0af6'
|
||||
down_revision = '08ec4f75af5e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('dataset_metadata_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('dataset_metadatas',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
|
||||
)
|
||||
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSONB(astext_type=sa.Text()),
|
||||
type_=postgresql.JSON(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('built_in_field_enabled')
|
||||
|
||||
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
|
||||
batch_op.drop_index('dataset_metadata_tenant_idx')
|
||||
batch_op.drop_index('dataset_metadata_dataset_idx')
|
||||
|
||||
op.drop_table('dataset_metadatas')
|
||||
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('dataset_metadata_binding_tenant_idx')
|
||||
batch_op.drop_index('dataset_metadata_binding_metadata_idx')
|
||||
batch_op.drop_index('dataset_metadata_binding_document_idx')
|
||||
batch_op.drop_index('dataset_metadata_binding_dataset_idx')
|
||||
|
||||
op.drop_table('dataset_metadata_bindings')
|
||||
# ### end Alembic commands ###
|
@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_storage import storage
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
@ -60,6 +61,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||
retrieval_model = db.Column(JSONB, nullable=True)
|
||||
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
@ -197,6 +199,19 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
||||
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
||||
}
|
||||
|
||||
@property
|
||||
def doc_metadata(self):
|
||||
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": dataset_metadata.id,
|
||||
"name": dataset_metadata.name,
|
||||
"type": dataset_metadata.type,
|
||||
}
|
||||
for dataset_metadata in dataset_metadatas
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||
@ -250,6 +265,7 @@ class Document(db.Model): # type: ignore[name-defined]
|
||||
db.Index("document_dataset_id_idx", "dataset_id"),
|
||||
db.Index("document_is_paused_idx", "is_paused"),
|
||||
db.Index("document_tenant_idx", "tenant_id"),
|
||||
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
@ -306,7 +322,7 @@ class Document(db.Model): # type: ignore[name-defined]
|
||||
archived_at = db.Column(db.DateTime, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = db.Column(db.String(40), nullable=True)
|
||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
||||
doc_metadata = db.Column(JSONB, nullable=True)
|
||||
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
doc_language = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@ -397,6 +413,78 @@ class Document(db.Model): # type: ignore[name-defined]
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
def upload_date(self):
|
||||
return self.created_at
|
||||
|
||||
@property
|
||||
def last_update_date(self):
|
||||
return self.updated_at
|
||||
|
||||
@property
|
||||
def doc_metadata_details(self):
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.filter(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
metadata_list = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict = {
|
||||
"id": metadata.id,
|
||||
"name": metadata.name,
|
||||
"type": metadata.type,
|
||||
"value": self.doc_metadata.get(metadata.type),
|
||||
}
|
||||
metadata_list.append(metadata_dict)
|
||||
# deal built-in fields
|
||||
metadata_list.extend(self.get_built_in_fields())
|
||||
|
||||
return metadata_list
|
||||
return None
|
||||
|
||||
def get_built_in_fields(self):
|
||||
built_in_fields = []
|
||||
built_in_fields.append({
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.document_name,
|
||||
"type": "string",
|
||||
"value": self.name,
|
||||
})
|
||||
built_in_fields.append({
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.uploader,
|
||||
"type": "string",
|
||||
"value": self.uploader,
|
||||
})
|
||||
built_in_fields.append({
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.upload_date,
|
||||
"type": "date",
|
||||
"value": self.created_at,
|
||||
})
|
||||
built_in_fields.append({
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.last_update_date,
|
||||
"type": "date",
|
||||
"value": self.updated_at,
|
||||
})
|
||||
built_in_fields.append({
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.source,
|
||||
"type": "string",
|
||||
"value": self.data_source_info,
|
||||
})
|
||||
return built_in_fields
|
||||
|
||||
def process_rule_dict(self):
|
||||
if self.dataset_process_rule_id:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
@ -930,3 +1018,41 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class DatasetMetadata(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_metadatas"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
|
||||
|
||||
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_metadata_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||
db.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
metadata_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
|
2067
api/poetry.lock
generated
2067
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,6 +16,7 @@ from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -599,9 +600,45 @@ class DocumentService:
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@ -684,7 +721,11 @@ class DocumentService:
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("No permission.")
|
||||
|
||||
document.name = name
|
||||
if dataset.built_in_field_enabled:
|
||||
if document.doc_metadata:
|
||||
document.doc_metadata[BuiltInField.document_name] = name
|
||||
else:
|
||||
document.name = name
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -1084,9 +1125,22 @@ class DocumentService:
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
BuiltInField.source: data_source_type,
|
||||
}
|
||||
if metadata is not None:
|
||||
document.doc_metadata = metadata.doc_metadata
|
||||
doc_metadata.update(metadata.doc_metadata)
|
||||
document.doc_type = metadata.doc_type
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
|
@ -124,3 +124,36 @@ class SegmentUpdateArgs(BaseModel):
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
id: Optional[str] = None
|
||||
content: str
|
||||
|
||||
|
||||
class MetadataArgs(BaseModel):
|
||||
type: Literal["string", "number", "time"]
|
||||
name: str
|
||||
|
||||
|
||||
class MetadataUpdateArgs(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
class MetadataValueUpdateArgs(BaseModel):
|
||||
fields: list[MetadataUpdateArgs]
|
||||
|
||||
|
||||
class MetadataDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
class DocumentMetadataOperation(BaseModel):
|
||||
document_id: str
|
||||
metadata_list: list[MetadataDetail]
|
||||
|
||||
|
||||
class MetadataOperationData(BaseModel):
|
||||
"""
|
||||
Metadata operation data
|
||||
"""
|
||||
|
||||
operation_data: list[DocumentMetadataOperation]
|
||||
|
182
api/services/metadata_service.py
Normal file
182
api/services/metadata_service.py
Normal file
@ -0,0 +1,182 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from tasks.update_documents_metadata_task import update_documents_metadata_task
|
||||
|
||||
|
||||
class MetadataService:
|
||||
@staticmethod
|
||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||
metadata = DatasetMetadata(
|
||||
dataset_id=dataset_id,
|
||||
type=metadata_args.type,
|
||||
name=metadata_args.name,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
metadata.name = name
|
||||
metadata.updated_by = current_user.id
|
||||
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
# update related documents
|
||||
documents = []
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
document.doc_metadata[name] = document.doc_metadata.pop(old_name)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
if document_ids:
|
||||
update_documents_metadata_task.delay(dataset_id, document_ids, lock_key)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def delete_metadata(dataset_id: str, metadata_id: str):
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# delete related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
document.doc_metadata.pop(metadata.name)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
if document_ids:
|
||||
update_documents_metadata_task.delay(dataset_id, document_ids, lock_key)
|
||||
|
||||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name, "type": "string"},
|
||||
{"name": BuiltInField.uploader, "type": "string"},
|
||||
{"name": BuiltInField.upload_date, "type": "date"},
|
||||
{"name": BuiltInField.last_update_date, "type": "date"},
|
||||
{"name": BuiltInField.source, "type": "string"},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def enable_built_in_field(dataset: Dataset):
|
||||
if dataset.built_in_fields:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_fields = True
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
for document in documents:
|
||||
document.doc_metadata[BuiltInField.document_name] = document.name
|
||||
document.doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
document.doc_metadata[BuiltInField.source] = document.data_source_type
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
db.session.commit()
|
||||
if document_ids:
|
||||
update_documents_metadata_task.delay(dataset.id, document_ids, lock_key)
|
||||
|
||||
@staticmethod
|
||||
def disable_built_in_field(dataset: Dataset):
|
||||
if not dataset.built_in_fields:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_fields = False
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
for document in documents:
|
||||
document.doc_metadata.pop(BuiltInField.document_name)
|
||||
document.doc_metadata.pop(BuiltInField.uploader)
|
||||
document.doc_metadata.pop(BuiltInField.upload_date)
|
||||
document.doc_metadata.pop(BuiltInField.last_update_date)
|
||||
document.doc_metadata.pop(BuiltInField.source)
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
db.session.commit()
|
||||
if document_ids:
|
||||
update_documents_metadata_task.delay(dataset.id, document_ids, lock_key)
|
||||
|
||||
@staticmethod
|
||||
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
|
||||
for operation in metadata_args.operation_data:
|
||||
lock_key = f"document_metadata_lock_{operation.document_id}"
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
|
||||
document = DocumentService.get_document(operation.document_id)
|
||||
if document is None:
|
||||
raise ValueError("Document not found.")
|
||||
document.doc_metadata = {}
|
||||
for metadata_value in metadata_args.fields:
|
||||
document.doc_metadata[metadata_value.name] = metadata_value.value
|
||||
if dataset.built_in_fields:
|
||||
document.doc_metadata[BuiltInField.document_name] = document.name
|
||||
document.doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
document.doc_metadata[BuiltInField.source] = document.data_source_type
|
||||
# deal metadata bindding
|
||||
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
||||
for metadata_value in operation.metadata_list:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=current_user.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=operation.document_id,
|
||||
metadata_id=metadata_value.id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
update_documents_metadata_task.delay(dataset.id, [document.id], lock_key)
|
||||
|
||||
@staticmethod
|
||||
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
|
||||
if dataset_id:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
if redis_client.get(lock_key):
|
||||
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
|
||||
redis_client.set(lock_key, 1, ex=3600)
|
||||
if document_id:
|
||||
lock_key = f"document_metadata_lock_{document_id}"
|
||||
if redis_client.get(lock_key):
|
||||
raise ValueError("Another document metadata operation is running, please wait a moment.")
|
||||
redis_client.set(lock_key, 1, ex=3600)
|
121
api/tasks/update_documents_metadata_task.py
Normal file
121
api/tasks/update_documents_metadata_task.py
Normal file
@ -0,0 +1,121 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignoreq
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import (
|
||||
Document as DatasetDocument,
|
||||
)
|
||||
from models.dataset import (
|
||||
DocumentSegment,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def update_documents_metadata_task(
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
lock_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Update documents metadata.
|
||||
:param dataset_id: dataset id
|
||||
:param document_ids: document ids
|
||||
|
||||
Usage: update_documents_metadata_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
logging.info(click.style("Start update documents metadata: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError("Dataset not found.")
|
||||
documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.id.in_(document_ids),
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not documents:
|
||||
raise ValueError("Documents not found.")
|
||||
for dataset_document in documents:
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not segments:
|
||||
continue
|
||||
# delete all documents in vector index
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=False, delete_child_chunks=True)
|
||||
# update documents metadata
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": dataset_document.id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.child_chunks
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": dataset_document.id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset.built_in_field_enabled:
|
||||
child_document.metadata[BuiltInField.uploader] = dataset_document.created_by
|
||||
child_document.metadata[BuiltInField.upload_date] = dataset_document.created_at
|
||||
child_document.metadata[BuiltInField.last_update_date] = dataset_document.updated_at
|
||||
child_document.metadata[BuiltInField.source] = dataset_document.data_source_type
|
||||
child_document.metadata[BuiltInField.original_filename] = dataset_document.name
|
||||
if dataset_document.doc_metadata:
|
||||
child_document.metadata.update(dataset_document.doc_metadata)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document) # noqa: B909
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Updated documents metadata: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Updated documents metadata failed")
|
||||
finally:
|
||||
if lock_key:
|
||||
redis_client.delete(lock_key)
|
@ -29,7 +29,7 @@ server {
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /e/ {
|
||||
location /e {
|
||||
proxy_pass http://plugin_daemon:5002;
|
||||
proxy_set_header Dify-Hook-Url $scheme://$host$request_uri;
|
||||
include proxy.conf;
|
||||
|
@ -94,7 +94,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
},
|
||||
]
|
||||
return navs
|
||||
}, [])
|
||||
}, [t])
|
||||
|
||||
useEffect(() => {
|
||||
if (appDetail) {
|
||||
@ -120,7 +120,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
}).finally(() => {
|
||||
setIsLoadingAppDetail(false)
|
||||
})
|
||||
}, [appId, pathname])
|
||||
}, [appId, router, setAppDetail])
|
||||
|
||||
useEffect(() => {
|
||||
if (!appDetailRes || isLoadingCurrentWorkspace || isLoadingAppDetail)
|
||||
@ -148,7 +148,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [appDetailRes, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, systemFeatures.enable_web_sso_switch_component])
|
||||
}, [appDetailRes, appId, getNavigations, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, router, setAppDetail, systemFeatures.enable_web_sso_switch_component])
|
||||
|
||||
useUnmount(() => {
|
||||
setAppDetail()
|
||||
|
@ -3,13 +3,13 @@
|
||||
import { useSelectedLayoutSegment } from 'next/navigation'
|
||||
import Link from 'next/link'
|
||||
import classNames from '@/utils/classnames'
|
||||
import type { RemixiconComponentType } from '@remixicon/react'
|
||||
|
||||
export type NavIcon = React.ComponentType<
|
||||
React.PropsWithoutRef<React.ComponentProps<'svg'>> & {
|
||||
title?: string | undefined
|
||||
titleId?: string | undefined
|
||||
}> | RemixiconComponentType
|
||||
}
|
||||
>
|
||||
|
||||
export type NavLinkProps = {
|
||||
name: string
|
||||
|
@ -94,7 +94,7 @@ const Configuration: FC = () => {
|
||||
})))
|
||||
const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
|
||||
|
||||
const latestPublishedAt = useMemo(() => appDetail?.model_config?.updated_at, [appDetail])
|
||||
const latestPublishedAt = useMemo(() => appDetail?.model_config.updated_at, [appDetail])
|
||||
const [formattingChanged, setFormattingChanged] = useState(false)
|
||||
const { setShowAccountSettingModal } = useModalContext()
|
||||
const [hasFetchedDetail, setHasFetchedDetail] = useState(false)
|
||||
|
@ -128,7 +128,7 @@ const Apps = ({
|
||||
icon_background,
|
||||
description,
|
||||
}) => {
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
const { export_data } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
try {
|
||||
@ -151,7 +151,7 @@ const Apps = ({
|
||||
if (app.app_id)
|
||||
await handleCheckPluginDependencies(app.app_id)
|
||||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id, mode }, push)
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
|
||||
|
@ -211,9 +211,7 @@ const Paragraph = (paragraph: any) => {
|
||||
return (
|
||||
<>
|
||||
<ImageGallery srcs={[children_node[0].properties.src]} />
|
||||
{
|
||||
Array.isArray(paragraph.children) ? <p>{paragraph.children.slice(1)}</p> : null
|
||||
}
|
||||
<p>{paragraph.children.slice(1)}</p>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ const Toast = ({
|
||||
}`}
|
||||
/>
|
||||
<div className={`flex ${size === 'md' ? 'gap-1' : 'gap-0.5'}`}>
|
||||
<div className={`flex justify-center items-center ${size === 'md' ? 'p-0.5' : 'p-1'}`}>
|
||||
<div className={`flex justify-center items-start ${size === 'md' ? 'p-0.5' : 'p-1'}`}>
|
||||
{type === 'success' && <RiCheckboxCircleFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-success`} aria-hidden="true" />}
|
||||
{type === 'error' && <RiErrorWarningFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-destructive`} aria-hidden="true" />}
|
||||
{type === 'warning' && <RiAlertFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-warning-secondary`} aria-hidden="true" />}
|
||||
|
@ -126,7 +126,7 @@ const Apps = ({
|
||||
icon_background,
|
||||
description,
|
||||
}) => {
|
||||
const { export_data, mode } = await fetchAppDetail(
|
||||
const { export_data } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
try {
|
||||
@ -149,7 +149,7 @@ const Apps = ({
|
||||
if (app.app_id)
|
||||
await handleCheckPluginDependencies(app.app_id)
|
||||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id, mode }, push)
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
|
||||
|
24
web/app/components/offline-notice.tsx
Normal file
24
web/app/components/offline-notice.tsx
Normal file
@ -0,0 +1,24 @@
|
||||
'use client'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import type { PropsWithChildren } from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export default function OfflineNotice({ children }: PropsWithChildren) {
|
||||
const { t } = useTranslation()
|
||||
const [showOfflineNotice, { setFalse }] = useBoolean(true)
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(setFalse, 60000)
|
||||
return () => clearTimeout(timer)
|
||||
}, [setFalse])
|
||||
return <>
|
||||
{showOfflineNotice && <div className='px-4 py-2 flex items-center justify-start gap-x-2 bg-[#FFFAEB] border-b-[0.5px] border-b-[#FEF0C7]'>
|
||||
<div className='rounded-[12px] flex items-center justify-center px-2 py-0.5 h-[22px] bg-[#f79009] text-white text-[11px] not-italic font-medium leading[18px]'>{t('common.offlineNoticeTitle')}</div>
|
||||
<div className='grow font-medium leading-[18px] text-[12px] not-italic text-[#344054]'>{t('common.offlineNotice')}</div>
|
||||
<RiCloseLine className='size-4 text-[#667085] cursor-pointer' onClick={setFalse} />
|
||||
</div>}
|
||||
{children}
|
||||
</>
|
||||
}
|
@ -7,6 +7,7 @@ import { TanstackQueryIniter } from '@/context/query-client'
|
||||
import { ThemeProvider } from 'next-themes'
|
||||
import './styles/globals.css'
|
||||
import './styles/markdown.scss'
|
||||
import OfflineNotice from './components/offline-notice'
|
||||
|
||||
export const metadata = {
|
||||
title: 'Dify',
|
||||
@ -61,7 +62,9 @@ const LocaleLayout = ({
|
||||
disableTransitionOnChange
|
||||
>
|
||||
<I18nServer>
|
||||
{children}
|
||||
<OfflineNotice>
|
||||
{children}
|
||||
</OfflineNotice>
|
||||
</I18nServer>
|
||||
</ThemeProvider>
|
||||
</TanstackQueryIniter>
|
||||
|
@ -120,7 +120,7 @@ export const ProviderContextProvider = ({
|
||||
if (localStorage.getItem('anthropic_quota_notice') === 'true')
|
||||
return
|
||||
|
||||
if (dayjs().isAfter(dayjs('2025-03-17')))
|
||||
if (dayjs().isAfter(dayjs('2025-03-11')))
|
||||
return
|
||||
|
||||
if (providersData?.data && providersData.data.length > 0) {
|
||||
|
@ -1,4 +1,6 @@
|
||||
const translation = {
|
||||
offlineNoticeTitle: 'Important Notice',
|
||||
offlineNotice: 'Dify v1.0.0 is now officially released. Effective March 5, 2025, the current environment will no longer be accessible, and all data will be permanently deleted. Please ensure that you back up any necessary data prior to this date to avoid any loss.',
|
||||
api: {
|
||||
success: 'Success',
|
||||
actionSuccess: 'Action succeeded',
|
||||
|
@ -1,4 +1,6 @@
|
||||
const translation = {
|
||||
offlineNoticeTitle: '重要通知',
|
||||
offlineNotice: 'Dify v1.0.0 现已正式发布。自 2025年 3 月 5 日起,当前环境将不可访问,所有数据将被永久删除。请务必在此日期之前备份所有必要数据,以避免任何数据丢失。',
|
||||
api: {
|
||||
success: '成功',
|
||||
actionSuccess: '操作成功',
|
||||
|
@ -1,16 +0,0 @@
|
||||
import { buildProviderQuery } from './_tools_util'
|
||||
|
||||
describe('makeProviderQuery', () => {
|
||||
test('collectionName without special chars', () => {
|
||||
expect(buildProviderQuery('ABC')).toBe('provider=ABC')
|
||||
})
|
||||
test('should escape &', () => {
|
||||
expect(buildProviderQuery('ABC&DEF')).toBe('provider=ABC%26DEF')
|
||||
})
|
||||
test('should escape /', () => {
|
||||
expect(buildProviderQuery('ABC/DEF')).toBe('provider=ABC%2FDEF')
|
||||
})
|
||||
test('should escape ?', () => {
|
||||
expect(buildProviderQuery('ABC?DEF')).toBe('provider=ABC%3FDEF')
|
||||
})
|
||||
})
|
@ -1,5 +0,0 @@
|
||||
export const buildProviderQuery = (collectionName: string): string => {
|
||||
const query = new URLSearchParams()
|
||||
query.set('provider', collectionName)
|
||||
return query.toString()
|
||||
}
|
@ -10,7 +10,6 @@ import type {
|
||||
} from '@/app/components/tools/types'
|
||||
import type { ToolWithProvider } from '@/app/components/workflow/types'
|
||||
import type { Label } from '@/app/components/tools/labels/constant'
|
||||
import { buildProviderQuery } from './_tools_util'
|
||||
|
||||
export const fetchCollectionList = () => {
|
||||
return get<Collection[]>('/workspaces/current/tool-providers')
|
||||
@ -25,13 +24,11 @@ export const fetchBuiltInToolList = (collectionName: string) => {
|
||||
}
|
||||
|
||||
export const fetchCustomToolList = (collectionName: string) => {
|
||||
const query = buildProviderQuery(collectionName)
|
||||
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?${query}`)
|
||||
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`)
|
||||
}
|
||||
|
||||
export const fetchModelToolList = (collectionName: string) => {
|
||||
const query = buildProviderQuery(collectionName)
|
||||
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?${query}`)
|
||||
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`)
|
||||
}
|
||||
|
||||
export const fetchWorkflowToolList = (appID: string) => {
|
||||
@ -68,8 +65,7 @@ export const parseParamsSchema = (schema: string) => {
|
||||
}
|
||||
|
||||
export const fetchCustomCollection = (collectionName: string) => {
|
||||
const query = buildProviderQuery(collectionName)
|
||||
return get<CustomCollectionBackend>(`/workspaces/current/tool-provider/api/get?${query}`)
|
||||
return get<CustomCollectionBackend>(`/workspaces/current/tool-provider/api/get?provider=${collectionName}`)
|
||||
}
|
||||
|
||||
export const createCustomCollection = (collection: CustomCollectionBackend) => {
|
||||
|
@ -76,7 +76,7 @@ export const correctToolProvider = (provider: string) => {
|
||||
if (provider.includes('/'))
|
||||
return provider
|
||||
|
||||
if (['stepfun', 'jina', 'siliconflow', 'gitee_ai'].includes(provider))
|
||||
if (['stepfun', 'jina', 'siliconflow'].includes(provider))
|
||||
return `langgenius/${provider}_tool/${provider}`
|
||||
|
||||
return `langgenius/${provider}/${provider}`
|
||||
@ -84,6 +84,6 @@ export const correctToolProvider = (provider: string) => {
|
||||
|
||||
export const canFindTool = (providerId: string, oldToolId?: string) => {
|
||||
return providerId === oldToolId
|
||||
|| providerId === `langgenius/${oldToolId}/${oldToolId}`
|
||||
|| providerId === `langgenius/${oldToolId}_tool/${oldToolId}`
|
||||
|| providerId === `langgenius/${oldToolId}/${oldToolId}`
|
||||
|| providerId === `langgenius/${oldToolId}_tool/${oldToolId}`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user