feat: add support for Bedrock LLAMA3 (#3890)

This commit is contained in:
longzhihun 2024-04-27 13:13:09 +08:00 committed by GitHub
parent 08a65d74d5
commit 43a5ba9415
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 27 deletions

View File

@ -8,6 +8,8 @@
- anthropic.claude-3-haiku-v1:0 - anthropic.claude-3-haiku-v1:0
- cohere.command-light-text-v14 - cohere.command-light-text-v14
- cohere.command-text-v14 - cohere.command-text-v14
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- meta.llama2-13b-chat-v1 - meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1 - meta.llama2-70b-chat-v1
- mistral.mistral-large-2402-v1:0 - mistral.mistral-large-2402-v1:0

View File

@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return:md = genai.GenerativeModel(model) :return:md = genai.GenerativeModel(model)
""" """
prefix = model.split('.')[0] prefix = model.split('.')[0]
model_name = model.split('.')[1]
if isinstance(messages, str): if isinstance(messages, str):
prompt = messages prompt = messages
else: else:
prompt = self._convert_messages_to_prompt(messages, prefix) prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt) return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message, model_prefix)
for message in messages
)
return text.rstrip()
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str: def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
""" """
Convert a single message to a string. Convert a single message to a string.
@ -446,6 +431,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
ai_prompt = "\n\nAssistant:" ai_prompt = "\n\nAssistant:"
elif model_prefix == "meta": elif model_prefix == "meta":
# LLAMA3
if model_name.startswith("llama3"):
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
ai_prompt = "\n\nAssistant:"
else:
# LLAMA2
human_prompt_prefix = "\n[INST]" human_prompt_prefix = "\n[INST]"
human_prompt_postfix = "[\\INST]\n" human_prompt_postfix = "[\\INST]\n"
ai_prompt = "" ai_prompt = ""
@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
""" """
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
:param messages: List of PromptMessage to combine. :param messages: List of PromptMessage to combine.
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
:return: Combined string with necessary human_prompt and ai_prompt tags. :return: Combined string with necessary human_prompt and ai_prompt tags.
""" """
if not messages: if not messages:
@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
messages.append(AssistantPromptMessage(content="")) messages.append(AssistantPromptMessage(content=""))
text = "".join( text = "".join(
self._convert_one_message_to_text(message, model_prefix) self._convert_one_message_to_text(message, model_prefix, model_name)
for message in messages for message in messages
) )
# trim off the trailing ' ' that might come from the "Assistant: " # trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip() return text.rstrip()
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
""" """
Create payload for bedrock api call depending on model provider Create payload for bedrock api call depending on model provider
""" """
payload = dict() payload = dict()
model_prefix = model.split('.')[0]
model_name = model.split('.')[1]
if model_prefix == "amazon": if model_prefix == "amazon":
payload["textGenerationConfig"] = { **model_parameters } payload["textGenerationConfig"] = { **model_parameters }
@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "meta": elif model_prefix == "meta":
payload = { **model_parameters } payload = { **model_parameters }
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
else: else:
raise ValueError(f"Got unknown model prefix {model_prefix}") raise ValueError(f"Got unknown model prefix {model_prefix}")
@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
) )
model_prefix = model.split('.')[0] model_prefix = model.split('.')[0]
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream) payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
# need workaround for ai21 models which doesn't support streaming # need workaround for ai21 models which doesn't support streaming
if stream and model_prefix != "ai21": if stream and model_prefix != "ai21":

View File

@ -0,0 +1,23 @@
model: meta.llama3-70b-instruct-v1:0
label:
en_US: Llama 3 Instruct 70B
model_type: llm
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00265'
output: '0.0035'
unit: '0.00001'
currency: USD

View File

@ -0,0 +1,23 @@
model: meta.llama3-8b-instruct-v1:0
label:
en_US: Llama 3 Instruct 8B
model_type: llm
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.0004'
output: '0.0006'
unit: '0.0001'
currency: USD