add redis lock on create collection in multiple thread mode (#3054)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2024-04-01 02:10:41 +08:00 committed by GitHub
parent 1716ac562c
commit 84d118de07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 128 additions and 105 deletions

View File

@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
@ -121,6 +122,8 @@ class Jieba(BaseKeyword):
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=20):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:

View File

@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -61,16 +62,6 @@ class MilvusVector(BaseVector):
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@ -187,7 +178,22 @@ class MilvusVector(BaseVector):
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
@ -225,8 +231,7 @@ class MilvusVector(BaseVector):
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)

View File

@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -77,6 +78,17 @@ class QdrantVector(BaseVector):
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
@ -84,12 +96,6 @@ class QdrantVector(BaseVector):
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
@ -118,6 +124,7 @@ class QdrantVector(BaseVector):
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
@ -79,15 +80,22 @@ class WeaviateVector(BaseVector):
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
self._create_collection()
# create vector
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
# create vector
self.add_texts(texts, embeddings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)