Compare commits

...

2 Commits

Author SHA1 Message Date
jyong
6e6604d28c delete remove tsne position 2024-07-02 14:55:05 +08:00
jyong
e7b792f537 delete remove tsne position 2024-07-02 14:52:41 +08:00

View File

@ -45,17 +45,6 @@ class HitTestingService:
if not retrieval_model: if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
@ -67,6 +56,7 @@ class HitTestingService:
) )
end = time.perf_counter() end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
@ -80,20 +70,10 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) return cls.compact_retrieve_response(dataset, query, all_documents)
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]): def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
text_embeddings = [
embeddings.embed_query(query)
]
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
query_position = tsne_position_data.pop(0)
i = 0 i = 0
records = [] records = []
for document in documents: for document in documents:
@ -113,7 +93,6 @@ class HitTestingService:
record = { record = {
"segment": segment, "segment": segment,
"score": document.metadata.get('score', None), "score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
} }
records.append(record) records.append(record)
@ -123,7 +102,6 @@ class HitTestingService:
return { return {
"query": { "query": {
"content": query, "content": query,
"tsne_position": query_position,
}, },
"records": records "records": records
} }