diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 36e2d6cd2d..6884e7fe80 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -92,11 +92,14 @@ class DatasetListApi(Resource): model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") data = marshal(datasets, dataset_detail_fields) for item in data: - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" - if item_model in model_names: - item['embedding_available'] = True + if item['indexing_technique'] == 'high_quality': + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if item_model in model_names: + item['embedding_available'] = True + else: + item['embedding_available'] = False else: - item['embedding_available'] = False + item['embedding_available'] = True response = { 'data': data, 'has_more': len(datasets) == limit, @@ -122,14 +125,6 @@ class DatasetListApi(Resource): # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") try: dataset = DatasetService.create_empty_dataset( @@ -167,6 +162,11 @@ class DatasetApi(Resource): @account_initialization_required def patch(self, dataset_id): dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) parser = reqparse.RequestParser() parser.add_argument('name', nullable=False, @@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource): parser = reqparse.RequestParser() parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') @@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource): try: response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id']) + args['doc_language'], args['dataset_id'], + args['indexing_technique']) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " @@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource): response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['info_list']['notion_info_list'], args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id']) + args['doc_language'], args['dataset_id'], + args['indexing_technique']) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2a1d3a681..89d8cd6a6f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource): # validate args DocumentService.document_create_args_validate(args) - # check embedding model setting - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) except ProviderTokenNotInitError as ex: @@ -339,15 +325,17 @@ class DatasetInitApi(Resource): parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() - - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + if args['indexing_technique'] == 'high_quality': + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) # validate args DocumentService.document_create_args_validate(args) @@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource): def delete(self, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + document = self.get_document(dataset_id, document_id) try: @@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource): def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + document = self.get_document(dataset_id, document_id) # The role of the current user in the ta table must be admin or owner diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index e31ea030e1..7dac492d1d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') - + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() @@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - - # check embedding model setting - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + if dataset.indexing_technique == 'high_quality': + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( DocumentSegment.id == str(segment_id), @@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource): if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() # check embedding model setting - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + if dataset.indexing_technique == 'high_quality': + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: @@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound('Document not found.') - # check embedding model setting - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - # check segment + if dataset.indexing_technique == 'high_quality': + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( DocumentSegment.id == str(segment_id), @@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) @@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound('Document not found.') - try: - ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - except LLMBadRequestError: - raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) # get file from request file = request.files['file'] # check file diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index ca5f172ece..563a8c1922 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -67,12 +67,13 @@ class DatesetDocumentStore: if max_position is None: max_position = 0 - - embedding_model = ModelFactory.get_embedding_model( - tenant_id=self._dataset.tenant_id, - model_provider_name=self._dataset.embedding_model_provider, - model_name=self._dataset.embedding_model - ) + embedding_model = None + if self._dataset.indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=self._dataset.tenant_id, + model_provider_name=self._dataset.embedding_model_provider, + model_name=self._dataset.embedding_model + ) for doc in docs: if not isinstance(doc, Document): @@ -88,7 +89,7 @@ class DatesetDocumentStore: ) # calc embedding use tokens - tokens = embedding_model.get_num_tokens(doc.page_content) + tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0 if not segment_document: max_position += 1 diff --git a/api/core/index/index.py b/api/core/index/index.py index 26b6a84dfe..be5ca31510 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -1,10 +1,18 @@ +import json + from flask import current_app +from langchain.embeddings import OpenAIEmbeddings from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.openai_model import OpenAIModel +from core.model_providers.providers.openai_provider import OpenAIProvider from models.dataset import Dataset +from models.provider import Provider, ProviderType class IndexBuilder: @@ -35,4 +43,13 @@ class IndexBuilder: ) ) else: - raise ValueError('Unknown indexing technique') \ No newline at end of file + raise ValueError('Unknown indexing technique') + + @classmethod + def get_default_high_quality_index(cls, dataset: Dataset): + embeddings = OpenAIEmbeddings(openai_api_key=' ') + return VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings + ) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 8892f4508e..ea4ce31db0 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -217,25 +217,29 @@ class IndexingRunner: db.session.commit() def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, + indexing_technique: str = 'economy') -> dict: """ Estimate the indexing for the document. """ + embedding_model = None if dataset_id: dataset = Dataset.query.filter_by( id=dataset_id ).first() if not dataset: raise ValueError('Dataset not found.') - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) + if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) else: - embedding_model = ModelFactory.get_embedding_model( - tenant_id=tenant_id - ) + if indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) tokens = 0 preview_texts = [] total_segments = 0 @@ -263,8 +267,8 @@ class IndexingRunner: for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - - tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) + if indexing_technique == 'high_quality' or embedding_model: + tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) if doc_form and doc_form == 'qa_model': text_generation_model = ModelFactory.get_text_generation_model( @@ -286,32 +290,35 @@ class IndexingRunner: return { "total_segments": total_segments, "tokens": tokens, - "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), - "currency": embedding_model.get_currency(), + "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, + "currency": embedding_model.get_currency() if embedding_model else 'USD', "preview": preview_texts } def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, + indexing_technique: str = 'economy') -> dict: """ Estimate the indexing for the document. """ + embedding_model = None if dataset_id: dataset = Dataset.query.filter_by( id=dataset_id ).first() if not dataset: raise ValueError('Dataset not found.') - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) + if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) else: - embedding_model = ModelFactory.get_embedding_model( - tenant_id=tenant_id - ) - + if indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) # load data from notion tokens = 0 preview_texts = [] @@ -356,8 +363,8 @@ class IndexingRunner: for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - - tokens += embedding_model.get_num_tokens(document.page_content) + if indexing_technique == 'high_quality' or embedding_model: + tokens += embedding_model.get_num_tokens(document.page_content) if doc_form and doc_form == 'qa_model': text_generation_model = ModelFactory.get_text_generation_model( @@ -379,8 +386,8 @@ class IndexingRunner: return { "total_segments": total_segments, "tokens": tokens, - "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), - "currency": embedding_model.get_currency(), + "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, + "currency": embedding_model.get_currency() if embedding_model else 'USD', "preview": preview_texts } @@ -657,12 +664,13 @@ class IndexingRunner: """ vector_index = IndexBuilder.get_index(dataset, 'high_quality') keyword_table_index = IndexBuilder.get_index(dataset, 'economy') - - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) + embedding_model = None + if dataset.indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) # chunk nodes by chunk size indexing_start_at = time.perf_counter() @@ -672,11 +680,11 @@ class IndexingRunner: # check document is paused self._check_document_paused_status(dataset_document.id) chunk_documents = documents[i:i + chunk_size] - - tokens += sum( - embedding_model.get_num_tokens(document.page_content) - for document in chunk_documents - ) + if dataset.indexing_technique == 'high_quality' or embedding_model: + tokens += sum( + embedding_model.get_num_tokens(document.page_content) + for document in chunk_documents + ) # save vector index if vector_index: diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 081bc593a6..a5f5c4d8f4 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,6 +1,5 @@ from events.dataset_event import dataset_was_deleted from events.event_handlers.document_index_event import document_index_created -from tasks.clean_dataset_task import clean_dataset_task import datetime import logging import time diff --git a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py new file mode 100644 index 0000000000..0753c26d1c --- /dev/null +++ b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py @@ -0,0 +1,46 @@ +"""update_dataset_model_field_null_available + +Revision ID: 4bcffcd64aa4 +Revises: 853f9b9cd3b6 +Create Date: 2023-08-28 20:58:50.077056 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4bcffcd64aa4' +down_revision = '853f9b9cd3b6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.alter_column('embedding_model', + existing_type=sa.VARCHAR(length=255), + nullable=True, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + batch_op.alter_column('embedding_model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True, + existing_server_default=sa.text("'openai'::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.alter_column('embedding_model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False, + existing_server_default=sa.text("'openai'::character varying")) + batch_op.alter_column('embedding_model', + existing_type=sa.VARCHAR(length=255), + nullable=False, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 6f7891a163..338eb173cf 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -36,10 +36,8 @@ class Dataset(db.Model): updated_by = db.Column(UUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - embedding_model = db.Column(db.String( - 255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) - embedding_model_provider = db.Column(db.String( - 255), nullable=False, server_default=db.text("'openai'::character varying")) + embedding_model = db.Column(db.String(255), nullable=True) + embedding_model_provider = db.Column(db.String(255), nullable=True) @property def dataset_keyword_table(self): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bfdd1a452f..d4bc8d833a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,7 @@ from flask import current_app from sqlalchemy import func from core.index.index import IndexBuilder +from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from extensions.ext_redis import redis_client from flask_login import current_user @@ -91,16 +92,18 @@ class DatasetService: if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError( f'Dataset with name {name} already exists.') - embedding_model = ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id - ) + embedding_model = None + if indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.model_provider.provider_name - dataset.embedding_model = embedding_model.name + dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None + dataset.embedding_model = embedding_model.name if embedding_model else None db.session.add(dataset) db.session.commit() return dataset @@ -115,6 +118,23 @@ class DatasetService: else: return dataset + @staticmethod + def check_dataset_model_setting(dataset): + if dataset.indexing_technique == 'high_quality': + try: + ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ValueError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ValueError(f"The dataset in unavailable, due to: " + f"{ex.description}") + @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) @@ -124,6 +144,19 @@ class DatasetService: if data['indexing_technique'] == 'economy': deal_dataset_vector_index_task.delay(dataset_id, 'remove') elif data['indexing_technique'] == 'high_quality': + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ValueError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) deal_dataset_vector_index_task.delay(dataset_id, 'add') filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} @@ -397,23 +430,23 @@ class DocumentService: # check document limit if current_app.config['EDITION'] == 'CLOUD': - count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] - count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] - for notion_info in notion_info_list: - count = count + len(notion_info['pages']) - documents_count = DocumentService.get_tenant_documents_count() - total_count = documents_count + count - tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if total_count > tenant_document_count: - raise ValueError(f"over document limit {tenant_document_count}.") + if 'original_document_id' not in document_data or not document_data['original_document_id']: + count = 0 + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + count = len(upload_file_list) + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + for notion_info in notion_info_list: + count = count + len(notion_info['pages']) + documents_count = DocumentService.get_tenant_documents_count() + total_count = documents_count + count + tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) + if total_count > tenant_document_count: + raise ValueError(f"over document limit {tenant_document_count}.") # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: dataset.data_source_type = document_data["data_source"]["type"] - db.session.commit() if not dataset.indexing_technique: if 'indexing_technique' not in document_data \ @@ -421,6 +454,13 @@ class DocumentService: raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] + if document_data["indexing_technique"] == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id + ) + dataset.embedding_model = embedding_model.name + dataset.embedding_model_provider = embedding_model.model_provider.provider_name + documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) @@ -466,11 +506,11 @@ class DocumentService: "upload_file_id": file_id, } document = DocumentService.build_document(dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch) + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, created_from, position, + account, file_name, batch) db.session.add(document) db.session.flush() document_ids.append(document.id) @@ -512,11 +552,11 @@ class DocumentService: "type": page['type'] } document = DocumentService.build_document(dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch) + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, created_from, position, + account, page['page_name'], batch) db.session.add(document) db.session.flush() document_ids.append(document.id) @@ -536,9 +576,9 @@ class DocumentService: @staticmethod def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, - account: Account, - name: str, batch: str): + document_language: str, data_source_info: dict, created_from: str, position: int, + account: Account, + name: str, batch: str): document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -567,6 +607,7 @@ class DocumentService: def update_document_with_dataset_id(dataset: Dataset, document_data: dict, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = 'web'): + DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) if document.display_status != 'available': raise ValueError("Document is not available") @@ -674,9 +715,11 @@ class DocumentService: tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) if total_count > tenant_document_count: raise ValueError(f"All your documents have overed limit {tenant_document_count}.") - embedding_model = ModelFactory.get_embedding_model( - tenant_id=tenant_id - ) + embedding_model = None + if document_data['indexing_technique'] == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) # save dataset dataset = Dataset( tenant_id=tenant_id, @@ -684,8 +727,8 @@ class DocumentService: data_source_type=document_data["data_source"]["type"], indexing_technique=document_data["indexing_technique"], created_by=account.id, - embedding_model=embedding_model.name, - embedding_model_provider=embedding_model.model_provider.provider_name + embedding_model=embedding_model.name if embedding_model else None, + embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None ) db.session.add(dataset) @@ -903,15 +946,15 @@ class SegmentService: content = args['content'] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) - - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - - # calc embedding use tokens - tokens = embedding_model.get_num_tokens(content) + tokens = 0 + if dataset.indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + # calc embedding use tokens + tokens = embedding_model.get_num_tokens(content) max_position = db.session.query(func.max(DocumentSegment.position)).filter( DocumentSegment.document_id == document.id ).scalar() @@ -973,15 +1016,16 @@ class SegmentService: kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) else: segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model - ) - - # calc embedding use tokens - tokens = embedding_model.get_num_tokens(content) + # calc embedding use tokens + tokens = embedding_model.get_num_tokens(content) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) @@ -1013,7 +1057,7 @@ class SegmentService: cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") - + # enabled segment need to delete index if segment.enabled: # send delete segment index task diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 86421a0115..864d7a8044 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': raise ValueError('Document is not available.') document_segments = [] - for segment in content: - content = segment['content'] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + embedding_model = None + if dataset.indexing_technique == 'high_quality': embedding_model = ModelFactory.get_embedding_model( tenant_id=dataset.tenant_id, model_provider_name=dataset.embedding_model_provider, model_name=dataset.embedding_model ) + for segment in content: + content = segment['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = embedding_model.get_num_tokens(content) + tokens = embedding_model.get_num_tokens(content) if embedding_model else 0 max_position = db.session.query(func.max(DocumentSegment.position)).filter( DocumentSegment.document_id == dataset_document.id ).scalar() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index bd40d20c4e..dea9059b00 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,8 +3,10 @@ import time import click from celery import shared_task +from flask import current_app from core.index.index import IndexBuilder +from core.index.vector_index.vector_index import VectorIndex from extensions.ext_database import db from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ AppDatasetJoin, Document @@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() - vector_index = IndexBuilder.get_index(dataset, 'high_quality') kw_index = IndexBuilder.get_index(dataset, 'economy') # delete from vector index - if vector_index: + if dataset.indexing_technique == 'high_quality': + vector_index = IndexBuilder.get_default_high_quality_index(dataset) try: vector_index.delete() except Exception: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index c92b353097..96d1dc9096 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): raise Exception('Dataset not found') if action == "remove": - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) + index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) index.delete() elif action == "add": dataset_documents = db.session.query(DatasetDocument).filter( @@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) + index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) documents = [] for dataset_document in dataset_documents: # delete from vector index