Compare commits

...

56 Commits

Author SHA1 Message Date
Jyong
086aeea181
Update build-push.yml 2025-03-03 14:12:54 +08:00
jyong
1d7c4a87d0 Merge branch 'feat/support-knowledge-metadata' into dev/plugin-deploy 2025-03-03 13:36:13 +08:00
jyong
9042b368e9 add metadata migration 2025-03-03 13:35:51 +08:00
NFish
f1bcd26c69
Update build-push.yml
remove api image build task
2025-02-28 22:23:30 +08:00
Jyong
3dcd8b6330
Update build-push.yml 2025-02-28 18:32:41 +08:00
NFish
10c088029c Merge branch 'feat/beta-offline-notice' into dev/plugin-deploy 2025-02-28 18:11:39 +08:00
NFish
73b1adf862 fix: update notice date 2025-02-28 18:10:48 +08:00
NFish
ae76dbd92c fix: update offline notice style 2025-02-28 18:10:13 +08:00
NFish
782df0c383 fix: update offline notice style 2025-02-28 18:08:26 +08:00
NFish
089207240e fix: add offline notice 2025-02-28 18:08:09 +08:00
NFish
53d30d537f Revert "Merge branch 'fix/adjust-price-frontend' into dev/plugin-deploy"
This reverts commit 7710d8e83b, reversing
changes made to 96cf0ed5af.
2025-02-28 18:02:51 +08:00
NFish
53512a4650 Revert "Merge branch 'feat/compliance-report-download' into dev/plugin-deploy"
This reverts commit 202a246e83, reversing
changes made to 7710d8e83b.
2025-02-28 18:01:35 +08:00
jyong
1fb7dcda24 Merge remote-tracking branch 'origin/dev/plugin-deploy' into dev/plugin-deploy 2025-02-28 16:57:49 +08:00
jyong
3c3e0a35f4 add metadata migration 2025-02-28 16:34:11 +08:00
NFish
202a246e83 Merge branch 'feat/compliance-report-download' into dev/plugin-deploy 2025-02-28 12:26:27 +08:00
NFish
08b968eca5 fix: workspace selector style 2025-02-28 12:25:50 +08:00
NFish
b1ac71db3e Merge branch 'main' into feat/compliance-report-download 2025-02-28 11:45:02 +08:00
NFish
7710d8e83b Merge branch 'fix/adjust-price-frontend' into dev/plugin-deploy 2025-02-27 18:14:08 +08:00
NFish
cf75fcdffc fix: merge main 2025-02-27 18:04:24 +08:00
NFish
6e8601b52c Merge branch 'main' into fix/adjust-price-frontend 2025-02-27 17:57:55 +08:00
jyong
96cf0ed5af add metadata migration 2025-02-27 17:21:54 +08:00
jyong
46a798bea8 dataset metadata fix 2025-02-27 16:23:02 +08:00
jyong
9e258c495d dataset metadata fix 2025-02-27 15:30:37 +08:00
jyong
c53786d229 dataset metadata update 2025-02-26 19:59:57 +08:00
jyong
17f23f4798 Merge branch 'main' into feat/support-knowledge-metadata
# Conflicts:
#	api/core/rag/datasource/retrieval_service.py
#	api/core/workflow/nodes/code/code_node.py
#	api/services/dataset_service.py
2025-02-26 19:59:14 +08:00
jyong
67f2c766bc dataset metadata update 2025-02-26 19:56:19 +08:00
jyong
5f995fac32 metadata update 2025-02-20 17:13:44 +08:00
jyong
f88f9d6970 metadata 2025-02-19 15:50:28 +08:00
jyong
d2cc502c71 knowledge metadata 2025-02-17 18:17:26 +08:00
NFish
b88194d1c6 fix: regenerate icons; replace iso icon 2025-02-10 16:38:56 +08:00
NFish
2b95e54d54 fix: add compliance file name ellipsis support 2025-02-10 16:34:59 +08:00
NFish
9bff9b5c9e fix: keep menus under open state when compliance is downloading 2025-02-07 14:16:51 +08:00
NFish
3dd2c170e7 fix: only saas version can download compliance 2025-02-07 12:24:52 +08:00
NFish
88f41f164f feat: support user download compliance files 2025-02-06 16:47:01 +08:00
NFish
cd932519b3 fix: add icon to user profile dropdown menu item 2025-02-05 15:49:50 +08:00
NFish
2ff2b08739 Merge branch 'main' into feat/compliance-report-download 2025-02-05 11:23:03 +08:00
NFish
a4a45421cc fix: update sandbox log history value in jp 2025-01-23 17:16:54 +08:00
NFish
aafab1b59e fix: update sandbox log histroy value 2025-01-23 17:09:44 +08:00
NFish
7f49f96c3f fix: update team members value 2025-01-23 17:02:49 +08:00
NFish
5673f03db5 fix: update documentsRequestQuota value 2025-01-23 16:30:39 +08:00
NFish
278adbc10e fix: update jp translate error 2025-01-23 14:49:04 +08:00
NFish
5d4e517397 fix: update billing button disabled style 2025-01-23 14:44:26 +08:00
NFish
c2671c16a8 fix: update bill page background opacity 2025-01-23 11:41:10 +08:00
NFish
10991cbc03 fix: update bill page background image 2025-01-23 11:06:01 +08:00
NFish
3fcf7e88b0 fix: UI adjust 2025-01-22 20:01:00 +08:00
NFish
ffa5af1356 fix: supports number format 2025-01-22 19:41:18 +08:00
NFish
066516b54d fix: update document limit tooltip content 2025-01-14 18:57:03 +08:00
NFish
49415e5e7f fix: update Knowledge Request Ratelimit tooltip text 2025-01-14 16:38:30 +08:00
NFish
a697bbdfa7 fix: update i18n 2025-01-09 10:25:04 +08:00
NFish
d5c31f8728 fix: update billing i18n in setting modal 2025-01-08 17:58:51 +08:00
NFish
508005b741 fix: replace contact sales url address 2025-01-08 14:32:16 +08:00
NFish
4f0ecdbb6e fix: use repeat-linear-gradient for GridMask to improve darkmode support 2025-01-07 14:23:19 +08:00
NFish
ab2e69faef fix: plan item can not show all content if language is jp 2025-01-07 12:23:17 +08:00
NFish
e46a3343b8 fix: new upgrade page 2025-01-07 11:42:41 +08:00
NFish
47637da734 wip: adjust self hosted page style 2025-01-06 10:47:38 +08:00
NFish
525bde28f6 fix: adjust cloud service 2025-01-03 16:18:24 +08:00
67 changed files with 2089 additions and 1562 deletions

View File

@ -5,6 +5,7 @@ on:
branches:
- "main"
- "deploy/dev"
- "dev/plugin-deploy"
release:
types: [published]

View File

@ -617,7 +617,7 @@ class DocumentDetailApi(DocumentResource):
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
@ -678,7 +678,7 @@ class DocumentDetailApi(DocumentResource):
"disabled_by": document.disabled_by,
"archived": document.archived,
"doc_type": document.doc_type,
"doc_metadata": document.doc_metadata,
"doc_metadata": document.doc_metadata_details,
"segment_count": document.segment_count,
"average_segment_length": document.average_segment_length,
"hit_count": document.hit_count,

View File

@ -0,0 +1,143 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return metadata, 201
class DatasetMetadataApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return metadata, 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 200
class DatasetMetadataBuiltInFieldApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return built_in_fields, 200
class DatasetMetadataBuiltInFieldActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return 200
class DocumentMetadataApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
MetadataService.update_documents_metadata(dataset, metadata_args)
return 200
api.add_resource(DatasetListApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@ -180,7 +180,7 @@ class ToolProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
if self.provider_name in ["jina", "siliconflow", "stepfun"]:
self.plugin_name = f"{self.provider_name}_tool"

View File

@ -111,12 +111,6 @@ class ProviderManager:
# Get all provider model records of the workspace
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
for provider_name in list(provider_name_to_provider_model_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_provider_model_records_dict:
provider_name_to_provider_model_records_dict[str(provider_id)] = (
provider_name_to_provider_model_records_dict[provider_name]
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(tenant_id)

View File

@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
segment_query = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
documents.append(

View File

@ -42,6 +42,7 @@ class RetrievalService:
reranking_model: Optional[dict] = None,
reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
document_ids_filter: Optional[list[str]] = None,
):
if not query:
return []
@ -65,6 +66,7 @@ class RetrievalService:
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
@ -80,6 +82,7 @@ class RetrievalService:
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
@ -131,7 +134,14 @@ class RetrievalService:
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
@ -140,7 +150,10 @@ class RetrievalService:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
documents = keyword.search(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@ -157,6 +170,7 @@ class RetrievalService:
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
@ -171,6 +185,7 @@ class RetrievalService:
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
if documents:

View File

@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_vector(query_vector)
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)

View File

@ -194,6 +194,11 @@ class AnalyticdbVectorBySql:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "WHERE 1=1"
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
@ -202,7 +207,7 @@ class AnalyticdbVectorBySql:
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
@ -220,12 +225,17 @@ class AnalyticdbVectorBySql:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),

View File

@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
filter=f"document_id IN ({document_ids})",
)
else:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],

View File

@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
results: QueryResult = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}},
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data

View File

@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:

View File

@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
document_ids_filter = kwargs.get("document_ids_filter")
filters = []
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try:
params = {}
if self._using_ugc:
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
filters = kwargs.get("filter", [])
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,

View File

@ -218,12 +218,18 @@ class MilvusVector(BaseVector):
"""
Search for documents by vector similarity.
"""
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
)
return self._process_search_results(
@ -239,6 +245,11 @@ class MilvusVector(BaseVector):
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'
results = self._client.search(
collection_name=self._collection_name,
@ -246,6 +257,7 @@ class MilvusVector(BaseVector):
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
)
return self._process_search_results(

View File

@ -131,6 +131,10 @@ class MyScaleVector(BaseVector):
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}

View File

@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = None
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=where_clause,
)
docs = []
for text, metadata, distance in cur:

View File

@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)

View File

@ -185,10 +185,15 @@ class OracleVector(BaseVector):
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
f" ORDER BY distance fetch first {top_k} rows only",
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
@ -241,9 +246,15 @@ class OracleVector(BaseVector):
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"select meta, text, embedding FROM {self.table_name}"
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
f"order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)],
)
docs = []

View File

@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
.limit(kwargs.get("top_k", 4))
.order_by("distance")
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
res = session.execute(stmt)
results = [(row[0], row[1]) for row in res]

View File

@ -155,10 +155,16 @@ class PGVector(BaseVector):
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" {where_clause}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
@ -176,10 +182,16 @@ class PGVector(BaseVector):
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
{where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query

View File

@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
@ -331,6 +330,14 @@ class QdrantVector(BaseVector):
),
],
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
@ -377,6 +384,14 @@ class QdrantVector(BaseVector):
),
]
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
scroll_filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,

View File

@ -223,8 +223,12 @@ class RelytVector(BaseVector):
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = kwargs.get("filter", {})
if document_ids_filter:
filter["document_id"] = document_ids_filter
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
)
# Organize results.
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
else f"metadata->>'{key!r}' = {value[0]!r}"
for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"

View File

@ -145,11 +145,16 @@ class TencentVector(BaseVector):
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
res = self._db.collection(self._collection_name).search(
vectors=[query_vector],
filter=filter,
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),

View File

@ -326,6 +326,14 @@ class TidbOnQdrantVector(BaseVector):
),
],
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
@ -368,6 +376,14 @@ class TidbOnQdrantVector(BaseVector):
)
]
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
scroll_filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,

View File

@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
docs = []
tidb_dist_func = self._get_distance_func()
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
with Session(self._engine) as session:
select_statement = sql_text(f"""
@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name}
{where_clause}
ORDER BY distance ASC
LIMIT :top_k
) t

View File

@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f"document_id in ({document_ids})"
else:
filter = ""
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=True,
include_data=True,
include_vectors=False,
filter=filter,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result:

View File

@ -49,6 +49,10 @@ class BaseVector(ABC):
def delete(self) -> None:
raise NotImplementedError
@abstractmethod
def update_metadata(self, document_id: str, metadata: dict) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata and "doc_id" in text.metadata:

View File

@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
query_vector, limit=kwargs.get("top_k", 4)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
docs = self._get_search_res(results, score_threshold)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
return docs
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:

View File

@ -168,16 +168,16 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
for uuid in ids:
try:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
except weaviate.UnexpectedStatusCodeException as e:
# tolerate not found error
if e.status_code != 404:
raise e
try:
self._client.batch.delete_objects(
class_name=self._collection_name,
where={"operator": "ContainsAny", "path": ["id"], "valueTextArray": ids},
output="minimal",
)
except weaviate.UnexpectedStatusCodeException as e:
# tolerate not found error
if e.status_code != 404:
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()

View File

@ -0,0 +1,9 @@
from enum import Enum
class BuiltInField(str, Enum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"

View File

@ -203,7 +203,6 @@ class DatasetRetrieval:
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata,
}
if invoke_from.to_source() == "dev":
@ -238,6 +237,7 @@ class DatasetRetrieval:
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
):
tools = []
for dataset in available_datasets:
@ -292,6 +292,11 @@ class DatasetRetrieval:
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
@ -323,6 +328,7 @@ class DatasetRetrieval:
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)

View File

@ -185,7 +185,7 @@ class ToolInvokeMessage(BaseModel):
"""
plain text, image url or link url
"""
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
meta: dict[str, Any] | None = None
@field_validator("message", mode="before")

View File

@ -123,7 +123,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
if self.retriever_from == "dev":

View File

@ -172,7 +172,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadat, # type: ignore
}
if self.retriever_from == "dev":

View File

@ -8,12 +8,12 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
@ -156,38 +156,16 @@ class AgentNode(ToolNode):
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
value_param = param.get("value", {})
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
provider_type=ToolProviderType.BUILT_IN,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
tool_parameters=tool.get("parameters", {}),
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
)
@ -200,26 +178,11 @@ class AgentNode(ToolNode):
tool_runtime.entity.description.llm = (
extra.get("descrption", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"provider_type": provider_type.value,
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
}
)
value = tool_value

View File

@ -1,4 +1,3 @@
from enum import Enum
from typing import Any, Literal, Union
from pydantic import BaseModel
@ -17,8 +16,3 @@ class AgentNodeData(BaseNodeData):
type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import Any, Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig
class RerankingModelConfig(BaseModel):
@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel):
model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"starts with",
"ends with",
"is",
"is not",
"empty",
"is not empty",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Conditon detail
"""
metadata_name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -1,6 +1,8 @@
import json
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import func
@ -9,21 +11,38 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.list_operator.exc import InvalidConditionError
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from models.dataset import Dataset, Document
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document
from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
@ -42,13 +61,14 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -63,7 +83,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@ -117,6 +137,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
@ -146,6 +169,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None:
@ -240,7 +264,6 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
"doc_metadata": document.doc_metadata,
},
"title": document.name,
}
@ -259,6 +282,134 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
item["metadata"]["position"] = position
return retrieval_resource_list
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> dict[str, list[str]]:
document_query = db.session.query(Document.id).filter(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
if node_data.metadata_filtering_mode == "disabled":
return None
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters:
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query
)
elif node_data.metadata_filtering_mode == "manual":
for condition in node_data.metadata_filtering_conditions.conditions:
metadata_name = condition.metadata_name
expected_value = condition.value
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).text
self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, document_query
)
else:
raise ValueError("Invalid metadata filtering mode")
documnents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list)
for document in documnents:
metadata_filter_document_ids[document.dataset_id].append(document.id)
return metadata_filter_document_ids
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
# get metadata model config
metadata_model_config = node_data.metadata_model_config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
metadata_fields=all_metadata_fields,
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
result_text = ""
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
match condition:
case "contains":
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))
case "not contains":
query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%"))
case "start with":
query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%"))
case "end with":
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}"))
case "is", "=":
query = query.filter(Document.doc_metadata[metadata_name] == value)
case "is not", "":
query = query.filter(Document.doc_metadata[metadata_name] != value)
case "is empty":
query = query.filter(Document.doc_metadata[metadata_name].is_(None))
case "is not empty":
query = query.filter(Document.doc_metadata[metadata_name].isnot(None))
case "before", "<":
query = query.filter(Document.doc_metadata[metadata_name] < value)
case "after", ">":
query = query.filter(Document.doc_metadata[metadata_name] > value)
case "", ">=":
query = query.filter(Document.doc_metadata[metadata_name] <= value)
case "", ">=":
query = query.filter(Document.doc_metadata[metadata_name] >= value)
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -344,3 +495,94 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
parameters=completion_params,
stop=stop,
)
def _calculate_rest_token(
self,
node_data: KnowledgeRetrievalNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode)
input_text = query
memory_str = ""
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@ -459,7 +459,6 @@ class LLMNode(BaseNode[LLMNodeData]):
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
return source

View File

@ -53,6 +53,8 @@ external_knowledge_info_fields = {
"external_knowledge_api_endpoint": fields.String,
}
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
dataset_detail_fields = {
"id": fields.String,
"name": fields.String,
@ -76,6 +78,8 @@ dataset_detail_fields = {
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
"built_in_field_enabled": fields.Boolean,
}
dataset_query_detail_fields = {
@ -87,3 +91,9 @@ dataset_query_detail_fields = {
"created_by": fields.String,
"created_at": TimestampField,
}
dataset_metadata_fields = {
"id": fields.String,
"type": fields.String,
"name": fields.String,
}

View File

@ -3,6 +3,13 @@ from flask_restful import fields # type: ignore
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
document_metadata_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"value": fields.String,
}
document_fields = {
"id": fields.String,
"position": fields.Integer,
@ -25,6 +32,7 @@ document_fields = {
"word_count": fields.Integer,
"hit_count": fields.Integer,
"doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
}
document_with_segments_fields = {
@ -51,6 +59,7 @@ document_with_segments_fields = {
"hit_count": fields.Integer,
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
}
dataset_and_document_fields = {

View File

@ -7,7 +7,6 @@ document_fields = {
"data_source_type": fields.String,
"name": fields.String,
"doc_type": fields.String,
"doc_metadata": fields.Raw,
}
segment_fields = {

View File

@ -0,0 +1,90 @@
"""add_metadata_function
Revision ID: d20049ed0af6
Revises: 08ec4f75af5e
Create Date: 2025-02-27 09:17:48.903213
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd20049ed0af6'
down_revision = '08ec4f75af5e'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_metadata_bindings',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
)
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)
op.create_table('dataset_metadatas',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
)
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.alter_column('doc_metadata',
existing_type=postgresql.JSON(astext_type=sa.Text()),
type_=postgresql.JSONB(astext_type=sa.Text()),
existing_nullable=True)
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
batch_op.alter_column('doc_metadata',
existing_type=postgresql.JSONB(astext_type=sa.Text()),
type_=postgresql.JSON(astext_type=sa.Text()),
existing_nullable=True)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('built_in_field_enabled')
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
batch_op.drop_index('dataset_metadata_tenant_idx')
batch_op.drop_index('dataset_metadata_dataset_idx')
op.drop_table('dataset_metadatas')
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
batch_op.drop_index('dataset_metadata_binding_tenant_idx')
batch_op.drop_index('dataset_metadata_binding_metadata_idx')
batch_op.drop_index('dataset_metadata_binding_document_idx')
batch_op.drop_index('dataset_metadata_binding_dataset_idx')
op.drop_table('dataset_metadata_bindings')
# ### end Alembic commands ###

View File

@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -60,6 +61,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
def dataset_keyword_table(self):
@ -197,6 +199,19 @@ class Dataset(db.Model): # type: ignore[name-defined]
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}
@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
return [
{
"id": dataset_metadata.id,
"name": dataset_metadata.name,
"type": dataset_metadata.type,
}
for dataset_metadata in dataset_metadatas
]
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")
@ -250,6 +265,7 @@ class Document(db.Model): # type: ignore[name-defined]
db.Index("document_dataset_id_idx", "dataset_id"),
db.Index("document_is_paused_idx", "is_paused"),
db.Index("document_tenant_idx", "tenant_id"),
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
)
# initial fields
@ -306,7 +322,7 @@ class Document(db.Model): # type: ignore[name-defined]
archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True)
doc_metadata = db.Column(JSONB, nullable=True)
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = db.Column(db.String(255), nullable=True)
@ -397,6 +413,78 @@ class Document(db.Model): # type: ignore[name-defined]
)
@property
def uploader(self):
user = db.session.query(Account).filter(Account.id == self.created_by).first()
return user.name if user else None
@property
def upload_date(self):
return self.created_at
@property
def last_update_date(self):
return self.updated_at
@property
def doc_metadata_details(self):
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.filter(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
metadata_list = []
for metadata in document_metadatas:
metadata_dict = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
"value": self.doc_metadata.get(metadata.type),
}
metadata_list.append(metadata_dict)
# deal built-in fields
metadata_list.extend(self.get_built_in_fields())
return metadata_list
return None
def get_built_in_fields(self):
built_in_fields = []
built_in_fields.append({
"id": "built-in",
"name": BuiltInField.document_name,
"type": "string",
"value": self.name,
})
built_in_fields.append({
"id": "built-in",
"name": BuiltInField.uploader,
"type": "string",
"value": self.uploader,
})
built_in_fields.append({
"id": "built-in",
"name": BuiltInField.upload_date,
"type": "date",
"value": self.created_at,
})
built_in_fields.append({
"id": "built-in",
"name": BuiltInField.last_update_date,
"type": "date",
"value": self.updated_at,
})
built_in_fields.append({
"id": "built-in",
"name": BuiltInField.source,
"type": "string",
"value": self.data_source_info,
})
return built_in_fields
def process_rule_dict(self):
if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict()
@ -930,3 +1018,41 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
db.Index("dataset_metadata_binding_document_idx", "document_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
metadata_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)

2067
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -599,9 +600,45 @@ class DocumentService:
return document
@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
return documents
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)
return documents
@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
return documents
@ -684,7 +721,11 @@ class DocumentService:
if document.tenant_id != current_user.current_tenant_id:
raise ValueError("No permission.")
document.name = name
if dataset.built_in_field_enabled:
if document.doc_metadata:
document.doc_metadata[BuiltInField.document_name] = name
else:
document.name = name
db.session.add(document)
db.session.commit()
@ -1084,9 +1125,22 @@ class DocumentService:
doc_form=document_form,
doc_language=document_language,
)
doc_metadata = {}
if dataset.built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
BuiltInField.source: data_source_type,
}
if metadata is not None:
document.doc_metadata = metadata.doc_metadata
doc_metadata.update(metadata.doc_metadata)
document.doc_type = metadata.doc_type
if doc_metadata:
document.doc_metadata = doc_metadata
return document
@staticmethod

View File

@ -124,3 +124,36 @@ class SegmentUpdateArgs(BaseModel):
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str
class MetadataArgs(BaseModel):
type: Literal["string", "number", "time"]
name: str
class MetadataUpdateArgs(BaseModel):
name: str
value: str
class MetadataValueUpdateArgs(BaseModel):
fields: list[MetadataUpdateArgs]
class MetadataDetail(BaseModel):
id: str
name: str
value: str
class DocumentMetadataOperation(BaseModel):
document_id: str
metadata_list: list[MetadataDetail]
class MetadataOperationData(BaseModel):
"""
Metadata operation data
"""
operation_data: list[DocumentMetadataOperation]

View File

@ -0,0 +1,182 @@
import datetime
from typing import Optional
from flask_login import current_user # type: ignore
from core.rag.index_processor.constant.built_in_field import BuiltInField
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from tasks.update_documents_metadata_task import update_documents_metadata_task
class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
metadata = DatasetMetadata(
dataset_id=dataset_id,
type=metadata_args.type,
name=metadata_args.name,
created_by=current_user.id,
)
db.session.add(metadata)
db.session.commit()
return metadata
@staticmethod
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata:
lock_key = f"dataset_metadata_lock_{dataset_id}"
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
metadata.name = name
metadata.updated_by = current_user.id
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents
documents = []
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
document.doc_metadata[name] = document.doc_metadata.pop(old_name)
db.session.add(document)
db.session.commit()
if document_ids:
update_documents_metadata_task.delay(dataset_id, document_ids, lock_key)
return metadata
@staticmethod
def delete_metadata(dataset_id: str, metadata_id: str):
lock_key = f"dataset_metadata_lock_{dataset_id}"
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# delete related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
document.doc_metadata.pop(metadata.name)
db.session.add(document)
db.session.commit()
if document_ids:
update_documents_metadata_task.delay(dataset_id, document_ids, lock_key)
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "date"},
{"name": BuiltInField.last_update_date, "type": "date"},
{"name": BuiltInField.source, "type": "string"},
]
@staticmethod
def enable_built_in_field(dataset: Dataset):
if dataset.built_in_fields:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_fields = True
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
document.doc_metadata[BuiltInField.document_name] = document.name
document.doc_metadata[BuiltInField.uploader] = document.uploader
document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S")
document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime(
"%Y-%m-%d %H:%M:%S"
)
document.doc_metadata[BuiltInField.source] = document.data_source_type
db.session.add(document)
document_ids.append(document.id)
db.session.commit()
if document_ids:
update_documents_metadata_task.delay(dataset.id, document_ids, lock_key)
@staticmethod
def disable_built_in_field(dataset: Dataset):
if not dataset.built_in_fields:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_fields = False
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
document.doc_metadata.pop(BuiltInField.document_name)
document.doc_metadata.pop(BuiltInField.uploader)
document.doc_metadata.pop(BuiltInField.upload_date)
document.doc_metadata.pop(BuiltInField.last_update_date)
document.doc_metadata.pop(BuiltInField.source)
db.session.add(document)
document_ids.append(document.id)
db.session.commit()
if document_ids:
update_documents_metadata_task.delay(dataset.id, document_ids, lock_key)
@staticmethod
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
for operation in metadata_args.operation_data:
lock_key = f"document_metadata_lock_{operation.document_id}"
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
document = DocumentService.get_document(operation.document_id)
if document is None:
raise ValueError("Document not found.")
document.doc_metadata = {}
for metadata_value in metadata_args.fields:
document.doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_fields:
document.doc_metadata[BuiltInField.document_name] = document.name
document.doc_metadata[BuiltInField.uploader] = document.uploader
document.doc_metadata[BuiltInField.upload_date] = document.upload_date.strftime("%Y-%m-%d %H:%M:%S")
document.doc_metadata[BuiltInField.last_update_date] = document.last_update_date.strftime(
"%Y-%m-%d %H:%M:%S"
)
document.doc_metadata[BuiltInField.source] = document.data_source_type
# deal metadata bindding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.tenant_id,
dataset_id=dataset.id,
document_id=operation.document_id,
metadata_id=metadata_value.id,
created_by=current_user.id,
)
db.session.add(dataset_metadata_binding)
db.session.add(document)
db.session.commit()
update_documents_metadata_task.delay(dataset.id, [document.id], lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)
if document_id:
lock_key = f"document_metadata_lock_{document_id}"
if redis_client.get(lock_key):
raise ValueError("Another document metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)

View File

@ -0,0 +1,121 @@
import logging
import time
from typing import Optional
import click
from celery import shared_task # type: ignoreq
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import (
Document as DatasetDocument,
)
from models.dataset import (
DocumentSegment,
)
from services.dataset_service import DatasetService
@shared_task(queue="dataset")
def update_documents_metadata_task(
dataset_id: str,
document_ids: list[str],
lock_key: Optional[str] = None,
):
"""
Update documents metadata.
:param dataset_id: dataset id
:param document_ids: document ids
Usage: update_documents_metadata_task.delay(dataset_id, document_ids)
"""
logging.info(click.style("Start update documents metadata: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise ValueError("Dataset not found.")
documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.indexing_status == "completed",
DatasetDocument.archived == False,
)
.all()
)
if not documents:
raise ValueError("Documents not found.")
for dataset_document in documents:
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.all()
)
if not segments:
continue
# delete all documents in vector index
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=False, delete_child_chunks=True)
# update documents metadata
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": dataset_document.id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": dataset_document.id,
"dataset_id": dataset_id,
},
)
if dataset.built_in_field_enabled:
child_document.metadata[BuiltInField.uploader] = dataset_document.created_by
child_document.metadata[BuiltInField.upload_date] = dataset_document.created_at
child_document.metadata[BuiltInField.last_update_date] = dataset_document.updated_at
child_document.metadata[BuiltInField.source] = dataset_document.data_source_type
child_document.metadata[BuiltInField.original_filename] = dataset_document.name
if dataset_document.doc_metadata:
child_document.metadata.update(dataset_document.doc_metadata)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) # noqa: B909
# save vector index
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logging.info(
click.style("Updated documents metadata: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Updated documents metadata failed")
finally:
if lock_key:
redis_client.delete(lock_key)

View File

@ -29,7 +29,7 @@ server {
include proxy.conf;
}
location /e/ {
location /e {
proxy_pass http://plugin_daemon:5002;
proxy_set_header Dify-Hook-Url $scheme://$host$request_uri;
include proxy.conf;

View File

@ -94,7 +94,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
},
]
return navs
}, [])
}, [t])
useEffect(() => {
if (appDetail) {
@ -120,7 +120,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
}).finally(() => {
setIsLoadingAppDetail(false)
})
}, [appId, pathname])
}, [appId, router, setAppDetail])
useEffect(() => {
if (!appDetailRes || isLoadingCurrentWorkspace || isLoadingAppDetail)
@ -148,7 +148,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
}
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [appDetailRes, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, systemFeatures.enable_web_sso_switch_component])
}, [appDetailRes, appId, getNavigations, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, router, setAppDetail, systemFeatures.enable_web_sso_switch_component])
useUnmount(() => {
setAppDetail()

View File

@ -3,13 +3,13 @@
import { useSelectedLayoutSegment } from 'next/navigation'
import Link from 'next/link'
import classNames from '@/utils/classnames'
import type { RemixiconComponentType } from '@remixicon/react'
export type NavIcon = React.ComponentType<
React.PropsWithoutRef<React.ComponentProps<'svg'>> & {
title?: string | undefined
titleId?: string | undefined
}> | RemixiconComponentType
}
>
export type NavLinkProps = {
name: string

View File

@ -94,7 +94,7 @@ const Configuration: FC = () => {
})))
const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
const latestPublishedAt = useMemo(() => appDetail?.model_config?.updated_at, [appDetail])
const latestPublishedAt = useMemo(() => appDetail?.model_config.updated_at, [appDetail])
const [formattingChanged, setFormattingChanged] = useState(false)
const { setShowAccountSettingModal } = useModalContext()
const [hasFetchedDetail, setHasFetchedDetail] = useState(false)

View File

@ -128,7 +128,7 @@ const Apps = ({
icon_background,
description,
}) => {
const { export_data, mode } = await fetchAppDetail(
const { export_data } = await fetchAppDetail(
currApp?.app.id as string,
)
try {
@ -151,7 +151,7 @@ const Apps = ({
if (app.app_id)
await handleCheckPluginDependencies(app.app_id)
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id, mode }, push)
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push)
}
catch (e) {
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })

View File

@ -211,9 +211,7 @@ const Paragraph = (paragraph: any) => {
return (
<>
<ImageGallery srcs={[children_node[0].properties.src]} />
{
Array.isArray(paragraph.children) ? <p>{paragraph.children.slice(1)}</p> : null
}
<p>{paragraph.children.slice(1)}</p>
</>
)
}

View File

@ -59,7 +59,7 @@ const Toast = ({
}`}
/>
<div className={`flex ${size === 'md' ? 'gap-1' : 'gap-0.5'}`}>
<div className={`flex justify-center items-center ${size === 'md' ? 'p-0.5' : 'p-1'}`}>
<div className={`flex justify-center items-start ${size === 'md' ? 'p-0.5' : 'p-1'}`}>
{type === 'success' && <RiCheckboxCircleFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-success`} aria-hidden="true" />}
{type === 'error' && <RiErrorWarningFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-destructive`} aria-hidden="true" />}
{type === 'warning' && <RiAlertFill className={`${size === 'md' ? 'w-5 h-5' : 'w-4 h-4'} text-text-warning-secondary`} aria-hidden="true" />}

View File

@ -126,7 +126,7 @@ const Apps = ({
icon_background,
description,
}) => {
const { export_data, mode } = await fetchAppDetail(
const { export_data } = await fetchAppDetail(
currApp?.app.id as string,
)
try {
@ -149,7 +149,7 @@ const Apps = ({
if (app.app_id)
await handleCheckPluginDependencies(app.app_id)
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id, mode }, push)
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push)
}
catch (e) {
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })

View File

@ -0,0 +1,24 @@
'use client'
import { RiCloseLine } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import type { PropsWithChildren } from 'react'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
export default function OfflineNotice({ children }: PropsWithChildren) {
const { t } = useTranslation()
const [showOfflineNotice, { setFalse }] = useBoolean(true)
useEffect(() => {
const timer = setTimeout(setFalse, 60000)
return () => clearTimeout(timer)
}, [setFalse])
return <>
{showOfflineNotice && <div className='px-4 py-2 flex items-center justify-start gap-x-2 bg-[#FFFAEB] border-b-[0.5px] border-b-[#FEF0C7]'>
<div className='rounded-[12px] flex items-center justify-center px-2 py-0.5 h-[22px] bg-[#f79009] text-white text-[11px] not-italic font-medium leading[18px]'>{t('common.offlineNoticeTitle')}</div>
<div className='grow font-medium leading-[18px] text-[12px] not-italic text-[#344054]'>{t('common.offlineNotice')}</div>
<RiCloseLine className='size-4 text-[#667085] cursor-pointer' onClick={setFalse} />
</div>}
{children}
</>
}

View File

@ -7,6 +7,7 @@ import { TanstackQueryIniter } from '@/context/query-client'
import { ThemeProvider } from 'next-themes'
import './styles/globals.css'
import './styles/markdown.scss'
import OfflineNotice from './components/offline-notice'
export const metadata = {
title: 'Dify',
@ -61,7 +62,9 @@ const LocaleLayout = ({
disableTransitionOnChange
>
<I18nServer>
{children}
<OfflineNotice>
{children}
</OfflineNotice>
</I18nServer>
</ThemeProvider>
</TanstackQueryIniter>

View File

@ -120,7 +120,7 @@ export const ProviderContextProvider = ({
if (localStorage.getItem('anthropic_quota_notice') === 'true')
return
if (dayjs().isAfter(dayjs('2025-03-17')))
if (dayjs().isAfter(dayjs('2025-03-11')))
return
if (providersData?.data && providersData.data.length > 0) {

View File

@ -1,4 +1,6 @@
const translation = {
offlineNoticeTitle: 'Important Notice',
offlineNotice: 'Dify v1.0.0 is now officially released. Effective March 5, 2025, the current environment will no longer be accessible, and all data will be permanently deleted. Please ensure that you back up any necessary data prior to this date to avoid any loss.',
api: {
success: 'Success',
actionSuccess: 'Action succeeded',

View File

@ -1,4 +1,6 @@
const translation = {
offlineNoticeTitle: '重要通知',
offlineNotice: 'Dify v1.0.0 现已正式发布。自 2025年 3 月 5 日起,当前环境将不可访问,所有数据将被永久删除。请务必在此日期之前备份所有必要数据,以避免任何数据丢失。',
api: {
success: '成功',
actionSuccess: '操作成功',

View File

@ -1,16 +0,0 @@
import { buildProviderQuery } from './_tools_util'
describe('makeProviderQuery', () => {
test('collectionName without special chars', () => {
expect(buildProviderQuery('ABC')).toBe('provider=ABC')
})
test('should escape &', () => {
expect(buildProviderQuery('ABC&DEF')).toBe('provider=ABC%26DEF')
})
test('should escape /', () => {
expect(buildProviderQuery('ABC/DEF')).toBe('provider=ABC%2FDEF')
})
test('should escape ?', () => {
expect(buildProviderQuery('ABC?DEF')).toBe('provider=ABC%3FDEF')
})
})

View File

@ -1,5 +0,0 @@
export const buildProviderQuery = (collectionName: string): string => {
const query = new URLSearchParams()
query.set('provider', collectionName)
return query.toString()
}

View File

@ -10,7 +10,6 @@ import type {
} from '@/app/components/tools/types'
import type { ToolWithProvider } from '@/app/components/workflow/types'
import type { Label } from '@/app/components/tools/labels/constant'
import { buildProviderQuery } from './_tools_util'
export const fetchCollectionList = () => {
return get<Collection[]>('/workspaces/current/tool-providers')
@ -25,13 +24,11 @@ export const fetchBuiltInToolList = (collectionName: string) => {
}
export const fetchCustomToolList = (collectionName: string) => {
const query = buildProviderQuery(collectionName)
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?${query}`)
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`)
}
export const fetchModelToolList = (collectionName: string) => {
const query = buildProviderQuery(collectionName)
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?${query}`)
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`)
}
export const fetchWorkflowToolList = (appID: string) => {
@ -68,8 +65,7 @@ export const parseParamsSchema = (schema: string) => {
}
export const fetchCustomCollection = (collectionName: string) => {
const query = buildProviderQuery(collectionName)
return get<CustomCollectionBackend>(`/workspaces/current/tool-provider/api/get?${query}`)
return get<CustomCollectionBackend>(`/workspaces/current/tool-provider/api/get?provider=${collectionName}`)
}
export const createCustomCollection = (collection: CustomCollectionBackend) => {

View File

@ -76,7 +76,7 @@ export const correctToolProvider = (provider: string) => {
if (provider.includes('/'))
return provider
if (['stepfun', 'jina', 'siliconflow', 'gitee_ai'].includes(provider))
if (['stepfun', 'jina', 'siliconflow'].includes(provider))
return `langgenius/${provider}_tool/${provider}`
return `langgenius/${provider}/${provider}`
@ -84,6 +84,6 @@ export const correctToolProvider = (provider: string) => {
export const canFindTool = (providerId: string, oldToolId?: string) => {
return providerId === oldToolId
|| providerId === `langgenius/${oldToolId}/${oldToolId}`
|| providerId === `langgenius/${oldToolId}_tool/${oldToolId}`
|| providerId === `langgenius/${oldToolId}/${oldToolId}`
|| providerId === `langgenius/${oldToolId}_tool/${oldToolId}`
}