flexrag.context_refine.summarizer 源代码
import re
from copy import deepcopy
from string import Template
from typing import Optional
import numpy as np
from flexrag.models import ENCODERS, GENERATORS, EncoderConfig, GeneratorConfig
from flexrag.prompt import ChatPrompt, ChatTurn
from flexrag.utils import TIME_METER, configure
from flexrag.utils.dataclasses import RetrievedContext
from .refiner import REFINERS, RefinerBase
[文档]
@configure
class AbstractiveSummarizerConfig(GeneratorConfig):
"""The configuration for the ``AbstractiveSummarizer``.
:param template: The template used to form the input text for the generator. Defaults to None.
The template should be a Python string.Template object.
The supported keys for the template are: [content, query].
:type template: Optional[str]
:param chat_prompt: The chat prompt for the generator. Defaults to None.
Only used when the generator is a chat-based generator.
:type chat_prompt: Optional[ChatPrompt]
:param substitute: Whether to substitute the original text with the summary. Defaults to True.
If False, the summary will be stored in a new field named as refined_field + "_summary".
:type substitute: bool
:param concatenate_contexts: Whether to concatenate the contexts into one text. Defaults to False.
:type concatenate_contexts: bool
:param refined_field: The field to refine. Required.
:type refined_field: str
The ``AbstractiveSummarizer`` supports multiple styles of summarizers, including T5, RECOMP, and LLM.
For example, to summarize the contexts using a `T5 style summarizer <https://arxiv.org/abs/1910.10683)>`_,
you can run the following code:
.. code-block:: python
cfg = AbstractiveSummarizerConfig(
template="summarize: ${content}",
generator_type="hf",
refined_field="text",
hf_config=HFGeneratorConfig(
model_path="google-t5/t5-small",
model_type="seq2seq",
)
)
summarizer = AbstractiveSummarizer(cfg)
To summarize the contexts using a `RECOMP style summarizer <https://arxiv.org/abs/2010.04348>`_,
you can run the following code:
.. code-block:: python
cfg = AbstractiveSummarizerConfig(
template="Question: ${query}\\n Document: ${content}\\n Summary: ",
generator_type="hf",
refined_field="text",
hf_config=HFGeneratorConfig(
model_path="fangyuan/hotpotqa_abstractive_compressor",
model_type="seq2seq",
)
)
summarizer = AbstractiveSummarizer(cfg)
To summarize the contexts using a `LLM style summarizer <https://arxiv.org/abs/2203.02155>`_,
you can run the following code:
.. code-block:: python
cfg = AbstractiveSummarizerConfig(
refined_field="text",
template="Query: ${query}\\nText: ${content}",
chat_prompt=ChatPrompt(
system="You are a skillful summarizer. Please summarize the following text based on given query.",
),
generator_type="openai",
openai_config=OpenAIGeneratorConfig(api_key=api_key, model_name="gpt-3.5-turbo")
)
summarizer = AbstractiveSummarizer(cfg)
"""
template: Optional[str] = None
chat_prompt: Optional[ChatPrompt] = None
substitute: bool = True
concatenate_contexts: bool = False
refined_field: Optional[str] = None
[文档]
@REFINERS("abstractive_summarizer", config_class=AbstractiveSummarizerConfig)
class AbstractiveSummarizer(RefinerBase):
"""The ``AbstractiveSummarizer`` summarizes the contexts using a generator."""
def __init__(self, cfg: AbstractiveSummarizerConfig):
super().__init__(cfg)
self.model = GENERATORS.load(cfg)
if cfg.template is not None:
self.template = Template(cfg.template)
else:
self.template = None
self.chat_prompt = cfg.chat_prompt
self.substitute = cfg.substitute
self.concatenate = cfg.concatenate_contexts
assert cfg.refined_field is not None, "The refined_field must be provided."
self.refined_field = cfg.refined_field
return
@TIME_METER("abstractive_summarize")
def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]:
# prepare input texts
if self.concatenate:
assert all(
contexts[0].query == context.query for context in contexts
), "All queries should be the same."
args = [
{
"content": " ".join(
[context.data[self.refined_field] for context in contexts]
),
"query": contexts[0].query,
}
]
else:
args = [
{
"content": context.data[self.refined_field],
"query": context.query,
}
for context in contexts
]
if self.template is not None:
input_texts = [self.template.safe_substitute(**arg) for arg in args]
else:
input_texts = [arg["content"] for arg in args]
# generate summaries
if self.chat_prompt is not None:
input_prompts = []
for text in input_texts:
prompt = deepcopy(self.chat_prompt)
prompt.update(ChatTurn(role="user", content=text))
input_prompts.append(prompt)
summaries = [i[0] for i in self.model.chat(input_prompts)]
else:
summaries = [i[0] for i in self.model.generate(input_texts)]
# update contexts
new_contexts = []
for context, summary in zip(contexts, summaries):
context = deepcopy(context)
if self.substitute:
context.data[self.refined_field] = summary
else:
if self.concatenate:
context.data[self.refined_field] = args[0]["content"]
context.data[self.refined_field + "_summary"] = summary
new_contexts.append(context)
return new_contexts
[文档]
@configure
class RecompExtractiveSummarizerConfig(EncoderConfig):
"""The configuration for the ``RecompExtractiveSummarizer``.
:param preserved_sents: The number of sentences to preserve. Defaults to 5.
:type preserved_sents: int
:param concatenate_contexts: Whether to concatenate the contexts into one text. Defaults to False.
:type concatenate_contexts: bool
:param substitute: Whether to substitute the original text with the summary. Defaults to False.
:type substitute: bool
:param refined_field: The field to refine. Required.
:type refined_field: str
The ``RecompExtractiveSummarizer`` is motivated by the RECOMP (https://arxiv.org/abs/2310.04408).
For example, to load a summarizer trained on hotpotqa dataset, you can run the following code:
.. code-block:: python
cfg = RecompExtractiveSummarizerConfig(
encoder_type="hf",
hf_config=HFEncoderConfig(
model_path="fangyuan/hotpotqa_extractive_compressor",
),
preserved_sents=5,
refined_field="text",
)
summarizer = RecompExtractiveSummarizer(cfg)
"""
preserved_sents: int = 5
concatenate_contexts: bool = False
substitute: bool = False
refined_field: Optional[str] = None
[文档]
@REFINERS("extractive_summarizer", config_class=RecompExtractiveSummarizerConfig)
class RecompExtractiveSummarizer(RefinerBase):
"""The ``ExtractiveSummarizer`` summarizes the contexts using an encoder."""
def __init__(self, cfg: RecompExtractiveSummarizerConfig) -> None:
self.model = ENCODERS.load(cfg)
assert self.model is not None, "The encoder model is not provided."
self.concatenate = cfg.concatenate_contexts
self.top_k = cfg.preserved_sents
self.substitute = cfg.substitute
assert cfg.refined_field is not None, "The refined_field must be provided."
self.refined_field = cfg.refined_field
return
@TIME_METER("extractive_summarize")
def refine(self, contexts: list[RetrievedContext]) -> list[RetrievedContext]:
if self.concatenate:
assert all(
contexts[0].query == context.query for context in contexts
), "All queries should be the same."
contents = [
" ".join([context.data[self.refined_field] for context in contexts])
]
queries = [contexts[0].query]
else:
contents = [context.data[self.refined_field] for context in contexts]
queries = [context.query for context in contexts]
# cut the texts into sentences
sent_lists = [
[i.strip() for i in re.split(r"(?<=[.!?])\s+", t) if len(i.strip()) > 5]
for t in contents
]
# encode the sentences & query
flat_sents = sum(sent_lists, [])
query_emb = self.model.encode(queries)
sents_emb = self.model.encode(flat_sents)
all_scores = query_emb @ sents_emb.T
# select topk sents
new_ctxs = []
for n, (sents, scores) in enumerate(zip(sent_lists, all_scores)):
scores = scores[
sum([len(i) for i in sent_lists[:n]]) : sum(
[len(i) for i in sent_lists[:n]]
)
+ len(sent_lists[n])
]
if len(scores) < self.top_k:
indices = np.arange(len(scores))
else:
indices = np.argpartition(-scores, self.top_k)[: self.top_k]
new_ctx = deepcopy(contexts[n])
if self.substitute:
new_ctx.data[self.refined_field] = " ".join(
[sents[i] for i in sorted(indices)]
)
else:
new_ctx.data[self.refined_field + "_summary"] = " ".join(
[sents[i] for i in sorted(indices)]
)
if self.concatenate:
new_ctx.data[self.refined_field] = contents[0]
new_ctxs.append(new_ctx)
return new_ctxs