from typing import Annotated, Optional
import numpy as np
from flexrag.models import ENCODERS, EncoderConfig
from flexrag.models.tokenizer import TOKENIZERS, TokenizerConfig
from flexrag.utils import LOGGER_MANAGER, Choices, configure
from .chunker_base import CHUNKERS, Chunk, ChunkerBase
from .sentence_splitter import SENTENCE_SPLITTERS, SentenceSplitterConfig
logger = LOGGER_MANAGER.get_logger("flexrag.chunking.semantic_chunker")
[docs]
@configure
class SemanticChunkerConfig(SentenceSplitterConfig, EncoderConfig, TokenizerConfig):
"""Configuration for SemanticChunker.
:param max_tokens: The maximum number of tokens in each chunk. Default is None.
:type max_tokens: Optional[int]
:param threshold: The threshold for semantic similarity. Default is None.
If provided, the `threshold_percentile` and `max_tokens` will be ignored.
:type threshold: Optional[float]
:param threshold_percentile: The ratio of the threshold for semantic similarity. Default is None.
Should be a value between 0 and 100. Higher values will result in more chunks. 5 is a good starting point.
If provided, the `max_tokens` will be ignored.
:type threshold_percentile: Optional[float]
:param similarity_window: The window size for calculating semantic similarity. Default is None.
:type similarity_window: Optional[int]
:param similarity_function: The similarity function to use. Default is "COS".
Available choices are "L2" for the reciprocal of euclidean distance, "IP" for inner product, and "COS" for cosine similarity.
:type similarity_function: str
The similarity higher than the threshold will be considered as coherent, and the chunks will be split at the points where the similarity is below the threshold.
Thus, at least one of `max_tokens`, `threshold`, or `threshold_percentile` should be provided.
If `threshold` is provided, the chunks will be split directly based on the threshold.
If `threshold_percentile` is provided, the threshold will be calculated automatically based on the similarity distribution.
If `max_tokens` is provided, the threshold will be calculated to ensure the chunks are within the token limit.
For example, to split the text into chunks with a maximum of 512 tokens, you can use the following configuration:
>>> from flexrag.chunking import SemanticChunker, SemanticChunkerConfig
>>> from flexrag.models import HFEncoderConfig
>>> config = SemanticChunkerConfig(
... max_tokens=512,
... encoder_type="hf",
... hf_config=HFEncoderConfig(model_path="BAAI/bge-small-en-v1.5"),
... )
>>> chunker = SemanticChunker(config)
To split the text into chunks with a threshold_percentile of 5%, you can use the following configuration:
>>> config = SemanticChunkerConfig(
... threshold_percentile=5,
... encoder_type="hf",
... hf_config=HFEncoderConfig(model_path="BAAI/bge-small-en-v1.5"),
... )
>>> chunker = SemanticChunker(config)
To split the text into chunks with a given threshold, you can use the following configuration:
>>> config = SemanticChunkerConfig(
... threshold=0.8,
... encoder_type="hf",
... hf_config=HFEncoderConfig(model_path="BAAI/bge-small-en-v1.5"),
... )
>>> chunker = SemanticChunker(config)
"""
max_tokens: Optional[int] = None
max_tokens_per_sentence: Optional[int] = None
threshold: Optional[float] = None
threshold_percentile: Optional[float] = None
similarity_window: int = 1
similarity_function: Annotated[str, Choices("L2", "IP", "COS")] = "COS"
[docs]
@CHUNKERS("semantic_chunker", config_class=SemanticChunkerConfig)
class SemanticChunker(ChunkerBase):
"""SemanticChunker splits text into sentences and then groups them into chunks based on semantic similarity.
This chunker is inspired by the Greg Kamradt's wonderful notebook:
https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
"""
def __init__(self, cfg: SemanticChunkerConfig) -> None:
# set the basic configurations
self.max_tokens = cfg.max_tokens if cfg.max_tokens is not None else float("inf")
self.threshold = cfg.threshold
self.similarity_window = cfg.similarity_window
self.threshold_percentile = cfg.threshold_percentile
self.similarity_function = cfg.similarity_function
if cfg.max_tokens_per_sentence is not None:
self.max_tokens_per_sentence = cfg.max_tokens_per_sentence
elif self.max_tokens is not None:
self.max_tokens_per_sentence = cfg.max_tokens
else:
self.max_tokens_per_sentence = None
# load the sentence splitter
self.splitter = SENTENCE_SPLITTERS.load(cfg)
# load the encoder
self.encoder = ENCODERS.load(cfg)
# load the tokenizer
self.tokenizer = TOKENIZERS.load(cfg)
return
[docs]
def chunk(self, text: str, return_str: bool = False) -> list[Chunk]:
# split the text into sentences
sentences = self._split_sentences(text)
if len(sentences) == 1:
chunks = [Chunk(text, 0, len(text))]
if return_str:
return [chunk.text for chunk in chunks]
return chunks
# combine the sentences to calculate the embeddings
combined_sentences = []
for i in range(len(sentences)):
combined_sentences.append(
" ".join(
sentences[
max(0, i - self.similarity_window) : i
+ self.similarity_window
+ 1
]
)
)
embeddings = self.encoder.encode(combined_sentences)
# calculate the similarity between the combined sentences
emb1 = embeddings[1:]
emb2 = embeddings[:-1]
match self.similarity_function:
case "L2":
similarity = 1 / np.linalg.norm(emb1 - emb2, axis=1)
case "IP":
similarity = np.einsum("ij,ij->i", emb1, emb2)
case "COS":
similarity = np.einsum("ij,ij->i", emb1, emb2) / (
np.linalg.norm(emb1, axis=1) * np.linalg.norm(emb2, axis=1)
)
case _:
raise ValueError(
f"Unknown similarity function: {self.similarity_function}"
)
# calculate the threshold
if self.threshold is not None:
threshold = self.threshold
elif self.threshold_percentile is not None:
threshold = np.percentile(similarity, self.threshold_percentile)
else:
assert (
self.max_tokens is not None
), "At least one of max_tokens, threshold, or threshold_percentile should be provided."
threshold = None
# group the sentences into chunks based on the threshold
if threshold is not None:
chunks = self._group_sentences(sentences, similarity, threshold)
else:
# check if the max_tokens is feasible
max_tokens = max(len(self.tokenizer.tokenize(sent)) for sent in sentences)
if max_tokens == self.max_tokens:
chunks = sentences
chunks = self._form_chunks(chunks)
if return_str:
return [chunk.text for chunk in chunks]
return chunks
elif max_tokens > self.max_tokens:
logger.warning(
f"The maximum number of tokens in a sentence is {max_tokens}, "
f"which is greater than the specified max_tokens {self.max_tokens}."
)
chunks = sentences
chunks = self._form_chunks(chunks)
if return_str:
return [chunk.text for chunk in chunks]
return chunks
# try to find the threshold that best fits the max_tokens
thresholds = np.sort(similarity)
left_pointer = 0
right_threshold = len(thresholds) - 1
while True:
mid_pointer = (left_pointer + right_threshold) // 2
threshold = thresholds[mid_pointer] + 1e-6
chunks_ = self._group_sentences(sentences, similarity, threshold)
max_tokens = max(
len(self.tokenizer.tokenize(chunk)) for chunk in chunks_
)
if left_pointer >= right_threshold:
break
if max_tokens > self.max_tokens:
left_pointer = mid_pointer + 1
else:
right_threshold = mid_pointer
# use the last threshold that fits the max_tokens
if max_tokens > self.max_tokens:
if mid_pointer + 1 < len(thresholds):
threshold = thresholds[mid_pointer + 1] + 1e-6
else:
logger.warning("Cannot find a suitable threshold.")
threshold = thresholds[mid_pointer] + 1e-6
else:
threshold = thresholds[mid_pointer] + 1e-6
chunks = self._group_sentences(sentences, similarity, threshold)
chunks = self._form_chunks(chunks)
if return_str:
return [chunk.text for chunk in chunks]
return chunks
def _group_sentences(
self, sentences: list[str], similarity: np.ndarray, threshold: float
) -> list[str]:
chunks = []
chunk = sentences[0]
for i in range(1, len(sentences)):
if similarity[i - 1] < threshold:
chunks.append(chunk)
chunk = sentences[i]
else:
if self.splitter.reversible:
chunk += sentences[i]
else:
chunk += " " + sentences[i]
chunks.append(chunk)
return chunks
def _split_sentences(self, text: str) -> list[str]:
"""Split the text into sentences."""
sents = self.splitter.split(text)
if self.max_tokens_per_sentence is None:
return sents
new_sents = []
for sent in sents:
tokens = self.tokenizer.tokenize(sent)
if len(tokens) <= self.max_tokens_per_sentence:
new_sents.append(sent)
continue
splitted_sents = []
for i in range(0, len(tokens), self.max_tokens_per_sentence):
splitted_sents.append(
self.tokenizer.detokenize(
tokens[i : i + self.max_tokens_per_sentence]
)
)
new_sents.extend(splitted_sents)
return new_sents
def _form_chunks(self, texts: list[str]) -> list[Chunk]:
chunks = []
current_index = 0
for text in texts:
if self.splitter.reversible:
chunks.append(
Chunk(
text=text,
start=current_index,
end=current_index + len(text),
)
)
current_index += len(text)
else:
chunks.append(Chunk(text=text))
return chunks