diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ebc5d31e7e..554a0bc0f9 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -234,6 +234,33 @@ class DatasetApi(Resource): ) parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") + + parser.add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", + ) + + parser.add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", + ) + + parser.add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", + ) args = parser.parse_args() data = request.get_json() diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 4eb43af332..1bc7ffdf49 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -152,7 +152,9 @@ class ExternalApiUseCheckApi(Resource): def get(self, external_knowledge_api_id): external_knowledge_api_id = str(external_knowledge_api_id) - external_api_template_is_using, count = ExternalDatasetService.external_api_template_use_check(external_knowledge_api_id) + external_api_template_is_using, count = ExternalDatasetService.external_api_template_use_check( + external_knowledge_api_id + ) return {"is_using": external_api_template_is_using, "count": count}, 200 diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c23b52cf55..ae61ba7112 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -112,11 +112,7 @@ class DatasetRetrieval: continue # pass if dataset is not available - if ( - dataset - and dataset.available_document_count == 0 - and dataset.provider != "external" - ): + if dataset and dataset.available_document_count == 0 and dataset.provider != "external": continue available_datasets.append(dataset) diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 9cf8da7acd..7aee087d78 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -41,6 +41,13 @@ dataset_retrieval_model_fields = { tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} +external_knowledge_info_fields = { + "external_knowledge_id": fields.String, + "external_knowledge_api_id": fields.String, + "external_knowledge_api_name": fields.String, + "external_knowledge_api_endpoint": fields.String, +} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -61,6 +68,7 @@ dataset_detail_fields = { "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "external_knowledge_info": fields.Nested(external_knowledge_info_fields), } dataset_query_detail_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index f5e8be970c..c61d467956 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -171,6 +171,29 @@ class Dataset(db.Model): return tags or [] + @property + def external_knowledge_info(self): + if self.provider != "external": + return None + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() + ) + if not external_knowledge_binding: + return None + external_api_template = ( + db.session.query(ExternalApiTemplates) + .filter(ExternalApiTemplates.id == external_knowledge_binding.external_api_template_id) + .first() + ) + if not external_api_template: + return None + return { + "external_knowledge_id": external_knowledge_binding.external_knowledge_id, + "external_knowledge_api_id": external_api_template.id, + "external_knowledge_api_name": external_api_template.name, + "external_knowledge_api_endpoint": json.loads(external_api_template.settings).get("endpoint", ""), + } + @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") @@ -734,21 +757,23 @@ class ExternalApiTemplates(db.Model): return json.loads(self.settings) if self.settings else None except JSONDecodeError: return None - + @property def dataset_bindings(self): - external_knowledge_bindings = db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.external_api_template_id == self.id).all() + external_knowledge_bindings = ( + db.session.query(ExternalKnowledgeBindings) + .filter(ExternalKnowledgeBindings.external_api_template_id == self.id) + .all() + ) dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() dataset_bindings = [] for dataset in datasets: - dataset_bindings.append({ - "id": dataset.id, - "name": dataset.name - }) + dataset_bindings.append({"id": dataset.id, "name": dataset.name}) return dataset_bindings + class ExternalKnowledgeBindings(db.Model): __tablename__ = "external_knowledge_bindings" __table_args__ = ( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 699d2d64f4..a3a5d7b84f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -184,7 +184,7 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id): + def get_dataset(dataset_id) -> Dataset: return Dataset.query.filter_by(id=dataset_id).first() @staticmethod @@ -225,81 +225,103 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): - data.pop("partial_member_list", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} dataset = DatasetService.get_dataset(dataset_id) + DatasetService.check_dataset_permission(dataset, user) - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: + if dataset.provider == "external": + dataset.retrieval_model = data.get("external_retrieval_model", None) + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", "") + external_knowledge_id = data.get("external_knowledge_id", None) + db.session.add(dataset) + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first() if ( - data["embedding_model_provider"] != dataset.embedding_model_provider - or data["embedding_model"] != dataset.embedding_model + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id ): - action = "update" - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + db.session.commit() + else: + data.pop("partial_member_list", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + action = None + if dataset.indexing_technique != data["indexing_technique"]: + # if update indexing_technique + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] - dataset.query.filter_by(id=dataset_id).update(filtered_data) + dataset.query.filter_by(id=dataset_id).update(filtered_data) - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) + db.session.commit() + if action: + deal_dataset_vector_index_task.delay(dataset_id, action) return dataset @staticmethod