fix: OpenAI o1 Bad Request Error (#12839)
This commit is contained in:
parent
a7b9375877
commit
46e95e8309
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||||
|
|
||||||
# o1 compatibility
|
# o1 compatibility
|
||||||
|
block_as_stream = False
|
||||||
if model.startswith("o1"):
|
if model.startswith("o1"):
|
||||||
if "max_tokens" in model_parameters:
|
if "max_tokens" in model_parameters:
|
||||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||||
del model_parameters["max_tokens"]
|
del model_parameters["max_tokens"]
|
||||||
|
|
||||||
|
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
|
||||||
|
if stream:
|
||||||
|
block_as_stream = True
|
||||||
|
stream = False
|
||||||
|
if "stream_options" in extra_model_kwargs:
|
||||||
|
del extra_model_kwargs["stream_options"]
|
||||||
|
|
||||||
if "stop" in extra_model_kwargs:
|
if "stop" in extra_model_kwargs:
|
||||||
del extra_model_kwargs["stop"]
|
del extra_model_kwargs["stop"]
|
||||||
|
|
||||||
@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
|
if block_as_stream:
|
||||||
|
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||||
|
|
||||||
|
return block_result
|
||||||
|
|
||||||
|
def _handle_chat_block_as_stream_response(
|
||||||
|
self,
|
||||||
|
block_result: LLMResult,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
|
"""
|
||||||
|
Handle llm chat response
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param response: response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
text = block_result.message.content
|
||||||
|
text = cast(str, text)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
text = self.enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=block_result.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=block_result.system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=block_result.message,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=block_result.usage,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
def _handle_chat_generate_response(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user