From 8523b34be73b4c743b462a0dd416d73a9773361d Mon Sep 17 00:00:00 2001 From: Joshua <138381132+joshua20231026@users.noreply.github.com> Date: Mon, 4 Mar 2024 17:31:01 +0800 Subject: [PATCH] add jina-reranker-v1-base-en (#2676) --- .../model_providers/jina/jina.yaml | 5 +- .../model_providers/jina/rerank/__init__.py | 0 .../jina/rerank/jina-reranker-v1-base-en.yaml | 4 + .../model_providers/jina/rerank/rerank.py | 105 ++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 api/core/model_runtime/model_providers/jina/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v1-base-en.yaml create mode 100644 api/core/model_runtime/model_providers/jina/rerank/rerank.py diff --git a/api/core/model_runtime/model_providers/jina/jina.yaml b/api/core/model_runtime/model_providers/jina/jina.yaml index ad90344d53..935546234b 100644 --- a/api/core/model_runtime/model_providers/jina/jina.yaml +++ b/api/core/model_runtime/model_providers/jina/jina.yaml @@ -2,7 +2,7 @@ provider: jina label: en_US: Jina description: - en_US: Embedding Model Supported + en_US: Embedding and Rerank Model Supported icon_small: en_US: icon_s_en.svg icon_large: @@ -13,9 +13,10 @@ help: en_US: Get your API key from Jina AI zh_Hans: 从 Jina 获取 API Key url: - en_US: https://jina.ai/embeddings/ + en_US: https://jina.ai/ supported_model_types: - text-embedding + - rerank configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/jina/rerank/__init__.py b/api/core/model_runtime/model_providers/jina/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v1-base-en.yaml b/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v1-base-en.yaml new file mode 100644 index 0000000000..bd3f31fbd1 --- /dev/null +++ b/api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v1-base-en.yaml @@ -0,0 +1,4 @@ +model: jina-reranker-v1-base-en +model_type: rerank +model_properties: + context_size: 8192 diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py new file mode 100644 index 0000000000..f644ea6512 --- /dev/null +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -0,0 +1,105 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class JinaRerankModel(RerankModel): + """ + Model class for Jina rerank model. + """ + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + try: + response = httpx.post( + "https://api.jina.ai/v1/rerank", + json={ + "model": model, + "query": query, + "documents": docs, + "top_n": top_n + }, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + }