Source code for flexrag.metrics.matching_metrics
from abc import abstractmethod
from collections import Counter
from flexrag.text_process import AnswerSimplifier
from flexrag.utils import TIME_METER, configure
from .metrics_base import METRICS, MetricsBase
@configure
class MatchingMetricsConfig:
"""Configuration class for MatchingMetrics.
:param simplify: Whether to simplify the answer before computing the matching score. Defaults to True.
:type simplify: bool
"""
simplify: bool = True
class MatchingMetrics(MetricsBase):
name: str
def __init__(self, cfg: MatchingMetricsConfig) -> None:
if cfg.simplify:
self.simplifier = AnswerSimplifier()
else:
self.simplifier = None
return
@abstractmethod
def compute_item(self, golds: list[str], response: str) -> float:
return
@TIME_METER("metrics.matching_score")
def compute(
self, responses: list[str], golden_responses: list[list[str]], **kwargs
) -> tuple[float, dict[str, list[float]]]:
if self.simplifier is not None:
responses = [self.simplifier(response) for response in responses]
golden_responses = [
[self.simplifier(gold) for gold in golds] for golds in golden_responses
]
matching_list = []
for golds, response in zip(golden_responses, responses):
matching_list.append(self.compute_item(golds, response))
matching_score = sum(matching_list) / len(matching_list)
return {self.name: matching_score}, {"item_score": matching_list}
ExactMatchConfig = MatchingMetricsConfig
AccuracyConfig = MatchingMetricsConfig
[docs]
@METRICS("generation_em", config_class=ExactMatchConfig)
class ExactMatch(MatchingMetrics):
"""ExactMatch metric computes if any of the golden responses is exactly the same as the predicted response."""
name = "generation_em"
def compute_item(self, golds: list[str], response: str) -> float:
return float(response in golds)
[docs]
@METRICS("generation_accuracy", config_class=AccuracyConfig)
class Accuracy(MatchingMetrics):
"""Accuracy metric computes if any of the golden responses is in the predicted response."""
name = "generation_accuracy"
def compute_item(self, golds: list[str], response: str) -> float:
return float(any(gold in response for gold in golds))
def f1_recall_precision(golds: list[str], response: str) -> tuple[float, float, float]:
true_counters = [Counter(gold.split()) for gold in golds]
pred_counter = Counter(response.split())
precision = 0.0
recall = 0.0
f1 = 0.0
for gold in true_counters:
common = sum((gold & pred_counter).values())
if common == 0:
continue
p = 1.0 * common / sum(pred_counter.values())
r = 1.0 * common / sum(gold.values())
f1_ = (2 * p * r) / (p + r)
precision = max(p, precision)
recall = max(r, recall)
f1 = max(f1, f1_)
return f1, recall, precision
F1Config = MatchingMetricsConfig
RecallConfig = MatchingMetricsConfig
PrecisionConfig = MatchingMetricsConfig
[docs]
@METRICS("generation_f1", config_class=F1Config)
class F1(MatchingMetrics):
"""F1 metric computes the F1 score of the predicted response against the golden responses."""
name = "generation_f1"
def compute_item(self, golds: list[str], response: str) -> float:
return f1_recall_precision(golds, response)[0]
[docs]
@METRICS("generation_recall", config_class=RecallConfig)
class Recall(MatchingMetrics):
"""Recall metric computes the recall of the predicted response against the golden responses."""
name = "generation_recall"
def compute_item(self, golds: list[str], response: str) -> float:
return f1_recall_precision(golds, response)[1]
[docs]
@METRICS("generation_precision", config_class=PrecisionConfig)
class Precision(MatchingMetrics):
"""Precision metric computes the precision of the predicted response against the golden responses."""
name = "generation_precision"
def compute_item(self, golds: list[str], response: str) -> float:
return f1_recall_precision(golds, response)[2]