add redis lock on create collection in multiple thread mode (#3054)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
1716ac562c
commit
84d118de07
@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
|
|||||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||||
|
|
||||||
|
|
||||||
@ -121,26 +122,28 @@ class Jieba(BaseKeyword):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||||
if dataset_keyword_table:
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
if dataset_keyword_table.keyword_table_dict:
|
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||||
return dataset_keyword_table.keyword_table_dict['__data__']['table']
|
if dataset_keyword_table:
|
||||||
else:
|
if dataset_keyword_table.keyword_table_dict:
|
||||||
dataset_keyword_table = DatasetKeywordTable(
|
return dataset_keyword_table.keyword_table_dict['__data__']['table']
|
||||||
dataset_id=self.dataset.id,
|
else:
|
||||||
keyword_table=json.dumps({
|
dataset_keyword_table = DatasetKeywordTable(
|
||||||
'__type__': 'keyword_table',
|
dataset_id=self.dataset.id,
|
||||||
'__data__': {
|
keyword_table=json.dumps({
|
||||||
"index_id": self.dataset.id,
|
'__type__': 'keyword_table',
|
||||||
"summary": None,
|
'__data__': {
|
||||||
"table": {}
|
"index_id": self.dataset.id,
|
||||||
}
|
"summary": None,
|
||||||
}, cls=SetEncoder)
|
"table": {}
|
||||||
)
|
}
|
||||||
db.session.add(dataset_keyword_table)
|
}, cls=SetEncoder)
|
||||||
db.session.commit()
|
)
|
||||||
|
db.session.add(dataset_keyword_table)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
|
@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections
|
|||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -61,17 +62,7 @@ class MilvusVector(BaseVector):
|
|||||||
'params': {"M": 8, "efConstruction": 64}
|
'params': {"M": 8, "efConstruction": 64}
|
||||||
}
|
}
|
||||||
metadatas = [d.metadata for d in texts]
|
metadatas = [d.metadata for d in texts]
|
||||||
|
self.create_collection(embeddings, metadatas, index_params)
|
||||||
# Grab the existing collection if it exists
|
|
||||||
from pymilvus import utility
|
|
||||||
alias = uuid4().hex
|
|
||||||
if self._client_config.secure:
|
|
||||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
||||||
else:
|
|
||||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
|
||||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
|
||||||
if not utility.has_collection(self._collection_name, using=alias):
|
|
||||||
self.create_collection(embeddings, metadatas, index_params)
|
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
@ -187,46 +178,60 @@ class MilvusVector(BaseVector):
|
|||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||||
) -> str:
|
):
|
||||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||||
from pymilvus.orm.types import infer_dtype_bydata
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
# Grab the existing collection if it exists
|
||||||
|
from pymilvus import utility
|
||||||
|
alias = uuid4().hex
|
||||||
|
if self._client_config.secure:
|
||||||
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||||
|
else:
|
||||||
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||||
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
||||||
|
password=self._client_config.password)
|
||||||
|
if not utility.has_collection(self._collection_name, using=alias):
|
||||||
|
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||||
|
from pymilvus.orm.types import infer_dtype_bydata
|
||||||
|
|
||||||
# Determine embedding dim
|
# Determine embedding dim
|
||||||
dim = len(embeddings[0])
|
dim = len(embeddings[0])
|
||||||
fields = []
|
fields = []
|
||||||
if metadatas:
|
if metadatas:
|
||||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||||
|
|
||||||
# Create the text field
|
# Create the text field
|
||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||||
)
|
)
|
||||||
# Create the primary key field
|
# Create the primary key field
|
||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(
|
FieldSchema(
|
||||||
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Create the vector field, supports binary or float vectors
|
# Create the vector field, supports binary or float vectors
|
||||||
fields.append(
|
fields.append(
|
||||||
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the schema for the collection
|
# Create the schema for the collection
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
|
|
||||||
for x in schema.fields:
|
for x in schema.fields:
|
||||||
self._fields.append(x.name)
|
self._fields.append(x.name)
|
||||||
# Since primary field is auto-id, no need to track it
|
# Since primary field is auto-id, no need to track it
|
||||||
self._fields.remove(Field.PRIMARY_KEY.value)
|
self._fields.remove(Field.PRIMARY_KEY.value)
|
||||||
|
|
||||||
# Create the collection
|
|
||||||
collection_name = self._collection_name
|
|
||||||
self._client.create_collection_with_schema(collection_name=collection_name,
|
|
||||||
schema=schema, index_param=index_params,
|
|
||||||
consistency_level=self._consistency_level)
|
|
||||||
return collection_name
|
|
||||||
|
|
||||||
|
# Create the collection
|
||||||
|
collection_name = self._collection_name
|
||||||
|
self._client.create_collection_with_schema(collection_name=collection_name,
|
||||||
|
schema=schema, index_param=index_params,
|
||||||
|
consistency_level=self._consistency_level)
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
def _init_client(self, config) -> MilvusClient:
|
def _init_client(self, config) -> MilvusClient:
|
||||||
if config.secure:
|
if config.secure:
|
||||||
uri = "https://" + str(config.host) + ":" + str(config.port)
|
uri = "https://" + str(config.host) + ":" + str(config.port)
|
||||||
|
@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal
|
|||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from qdrant_client import grpc # noqa
|
from qdrant_client import grpc # noqa
|
||||||
@ -77,6 +78,17 @@ class QdrantVector(BaseVector):
|
|||||||
vector_size = len(embeddings[0])
|
vector_size = len(embeddings[0])
|
||||||
# get collection name
|
# get collection name
|
||||||
collection_name = self._collection_name
|
collection_name = self._collection_name
|
||||||
|
# create collection
|
||||||
|
self.create_collection(collection_name, vector_size)
|
||||||
|
|
||||||
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
|
def create_collection(self, collection_name: str, vector_size: int):
|
||||||
|
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
collection_name = collection_name or uuid.uuid4().hex
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
all_collection_name = []
|
all_collection_name = []
|
||||||
collections_response = self._client.get_collections()
|
collections_response = self._client.get_collections()
|
||||||
@ -84,40 +96,35 @@ class QdrantVector(BaseVector):
|
|||||||
for collection in collection_list:
|
for collection in collection_list:
|
||||||
all_collection_name.append(collection.name)
|
all_collection_name.append(collection.name)
|
||||||
if collection_name not in all_collection_name:
|
if collection_name not in all_collection_name:
|
||||||
# create collection
|
from qdrant_client.http import models as rest
|
||||||
self.create_collection(collection_name, vector_size)
|
vectors_config = rest.VectorParams(
|
||||||
|
size=vector_size,
|
||||||
|
distance=rest.Distance[self._distance_func],
|
||||||
|
)
|
||||||
|
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||||
|
max_indexing_threads=0, on_disk=False)
|
||||||
|
self._client.recreate_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=vectors_config,
|
||||||
|
hnsw_config=hnsw_config,
|
||||||
|
timeout=int(self._client_config.timeout),
|
||||||
|
)
|
||||||
|
|
||||||
self.add_texts(texts, embeddings, **kwargs)
|
# create payload index
|
||||||
|
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||||
def create_collection(self, collection_name: str, vector_size: int):
|
field_schema=PayloadSchemaType.KEYWORD,
|
||||||
from qdrant_client.http import models as rest
|
field_type=PayloadSchemaType.KEYWORD)
|
||||||
vectors_config = rest.VectorParams(
|
# creat full text index
|
||||||
size=vector_size,
|
text_index_params = TextIndexParams(
|
||||||
distance=rest.Distance[self._distance_func],
|
type=TextIndexType.TEXT,
|
||||||
)
|
tokenizer=TokenizerType.MULTILINGUAL,
|
||||||
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
min_token_len=2,
|
||||||
max_indexing_threads=0, on_disk=False)
|
max_token_len=20,
|
||||||
self._client.recreate_collection(
|
lowercase=True
|
||||||
collection_name=collection_name,
|
)
|
||||||
vectors_config=vectors_config,
|
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
||||||
hnsw_config=hnsw_config,
|
field_schema=text_index_params)
|
||||||
timeout=int(self._client_config.timeout),
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
)
|
|
||||||
|
|
||||||
# create payload index
|
|
||||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
|
||||||
field_schema=PayloadSchemaType.KEYWORD,
|
|
||||||
field_type=PayloadSchemaType.KEYWORD)
|
|
||||||
# creat full text index
|
|
||||||
text_index_params = TextIndexParams(
|
|
||||||
type=TextIndexType.TEXT,
|
|
||||||
tokenizer=TokenizerType.MULTILINGUAL,
|
|
||||||
min_token_len=2,
|
|
||||||
max_token_len=20,
|
|
||||||
lowercase=True
|
|
||||||
)
|
|
||||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
|
||||||
field_schema=text_index_params)
|
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
|
@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator
|
|||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
@ -79,16 +80,23 @@ class WeaviateVector(BaseVector):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
# create collection
|
||||||
schema = self._default_schema(self._collection_name)
|
self._create_collection()
|
||||||
|
|
||||||
# check whether the index already exists
|
|
||||||
if not self._client.schema.contains(schema):
|
|
||||||
# create collection
|
|
||||||
self._client.schema.create_class(schema)
|
|
||||||
# create vector
|
# create vector
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
|
def _create_collection(self):
|
||||||
|
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
schema = self._default_schema(self._collection_name)
|
||||||
|
if not self._client.schema.contains(schema):
|
||||||
|
# create collection
|
||||||
|
self._client.schema.create_class(schema)
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
|
Loading…
Reference in New Issue
Block a user