Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
f9bec1edf8
commit
a55ba6e614
@ -92,11 +92,14 @@ class DatasetListApi(Resource):
|
|||||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = marshal(datasets, dataset_detail_fields)
|
||||||
for item in data:
|
for item in data:
|
||||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
if item['indexing_technique'] == 'high_quality':
|
||||||
if item_model in model_names:
|
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||||
item['embedding_available'] = True
|
if item_model in model_names:
|
||||||
|
item['embedding_available'] = True
|
||||||
|
else:
|
||||||
|
item['embedding_available'] = False
|
||||||
else:
|
else:
|
||||||
item['embedding_available'] = False
|
item['embedding_available'] = True
|
||||||
response = {
|
response = {
|
||||||
'data': data,
|
'data': data,
|
||||||
'has_more': len(datasets) == limit,
|
'has_more': len(datasets) == limit,
|
||||||
@ -122,14 +125,6 @@ class DatasetListApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
try:
|
|
||||||
ModelFactory.get_embedding_model(
|
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
@ -167,6 +162,11 @@ class DatasetApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, dataset_id):
|
def patch(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('name', nullable=False,
|
parser.add_argument('name', nullable=False,
|
||||||
@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||||
|
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||||
@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||||
args['process_rule'], args['doc_form'],
|
args['process_rule'], args['doc_form'],
|
||||||
args['doc_language'], args['dataset_id'])
|
args['doc_language'], args['dataset_id'],
|
||||||
|
args['indexing_technique'])
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||||
args['info_list']['notion_info_list'],
|
args['info_list']['notion_info_list'],
|
||||||
args['process_rule'], args['doc_form'],
|
args['process_rule'], args['doc_form'],
|
||||||
args['doc_language'], args['dataset_id'])
|
args['doc_language'], args['dataset_id'],
|
||||||
|
args['indexing_technique'])
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
|
@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
|
|||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
|
|
||||||
# check embedding model setting
|
|
||||||
try:
|
|
||||||
ModelFactory.get_embedding_model(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
|
||||||
model_name=dataset.embedding_model
|
|
||||||
)
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ProviderNotInitializeError(ex.description)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
|
|||||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||||
location='json')
|
location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args['indexing_technique'] == 'high_quality':
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
ModelFactory.get_embedding_model(
|
||||||
tenant_id=current_user.current_tenant_id
|
tenant_id=current_user.current_tenant_id
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
f"in the Settings -> Model Provider.")
|
f"in the Settings -> Model Provider.")
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(args)
|
||||||
@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
|
|||||||
def delete(self, dataset_id, document_id):
|
def delete(self, dataset_id, document_id):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
|
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
|
|||||||
def patch(self, dataset_id, document_id, action):
|
def patch(self, dataset_id, document_id, action):
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
|
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound('Dataset not found.')
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
try:
|
||||||
ModelFactory.get_embedding_model(
|
ModelFactory.get_embedding_model(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_name=dataset.embedding_model
|
||||||
)
|
)
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
f"in the Settings -> Model Provider.")
|
f"in the Settings -> Model Provider.")
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id),
|
||||||
@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
try:
|
if dataset.indexing_technique == 'high_quality':
|
||||||
ModelFactory.get_embedding_model(
|
try:
|
||||||
tenant_id=current_user.current_tenant_id,
|
ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
except LLMBadRequestError:
|
)
|
||||||
raise ProviderNotInitializeError(
|
except LLMBadRequestError:
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
raise ProviderNotInitializeError(
|
||||||
f"in the Settings -> Model Provider.")
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
except ProviderTokenNotInitError as ex:
|
f"in the Settings -> Model Provider.")
|
||||||
raise ProviderNotInitializeError(ex.description)
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
try:
|
try:
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound('Dataset not found.')
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound('Document not found.')
|
||||||
# check embedding model setting
|
if dataset.indexing_technique == 'high_quality':
|
||||||
try:
|
# check embedding model setting
|
||||||
ModelFactory.get_embedding_model(
|
try:
|
||||||
tenant_id=current_user.current_tenant_id,
|
ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=current_user.current_tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
except LLMBadRequestError:
|
)
|
||||||
raise ProviderNotInitializeError(
|
except LLMBadRequestError:
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
raise ProviderNotInitializeError(
|
||||||
f"in the Settings -> Model Provider.")
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
except ProviderTokenNotInitError as ex:
|
f"in the Settings -> Model Provider.")
|
||||||
raise ProviderNotInitializeError(ex.description)
|
except ProviderTokenNotInitError as ex:
|
||||||
# check segment
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = DocumentSegment.query.filter(
|
||||||
DocumentSegment.id == str(segment_id),
|
DocumentSegment.id == str(segment_id),
|
||||||
@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound('Dataset not found.')
|
raise NotFound('Dataset not found.')
|
||||||
|
# check user's model setting
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
# check document
|
# check document
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found.')
|
raise NotFound('Document not found.')
|
||||||
try:
|
|
||||||
ModelFactory.get_embedding_model(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
|
||||||
model_name=dataset.embedding_model
|
|
||||||
)
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
f"No Embedding Model available. Please configure a valid provider "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ProviderNotInitializeError(ex.description)
|
|
||||||
# get file from request
|
# get file from request
|
||||||
file = request.files['file']
|
file = request.files['file']
|
||||||
# check file
|
# check file
|
||||||
|
@ -67,12 +67,13 @@ class DatesetDocumentStore:
|
|||||||
|
|
||||||
if max_position is None:
|
if max_position is None:
|
||||||
max_position = 0
|
max_position = 0
|
||||||
|
embedding_model = None
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if self._dataset.indexing_technique == 'high_quality':
|
||||||
tenant_id=self._dataset.tenant_id,
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
model_provider_name=self._dataset.embedding_model_provider,
|
tenant_id=self._dataset.tenant_id,
|
||||||
model_name=self._dataset.embedding_model
|
model_provider_name=self._dataset.embedding_model_provider,
|
||||||
)
|
model_name=self._dataset.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if not isinstance(doc, Document):
|
if not isinstance(doc, Document):
|
||||||
@ -88,7 +89,7 @@ class DatesetDocumentStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = embedding_model.get_num_tokens(doc.page_content)
|
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
|
||||||
|
|
||||||
if not segment_document:
|
if not segment_document:
|
||||||
max_position += 1
|
max_position += 1
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
from core.index.vector_index.vector_index import VectorIndex
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||||
|
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||||
|
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class IndexBuilder:
|
class IndexBuilder:
|
||||||
@ -36,3 +44,12 @@ class IndexBuilder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown indexing technique')
|
raise ValueError('Unknown indexing technique')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||||
|
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||||
|
return VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
@ -217,25 +217,29 @@ class IndexingRunner:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||||
|
indexing_technique: str = 'economy') -> dict:
|
||||||
"""
|
"""
|
||||||
Estimate the indexing for the document.
|
Estimate the indexing for the document.
|
||||||
"""
|
"""
|
||||||
|
embedding_model = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=dataset_id
|
id=dataset_id
|
||||||
).first()
|
).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset not found.')
|
raise ValueError('Dataset not found.')
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||||
tenant_id=dataset.tenant_id,
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=dataset.tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if indexing_technique == 'high_quality':
|
||||||
tenant_id=tenant_id
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
)
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
tokens = 0
|
tokens = 0
|
||||||
preview_texts = []
|
preview_texts = []
|
||||||
total_segments = 0
|
total_segments = 0
|
||||||
@ -263,8 +267,8 @@ class IndexingRunner:
|
|||||||
for document in documents:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(document.page_content)
|
preview_texts.append(document.page_content)
|
||||||
|
if indexing_technique == 'high_quality' or embedding_model:
|
||||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||||
|
|
||||||
if doc_form and doc_form == 'qa_model':
|
if doc_form and doc_form == 'qa_model':
|
||||||
text_generation_model = ModelFactory.get_text_generation_model(
|
text_generation_model = ModelFactory.get_text_generation_model(
|
||||||
@ -286,32 +290,35 @@ class IndexingRunner:
|
|||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||||
"currency": embedding_model.get_currency(),
|
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||||
|
indexing_technique: str = 'economy') -> dict:
|
||||||
"""
|
"""
|
||||||
Estimate the indexing for the document.
|
Estimate the indexing for the document.
|
||||||
"""
|
"""
|
||||||
|
embedding_model = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=dataset_id
|
id=dataset_id
|
||||||
).first()
|
).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError('Dataset not found.')
|
raise ValueError('Dataset not found.')
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||||
tenant_id=dataset.tenant_id,
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=dataset.tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if indexing_technique == 'high_quality':
|
||||||
tenant_id=tenant_id
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
)
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
# load data from notion
|
# load data from notion
|
||||||
tokens = 0
|
tokens = 0
|
||||||
preview_texts = []
|
preview_texts = []
|
||||||
@ -356,8 +363,8 @@ class IndexingRunner:
|
|||||||
for document in documents:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(document.page_content)
|
preview_texts.append(document.page_content)
|
||||||
|
if indexing_technique == 'high_quality' or embedding_model:
|
||||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||||
|
|
||||||
if doc_form and doc_form == 'qa_model':
|
if doc_form and doc_form == 'qa_model':
|
||||||
text_generation_model = ModelFactory.get_text_generation_model(
|
text_generation_model = ModelFactory.get_text_generation_model(
|
||||||
@ -379,8 +386,8 @@ class IndexingRunner:
|
|||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||||
"currency": embedding_model.get_currency(),
|
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -657,12 +664,13 @@ class IndexingRunner:
|
|||||||
"""
|
"""
|
||||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
embedding_model = None
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if dataset.indexing_technique == 'high_quality':
|
||||||
tenant_id=dataset.tenant_id,
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=dataset.tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
# chunk nodes by chunk size
|
# chunk nodes by chunk size
|
||||||
indexing_start_at = time.perf_counter()
|
indexing_start_at = time.perf_counter()
|
||||||
@ -672,11 +680,11 @@ class IndexingRunner:
|
|||||||
# check document is paused
|
# check document is paused
|
||||||
self._check_document_paused_status(dataset_document.id)
|
self._check_document_paused_status(dataset_document.id)
|
||||||
chunk_documents = documents[i:i + chunk_size]
|
chunk_documents = documents[i:i + chunk_size]
|
||||||
|
if dataset.indexing_technique == 'high_quality' or embedding_model:
|
||||||
tokens += sum(
|
tokens += sum(
|
||||||
embedding_model.get_num_tokens(document.page_content)
|
embedding_model.get_num_tokens(document.page_content)
|
||||||
for document in chunk_documents
|
for document in chunk_documents
|
||||||
)
|
)
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
if vector_index:
|
if vector_index:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
from events.event_handlers.document_index_event import document_index_created
|
from events.event_handlers.document_index_event import document_index_created
|
||||||
from tasks.clean_dataset_task import clean_dataset_task
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
"""update_dataset_model_field_null_available
|
||||||
|
|
||||||
|
Revision ID: 4bcffcd64aa4
|
||||||
|
Revises: 853f9b9cd3b6
|
||||||
|
Create Date: 2023-08-28 20:58:50.077056
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '4bcffcd64aa4'
|
||||||
|
down_revision = '853f9b9cd3b6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('embedding_model',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
nullable=True,
|
||||||
|
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||||
|
batch_op.alter_column('embedding_model_provider',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
nullable=True,
|
||||||
|
existing_server_default=sa.text("'openai'::character varying"))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('embedding_model_provider',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
nullable=False,
|
||||||
|
existing_server_default=sa.text("'openai'::character varying"))
|
||||||
|
batch_op.alter_column('embedding_model',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
nullable=False,
|
||||||
|
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -36,10 +36,8 @@ class Dataset(db.Model):
|
|||||||
updated_by = db.Column(UUID, nullable=True)
|
updated_by = db.Column(UUID, nullable=True)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False,
|
updated_at = db.Column(db.DateTime, nullable=False,
|
||||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
embedding_model = db.Column(db.String(
|
embedding_model = db.Column(db.String(255), nullable=True)
|
||||||
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
|
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||||
embedding_model_provider = db.Column(db.String(
|
|
||||||
255), nullable=False, server_default=db.text("'openai'::character varying"))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_keyword_table(self):
|
def dataset_keyword_table(self):
|
||||||
|
@ -10,6 +10,7 @@ from flask import current_app
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
from core.index.index import IndexBuilder
|
from core.index.index import IndexBuilder
|
||||||
|
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -91,16 +92,18 @@ class DatasetService:
|
|||||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||||
raise DatasetNameDuplicateError(
|
raise DatasetNameDuplicateError(
|
||||||
f'Dataset with name {name} already exists.')
|
f'Dataset with name {name} already exists.')
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = None
|
||||||
tenant_id=current_user.current_tenant_id
|
if indexing_technique == 'high_quality':
|
||||||
)
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=current_user.current_tenant_id
|
||||||
|
)
|
||||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||||
dataset.created_by = account.id
|
dataset.created_by = account.id
|
||||||
dataset.updated_by = account.id
|
dataset.updated_by = account.id
|
||||||
dataset.tenant_id = tenant_id
|
dataset.tenant_id = tenant_id
|
||||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
||||||
dataset.embedding_model = embedding_model.name
|
dataset.embedding_model = embedding_model.name if embedding_model else None
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return dataset
|
return dataset
|
||||||
@ -115,6 +118,23 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_dataset_model_setting(dataset):
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
try:
|
||||||
|
ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(f"The dataset in unavailable, due to: "
|
||||||
|
f"{ex.description}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_dataset(dataset_id, data, user):
|
def update_dataset(dataset_id, data, user):
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -124,6 +144,19 @@ class DatasetService:
|
|||||||
if data['indexing_technique'] == 'economy':
|
if data['indexing_technique'] == 'economy':
|
||||||
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
||||||
elif data['indexing_technique'] == 'high_quality':
|
elif data['indexing_technique'] == 'high_quality':
|
||||||
|
# check embedding model setting
|
||||||
|
try:
|
||||||
|
ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
f"No Embedding Model available. Please configure a valid provider "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(ex.description)
|
||||||
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
||||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||||
|
|
||||||
@ -397,23 +430,23 @@ class DocumentService:
|
|||||||
|
|
||||||
# check document limit
|
# check document limit
|
||||||
if current_app.config['EDITION'] == 'CLOUD':
|
if current_app.config['EDITION'] == 'CLOUD':
|
||||||
count = 0
|
if 'original_document_id' not in document_data or not document_data['original_document_id']:
|
||||||
if document_data["data_source"]["type"] == "upload_file":
|
count = 0
|
||||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
if document_data["data_source"]["type"] == "upload_file":
|
||||||
count = len(upload_file_list)
|
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||||
elif document_data["data_source"]["type"] == "notion_import":
|
count = len(upload_file_list)
|
||||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
elif document_data["data_source"]["type"] == "notion_import":
|
||||||
for notion_info in notion_info_list:
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||||
count = count + len(notion_info['pages'])
|
for notion_info in notion_info_list:
|
||||||
documents_count = DocumentService.get_tenant_documents_count()
|
count = count + len(notion_info['pages'])
|
||||||
total_count = documents_count + count
|
documents_count = DocumentService.get_tenant_documents_count()
|
||||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
total_count = documents_count + count
|
||||||
if total_count > tenant_document_count:
|
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
if total_count > tenant_document_count:
|
||||||
|
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||||
# if dataset is empty, update dataset data_source_type
|
# if dataset is empty, update dataset data_source_type
|
||||||
if not dataset.data_source_type:
|
if not dataset.data_source_type:
|
||||||
dataset.data_source_type = document_data["data_source"]["type"]
|
dataset.data_source_type = document_data["data_source"]["type"]
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
if not dataset.indexing_technique:
|
if not dataset.indexing_technique:
|
||||||
if 'indexing_technique' not in document_data \
|
if 'indexing_technique' not in document_data \
|
||||||
@ -421,6 +454,13 @@ class DocumentService:
|
|||||||
raise ValueError("Indexing technique is required")
|
raise ValueError("Indexing technique is required")
|
||||||
|
|
||||||
dataset.indexing_technique = document_data["indexing_technique"]
|
dataset.indexing_technique = document_data["indexing_technique"]
|
||||||
|
if document_data["indexing_technique"] == 'high_quality':
|
||||||
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=dataset.tenant_id
|
||||||
|
)
|
||||||
|
dataset.embedding_model = embedding_model.name
|
||||||
|
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||||
|
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||||
@ -466,11 +506,11 @@ class DocumentService:
|
|||||||
"upload_file_id": file_id,
|
"upload_file_id": file_id,
|
||||||
}
|
}
|
||||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||||
document_data["data_source"]["type"],
|
document_data["data_source"]["type"],
|
||||||
document_data["doc_form"],
|
document_data["doc_form"],
|
||||||
document_data["doc_language"],
|
document_data["doc_language"],
|
||||||
data_source_info, created_from, position,
|
data_source_info, created_from, position,
|
||||||
account, file_name, batch)
|
account, file_name, batch)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
document_ids.append(document.id)
|
document_ids.append(document.id)
|
||||||
@ -512,11 +552,11 @@ class DocumentService:
|
|||||||
"type": page['type']
|
"type": page['type']
|
||||||
}
|
}
|
||||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||||
document_data["data_source"]["type"],
|
document_data["data_source"]["type"],
|
||||||
document_data["doc_form"],
|
document_data["doc_form"],
|
||||||
document_data["doc_language"],
|
document_data["doc_language"],
|
||||||
data_source_info, created_from, position,
|
data_source_info, created_from, position,
|
||||||
account, page['page_name'], batch)
|
account, page['page_name'], batch)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
document_ids.append(document.id)
|
document_ids.append(document.id)
|
||||||
@ -536,9 +576,9 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||||
account: Account,
|
account: Account,
|
||||||
name: str, batch: str):
|
name: str, batch: str):
|
||||||
document = Document(
|
document = Document(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
@ -567,6 +607,7 @@ class DocumentService:
|
|||||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||||
created_from: str = 'web'):
|
created_from: str = 'web'):
|
||||||
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||||
if document.display_status != 'available':
|
if document.display_status != 'available':
|
||||||
raise ValueError("Document is not available")
|
raise ValueError("Document is not available")
|
||||||
@ -674,9 +715,11 @@ class DocumentService:
|
|||||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||||
if total_count > tenant_document_count:
|
if total_count > tenant_document_count:
|
||||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = None
|
||||||
tenant_id=tenant_id
|
if document_data['indexing_technique'] == 'high_quality':
|
||||||
)
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
# save dataset
|
# save dataset
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -684,8 +727,8 @@ class DocumentService:
|
|||||||
data_source_type=document_data["data_source"]["type"],
|
data_source_type=document_data["data_source"]["type"],
|
||||||
indexing_technique=document_data["indexing_technique"],
|
indexing_technique=document_data["indexing_technique"],
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
embedding_model=embedding_model.name,
|
embedding_model=embedding_model.name if embedding_model else None,
|
||||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
@ -903,15 +946,15 @@ class SegmentService:
|
|||||||
content = args['content']
|
content = args['content']
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
segment_hash = helper.generate_text_hash(content)
|
segment_hash = helper.generate_text_hash(content)
|
||||||
|
tokens = 0
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
if dataset.indexing_technique == 'high_quality':
|
||||||
tenant_id=dataset.tenant_id,
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
tenant_id=dataset.tenant_id,
|
||||||
model_name=dataset.embedding_model
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
)
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = embedding_model.get_num_tokens(content)
|
tokens = embedding_model.get_num_tokens(content)
|
||||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||||
DocumentSegment.document_id == document.id
|
DocumentSegment.document_id == document.id
|
||||||
).scalar()
|
).scalar()
|
||||||
@ -973,15 +1016,16 @@ class SegmentService:
|
|||||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||||
else:
|
else:
|
||||||
segment_hash = helper.generate_text_hash(content)
|
segment_hash = helper.generate_text_hash(content)
|
||||||
|
tokens = 0
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
# calc embedding use tokens
|
||||||
tenant_id=dataset.tenant_id,
|
tokens = embedding_model.get_num_tokens(content)
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
|
||||||
model_name=dataset.embedding_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# calc embedding use tokens
|
|
||||||
tokens = embedding_model.get_num_tokens(content)
|
|
||||||
segment.content = content
|
segment.content = content
|
||||||
segment.index_node_hash = segment_hash
|
segment.index_node_hash = segment_hash
|
||||||
segment.word_count = len(content)
|
segment.word_count = len(content)
|
||||||
|
@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
|
|||||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||||
raise ValueError('Document is not available.')
|
raise ValueError('Document is not available.')
|
||||||
document_segments = []
|
document_segments = []
|
||||||
for segment in content:
|
embedding_model = None
|
||||||
content = segment['content']
|
if dataset.indexing_technique == 'high_quality':
|
||||||
doc_id = str(uuid.uuid4())
|
|
||||||
segment_hash = helper.generate_text_hash(content)
|
|
||||||
embedding_model = ModelFactory.get_embedding_model(
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider_name=dataset.embedding_model_provider,
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
model_name=dataset.embedding_model
|
model_name=dataset.embedding_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for segment in content:
|
||||||
|
content = segment['content']
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
segment_hash = helper.generate_text_hash(content)
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = embedding_model.get_num_tokens(content)
|
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
|
||||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||||
DocumentSegment.document_id == dataset_document.id
|
DocumentSegment.document_id == dataset_document.id
|
||||||
).scalar()
|
).scalar()
|
||||||
|
@ -3,8 +3,10 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
from core.index.index import IndexBuilder
|
from core.index.index import IndexBuilder
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
||||||
AppDatasetJoin, Document
|
AppDatasetJoin, Document
|
||||||
@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
|||||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||||
|
|
||||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
|
||||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
if vector_index:
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||||
try:
|
try:
|
||||||
vector_index.delete()
|
vector_index.delete()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
raise Exception('Dataset not found')
|
raise Exception('Dataset not found')
|
||||||
|
|
||||||
if action == "remove":
|
if action == "remove":
|
||||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||||
index.delete()
|
index.delete()
|
||||||
elif action == "add":
|
elif action == "add":
|
||||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||||
@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
|
|
||||||
if dataset_documents:
|
if dataset_documents:
|
||||||
# save vector index
|
# save vector index
|
||||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||||
documents = []
|
documents = []
|
||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
|
Loading…
Reference in New Issue
Block a user