diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 42fad111ce..2f136f6f87 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,5 +1,7 @@ import base64 import logging +import random +import time from typing import Any, Optional, cast import numpy as np @@ -19,11 +21,49 @@ from models.dataset import Embedding logger = logging.getLogger(__name__) +def retry(max_retries: int, base_delay: int, max_wait_time: int = 16): + """ + A retry decorator that uses an exponential backoff algorithm. + :param max_retries: The maximum number of retries. + :param base_delay: The base delay time in seconds. + :param max_wait_time: The maximum wait time in seconds. + :return: The decorated function. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + retries = 0 + while retries <= max_retries: + try: + return func(*args, **kwargs) + except Exception as e: + if retries == max_retries: + raise e + logger.warning(f"Attempt {retries + 1} failed: {e}") + retries += 1 + delay = base_delay * 2**retries + if delay > max_wait_time: + delay = max_wait_time + delay += random.uniform(0, 1) + logger.info(f"Retrying in {delay:.2f} seconds...") + time.sleep(delay) + + return wrapper + + return decorator + + class CacheEmbedding(Embeddings): def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: self._model_instance = model_instance self._user = user + @retry(max_retries=10, base_delay=1, max_wait_time=16) + def text_embedding(self, texts: list[str]): + return self._model_instance.invoke_text_embedding( + texts=texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + ) + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists @@ -58,9 +98,7 @@ class CacheEmbedding(Embeddings): for i in range(0, len(embedding_queue_texts), max_chunks): batch_texts = embedding_queue_texts[i : i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT - ) + embedding_result = self.text_embedding(batch_texts) for vector in embedding_result.embeddings: try: