flexrag.text_process.basic_processors 源代码
import re
import string
from dataclasses import field
from typing import Optional
from flexrag.utils import configure
from .processor import PROCESSORS, Processor, TextUnit
from .utils import UnifiedTokenizer, UTokenizerConfig
[文档]
@configure
class TokenNormalizerConfig:
lang: str = "en"
penn: bool = True
norm_quote_commas: bool = True
norm_numbers: bool = True
pre_replace_unicode_punct: bool = False
post_remove_control_chars: bool = False
perl_parity: bool = False
[文档]
@PROCESSORS("token_normalize", config_class=TokenNormalizerConfig)
class TokenNormalizer(Processor):
def __init__(self, cfg: TokenNormalizerConfig) -> None:
from sacremoses import MosesPunctNormalizer
self.normalizer = MosesPunctNormalizer(
lang=cfg.lang,
penn=cfg.penn,
norm_quote_commas=cfg.norm_quote_commas,
norm_numbers=cfg.norm_numbers,
pre_replace_unicode_punct=cfg.pre_replace_unicode_punct,
post_remove_control_chars=cfg.post_remove_control_chars,
perl_parity=cfg.perl_parity,
)
return
def process(self, input_text: TextUnit) -> TextUnit:
input_text.content = self.normalizer.normalize(input_text.content)
return input_text
[文档]
@PROCESSORS("chinese_simplify")
class ChineseSimplifier(Processor):
def __init__(self):
import opencc
self.converter = opencc.OpenCC()
return
def process(self, input_text: TextUnit) -> TextUnit:
input_text.content = self.converter.convert(input_text.content)
return input_text
[文档]
@PROCESSORS("lowercase")
class Lowercase(Processor):
def process(self, input_text: TextUnit) -> TextUnit:
input_text.content = input_text.content.lower()
return input_text
[文档]
@PROCESSORS("unify")
class Unifier(Processor):
def __init__(self) -> None:
from unidecode import unidecode
self.unidecode = unidecode
return
def process(self, input_text: TextUnit) -> TextUnit:
input_text.content = self.unidecode(input_text.content)
return input_text
[文档]
@configure
class TruncatorConfig:
max_chars: Optional[int] = None
max_bytes: Optional[int] = None
max_tokens: Optional[int] = None
tokenizer_config: UTokenizerConfig = field(default_factory=UTokenizerConfig)
[文档]
@PROCESSORS("truncate", config_class=TruncatorConfig)
class Truncator(Processor):
def __init__(self, cfg: TruncatorConfig) -> None:
self.max_chars = cfg.max_chars
self.max_bytes = cfg.max_bytes
self.max_tokens = cfg.max_tokens
if self.max_tokens is not None:
self.tokenizer = UnifiedTokenizer(tokenizer_type=cfg.tokenizer_config)
assert (
self.max_chars is not None or self.max_bytes is not None
), "At least one of max_tokens and max_bytes should be set"
return
def process(self, input_text: TextUnit) -> TextUnit:
if self.max_tokens is not None:
tokens = self.tokenizer.tokenize(input_text.content)
input_text.content = self.tokenizer.detokenize(tokens[: self.max_tokens])
if self.max_chars is not None:
input_text.content = input_text.content[: self.max_chars]
if self.max_bytes is not None:
input_bytes = input_text.content.encode("utf-8")
if len(input_bytes) > self.max_bytes:
trunc_point = self.max_bytes
# do not truncate in the middle of a character
while trunc_point > 0 and (input_bytes[trunc_point] & 0xC0) == 0x80:
trunc_point -= 1
# ensure utf-8 encoding
while trunc_point > 0:
try:
_ = input_bytes[:trunc_point].decode("utf-8")
break
except:
trunc_point -= 1
input_text.content = input_bytes[:trunc_point].decode("utf-8")
return input_text
[文档]
@PROCESSORS("simplify_answer")
class AnswerSimplifier(Processor):
def process(self, input_text: TextUnit) -> TextUnit:
# lower case
input_text.content = input_text.content.lower()
# remove_articles
text = re.sub(r"\b(a|an|the)\b", " ", input_text.content)
# unify white space
text = " ".join(text.split())
# remove punctuation
exclude = set(string.punctuation)
input_text.content = "".join(ch for ch in text if ch not in exclude)
return input_text