From 4d6b45427cd9637c3158be2c0f46600759097d63 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 20 Nov 2024 15:10:41 +0800 Subject: [PATCH] Support streaming output for OpenAI o1-preview and o1-mini (#10890) --- .../model_providers/openai/llm/llm.py | 50 +------------------ .../model_providers/openrouter/llm/llm.py | 18 +------ 2 files changed, 3 insertions(+), 65 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 68317d7179..f16f81c125 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) # o1 compatibility - block_as_stream = False if model.startswith("o1"): if "max_tokens" in model_parameters: model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] del model_parameters["max_tokens"] - if stream: - block_as_stream = True - stream = False - - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] - if "stop" in extra_model_kwargs: del extra_model_kwargs["stop"] @@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - if block_as_stream: - return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) - - return block_result - - def _handle_chat_block_as_stream_response( - self, - block_result: LLMResult, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Handle llm chat response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :param stop: stop words - :return: llm response chunk generator - """ - text = block_result.message.content - text = cast(str, text) - - if stop: - text = self.enforce_stop_tokens(text, stop) - - yield LLMResultChunk( - model=block_result.model, - prompt_messages=prompt_messages, - system_fingerprint=block_result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=text), - finish_reason="stop", - usage=block_result.usage, - ), - ) + return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( self, diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index 736ab8e7a8..2d6ece8113 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) - - block_as_stream = False - if model.startswith("openai/o1"): - block_as_stream = True - stop = None - - # invoke block as stream - if stream and block_as_stream: - return self._generate_block_as_stream( - model, credentials, prompt_messages, model_parameters, tools, stop, user - ) - else: - return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _generate_block_as_stream( self, @@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): stop: Optional[list[str]] = None, user: Optional[str] = None, ) -> Generator: - resp: LLMResult = super()._generate( - model, credentials, prompt_messages, model_parameters, tools, stop, False, user - ) + resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user) yield LLMResultChunk( model=model,