feat: support tidb vector (#4588)
This commit is contained in:
parent
602c4e51ec
commit
0797f9bc05
@ -112,6 +112,13 @@ PGVECTOR_USER=postgres
|
||||
PGVECTOR_PASSWORD=postgres
|
||||
PGVECTOR_DATABASE=postgres
|
||||
|
||||
# Tidb Vector configuration
|
||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||
TIDB_VECTOR_PORT=4000
|
||||
TIDB_VECTOR_USER=xxx.root
|
||||
TIDB_VECTOR_PASSWORD=xxxxxx
|
||||
TIDB_VECTOR_DATABASE=dify
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
@ -299,6 +299,13 @@ class Config:
|
||||
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
||||
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
||||
|
||||
# tidb-vector settings
|
||||
self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
|
||||
self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
|
||||
self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
|
||||
self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
|
||||
self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
|
@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
|
||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
@ -497,7 +497,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
if vector_type in {'milvus', 'relyt', 'pgvector'}:
|
||||
if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
|
0
api/core/rag/datasource/vdb/tidb_vector/__init__.py
Normal file
0
api/core/rag/datasource/vdb/tidb_vector/__init__.py
Normal file
214
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Normal file
214
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Normal file
@ -0,0 +1,214 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from pydantic import BaseModel, root_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TiDBVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config TIDB_VECTOR_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config TIDB_VECTOR_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
|
||||
if not values['database']:
|
||||
raise ValueError("config TIDB_VECTOR_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class TiDBVector(BaseVector):
|
||||
|
||||
def _table(self, dim: int) -> Table:
|
||||
from tidb_vector.sqlalchemy import VectorType
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column('id', String(36), primary_key=True, nullable=False),
|
||||
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
||||
extend_existing=True
|
||||
)
|
||||
|
||||
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||
f"ssl_verify_cert=true&ssl_verify_identity=true")
|
||||
self._distance_func = distance_func.lower()
|
||||
self._engine = create_engine(self._url)
|
||||
self._orm_base = declarative_base()
|
||||
self._dimension = 1536
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
logger.info("create collection and add texts, collection_name: " + self._collection_name)
|
||||
self._create_collection(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
self._dimension = len(embeddings[0])
|
||||
pass
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id CHAR(36) PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
# tidb vector not support 'CREATE/ADD INDEX' now
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
table = self._table(len(embeddings[0]))
|
||||
ids = self._get_uuids(documents)
|
||||
metas = [d.metadata for d in documents]
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
chunks_table_data = []
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
for id, text, meta, embedding in zip(
|
||||
ids, texts, metas, embeddings
|
||||
):
|
||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
conn.execute(insert(table).values(chunks_table_data))
|
||||
return ids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
result = self.get_ids_by_metadata_field('doc_id', id)
|
||||
return len(result) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with Session(self._engine) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def _delete_by_ids(self, ids: list[str]) -> bool:
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
table = self._table(self._dimension)
|
||||
try:
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||
if ids:
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
filter = kwargs.get('filter')
|
||||
distance = 1 - score_threshold
|
||||
|
||||
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||
query_vector_str = "[" + query_vector_str + "]"
|
||||
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
|
||||
|
||||
docs = []
|
||||
if self._distance_func == 'l2':
|
||||
tidb_func = 'Vec_l2_distance'
|
||||
elif self._distance_func == 'l2':
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
else:
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
)
|
||||
res = session.execute(select_statement)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
for meta, text in results:
|
||||
docs.append(Document(page_content=text, metadata=json.loads(meta)))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# tidb doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with Session(self._engine) as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
@ -187,6 +187,31 @@ class Vector:
|
||||
database=config.get("PGVECTOR_DATABASE"),
|
||||
),
|
||||
)
|
||||
elif vector_type == "tidb_vector":
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": 'tidb_vector',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
config=TiDBVectorConfig(
|
||||
host=config.get('TIDB_VECTOR_HOST'),
|
||||
port=config.get('TIDB_VECTOR_PORT'),
|
||||
user=config.get('TIDB_VECTOR_USER'),
|
||||
password=config.get('TIDB_VECTOR_PASSWORD'),
|
||||
database=config.get('TIDB_VECTOR_DATABASE'),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
@ -81,5 +81,7 @@ pgvecto-rs==0.1.4
|
||||
firecrawl-py==0.0.5
|
||||
oss2==2.18.5
|
||||
pgvector==0.2.5
|
||||
pymysql==1.1.1
|
||||
tidb-vector==0.0.9
|
||||
google-cloud-aiplatform==1.49.0
|
||||
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
||||
|
@ -0,0 +1,63 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||
from models.dataset import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector():
|
||||
return TiDBVector(
|
||||
collection_name='test_collection',
|
||||
config=TiDBVectorConfig(
|
||||
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
|
||||
port="4000",
|
||||
user="xxx.root",
|
||||
password="xxxxxx",
|
||||
database="dify"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TiDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def text_exists(self):
|
||||
exist = self.vector.text_exists(self.example_doc_id)
|
||||
assert exist == False
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
assert len(ids) == 0
|
||||
|
||||
def delete_by_document_id(self):
|
||||
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||
|
||||
|
||||
def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
|
||||
TiDBVectorTest(vector=tidb_vector).run_all_tests()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tidbvector_mock(tidb_vector, mock_session):
|
||||
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
|
||||
with patch.object(tidb_vector._engine, 'connect'):
|
||||
yield tidb_vector
|
@ -134,6 +134,12 @@ services:
|
||||
PGVECTOR_USER: postgres
|
||||
PGVECTOR_PASSWORD: difyai123456
|
||||
PGVECTOR_DATABASE: dify
|
||||
# tidb vector configurations
|
||||
TIDB_VECTOR_HOST: tidb
|
||||
TIDB_VECTOR_PORT: 4000
|
||||
TIDB_VECTOR_USER: xxx.root
|
||||
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||
TIDB_VECTOR_DATABASE: dify
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
@ -289,6 +295,12 @@ services:
|
||||
PGVECTOR_USER: postgres
|
||||
PGVECTOR_PASSWORD: difyai123456
|
||||
PGVECTOR_DATABASE: dify
|
||||
# tidb vector configurations
|
||||
TIDB_VECTOR_HOST: tidb
|
||||
TIDB_VECTOR_PORT: 4000
|
||||
TIDB_VECTOR_USER: xxx.root
|
||||
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||
TIDB_VECTOR_DATABASE: dify
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE: public
|
||||
NOTION_CLIENT_SECRET: you-client-secret
|
||||
|
Loading…
Reference in New Issue
Block a user