Add embedding models in fireworks provider (#8728)
This commit is contained in:
parent
4669eb24be
commit
91f70d0bd9
@ -15,6 +15,7 @@ help:
|
|||||||
en_US: https://fireworks.ai/account/api-keys
|
en_US: https://fireworks.ai/account/api-keys
|
||||||
supported_model_types:
|
supported_model_types:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
model: WhereIsAI/UAE-Large-V1
|
||||||
|
label:
|
||||||
|
zh_Hans: UAE-Large-V1
|
||||||
|
en_US: UAE-Large-V1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
||||||
|
max_chunks: 1
|
||||||
|
pricing:
|
||||||
|
input: '0.008'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
@ -0,0 +1,12 @@
|
|||||||
|
model: thenlper/gte-base
|
||||||
|
label:
|
||||||
|
zh_Hans: GTE-base
|
||||||
|
en_US: GTE-base
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
||||||
|
max_chunks: 1
|
||||||
|
pricing:
|
||||||
|
input: '0.008'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
@ -0,0 +1,12 @@
|
|||||||
|
model: thenlper/gte-large
|
||||||
|
label:
|
||||||
|
zh_Hans: GTE-large
|
||||||
|
en_US: GTE-large
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 512
|
||||||
|
max_chunks: 1
|
||||||
|
pricing:
|
||||||
|
input: '0.008'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
@ -0,0 +1,12 @@
|
|||||||
|
model: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
label:
|
||||||
|
zh_Hans: nomic-embed-text-v1.5
|
||||||
|
en_US: nomic-embed-text-v1.5
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 8192
|
||||||
|
max_chunks: 16
|
||||||
|
pricing:
|
||||||
|
input: '0.008'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
@ -0,0 +1,12 @@
|
|||||||
|
model: nomic-ai/nomic-embed-text-v1
|
||||||
|
label:
|
||||||
|
zh_Hans: nomic-embed-text-v1
|
||||||
|
en_US: nomic-embed-text-v1
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 8192
|
||||||
|
max_chunks: 16
|
||||||
|
pricing:
|
||||||
|
input: '0.008'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: 'USD'
|
@ -0,0 +1,151 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Fireworks text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:param input_type: input type
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
if user:
|
||||||
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
|
extra_model_kwargs["encoding_format"] = "float"
|
||||||
|
|
||||||
|
context_size = self._get_context_size(model, credentials)
|
||||||
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
indices = []
|
||||||
|
used_tokens = 0
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||||
|
# TODO: Optimize for better token estimation and chunking
|
||||||
|
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
if num_tokens >= context_size:
|
||||||
|
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||||
|
# if num tokens is larger than context length, only use the start
|
||||||
|
inputs.append(text[0:cutoff])
|
||||||
|
else:
|
||||||
|
inputs.append(text)
|
||||||
|
indices += [i]
|
||||||
|
|
||||||
|
batched_embeddings = []
|
||||||
|
_iter = range(0, len(inputs), max_chunks)
|
||||||
|
|
||||||
|
for i in _iter:
|
||||||
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
texts=inputs[i : i + max_chunks],
|
||||||
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
|
)
|
||||||
|
used_tokens += embedding_used_tokens
|
||||||
|
batched_embeddings += embeddings_batch
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||||
|
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# transform credentials to kwargs for model instance
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
client = OpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
|
# call embedding model
|
||||||
|
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _embedding_invoke(
|
||||||
|
self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||||
|
) -> tuple[list[list[float]], int]:
|
||||||
|
"""
|
||||||
|
Invoke embedding model
|
||||||
|
:param model: model name
|
||||||
|
:param client: model client
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param extra_model_kwargs: extra model kwargs
|
||||||
|
:return: embeddings and used tokens
|
||||||
|
"""
|
||||||
|
response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs)
|
||||||
|
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.fireworks.text_embedding.text_embedding import FireworksTextEmbeddingModel
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||||
|
def test_validate_credentials(setup_openai_mock):
|
||||||
|
model = FireworksTextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": "invalid_key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||||
|
def test_invoke_model(setup_openai_mock):
|
||||||
|
model = FireworksTextEmbeddingModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="nomic-ai/nomic-embed-text-v1.5",
|
||||||
|
credentials={
|
||||||
|
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 4
|
||||||
|
assert result.usage.total_tokens == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = FireworksTextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="nomic-ai/nomic-embed-text-v1.5",
|
||||||
|
credentials={
|
||||||
|
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 2
|
Loading…
Reference in New Issue
Block a user