feat: add mixedbread as a new model provider (#8523)
This commit is contained in:
parent
7c485f8bb8
commit
1ecf70dca0
@ -38,3 +38,4 @@
|
|||||||
- perfxcloud
|
- perfxcloud
|
||||||
- zhinao
|
- zhinao
|
||||||
- fireworks
|
- fireworks
|
||||||
|
- mixedbread
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 121 KiB |
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
@ -0,0 +1,27 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedBreadProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
|
||||||
|
|
||||||
|
# Use `mxbai-embed-large-v1` model for validate,
|
||||||
|
model_instance.validate_credentials(model="mxbai-embed-large-v1", credentials=credentials)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||||
|
raise ex
|
@ -0,0 +1,31 @@
|
|||||||
|
provider: mixedbread
|
||||||
|
label:
|
||||||
|
en_US: MixedBread
|
||||||
|
description:
|
||||||
|
en_US: Embedding and Rerank Model Supported
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.png
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
background: "#EFFDFD"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API key from MixedBread AI
|
||||||
|
zh_Hans: 从 MixedBread 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://www.mixedbread.ai/
|
||||||
|
supported_model_types:
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
@ -0,0 +1,4 @@
|
|||||||
|
model: mxbai-rerank-large-v1
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
@ -0,0 +1,125 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
|
||||||
|
class MixedBreadRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for MixedBread rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n documents to return
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", "https://api.mixedbread.ai/v1")
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.post(
|
||||||
|
base_url + "/reranking",
|
||||||
|
json={"model": model, "query": query, "input": docs, "top_k": top_n, "return_input": True},
|
||||||
|
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in results["data"]:
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result["index"],
|
||||||
|
text=result["input"],
|
||||||
|
score=result["score"],
|
||||||
|
)
|
||||||
|
if score_threshold is None or result["score"] >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [httpx.ConnectError],
|
||||||
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||||
|
InvokeBadRequestError: [httpx.RequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,8 @@
|
|||||||
|
model: mxbai-embed-2d-large-v1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
||||||
|
pricing:
|
||||||
|
input: '0.0001'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
@ -0,0 +1,8 @@
|
|||||||
|
model: mxbai-embed-large-v1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
||||||
|
pricing:
|
||||||
|
input: '0.0001'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
@ -0,0 +1,163 @@
|
|||||||
|
import time
|
||||||
|
from json import JSONDecodeError, dumps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for MixedBread text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_base: str = "https://api.mixedbread.ai/v1"
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
api_key = credentials["api_key"]
|
||||||
|
if not api_key:
|
||||||
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", self.api_base)
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
url = base_url + "/embeddings"
|
||||||
|
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
data = {"model": model, "input": texts}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, data=dumps(data))
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
msg = resp["detail"]
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise InvokeAuthorizationError(msg)
|
||||||
|
elif response.status_code == 429:
|
||||||
|
raise InvokeRateLimitError(msg)
|
||||||
|
elif response.status_code == 500:
|
||||||
|
raise InvokeServerUnavailableError(msg)
|
||||||
|
else:
|
||||||
|
raise InvokeBadRequestError(msg)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
embeddings = resp["data"]
|
||||||
|
usage = resp["usage"]
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(
|
||||||
|
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -122,6 +122,7 @@ CODE_EXECUTION_API_KEY = "dify-sandbox"
|
|||||||
FIRECRAWL_API_KEY = "fc-"
|
FIRECRAWL_API_KEY = "fc-"
|
||||||
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
|
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
|
||||||
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
|
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
|
||||||
|
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.mixedbread.mixedbread import MixedBreadProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = MixedBreadProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"usage": {"prompt_tokens": 3, "total_tokens": 3},
|
||||||
|
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
|
||||||
|
"object": "list",
|
||||||
|
"normalized": "true",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"dimensions": 1024,
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")})
|
@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = MixedBreadRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="mxbai-rerank-large-v1",
|
||||||
|
credentials={"api_key": "invalid_key"},
|
||||||
|
)
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"usage": {"prompt_tokens": 86, "total_tokens": 86},
|
||||||
|
"model": "mixedbread-ai/mxbai-rerank-large-v1",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"score": 0.06762695,
|
||||||
|
"input": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
|
||||||
|
"States Census, Carson City had a population of 55,274.",
|
||||||
|
"object": "text_document",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"score": 0.057403564,
|
||||||
|
"input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific "
|
||||||
|
"Ocean that are a political division controlled by the United States. Its capital is "
|
||||||
|
"Saipan.",
|
||||||
|
"object": "text_document",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"object": "list",
|
||||||
|
"top_k": 2,
|
||||||
|
"return_input": True,
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(
|
||||||
|
model="mxbai-rerank-large-v1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = MixedBreadRerankModel()
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"usage": {"prompt_tokens": 56, "total_tokens": 56},
|
||||||
|
"model": "mixedbread-ai/mxbai-rerank-large-v1",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"score": 0.6044922,
|
||||||
|
"input": "Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
"object": "text_document",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"score": 0.0703125,
|
||||||
|
"input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a "
|
||||||
|
"team named PopiParty.",
|
||||||
|
"object": "text_document",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"object": "list",
|
||||||
|
"top_k": 2,
|
||||||
|
"return_input": "true",
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="mxbai-rerank-large-v1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||||
|
},
|
||||||
|
query="Who is Kasumi?",
|
||||||
|
docs=[
|
||||||
|
"Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
|
||||||
|
"PopiParty.",
|
||||||
|
],
|
||||||
|
score_threshold=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 0
|
||||||
|
assert result.docs[0].score >= 0.5
|
@ -0,0 +1,78 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.mixedbread.text_embedding.text_embedding import MixedBreadTextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = MixedBreadTextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(model="mxbai-embed-large-v1", credentials={"api_key": "invalid_key"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"usage": {"prompt_tokens": 3, "total_tokens": 3},
|
||||||
|
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
|
||||||
|
"object": "list",
|
||||||
|
"normalized": "true",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"dimensions": 1024,
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(
|
||||||
|
model="mxbai-embed-large-v1", credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = MixedBreadTextEmbeddingModel()
|
||||||
|
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"usage": {"prompt_tokens": 6, "total_tokens": 6},
|
||||||
|
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
"data": [
|
||||||
|
{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"},
|
||||||
|
{"embedding": [0.23333 for _ in range(1024)], "index": 1, "object": "embedding"},
|
||||||
|
],
|
||||||
|
"object": "list",
|
||||||
|
"normalized": "true",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"dimensions": 1024,
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="mxbai-embed-large-v1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world"],
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 2
|
||||||
|
assert result.usage.total_tokens == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = MixedBreadTextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="mxbai-embed-large-v1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["ping"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 1
|
@ -8,4 +8,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
|
|||||||
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
|
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
|
||||||
api/tests/integration_tests/model_runtime/upstage \
|
api/tests/integration_tests/model_runtime/upstage \
|
||||||
api/tests/integration_tests/model_runtime/fireworks \
|
api/tests/integration_tests/model_runtime/fireworks \
|
||||||
api/tests/integration_tests/model_runtime/nomic
|
api/tests/integration_tests/model_runtime/nomic \
|
||||||
|
api/tests/integration_tests/model_runtime/mixedbread
|
||||||
|
Loading…
Reference in New Issue
Block a user