diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index afe0ca7c69..3b3f199a0a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -349,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document = self.get_document(dataset_id, document_id) if document.indexing_status in ['completed', 'error']: - raise DocumentAlreadyFinishedError() + indexing_runner.calculate_tokens(document) data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 83dbacbfcc..fa904799a3 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -214,6 +214,61 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() + def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None, + indexing_technique: str = 'economy') -> dict: + """ + Estimate the indexing for the document. + """ + embedding_model_instance = None + if dataset_id: + dataset = Dataset.query.filter_by( + id=dataset_id + ).first() + if not dataset: + raise ValueError('Dataset not found.') + if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + if indexing_technique == 'high_quality': + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + preview_texts = [] + total_segments = 0 + total_price = 0 + currency = 'USD' + if embedding_model_instance: + embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance) + embedding_price_info = embedding_model_type_instance.get_price( + model=embedding_model_instance.model, + credentials=embedding_model_instance.credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + total_price = '{:f}'.format(embedding_price_info.total_amount) + currency = embedding_price_info.currency + return { + "total_segments": total_segments, + "tokens": tokens, + "total_price": total_price, + "currency": currency, + "preview": preview_texts + } + + + def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, indexing_technique: str = 'economy') -> dict: