From 43a5ba9415cf10b672edeff49ab6ebc28180732c Mon Sep 17 00:00:00 2001 From: longzhihun <38651850@qq.com> Date: Sat, 27 Apr 2024 13:13:09 +0800 Subject: [PATCH] feat: add support for Bedrock LLAMA3 (#3890) --- .../bedrock/llm/_position.yaml | 2 + .../model_providers/bedrock/llm/llm.py | 49 +++++++++---------- .../llm/meta.llama3-70b-instruct-v1.yaml | 23 +++++++++ .../llm/meta.llama3-8b-instruct-v1.yaml | 23 +++++++++ 4 files changed, 70 insertions(+), 27 deletions(-) create mode 100644 api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-70b-instruct-v1.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-8b-instruct-v1.yaml diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index 24665553b9..7f4d2035cc 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -8,6 +8,8 @@ - anthropic.claude-3-haiku-v1:0 - cohere.command-light-text-v14 - cohere.command-text-v14 +- meta.llama3-8b-instruct-v1:0 +- meta.llama3-70b-instruct-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 - mistral.mistral-large-2402-v1:0 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 48723fdf88..81a9ce2f00 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return:md = genai.GenerativeModel(model) """ prefix = model.split('.')[0] - + model_name = model.split('.')[1] if isinstance(messages, str): prompt = messages else: - prompt = self._convert_messages_to_prompt(messages, prefix) + prompt = self._convert_messages_to_prompt(messages, prefix, model_name) return self._get_num_tokens_by_gpt2(prompt) - def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str: - """ - Format a list of messages into a full prompt for the Google model - - :param messages: List of PromptMessage to combine. - :return: Combined string with necessary human_prompt and ai_prompt tags. - """ - messages = messages.copy() # don't mutate the original list - - text = "".join( - self._convert_one_message_to_text(message, model_prefix) - for message in messages - ) - - return text.rstrip() def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str: + def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: """ Convert a single message to a string. @@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ai_prompt = "\n\nAssistant:" elif model_prefix == "meta": - human_prompt_prefix = "\n[INST]" - human_prompt_postfix = "[\\INST]\n" - ai_prompt = "" - + # LLAMA3 + if model_name.startswith("llama3"): + human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ai_prompt = "\n\nAssistant:" + else: + # LLAMA2 + human_prompt_prefix = "\n[INST]" + human_prompt_postfix = "[\\INST]\n" + ai_prompt = "" + elif model_prefix == "mistral": human_prompt_prefix = "[INST]" human_prompt_postfix = "[\\INST]\n" @@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models :param messages: List of PromptMessage to combine. + :param model_name: specific model name.Optional,just to distinguish llama2 and llama3 :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: @@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel): messages.append(AssistantPromptMessage(content="")) text = "".join( - self._convert_one_message_to_text(message, model_prefix) + self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages ) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): """ Create payload for bedrock api call depending on model provider """ payload = dict() + model_prefix = model.split('.')[0] + model_name = model.split('.')[1] if model_prefix == "amazon": payload["textGenerationConfig"] = { **model_parameters } @@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): elif model_prefix == "meta": payload = { **model_parameters } - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name) else: raise ValueError(f"Got unknown model prefix {model_prefix}") @@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ) model_prefix = model.split('.')[0] - payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream) + payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming if stream and model_prefix != "ai21": diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-70b-instruct-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-70b-instruct-v1.yaml new file mode 100644 index 0000000000..204662690e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-70b-instruct-v1.yaml @@ -0,0 +1,23 @@ +model: meta.llama3-70b-instruct-v1:0 +label: + en_US: Llama 3 Instruct 70B +model_type: llm +model_properties: + mode: completion + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00265' + output: '0.0035' + unit: '0.00001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-8b-instruct-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-8b-instruct-v1.yaml new file mode 100644 index 0000000000..dd4f666a5f --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-8b-instruct-v1.yaml @@ -0,0 +1,23 @@ +model: meta.llama3-8b-instruct-v1:0 +label: + en_US: Llama 3 Instruct 8B +model_type: llm +model_properties: + mode: completion + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.0004' + output: '0.0006' + unit: '0.0001' + currency: USD