feat: fix azure completion choices return empty (#708)
This commit is contained in:
parent
a856ef387b
commit
e18211ffea
@ -1,5 +1,7 @@
|
|||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
|
||||||
from langchain.llms import AzureOpenAI
|
from langchain.llms import AzureOpenAI
|
||||||
|
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
||||||
|
update_token_usage
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
|
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
|
||||||
|
|
||||||
@ -67,3 +69,58 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_kwargs_from_model_params(cls, params: dict):
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full LLM output.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = openai.generate(["Tell me a joke."])
|
||||||
|
"""
|
||||||
|
params = self._invocation_params
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||||
|
choices = []
|
||||||
|
token_usage: Dict[str, int] = {}
|
||||||
|
# Get the token usage from the response.
|
||||||
|
# Includes prompt, completion, and total tokens used.
|
||||||
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||||
|
for _prompts in sub_prompts:
|
||||||
|
if self.streaming:
|
||||||
|
if len(_prompts) > 1:
|
||||||
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
|
params["stream"] = True
|
||||||
|
response = _streaming_response_template()
|
||||||
|
for stream_resp in completion_with_retry(
|
||||||
|
self, prompt=_prompts, **params
|
||||||
|
):
|
||||||
|
if len(stream_resp["choices"]) > 0:
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
stream_resp["choices"][0]["text"],
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||||
|
)
|
||||||
|
_update_response(response, stream_resp)
|
||||||
|
choices.extend(response["choices"])
|
||||||
|
else:
|
||||||
|
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||||
|
choices.extend(response["choices"])
|
||||||
|
if not self.streaming:
|
||||||
|
# Can't update token usage if streaming
|
||||||
|
update_token_usage(_keys, response, token_usage)
|
||||||
|
return self.create_llm_result(choices, prompts, token_usage)
|
Loading…
Reference in New Issue
Block a user