feat: add support for Bedrock LLAMA3 (#3890)
This commit is contained in:
parent
08a65d74d5
commit
43a5ba9415
@ -8,6 +8,8 @@
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
|
@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
|
||||
model_name = model.split('.')[1]
|
||||
if isinstance(messages, str):
|
||||
prompt = messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix)
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "meta":
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
# LLAMA3
|
||||
if model_name.startswith("llama3"):
|
||||
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
else:
|
||||
# LLAMA2
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
self._convert_one_message_to_text(message, model_prefix, model_name)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload["textGenerationConfig"] = { **model_parameters }
|
||||
@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
elif model_prefix == "meta":
|
||||
payload = { **model_parameters }
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||
@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
|
||||
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
|
||||
|
||||
# need workaround for ai21 models which doesn't support streaming
|
||||
if stream and model_prefix != "ai21":
|
||||
|
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-70b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 70B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00265'
|
||||
output: '0.0035'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-8b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 8B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.0004'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
Loading…
Reference in New Issue
Block a user