diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index d72d1bd83a..1a4cc15371 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -120,6 +120,7 @@ class _CommonWenxin: "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", + "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py b/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 0000000000..ef4b07d767 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,8 @@ +model: bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 4096 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py new file mode 100644 index 0000000000..b22aead22b --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py @@ -0,0 +1,147 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin + + +class WenxinRerank(_CommonWenxin): + def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): + access_token = self._get_access_token() + url = f"{self.api_bases[model]}?access_token={access_token}" + + try: + response = httpx.post( + url, + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + +class WenxinRerankModel(RerankModel): + """ + Model class for wenxin rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + + wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key) + + try: + results = wenxin_rerank.rerank(model, query, docs, top_n) + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index 6a6b38e6a1..d8acfd8120 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -18,6 +18,7 @@ help: supported_model_types: - llm - text-embedding + - rerank configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py new file mode 100644 index 0000000000..33c803e8e1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py @@ -0,0 +1,21 @@ +import os +from time import sleep + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel + + +def test_invoke_bce_reranker_base_v1(): + sleep(3) + model = WenxinRerankModel() + + response = model.invoke( + model="bce-reranker-base_v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + query="What is Deep Learning?", + docs=["Deep Learning is ...", "My Book is ..."], + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 2