refactor: implement update_segment_by_id method for segment updates with validation and checks
This commit is contained in:
parent
eb22554448
commit
8203899907
@ -192,9 +192,13 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args["segment"], document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(**args["segment"]),
|
||||
segment,
|
||||
document,
|
||||
dataset
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi (DatasetApiResource):
|
||||
|
@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError, ProviderNotInitializeError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
@ -2180,6 +2180,69 @@ class SegmentService:
|
||||
|
||||
return segments, total
|
||||
|
||||
@classmethod
|
||||
def update_segment_by_id(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
segment_id: str,
|
||||
segment_data: dict,
|
||||
user_id: str
|
||||
) -> tuple[DocumentSegment, Document]:
|
||||
"""Update a segment by its ID with validation and checks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check embedding model setting if high quality
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=user_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# check segment
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == segment_id,
|
||||
DocumentSegment.tenant_id == user_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate and update segment
|
||||
cls.segment_create_args_validate(segment_data, document)
|
||||
updated_segment = cls.update_segment(
|
||||
SegmentUpdateArgs(**segment_data),
|
||||
segment,
|
||||
document,
|
||||
dataset
|
||||
)
|
||||
|
||||
return updated_segment, document
|
||||
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user