Merge d88624755a
into a30945312a
This commit is contained in:
commit
c22c544e7f
@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
@ -19,11 +21,49 @@ from models.dataset import Embedding
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def retry(max_retries: int, base_delay: int, max_wait_time: int = 16):
|
||||
"""
|
||||
A retry decorator that uses an exponential backoff algorithm.
|
||||
:param max_retries: The maximum number of retries.
|
||||
:param base_delay: The base delay time in seconds.
|
||||
:param max_wait_time: The maximum wait time in seconds.
|
||||
:return: The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if retries == max_retries:
|
||||
raise e
|
||||
logger.warning(f"Attempt {retries + 1} failed: {e}")
|
||||
retries += 1
|
||||
delay = base_delay * 2**retries
|
||||
if delay > max_wait_time:
|
||||
delay = max_wait_time
|
||||
delay += random.uniform(0, 1)
|
||||
logger.info(f"Retrying in {delay:.2f} seconds...")
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
@retry(max_retries=10, base_delay=1, max_wait_time=16)
|
||||
def text_embedding(self, texts: list[str]):
|
||||
return self._model_instance.invoke_text_embedding(
|
||||
texts=texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -58,9 +98,7 @@ class CacheEmbedding(Embeddings):
|
||||
for i in range(0, len(embedding_queue_texts), max_chunks):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
embedding_result = self.text_embedding(batch_texts)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user