from collections import defaultdict
from dataclasses import field
from typing import Annotated, Optional
import pytrec_eval
from flexrag.text_process import AnswerSimplifier
from flexrag.utils import TIME_METER, Choices, configure
from flexrag.utils.dataclasses import Context, RetrievedContext
from .metrics_base import METRICS, MetricsBase
try:
from .lib_rel import get_contain_map
has_librel = True
except:
has_librel = False
def get_contain_map_py(evidences: list[str], retrieved: list[str]) -> list[list[bool]]:
if has_librel:
return get_contain_map(evidences, retrieved)
contain_map: list[list[bool]] = []
for ret in retrieved:
contain_map.append([])
for evd in evidences:
contain_map[-1].append(evd in ret)
return contain_map
[docs]
@configure
class SuccessRateConfig:
"""Configuration for ``SuccessRate`` metric.
This metric computes whether the retrieved contexts contain any of the golden responses.
:param eval_field: The field to evaluate. Defaults to None.
If None, only strings are supported as the `retrieved_contexts`.
:type eval_field: Optional[str]
:param simplify: Whether to simplify the retrieved contexts. Defaults to True.
:type simplify: bool
"""
eval_field: Optional[str] = None
simplify: bool = True
[docs]
@METRICS("retrieval_success_rate", config_class=SuccessRateConfig)
class SuccessRate(MetricsBase):
"""The SuccessRate metric computes whether the retrieved contexts contain any of the golden responses."""
def __init__(self, cfg: SuccessRateConfig) -> None:
self.eval_field = cfg.eval_field
if cfg.simplify:
self.simplifier = AnswerSimplifier()
else:
self.simplifier = None
return
@TIME_METER("metrics.retrieval_success_rate")
def compute(
self,
golden_responses: list[list[str]] = None,
retrieved_contexts: list[list[str | Context]] = None,
**kwargs,
) -> tuple[dict[str, float], dict]:
# compute relevance map
success_map: list[bool] = []
for golds, ctxs in zip(golden_responses, retrieved_contexts):
if len(ctxs) == 0:
success_map.append(False)
continue
if isinstance(ctxs[0], Context):
assert self.eval_field is not None
ctxs = [ctx.data[self.eval_field] for ctx in ctxs]
if isinstance(ctxs[0], dict):
ctxs = [ctx["data"][self.eval_field] for ctx in ctxs]
if self.simplifier is not None:
ctxs = [self.simplifier(ctx) for ctx in ctxs]
golds = [self.simplifier(gold) for gold in golds]
rel_map = get_contain_map_py(golds, ctxs)
is_success = any(sum(rel_map, []))
success_map.append(is_success)
score = sum(success_map) / len(success_map)
return {"retrieval_success_rate": score}, {"success_map": success_map}
def pytrec_evaluate(
retrieved_contexts: list[list[RetrievedContext | str]],
golden_contexts: list[list[Context | str]],
k_values: list[int] = [1, 5, 10],
measure: Annotated[
str, Choices("recall", "precision", "ndcg", "map", "mrr")
] = "recall",
) -> tuple[dict[str, float], dict]:
"""Evaluate the retrieval results using pytrec_eval.
:param retrieved_contexts: The retrieved contexts.
:type retrieved_contexts: list[list[RetrievedContext]]
:param golden_contexts: The golden contexts.
:type golden_contexts: list[list[Context]]
:param k_values: The k values for evaluation. Defaults to [1, 5, 10].
:type k_values: list[int], optional
:param measure: The evaluation measure. Defaults to "recall".
Available choices are "recall", "precision", "ndcg", and "map".
:type measure: str, optional
:return: The evaluation scores and details.
:rtype: tuple[dict[str, float], dict]
"""
# convert flexrag format to pytrec_eval format
qrels: dict[str, dict[str, int]] = {}
retrieved: dict[str, dict[str, float]] = {}
ctx_ids: dict[str, str] = {} # label the context if it is a string
for n, (gctxs, rctxs) in enumerate(zip(golden_contexts, retrieved_contexts)):
qrels[str(n)] = {}
retrieved[str(n)] = {}
for ctx in gctxs:
if isinstance(ctx, str):
ctx_id = ctx_ids.get(ctx, str(len(ctx_ids)))
ctx_ids[ctx] = ctx_id
ctx_score = 1
else:
ctx_id = ctx.context_id
ctx_score = ctx.meta_data.get("score", 1)
qrels[str(n)][ctx_id] = ctx_score
for ctx in rctxs:
if isinstance(ctx, str):
ctx_id = ctx_ids.get(ctx, str(len(ctx_ids)))
ctx_ids[ctx] = ctx_id
ctx_score = 1.0
else:
ctx_id = ctx.context_id
ctx_score = ctx.score or 1.0
retrieved[str(n)][ctx_id] = ctx_score
# prepare pytrec_eval measure_strings
k_values = [str(k) for k in k_values]
measures_: set[str] # for pytrec_eval args
measure_strs_: list[str] # for pytrec_eval results
measure_strs: list[str] # for FlexRAG results
match measure:
case "recall":
measures_ = {f"recall." + ",".join(k_values)}
measure_strs_ = [f"recall_{k}" for k in k_values]
measure_strs = [f"Recall@{k}" for k in k_values]
case "precision":
measures_ = {f"P." + ",".join(k_values)}
measure_strs_ = [f"P_{k}" for k in k_values]
measure_strs = [f"Precision@{k}" for k in k_values]
case "ndcg":
if len(k_values) > 1:
measures_ = {f"ndcg_cut." + ",".join(k_values)}
measure_strs_ = [f"ndcg_cut_{k}" for k in k_values]
measure_strs = [f"nDCG@{k}" for k in k_values]
else:
measures_ = {"ndcg"}
measure_strs_ = ["ndcg"]
measure_strs = ["nDCG"]
case "map":
if len(k_values) > 1:
measures_ = {f"map_cut." + ",".join(k_values)}
measure_strs_ = [f"map_cut_{k}" for k in k_values]
measure_strs = [f"MAP@{k}" for k in k_values]
else:
measures_ = {"map"}
measure_strs_ = ["map"]
measure_strs = ["MAP"]
case "mrr":
measures_ = {"recip_rank"}
measure_strs_ = ["recip_rank"]
measure_strs = ["MRR"]
case _:
raise ValueError(f"Invalid measure: {measure}")
# evaluate using pytrec_eval
evaluator = pytrec_eval.RelevanceEvaluator(
query_relevance=qrels,
measures=measures_,
)
details = evaluator.evaluate(retrieved)
# extract the results
scores = defaultdict(list)
for key in retrieved.keys():
for measure_str_, measure_str in zip(measure_strs_, measure_strs):
scores[measure_str].append(details[key][measure_str_])
scores = {k: sum(v) / len(v) for k, v in scores.items()}
return scores, details
[docs]
@configure
class RetrievalRecallConfig:
"""Configuration for ``RetrievalRecall`` metric.
This metric computes the recall of the retrieved contexts.
The computation is based on `pytrec_eval <https://github.com/cvangysel/pytrec_eval>`_.
:param k_values: The k values for evaluation. Defaults to [1, 5, 10].
:type k_values: list[int]
"""
k_values: list[int] = field(default_factory=lambda: [1, 5, 10])
[docs]
@METRICS("retrieval_recall", config_class=RetrievalRecallConfig)
class RetrievalRecall(MetricsBase):
"""The RetrievalRecall metric computes the recall of the retrieved contexts."""
def __init__(self, cfg: RetrievalRecallConfig) -> None:
self.k_values = cfg.k_values
return
@TIME_METER("metrics.retrieval_recall")
def compute(
self,
retrieved_contexts: list[list[RetrievedContext]] = None,
golden_contexts: list[list[Context]] = None,
**kwargs,
) -> tuple[dict[str, float], dict]:
scores, details = pytrec_evaluate(
retrieved_contexts=retrieved_contexts,
golden_contexts=golden_contexts,
k_values=self.k_values,
measure="recall",
)
return scores, details
[docs]
@configure
class RetrievalPrecisionConfig:
"""Configuration for ``RetrievalPrecision`` metric.
This metric computes the precision of the retrieved contexts.
The computation is based on `pytrec_eval <https://github.com/cvangysel/pytrec_eval>`_.
:param k_values: The k values for evaluation. Defaults to [1, 5, 10].
:type k_values: list[int]
"""
k_values: list[int] = field(default_factory=lambda: [1, 5, 10])
[docs]
@METRICS("retrieval_precision", config_class=RetrievalPrecisionConfig)
class RetrievalPrecision(MetricsBase):
"""The RetrievalPrecision metric computes the precision of the retrieved contexts."""
def __init__(self, cfg: RetrievalPrecisionConfig) -> None:
self.k_values = cfg.k_values
return
@TIME_METER("metrics.retrieval_precision")
def compute(
self,
retrieved_contexts: list[list[RetrievedContext]] = None,
golden_contexts: list[list[Context]] = None,
**kwargs,
) -> tuple[float, object]:
scores, details = pytrec_evaluate(
retrieved_contexts=retrieved_contexts,
golden_contexts=golden_contexts,
k_values=self.k_values,
measure="precision",
)
return scores, details
[docs]
@configure
class RetrievalMAPConfig:
"""Configuration for ``RetrievalMAP`` metric.
This metric computes the MAP of the retrieved contexts.
The computation is based on `pytrec_eval <https://github.com/cvangysel/pytrec_eval>`_.
:param k_values: The k values for evaluation. Defaults to [1, 5, 10].
:type k_values: list[int]
"""
k_values: list[int] = field(default_factory=list)
[docs]
@METRICS("retrieval_map", config_class=RetrievalMAPConfig)
class RetrievalMAP(MetricsBase):
"""The RetrievalMAP metric computes the Mean Average Precision (MAP) of the retrieved contexts."""
def __init__(self, cfg: RetrievalMAPConfig) -> None:
self.k_values = cfg.k_values
return
@TIME_METER("metrics.retrieval_map")
def compute(
self,
retrieved_contexts: list[list[RetrievedContext]] = None,
golden_contexts: list[list[Context]] = None,
**kwargs,
) -> tuple[dict[str, float], dict]:
scores, details = pytrec_evaluate(
retrieved_contexts=retrieved_contexts,
golden_contexts=golden_contexts,
k_values=self.k_values,
measure="map",
)
return scores, details
[docs]
@configure
class RetrievalNDCGConfig:
"""Configuration for ``RetrievalNDCG`` metric.
This metric computes the nDCG of the retrieved contexts.
The computation is based on `pytrec_eval <https://github.com/cvangysel/pytrec_eval>`_.
:param k_values: The k values for evaluation. Defaults to [1, 5, 10].
:type k_values: list[int]
"""
k_values: list[int] = field(default_factory=list)
[docs]
@METRICS("retrieval_ndcg", config_class=RetrievalNDCGConfig)
class RetrievalNDCG(MetricsBase):
"""The RetrievalNDCG metric computes the Normalized Discounted Cumulative Gain (nDCG) of the retrieved contexts."""
def __init__(self, cfg: RetrievalNDCGConfig) -> None:
self.k_values = cfg.k_values
return
@TIME_METER("metrics.retrieval_ndcg")
def compute(
self,
retrieved_contexts: list[list[RetrievedContext]] = None,
golden_contexts: list[list[Context]] = None,
**kwargs,
) -> tuple[dict[str, float], dict]:
scores, details = pytrec_evaluate(
retrieved_contexts=retrieved_contexts,
golden_contexts=golden_contexts,
k_values=self.k_values,
measure="ndcg",
)
return scores, details
@METRICS("retrieval_mrr")
class RetrievalMRR(MetricsBase):
"""The RetrievalMRR metric computes the Mean Reciprocal Rank (MRR) of the retrieved contexts."""
@TIME_METER("metrics.retrieval_mrr")
def compute(
self,
retrieved_contexts: list[list[RetrievedContext]] = None,
golden_contexts: list[list[Context]] = None,
**kwargs,
) -> tuple[dict[str, float], dict]:
scores, details = pytrec_evaluate(
retrieved_contexts=retrieved_contexts,
golden_contexts=golden_contexts,
measure="mrr",
)
return scores, details