fix: OpenAI o1 Bad Request Error (#12839)

This commit is contained in:
k-zaku 2025-01-21 16:29:13 +09:00 committed by GitHub
parent a7b9375877
commit 46e95e8309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import re
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility # o1 compatibility
block_as_stream = False
if model.startswith("o1"): if model.startswith("o1"):
if "max_tokens" in model_parameters: if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"] del model_parameters["max_tokens"]
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs: if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"] del extra_model_kwargs["stop"]
@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream: if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
finish_reason="stop",
usage=block_result.usage,
),
)
def _handle_chat_generate_response( def _handle_chat_generate_response(
self, self,