diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 5d34c80113..67f9b6384d 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -76,16 +76,20 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" + final_chunks = [] - # Get appropriate separator to use separator = self._separators[-1] - for _s in self._separators: + new_separators = [] + + for i, _s in enumerate(self._separators): if _s == "": separator = _s break if _s in text: separator = _s + new_separators = self._separators[i + 1 :] break + # Now that we have the separator, split the text if separator: if separator == " ": @@ -94,23 +98,52 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = text.split(separator) else: splits = list(text) - # Now go merging things, recursively splitting longer texts. + splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) - for s, s_len in zip(splits, s_lens): - if s_len < self._chunk_size: - _good_splits.append(s) - _good_splits_lengths.append(s_len) - else: - if _good_splits: - merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) - final_chunks.extend(merged_text) - _good_splits = [] - _good_splits_lengths = [] - other_info = self.recursive_split_text(s) - final_chunks.extend(other_info) - if _good_splits: - merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) - final_chunks.extend(merged_text) + if _separator != "": + for s, s_len in zip(splits, s_lens): + if s_len < self._chunk_size: + _good_splits.append(s) + _good_splits_lengths.append(s_len) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) + final_chunks.extend(merged_text) + _good_splits = [] + _good_splits_lengths = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) + final_chunks.extend(merged_text) + else: + current_part = "" + current_length = 0 + overlap_part = "" + overlap_part_length = 0 + for s, s_len in zip(splits, s_lens): + if current_length + s_len <= self._chunk_size - self._chunk_overlap: + current_part += s + current_length += s_len + elif current_length + s_len <= self._chunk_size: + current_part += s + current_length += s_len + overlap_part += s + overlap_part_length += s_len + else: + final_chunks.append(current_part) + current_part = overlap_part + s + current_length = s_len + overlap_part_length + overlap_part = "" + overlap_part_length = 0 + if current_part: + final_chunks.append(current_part) + return final_chunks