From 33ba7e659be1bad68dae04f3ee95b2108474d074 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:07:29 +0800 Subject: [PATCH 1/2] fix vector db sql injection (#16096) --- .../rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py | 4 ++++ api/core/rag/datasource/vdb/myscale/myscale_vector.py | 2 ++ api/core/rag/datasource/vdb/opengauss/opengauss.py | 6 ++++-- api/core/rag/datasource/vdb/pgvector/pgvector.py | 5 ++++- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 4d8f792941..0b2f4cf6e2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -194,6 +194,8 @@ class AnalyticdbVectorBySql: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") score_threshold = float(kwargs.get("score_threshold") or 0.0) with self._get_cursor() as cur: query_vector_str = json.dumps(query_vector) @@ -220,6 +222,8 @@ class AnalyticdbVectorBySql: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"""SELECT id, vector, page_content, metadata_, diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 556b952ec2..3223010966 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -125,6 +125,8 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") score_threshold = float(kwargs.get("score_threshold") or 0.0) where_str = ( f"WHERE dist < {1 - score_threshold}" diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py index 7b51eb4bd8..2e5b7a31e4 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py @@ -155,7 +155,8 @@ class OpenGauss(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" @@ -174,7 +175,8 @@ class OpenGauss(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 13c214bfd7..b6153a5b09 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -171,6 +171,8 @@ class PGVector(BaseVector): :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: cur.execute( @@ -190,7 +192,8 @@ class PGVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") with self._get_cursor() as cur: if self.pg_bigm: cur.execute("SET pg_bigm.similarity_limit TO 0.000001") From 6f6ba2f025a5c2edb759d962d7a20e0117736abf Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:07:53 +0800 Subject: [PATCH 2/2] fix(api): enhance provider model records handling for missing langgenius providers (#16089) --- api/core/provider_manager.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7d8ff18983..099acfd7f4 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -149,6 +149,11 @@ class ProviderManager: provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) + provider_id_entity = ModelProviderID(provider_name) + if provider_id_entity.is_langgenius(): + provider_model_records.extend( + provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) + ) # Convert to custom configuration custom_configuration = self._to_custom_configuration( @@ -190,6 +195,20 @@ class ProviderManager: provider_name ) + provider_id_entity = ModelProviderID(provider_name) + + if provider_id_entity.is_langgenius(): + if provider_model_settings is not None: + provider_model_settings.extend( + provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, []) + ) + if provider_load_balancing_configs is not None: + provider_load_balancing_configs.extend( + provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_id_entity.provider_name, [] + ) + ) + # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, @@ -207,7 +226,7 @@ class ProviderManager: model_settings=model_settings, ) - provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration + provider_configurations[str(provider_id_entity)] = provider_configuration # Return the encapsulated object return provider_configurations