From 46e95e83091dcbb5e33572befdddd06cefd8edaa Mon Sep 17 00:00:00 2001 From: k-zaku Date: Tue, 21 Jan 2025 16:29:13 +0900 Subject: [PATCH] fix: OpenAI o1 Bad Request Error (#12839) --- .../model_providers/openai/llm/llm.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 86042de6f4..634dbc5535 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,5 +1,6 @@ import json import logging +import re from collections.abc import Generator from typing import Any, Optional, Union, cast @@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) # o1 compatibility + block_as_stream = False if model.startswith("o1"): if "max_tokens" in model_parameters: model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] del model_parameters["max_tokens"] + if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model): + if stream: + block_as_stream = True + stream = False + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + if "stop" in extra_model_kwargs: del extra_model_kwargs["stop"] @@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=block_result.message, + finish_reason="stop", + usage=block_result.usage, + ), + ) def _handle_chat_generate_response( self,