"""
This module contains a wrapper class for the sacreBLEU metric from https://github.com/mjpost/sacreBLEU.
"""
from typing import Union, Optional, List, Dict
from string2string.misc.default_tokenizer import Tokenizer
from sacrebleu import corpus_bleu
# Pre-defined tokenizers for sacreBLEU
# This list taken from https://github.com/mjpost/sacrebleu/blob/4f4124642c4eb0b7120e50119c669f0570a326a7/sacrebleu/metrics/bleu.py#L18
ALLOWED_TOKENIZERS = {
'none': 'tokenizer_none.NoneTokenizer',
'zh': 'tokenizer_zh.TokenizerZh',
'13a': 'tokenizer_13a.Tokenizer13a',
'intl': 'tokenizer_intl.TokenizerV14International',
'char': 'tokenizer_char.TokenizerChar',
'ja-mecab': 'tokenizer_ja_mecab.TokenizerJaMecab',
'ko-mecab': 'tokenizer_ko_mecab.TokenizerKoMecab',
'spm': 'tokenizer_spm.TokenizerSPM',
'flores101': 'tokenizer_spm.Flores101Tokenizer',
'flores200': 'tokenizer_spm.Flores200Tokenizer',
}
[docs]class sacreBLEU:
"""
This class contains the sacreBLEU metric.
"""
[docs] def __init__(self) -> None:
"""
Initializes the BLEU class.
"""
pass
[docs] def compute(self,
predictions: List[str],
references: List[List[str]],
smooth_method: str = 'exp',
smooth_value: Optional[float] = None,
lowercase: bool = False,
tokenizer_name: Optional[str] = 'none',
use_effective_order: bool = False,
return_only: List[str] = ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len']
):
"""
Returns the BLEU score between a list of predictions and list of list of references.
Arguments:
predictions (List[str]): The predictions.
references (List[List[str]]): The references (or ground truth strings).
smooth_method (str): The smoothing method. Default is "exp". Other options are "floor", "add-k" and "none".
smooth_value (Optional[float]): The smoothing value for floor and add-k smoothing. Default is None.
lowercase (bool): Whether to lowercase the text. Default is False.
tokenizer_name (str): The tokenizer name. Default is "none". Other options are "zh", "13a", "intl", "char", "ja-mecab", "ko-mecab", "spm", "flores101" and "flores200".
use_effective_order (bool): Whether to use the effective order. Default is False.
return_only (Optional[List[str]]): The list of BLEU score components to return. Default is ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'].
Returns:
Dict[str, float]: The BLEU score (between 0 and 1).
Raises:
ValueError: If the number of predictions does not match the number of references.
ValueError: If the tokenizer name is invalid.
"""
# Check that the number of predictions matches the number of references
if len(predictions) != len(references):
raise ValueError('The number of predictions does not match the number of references.')
# Check that the tokenizer name is valid
if tokenizer_name not in ALLOWED_TOKENIZERS:
raise ValueError('The tokenizer name is invalid.')
# Check that the size of each reference list is the same
reference_size = len(references[0])
for reference in references:
if len(reference) != reference_size:
raise ValueError('The size of each reference list is not the same.')
# Transform the references into a list of list of references.
# This is necessary because sacrebleu.corpus_bleu expects a list of list of references.
transformed_references = [[refs[i] for refs in references] for i in range(reference_size)]
# Compute the BLEU score using sacrebleu.corpus_bleu
# This function returns "BLEUScore(score, correct, total, precisions, bp, sys_len, ref_len)"
bleu_score = corpus_bleu(
hypotheses=predictions,
references=transformed_references,
smooth_method=smooth_method,
smooth_value=smooth_value,
lowercase=lowercase,
use_effective_order=use_effective_order,
**(dict(tokenize=ALLOWED_TOKENIZERS[tokenizer_name]) if tokenizer_name != 'none' else {}),
)
# Get a summary of all the relevant BLEU score components
final_scores = {k: getattr(bleu_score, k) for k in return_only}
# Return the BLEU score
return final_scores
# predictions = ["hello there general kenobi", "foo bar foobar"]
# references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]]
# sbleu = sacreBLEU()
# bleu_score = sbleu.compute(predictions, references)
# print(bleu_score)