refactor: implement update_segment_by_id method for segment updates with validation and checks

This commit is contained in:
ZeroZ_JQ 2025-03-18 17:57:52 +08:00
parent eb22554448
commit 8203899907
2 changed files with 71 additions and 4 deletions

View File

@ -192,9 +192,13 @@ class DatasetSegmentApi(DatasetApiResource):
parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]),
segment,
document,
dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
class ChildChunkApi (DatasetApiResource):

View File

@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError, ProviderNotInitializeError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
@ -2180,6 +2180,69 @@ class SegmentService:
return segments, total
@classmethod
def update_segment_by_id(
cls,
tenant_id: str,
dataset_id: str,
document_id: str,
segment_id: str,
segment_data: dict,
user_id: str
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting if high quality
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=user_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == segment_id,
DocumentSegment.tenant_id == user_id
).first()
if not segment:
raise NotFound("Segment not found.")
# validate and update segment
cls.segment_create_args_validate(segment_data, document)
updated_segment = cls.update_segment(
SegmentUpdateArgs(**segment_data),
segment,
document,
dataset
)
return updated_segment, document
class DatasetCollectionBindingService:
@classmethod