From f804adbff3410a0af4672a12e9b74dbc9e03e2f3 Mon Sep 17 00:00:00 2001 From: miendinh <22139872+miendinh@users.noreply.github.com> Date: Sat, 25 May 2024 12:40:25 +0700 Subject: [PATCH] feat: Support for Vertex AI - load Default Application Configuration (#4641) Co-authored-by: miendinh Co-authored-by: crazywoola <427733928@qq.com> --- .../model_providers/vertex_ai/llm/llm.py | 7 +++++-- .../vertex_ai/text_embedding/text_embedding.py | 16 ++++++++++------ .../model_providers/vertex_ai/vertex_ai.yaml | 4 ++-- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 5e3905af98..7d2fbd087a 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -164,10 +164,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): config_kwargs["stop_sequences"] = stop service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) - service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) history = [] system_instruction = GEMINI_BLOCK_MODE_PROMPT diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index ece63806c3..2404ba5894 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -41,15 +41,16 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): :return: embeddings result """ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) - service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) client = VertexTextEmbeddingModel.from_pretrained(model) - - embeddings_batch, embedding_used_tokens = self._embedding_invoke( client=client, texts=texts @@ -103,10 +104,13 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ try: service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) - service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + if service_account_info: + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + else: + aiplatform.init(project=project_id, location=location) client = VertexTextEmbeddingModel.from_pretrained(model) diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml index 8b7f216b55..27a4d03fe2 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml @@ -36,8 +36,8 @@ provider_credential_schema: en_US: Enter your Google Cloud Location - variable: vertex_service_account_key label: - en_US: Service Account Key + en_US: Service Account Key (Leave blank if you use Application Default Credentials) type: secret-input - required: true + required: false placeholder: en_US: Enter your Google Cloud Service Account Key in base64 format