Fix: style checks and unittests (#12603)
This commit is contained in:
parent
04dade2f9b
commit
c03adcb154
@ -60,20 +60,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||||
f"Received status code {response.status_code} for URL {url} which is in the force list")
|
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logging.warning(f"Request to URL {url} failed on attempt {
|
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||||
retries + 1}: {e}")
|
|
||||||
if max_retries == 0:
|
if max_retries == 0:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
retries += 1
|
retries += 1
|
||||||
if retries <= max_retries:
|
if retries <= max_retries:
|
||||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||||
raise MaxRetriesExceededError(
|
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||||
f"Reached maximum retries ({max_retries}) for URL {url}")
|
|
||||||
|
|
||||||
|
|
||||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
|
@ -17,8 +17,7 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
logging.getLogger("lindorm").setLevel(logging.WARN)
|
logging.getLogger("lindorm").setLevel(logging.WARN)
|
||||||
|
|
||||||
ROUTING_FIELD = "routing_field"
|
ROUTING_FIELD = "routing_field"
|
||||||
@ -135,8 +134,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
self._client.delete(index=self._collection_name, id=id, params=params)
|
self._client.delete(index=self._collection_name, id=id, params=params)
|
||||||
self.refresh()
|
self.refresh()
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||||
f"DELETE BY ID: ID {id} does not exist in the index.")
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
if self._using_ugc:
|
if self._using_ugc:
|
||||||
@ -147,8 +145,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
self.refresh()
|
self.refresh()
|
||||||
else:
|
else:
|
||||||
if self._client.indices.exists(index=self._collection_name):
|
if self._client.indices.exists(index=self._collection_name):
|
||||||
self._client.indices.delete(
|
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||||
index=self._collection_name, params={"timeout": 60})
|
|
||||||
logger.info("Delete index success")
|
logger.info("Delete index success")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||||
@ -171,14 +168,13 @@ class LindormVectorStore(BaseVector):
|
|||||||
raise ValueError("All elements in query_vector should be floats")
|
raise ValueError("All elements in query_vector should be floats")
|
||||||
|
|
||||||
top_k = kwargs.get("top_k", 10)
|
top_k = kwargs.get("top_k", 10)
|
||||||
query = default_vector_search_query(
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||||
query_vector=query_vector, k=top_k, **kwargs)
|
|
||||||
try:
|
try:
|
||||||
params = {}
|
params = {}
|
||||||
if self._using_ugc:
|
if self._using_ugc:
|
||||||
params["routing"] = self._routing
|
params["routing"] = self._routing
|
||||||
response = self._client.search(index=self._collection_name, body=query, params=params)
|
response = self._client.search(index=self._collection_name, body=query, params=params)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception(f"Error executing vector search, query: {query}")
|
logger.exception(f"Error executing vector search, query: {query}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -224,8 +220,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
routing=routing,
|
routing=routing,
|
||||||
routing_field=self._routing_field,
|
routing_field=self._routing_field,
|
||||||
)
|
)
|
||||||
response = self._client.search(
|
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||||
index=self._collection_name, body=full_text_query)
|
|
||||||
docs = []
|
docs = []
|
||||||
for hit in response["hits"]["hits"]:
|
for hit in response["hits"]["hits"]:
|
||||||
docs.append(
|
docs.append(
|
||||||
@ -243,8 +238,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(
|
logger.info(f"Collection {self._collection_name} already exists.")
|
||||||
f"Collection {self._collection_name} already exists.")
|
|
||||||
return
|
return
|
||||||
if self._client.indices.exists(index=self._collection_name):
|
if self._client.indices.exists(index=self._collection_name):
|
||||||
logger.info(f"{self._collection_name.lower()} already exists.")
|
logger.info(f"{self._collection_name.lower()} already exists.")
|
||||||
@ -264,13 +258,10 @@ class LindormVectorStore(BaseVector):
|
|||||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||||
nlist = kwargs.pop("nlist", 1000)
|
nlist = kwargs.pop("nlist", 1000)
|
||||||
centroids_use_hnsw = kwargs.pop(
|
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||||
"centroids_use_hnsw", True if nlist >= 5000 else False)
|
|
||||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||||
centroids_hnsw_ef_construct = kwargs.pop(
|
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||||
"centroids_hnsw_ef_construct", 500)
|
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||||
centroids_hnsw_ef_search = kwargs.pop(
|
|
||||||
"centroids_hnsw_ef_search", 100)
|
|
||||||
mapping = default_text_mapping(
|
mapping = default_text_mapping(
|
||||||
dimension,
|
dimension,
|
||||||
method_name,
|
method_name,
|
||||||
@ -290,8 +281,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
using_ugc=self._using_ugc,
|
using_ugc=self._using_ugc,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self._client.indices.create(
|
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||||
index=self._collection_name.lower(), body=mapping)
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
# logger.info(f"create index success: {self._collection_name}")
|
# logger.info(f"create index success: {self._collection_name}")
|
||||||
|
|
||||||
@ -396,8 +386,7 @@ def default_text_search_query(
|
|||||||
# build complex search_query when either of must/must_not/should/filter is specified
|
# build complex search_query when either of must/must_not/should/filter is specified
|
||||||
if must:
|
if must:
|
||||||
if not isinstance(must, list):
|
if not isinstance(must, list):
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
|
||||||
f"unexpected [must] clause with {type(filters)}")
|
|
||||||
if query_clause not in must:
|
if query_clause not in must:
|
||||||
must.append(query_clause)
|
must.append(query_clause)
|
||||||
else:
|
else:
|
||||||
@ -407,22 +396,19 @@ def default_text_search_query(
|
|||||||
|
|
||||||
if must_not:
|
if must_not:
|
||||||
if not isinstance(must_not, list):
|
if not isinstance(must_not, list):
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
|
||||||
f"unexpected [must_not] clause with {type(filters)}")
|
|
||||||
boolean_query["must_not"] = must_not
|
boolean_query["must_not"] = must_not
|
||||||
|
|
||||||
if should:
|
if should:
|
||||||
if not isinstance(should, list):
|
if not isinstance(should, list):
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
|
||||||
f"unexpected [should] clause with {type(filters)}")
|
|
||||||
boolean_query["should"] = should
|
boolean_query["should"] = should
|
||||||
if minimum_should_match != 0:
|
if minimum_should_match != 0:
|
||||||
boolean_query["minimum_should_match"] = minimum_should_match
|
boolean_query["minimum_should_match"] = minimum_should_match
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
if not isinstance(filters, list):
|
if not isinstance(filters, list):
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
|
||||||
f"unexpected [filter] clause with {type(filters)}")
|
|
||||||
boolean_query["filter"] = filters
|
boolean_query["filter"] = filters
|
||||||
|
|
||||||
search_query = {"size": k, "query": {"bool": boolean_query}}
|
search_query = {"size": k, "query": {"bool": boolean_query}}
|
||||||
|
@ -44,13 +44,11 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = variable_pool.get(
|
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
|
||||||
node_data.query_variable_selector) if node_data.query_variable_selector else None
|
|
||||||
query = variable.value if variable else None
|
query = variable.value if variable else None
|
||||||
variables = {"query": query}
|
variables = {"query": query}
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(
|
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||||
node_data.model)
|
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(
|
memory = self._fetch_memory(
|
||||||
node_data_memory=node_data.memory,
|
node_data_memory=node_data.memory,
|
||||||
@ -58,8 +56,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
# fetch instruction
|
# fetch instruction
|
||||||
node_data.instruction = node_data.instruction or ""
|
node_data.instruction = node_data.instruction or ""
|
||||||
node_data.instruction = variable_pool.convert_template(
|
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||||
node_data.instruction).text
|
|
||||||
|
|
||||||
files = (
|
files = (
|
||||||
self._fetch_files(
|
self._fetch_files(
|
||||||
@ -181,15 +178,12 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
variable_mapping = {"query": node_data.query_variable_selector}
|
variable_mapping = {"query": node_data.query_variable_selector}
|
||||||
variable_selectors = []
|
variable_selectors = []
|
||||||
if node_data.instruction:
|
if node_data.instruction:
|
||||||
variable_template_parser = VariableTemplateParser(
|
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||||
template=node_data.instruction)
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
variable_selectors.extend(
|
|
||||||
variable_template_parser.extract_variable_selectors())
|
|
||||||
for variable_selector in variable_selectors:
|
for variable_selector in variable_selectors:
|
||||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||||
|
|
||||||
variable_mapping = {node_id + "." + key: value for key,
|
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||||
value in variable_mapping.items()}
|
|
||||||
|
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@ -210,8 +204,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
) -> int:
|
) -> int:
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
prompt_template = self._get_prompt_template(
|
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||||
node_data, query, None, 2000)
|
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs={},
|
inputs={},
|
||||||
@ -224,15 +217,13 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
rest_tokens = 2000
|
rest_tokens = 2000
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
if model_context_tokens:
|
if model_context_tokens:
|
||||||
model_instance = ModelInstance(
|
model_instance = ModelInstance(
|
||||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_message_tokens = model_instance.get_llm_num_tokens(
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
prompt_messages)
|
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
@ -273,8 +264,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
system_prompt_messages = LLMNodeChatModelMessage(
|
system_prompt_messages = LLMNodeChatModelMessage(
|
||||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(
|
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||||
histories=memory_str)
|
|
||||||
)
|
)
|
||||||
prompt_messages.append(system_prompt_messages)
|
prompt_messages.append(system_prompt_messages)
|
||||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||||
@ -315,5 +305,4 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise InvalidModelTypeError(
|
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||||
f"Model mode {model_mode} not support.")
|
|
||||||
|
@ -7,6 +7,12 @@ env =
|
|||||||
CODE_EXECUTION_API_KEY = dify-sandbox
|
CODE_EXECUTION_API_KEY = dify-sandbox
|
||||||
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
|
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
|
||||||
CODE_MAX_STRING_LENGTH = 80000
|
CODE_MAX_STRING_LENGTH = 80000
|
||||||
|
PLUGIN_API_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||||
|
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||||
|
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||||
|
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||||
|
MARKETPLACE_ENABLED=true
|
||||||
|
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||||
FIRECRAWL_API_KEY = fc-
|
FIRECRAWL_API_KEY = fc-
|
||||||
FIREWORKS_API_KEY = fw_aaaaaaaaaaaaaaaaaaaa
|
FIREWORKS_API_KEY = fw_aaaaaaaaaaaaaaaaaaaa
|
||||||
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
|
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
|
||||||
|
@ -68,8 +68,7 @@ def test_executor_with_json_body_and_object_variable():
|
|||||||
system_variables={},
|
system_variables={},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "object"], {
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
"name": "John Doe", "age": 30, "email": "john@example.com"})
|
|
||||||
|
|
||||||
# Prepare the node data
|
# Prepare the node data
|
||||||
node_data = HttpRequestNodeData(
|
node_data = HttpRequestNodeData(
|
||||||
@ -124,8 +123,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
|||||||
system_variables={},
|
system_variables={},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
variable_pool.add(["pre_node_id", "object"], {
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
"name": "John Doe", "age": 30, "email": "john@example.com"})
|
|
||||||
|
|
||||||
# Prepare the node data
|
# Prepare the node data
|
||||||
node_data = HttpRequestNodeData(
|
node_data = HttpRequestNodeData(
|
||||||
|
@ -18,14 +18,6 @@ from models.enums import UserFrom
|
|||||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
def test_plain_text_to_dict():
|
|
||||||
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
|
|
||||||
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
|
|
||||||
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
|
|
||||||
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {
|
|
||||||
"aa": "bb", "cc": "dd"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_http_request_node_binary_file(monkeypatch):
|
def test_http_request_node_binary_file(monkeypatch):
|
||||||
data = HttpRequestNodeData(
|
data = HttpRequestNodeData(
|
||||||
title="test",
|
title="test",
|
||||||
@ -191,8 +183,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||||||
|
|
||||||
def attr_checker(*args, **kwargs):
|
def attr_checker(*args, **kwargs):
|
||||||
assert kwargs["data"] == {"name": "test"}
|
assert kwargs["data"] == {"name": "test"}
|
||||||
assert kwargs["files"] == {
|
assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")}
|
||||||
"file": (None, b"test", "application/octet-stream")}
|
|
||||||
return httpx.Response(200, content=b"")
|
return httpx.Response(200, content=b"")
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
|
Loading…
Reference in New Issue
Block a user