openai compatiable api usage and id (#9800)

Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
This commit is contained in:
guogeer 2024-10-24 21:51:36 +08:00 committed by GitHub
parent 9986e4c6d0
commit 70ddc0ce43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 7 deletions

View File

@ -105,6 +105,7 @@ class LLMResult(BaseModel):
Model class for llm result. Model class for llm result.
""" """
id: Optional[str] = None
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: list[PromptMessage]
message: AssistantPromptMessage message: AssistantPromptMessage

View File

@ -397,16 +397,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
chunk_index = 0 chunk_index = 0
def create_final_llm_result_chunk( def create_final_llm_result_chunk(
index: int, message: AssistantPromptMessage, finish_reason: str id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
) -> LLMResultChunk: ) -> LLMResultChunk:
# calculate num tokens # calculate num tokens
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(model, full_assistant_content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk( return LLMResultChunk(
id=id,
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
@ -450,7 +455,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
tool_call.function.arguments += new_tool_call.function.arguments tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = None # The default value of finish_reason is None finish_reason = None # The default value of finish_reason is None
message_id, usage = None, None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip() chunk = chunk.strip()
if chunk: if chunk:
@ -462,20 +467,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
continue continue
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json: dict = json.loads(decoded_chunk)
# stream ended # stream ended
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index + 1, index=chunk_index + 1,
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.", finish_reason="Non-JSON encountered.",
usage=usage,
) )
break break
if chunk_json:
if u := chunk_json.get("usage"):
usage = u
if not chunk_json or len(chunk_json["choices"]) == 0: if not chunk_json or len(chunk_json["choices"]) == 0:
continue continue
choice = chunk_json["choices"][0] choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason") finish_reason = chunk_json["choices"][0].get("finish_reason")
message_id = chunk_json.get("id")
chunk_index += 1 chunk_index += 1
if "delta" in choice: if "delta" in choice:
@ -524,6 +535,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
continue continue
yield LLMResultChunk( yield LLMResultChunk(
id=message_id,
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
@ -536,6 +548,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
if tools_calls: if tools_calls:
yield LLMResultChunk( yield LLMResultChunk(
id=message_id,
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
@ -545,17 +558,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
) )
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason id=message_id,
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage,
) )
def _handle_generate_response( def _handle_generate_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> LLMResult: ) -> LLMResult:
response_json = response.json() response_json: dict = response.json()
completion_type = LLMMode.value_of(credentials["mode"]) completion_type = LLMMode.value_of(credentials["mode"])
output = response_json["choices"][0] output = response_json["choices"][0]
message_id = response_json.get("id")
response_content = "" response_content = ""
tool_calls = None tool_calls = None
@ -593,6 +611,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
# transform response # transform response
result = LLMResult( result = LLMResult(
id=message_id,
model=response_json["model"], model=response_json["model"],
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_message, message=assistant_message,