Compare commits
2 Commits
main
...
fix/remove
Author | SHA1 | Date | |
---|---|---|---|
![]() |
6e6604d28c | ||
![]() |
e7b792f537 |
@ -45,17 +45,6 @@ class HitTestingService:
|
|||||||
if not retrieval_model:
|
if not retrieval_model:
|
||||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
# get embedding model
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model=dataset.embedding_model
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = CacheEmbedding(embedding_model)
|
|
||||||
|
|
||||||
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
@ -67,6 +56,7 @@ class HitTestingService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||||
|
|
||||||
dataset_query = DatasetQuery(
|
dataset_query = DatasetQuery(
|
||||||
@ -80,20 +70,10 @@ class HitTestingService:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
|
return cls.compact_retrieve_response(dataset, query, all_documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
|
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
|
||||||
text_embeddings = [
|
|
||||||
embeddings.embed_query(query)
|
|
||||||
]
|
|
||||||
|
|
||||||
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
|
|
||||||
|
|
||||||
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
|
|
||||||
|
|
||||||
query_position = tsne_position_data.pop(0)
|
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
records = []
|
records = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -113,7 +93,6 @@ class HitTestingService:
|
|||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get('score', None),
|
"score": document.metadata.get('score', None),
|
||||||
"tsne_position": tsne_position_data[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
@ -123,7 +102,6 @@ class HitTestingService:
|
|||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
"content": query,
|
"content": query,
|
||||||
"tsne_position": query_position,
|
|
||||||
},
|
},
|
||||||
"records": records
|
"records": records
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user