From a76fde3d23f8162d56cc71e16d992df441f09026 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 23 Aug 2023 19:47:50 +0800 Subject: [PATCH] feat: optimize hf inference endpoint (#975) --- .../models/llm/huggingface_hub_model.py | 12 +++--- .../providers/huggingface_hub_provider.py | 16 ++++++-- .../llms/huggingface_endpoint_llm.py | 39 +++++++++++++++++++ .../test_huggingface_hub_provider.py | 3 +- 4 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 api/core/third_party/langchain/llms/huggingface_endpoint_llm.py diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py index 16aec70c30..7e800e3fea 100644 --- a/api/core/model_providers/models/llm/huggingface_hub_model.py +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -1,16 +1,14 @@ -import decimal -from functools import wraps from typing import List, Optional, Any from langchain import HuggingFaceHub from langchain.callbacks.manager import Callbacks -from langchain.llms import HuggingFaceEndpoint from langchain.schema import LLMResult from core.model_providers.error import LLMBadRequestError from core.model_providers.models.llm.base import BaseLLM -from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM class HuggingfaceHubModel(BaseLLM): @@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': - client = HuggingFaceEndpoint( + client = HuggingFaceEndpointLLM( endpoint_url=self.credentials['huggingfacehub_endpoint_url'], - task='text2text-generation', + task=self.credentials['task_type'], model_kwargs=provider_model_kwargs, huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], - callbacks=self.callbacks, + callbacks=self.callbacks ) else: client = HuggingFaceHub( diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index ded94e2a44..0c638a4e6b 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -2,7 +2,6 @@ import json from typing import Type from huggingface_hub import HfApi -from langchain.llms import HuggingFaceEndpoint from core.helper import encrypter from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType @@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel +from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM from models.provider import ProviderType @@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider): if 'huggingfacehub_endpoint_url' not in credentials: raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.') + if 'task_type' not in credentials: + raise CredentialsValidateFailedError('Task Type must be provided.') + + if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): + raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.') + try: - llm = HuggingFaceEndpoint( + llm = HuggingFaceEndpointLLM( endpoint_url=credentials['huggingfacehub_endpoint_url'], - task="text2text-generation", + task=credentials['task_type'], model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, huggingfacehub_api_token=credentials['huggingfacehub_api_token'] ) @@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider): } credentials = json.loads(provider_model.encrypted_config) + + if 'task_type' not in credentials: + credentials['task_type'] = 'text-generation' + if credentials['huggingfacehub_api_token']: credentials['huggingfacehub_api_token'] = encrypter.decrypt_token( self.provider.tenant_id, diff --git a/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py b/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py new file mode 100644 index 0000000000..71ee684e3d --- /dev/null +++ b/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py @@ -0,0 +1,39 @@ +from typing import Dict + +from langchain.llms import HuggingFaceEndpoint +from pydantic import Extra, root_validator + +from langchain.utils import get_from_dict_or_env + + +class HuggingFaceEndpointLLM(HuggingFaceEndpoint): + """HuggingFace Endpoint models. + + To use, you should have the ``huggingface_hub`` python package installed, and the + environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass + it as a named parameter to the constructor. + + Only supports `text-generation` and `text2text-generation` for now. + + Example: + .. code-block:: python + + from langchain.llms import HuggingFaceEndpoint + endpoint_url = ( + "https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud" + ) + hf = HuggingFaceEndpoint( + endpoint_url=endpoint_url, + huggingfacehub_api_token="my-api-key" + ) + """ + + @root_validator(allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + huggingfacehub_api_token = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) + + values["huggingfacehub_api_token"] = huggingfacehub_api_token + return values diff --git a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py index 61456f64f4..7f77d3c212 100644 --- a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py +++ b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py @@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = { INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = { 'huggingfacehub_api_type': 'inference_endpoints', 'huggingfacehub_api_token': 'valid_key', - 'huggingfacehub_endpoint_url': 'valid_url' + 'huggingfacehub_endpoint_url': 'valid_url', + 'task_type': 'text-generation' } def encrypt_side_effect(tenant_id, encrypt_key):