feat: Add Cohere Command R / R+ model support (#3333)

This commit is contained in:
takatost 2024-04-11 01:22:55 +08:00 committed by GitHub
parent bf63a43bda
commit 826c422ac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 404 additions and 128 deletions

View File

@ -1,3 +1,5 @@
- command-r
- command-r-plus
- command-chat - command-chat
- command-light-chat - command-light-chat
- command-nightly-chat - command-nightly-chat

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500 max: 500
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
- name: preamble_override - name: preamble_override
label: label:

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500 max: 500
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
- name: preamble_override - name: preamble_override
label: label:

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500 max: 500
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
- name: preamble_override - name: preamble_override
label: label:

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty use_template: frequency_penalty
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
pricing: pricing:
input: '0.3' input: '0.3'

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty use_template: frequency_penalty
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
pricing: pricing:
input: '0.3' input: '0.3'

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500 max: 500
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
- name: preamble_override - name: preamble_override
label: label:

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty use_template: frequency_penalty
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
pricing: pricing:
input: '1.0' input: '1.0'

View File

@ -0,0 +1,45 @@
model: command-r-plus
label:
en_US: command-r-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '3'
output: '15'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,45 @@
model: command-r
label:
en_US: command-r
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '0.5'
output: '1.5'
unit: '0.000001'
currency: USD

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty use_template: frequency_penalty
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 256 default: 1024
max: 4096 max: 4096
pricing: pricing:
input: '1.0' input: '1.0'

View File

@ -1,20 +1,38 @@
import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator, Iterator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import cohere import cohere
from cohere.responses import Chat, Generations from cohere import (
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration ChatMessage,
from cohere.responses.generation import StreamingGenerations, StreamingText ChatStreamRequestToolResultsItem,
GenerateStreamedResponse,
GenerateStreamedResponse_StreamEnd,
GenerateStreamedResponse_StreamError,
GenerateStreamedResponse_TextGeneration,
Generation,
NonStreamedChatResponse,
StreamedChatResponse,
StreamedChatResponse_StreamEnd,
StreamedChatResponse_TextGeneration,
StreamedChatResponse_ToolCallsGeneration,
Tool,
ToolCall,
ToolParameterDefinitionsValue,
)
from cohere.core import RequestOptions
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageRole,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
credentials=credentials, credentials=credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=model_parameters, model_parameters=model_parameters,
tools=tools,
stop=stop, stop=stop,
stream=stream, stream=stream,
user=user user=user
@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if stop: if stop:
model_parameters['end_sequences'] = stop model_parameters['end_sequences'] = stop
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
stream=stream,
**model_parameters,
)
if stream: if stream:
response = client.generate_stream(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Generations, def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
prompt_messages: list[PromptMessage]) \ prompt_messages: list[PromptMessage]) \
-> LLMResult: -> LLMResult:
""" """
@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
) )
# calculate num tokens # calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens'] prompt_tokens = int(response.meta.billed_units.input_tokens)
completion_tokens = response.meta['billed_units']['output_tokens'] completion_tokens = int(response.meta.billed_units.output_tokens)
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm stream response
@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1 index = 1
full_assistant_content = '' full_assistant_content = ''
for chunk in response: for chunk in response:
if isinstance(chunk, StreamingText): if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
chunk = cast(StreamingText, chunk) chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
text = chunk.text text = chunk.text
if text is None: if text is None:
@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
) )
index += 1 index += 1
elif chunk is None: elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
# calculate num tokens # calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens'] prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
completion_tokens = response.meta['billed_units']['output_tokens'] completion_tokens = self._num_tokens_from_messages(
model,
credentials,
[AssistantPromptMessage(content=full_assistant_content)]
)
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index,
message=AssistantPromptMessage(content=''), message=AssistantPromptMessage(content=''),
finish_reason=response.finish_reason, finish_reason=chunk.finish_reason,
usage=usage usage=usage
) )
) )
break break
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
raise InvokeBadRequestError(chunk.err)
def _chat_generate(self, model: str, credentials: dict, def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
""" """
Invoke llm chat model Invoke llm chat model
@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials :param credentials: credentials
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param model_parameters: model parameters :param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words :param stop: stop words
:param stream: is stream response :param stream: is stream response
:param user: unique user id :param user: unique user id
@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'))
if user: if stop:
model_parameters['user_name'] = user model_parameters['stop_sequences'] = stop
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) if tools:
model_parameters['tools'] = self._convert_tools(tools)
message, chat_histories, tool_results \
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
if tool_results:
model_parameters['tool_results'] = tool_results
# chat model # chat model
real_model = model real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat') real_model = model.removesuffix('-chat')
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
stream=stream,
**model_parameters,
)
if stream: if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) response = client.chat_stream(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
prompt_messages: list[PromptMessage]) \
-> LLMResult: -> LLMResult:
""" """
Handle llm chat response Handle llm chat response
@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials :param credentials: credentials
:param response: response :param response: response
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param stop: stop words
:return: llm response :return: llm response
""" """
assistant_text = response.text assistant_text = response.text
tool_calls = []
if response.tool_calls:
for cohere_tool_call in response.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=assistant_text content=assistant_text,
tool_calls=tool_calls
) )
# calculate num tokens # calculate num tokens
@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
if stop:
# enforce stop tokens
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
# transform response # transform response
response = LLMResult( response = LLMResult(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_prompt_message, message=assistant_prompt_message,
usage=usage, usage=usage
system_fingerprint=response.preamble
) )
return response return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], response: Iterator[StreamedChatResponse],
stop: Optional[list[str]] = None) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm chat stream response Handle llm chat stream response
:param model: model name :param model: model name
:param response: response :param response: response
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param stop: stop words
:return: llm response chunk generator :return: llm response chunk generator
""" """
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, def final_response(full_text: str,
preamble: Optional[str] = None) -> LLMResultChunk: tool_calls: list[AssistantPromptMessage.ToolCall],
index: int,
finish_reason: Optional[str] = None) -> LLMResultChunk:
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
full_assistant_prompt_message = AssistantPromptMessage( full_assistant_prompt_message = AssistantPromptMessage(
content=full_text content=full_text,
tool_calls=tool_calls
) )
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return LLMResultChunk( return LLMResultChunk(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint=preamble,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index,
message=AssistantPromptMessage(content=''), message=AssistantPromptMessage(content='', tool_calls=tool_calls),
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage usage=usage
) )
@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1 index = 1
full_assistant_content = '' full_assistant_content = ''
tool_calls = []
for chunk in response: for chunk in response:
if isinstance(chunk, StreamTextGeneration): if isinstance(chunk, StreamedChatResponse_TextGeneration):
chunk = cast(StreamTextGeneration, chunk) chunk = cast(StreamedChatResponse_TextGeneration, chunk)
text = chunk.text text = chunk.text
if text is None: if text is None:
@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
content=text content=text
) )
# stop
# notice: This logic can only cover few stop scenarios
if stop and text in stop:
yield final_response(full_assistant_content, index, 'stop')
break
full_assistant_content += text full_assistant_content += text
yield LLMResultChunk( yield LLMResultChunk(
@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
) )
index += 1 index += 1
elif isinstance(chunk, StreamEnd): elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
chunk = cast(StreamEnd, chunk) chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
tool_calls = []
if chunk.tool_calls:
for cohere_tool_call in chunk.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
index += 1 index += 1
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-> tuple[str, list[dict]]: -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
""" """
Convert prompt messages to message and chat histories Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: :return:
""" """
chat_histories = [] chat_histories = []
latest_tool_call_n_outputs = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) if prompt_message.role == PromptMessageRole.ASSISTANT:
prompt_message = cast(AssistantPromptMessage, prompt_message)
if prompt_message.tool_calls:
for tool_call in prompt_message.tool_calls:
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call.function.name,
parameters=json.loads(tool_call.function.arguments)
),
outputs=[]
))
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
elif prompt_message.role == PromptMessageRole.TOOL:
prompt_message = cast(ToolPromptMessage, prompt_message)
if latest_tool_call_n_outputs:
i = 0
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call_n_outputs.call.name,
parameters=tool_call_n_outputs.call.parameters
),
outputs=[{
"result": prompt_message.content
}]
)
break
i += 1
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
if latest_tool_call_n_outputs:
new_latest_tool_call_n_outputs = []
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.outputs:
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
# get latest message from chat histories and pop it # get latest message from chat histories and pop it
if len(chat_histories) > 0: if len(chat_histories) > 0:
latest_message = chat_histories.pop() latest_message = chat_histories.pop()
message = latest_message['message'] message = latest_message.message
else: else:
raise ValueError('Prompt messages is empty') raise ValueError('Prompt messages is empty')
return message, chat_histories return message, chat_histories, latest_tool_call_n_outputs
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
""" """
Convert PromptMessage to dict for Cohere model Convert PromptMessage to dict for Cohere model
""" """
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
if isinstance(message.content, str): if isinstance(message.content, str):
message_dict = {"role": "USER", "message": message.content} chat_message = ChatMessage(role="USER", message=message.content)
else: else:
sub_message_text = '' sub_message_text = ''
for message_content in message.content: for message_content in message.content:
@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
message_content = cast(TextPromptMessageContent, message_content) message_content = cast(TextPromptMessageContent, message_content)
sub_message_text += message_content.data sub_message_text += message_content.data
message_dict = {"role": "USER", "message": sub_message_text} chat_message = ChatMessage(role="USER", message=sub_message_text)
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "CHATBOT", "message": message.content} if not message.content:
return None
chat_message = ChatMessage(role="CHATBOT", message=message.content)
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "USER", "message": message.content} chat_message = ChatMessage(role="USER", message=message.content)
elif isinstance(message, ToolPromptMessage):
return None
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if message.name: return chat_message
message_dict["user_name"] = message.name
return message_dict def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
"""
Convert tools to Cohere model
"""
cohere_tools = []
for tool in tools:
properties = tool.parameters['properties']
required_properties = tool.parameters['required']
parameter_definitions = {}
for p_key, p_val in properties.items():
required = False
if property in required_properties:
required = True
desc = p_val['description']
if 'enum' in p_val:
desc += (f"; Only accepts one of the following predefined options: "
f"[{', '.join(p_val['enum'])}]")
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
description=desc,
type=p_val['type'],
required=required
)
cohere_tool = Tool(
name=tool.name,
description=tool.description,
parameter_definitions=parameter_definitions
)
cohere_tools.append(cohere_tool)
return cohere_tools
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
""" """
@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model=model model=model
) )
return response.length return len(response.tokens)
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
"""Calculate num tokens Cohere model.""" """Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages] calc_messages = []
message_strs = [f"{message['role']}: {message['message']}" for message in messages] for message in messages:
cohere_message = self._convert_prompt_message_to_dict(message)
if cohere_message:
calc_messages.append(cohere_message)
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
message_str = "\n".join(message_strs) message_str = "\n".join(message_strs)
real_model = model real_model = model
@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
""" """
return { return {
InvokeConnectionError: [ InvokeConnectionError: [
cohere.CohereConnectionError cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
], ],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [ InvokeBadRequestError: [
cohere.CohereAPIError, cohere.core.api_error.ApiError,
cohere.CohereError, cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
] ]
} }

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
import cohere import cohere
from cohere.core import RequestOptions
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'))
results = client.rerank( response = client.rerank(
query=query, query=query,
documents=docs, documents=docs,
model=model, model=model,
top_n=top_n top_n=top_n,
return_documents=True,
request_options=RequestOptions(max_retries=0)
) )
rerank_documents = [] rerank_documents = []
for idx, result in enumerate(results): for idx, result in enumerate(response.results):
# format document # format document
rerank_document = RerankDocument( rerank_document = RerankDocument(
index=result.index, index=result.index,
text=result.document['text'], text=result.document.text,
score=result.relevance_score, score=result.relevance_score,
) )
@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
""" """
return { return {
InvokeConnectionError: [ InvokeConnectionError: [
cohere.CohereConnectionError, cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
], ],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [ InvokeBadRequestError: [
cohere.CohereAPIError, cohere.core.api_error.ApiError,
cohere.CohereError, cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
] ]
} }

View File

@ -3,7 +3,7 @@ from typing import Optional
import cohere import cohere
import numpy as np import numpy as np
from cohere.responses import Tokens from cohere.core import RequestOptions
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
text=text text=text
) )
for j in range(0, tokenize_response.length, context_size): for j in range(0, len(tokenize_response), context_size):
tokens += [tokenize_response.token_strings[j: j + context_size]] tokens += [tokenize_response[j: j + context_size]]
indices += [i] indices += [i]
batched_embeddings = [] batched_embeddings = []
@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
return response.length return len(response)
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
""" """
Tokenize text Tokenize text
:param model: model name :param model: model name
@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: :return:
""" """
if not text: if not text:
return Tokens([], [], {}) return []
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'))
response = client.tokenize( response = client.tokenize(
text=text, text=text,
model=model model=model,
offline=False,
request_options=RequestOptions(max_retries=0)
) )
return response return response.token_strings
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
response = client.embed( response = client.embed(
texts=texts, texts=texts,
model=model, model=model,
input_type='search_document' if len(texts) > 1 else 'search_query' input_type='search_document' if len(texts) > 1 else 'search_query',
request_options=RequestOptions(max_retries=1)
) )
return response.embeddings, response.meta['billed_units']['input_tokens'] return response.embeddings, int(response.meta.billed_units.input_tokens)
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
""" """
@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
""" """
return { return {
InvokeConnectionError: [ InvokeConnectionError: [
cohere.CohereConnectionError cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
], ],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [ InvokeBadRequestError: [
cohere.CohereAPIError, cohere.core.api_error.ApiError,
cohere.CohereError, cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
] ]
} }

View File

@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
) )
), ),
max_token_limit=rest_tokens, max_token_limit=rest_tokens,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
) )
# get prompt # get prompt

View File

@ -47,7 +47,8 @@ replicate~=0.22.0
websocket-client~=1.7.0 websocket-client~=1.7.0
dashscope[tokenizer]~=1.14.0 dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4 huggingface_hub~=0.16.4
transformers~=4.31.0 transformers~=4.35.0
tokenizers~=0.15.0
pandas==1.5.3 pandas==1.5.3
xinference-client==0.9.4 xinference-client==0.9.4
safetensors==0.3.2 safetensors==0.3.2
@ -55,7 +56,7 @@ zhipuai==1.0.7
werkzeug~=3.0.1 werkzeug~=3.0.1
pymilvus==2.3.0 pymilvus==2.3.0
qdrant-client==1.7.3 qdrant-client==1.7.3
cohere~=4.44 cohere~=5.2.4
pyyaml~=6.0.1 pyyaml~=6.0.1
numpy~=1.25.2 numpy~=1.25.2
unstructured[docx,pptx,msg,md,ppt]~=0.10.27 unstructured[docx,pptx,msg,md,ppt]~=0.10.27