diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e25b08ff8a..19a6700318 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -14,7 +14,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields -from models.dataset import ChildChunk, Dataset +from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs from services.errors.chunk import ( @@ -336,9 +336,10 @@ class DatasetChildChunkApi(DatasetApiResource): # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = SegmentService.get_child_chunk_by_id( + child_chunk_id=child_chunk_id, + tenant_id=current_user.current_tenant_id + ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -376,9 +377,10 @@ class DatasetChildChunkApi(DatasetApiResource): # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = SegmentService.get_child_chunk_by_id( + child_chunk_id=child_chunk_id, + tenant_id=current_user.current_tenant_id + ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2a36990e2a..261aedce8a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2155,6 +2155,14 @@ class SegmentService: query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + @classmethod + def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: + """Get a child chunk by its ID.""" + return ChildChunk.query.filter( + ChildChunk.id == child_chunk_id, + ChildChunk.tenant_id == tenant_id + ).first() + @classmethod def get_segments( cls,