marksverdhei commited on
Commit
9b8727e
·
verified ·
1 Parent(s): a366cfb

Upload errant_gec.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. errant_gec.py +225 -0
errant_gec.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ERRANT metric for Grammatical Error Correction evaluation.
2
+
3
+ This metric uses the ERRANT (ERRor ANnotation Toolkit) to evaluate
4
+ grammatical error correction systems by comparing edit operations
5
+ between source, prediction, and reference sentences.
6
+ """
7
+
8
+ import datasets
9
+ import evaluate
10
+
11
+
12
+ _CITATION = """\
13
+ @inproceedings{bryant-etal-2017-automatic,
14
+ title = "Automatic Annotation and Evaluation of Error Types for Grammatical Error Correction",
15
+ author = "Bryant, Christopher and
16
+ Felice, Mariano and
17
+ Briscoe, Ted",
18
+ booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
19
+ month = jul,
20
+ year = "2017",
21
+ address = "Vancouver, Canada",
22
+ publisher = "Association for Computational Linguistics",
23
+ url = "https://aclanthology.org/P17-1074",
24
+ doi = "10.18653/v1/P17-1074",
25
+ pages = "793--805",
26
+ }
27
+ """
28
+
29
+ _DESCRIPTION = """\
30
+ ERRANT (ERRor ANnotation Toolkit) is a metric for evaluating grammatical error
31
+ correction (GEC) systems. It computes precision, recall, and F-score by comparing
32
+ the edit operations needed to transform source sentences into predictions versus
33
+ the edit operations needed to transform source sentences into references.
34
+
35
+ This metric requires three inputs:
36
+ - sources: The original (uncorrected) sentences
37
+ - predictions: The model's corrected sentences
38
+ - references: The gold standard corrected sentences
39
+
40
+ The metric extracts edits using the ERRANT library and computes:
41
+ - Precision: What fraction of predicted edits are correct
42
+ - Recall: What fraction of gold edits were predicted
43
+ - F0.5: F-score with beta=0.5 (weighing precision twice as much as recall)
44
+ """
45
+
46
+ _KWARGS_DESCRIPTION = """
47
+ Args:
48
+ sources: list of source (original/uncorrected) sentences
49
+ predictions: list of predicted (corrected) sentences
50
+ references: list of reference (gold corrected) sentences
51
+ lang: language code for spaCy model (default: "en")
52
+ - "en": English (requires en_core_web_sm)
53
+ - "nb": Norwegian Bokmål (requires nb_core_news_sm)
54
+ - "de": German (requires de_core_news_sm)
55
+ - etc. (any language with a spaCy model)
56
+ beta: beta value for F-score calculation (default: 0.5)
57
+
58
+ Returns:
59
+ precision: fraction of predicted edits that are correct
60
+ recall: fraction of gold edits that were predicted
61
+ f0.5: F-score with the specified beta value
62
+
63
+ Examples:
64
+ >>> import evaluate
65
+ >>> errant_gec = evaluate.load("marksverdhei/errant_gec")
66
+ >>> results = errant_gec.compute(
67
+ ... sources=["This are a sentence ."],
68
+ ... predictions=["This is a sentence ."],
69
+ ... references=["This is a sentence ."],
70
+ ... lang="en"
71
+ ... )
72
+ >>> print(results)
73
+ {'precision': 1.0, 'recall': 1.0, 'f0.5': 1.0}
74
+ """
75
+
76
+ # Map language codes to spaCy model names
77
+ SPACY_MODELS = {
78
+ "en": "en_core_web_sm",
79
+ "nb": "nb_core_news_sm",
80
+ "nn": "nb_core_news_sm", # Use Bokmål model for Nynorsk as fallback
81
+ "de": "de_core_news_sm",
82
+ "es": "es_core_news_sm",
83
+ "fr": "fr_core_news_sm",
84
+ "it": "it_core_news_sm",
85
+ "nl": "nl_core_news_sm",
86
+ "pt": "pt_core_news_sm",
87
+ "ru": "ru_core_news_sm",
88
+ "zh": "zh_core_web_sm",
89
+ "ja": "ja_core_news_sm",
90
+ }
91
+
92
+
93
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
94
+ class Errant(evaluate.Metric):
95
+ """ERRANT metric for grammatical error correction evaluation."""
96
+
97
+ def __init__(self, *args, **kwargs):
98
+ super().__init__(*args, **kwargs)
99
+ self._annotators = {} # Cache annotators per language
100
+
101
+ def _info(self):
102
+ return evaluate.MetricInfo(
103
+ module_type="metric",
104
+ description=_DESCRIPTION,
105
+ citation=_CITATION,
106
+ inputs_description=_KWARGS_DESCRIPTION,
107
+ features=datasets.Features(
108
+ {
109
+ "sources": datasets.Value("string"),
110
+ "predictions": datasets.Value("string"),
111
+ "references": datasets.Value("string"),
112
+ }
113
+ ),
114
+ reference_urls=["https://github.com/chrisjbryant/errant"],
115
+ )
116
+
117
+ def _get_annotator(self, lang: str):
118
+ """Get or create an ERRANT annotator for the specified language."""
119
+ if lang in self._annotators:
120
+ return self._annotators[lang]
121
+
122
+ import errant
123
+ import spacy
124
+
125
+ model_name = SPACY_MODELS.get(lang, f"{lang}_core_news_sm")
126
+
127
+ try:
128
+ nlp = spacy.load(model_name)
129
+ except OSError:
130
+ raise OSError(
131
+ f"spaCy model '{model_name}' not found. "
132
+ f"Please install it with: python -m spacy download {model_name}"
133
+ )
134
+
135
+ # ERRANT uses 'en' as base but we provide the spaCy model
136
+ # The language code is mainly used for tokenization rules
137
+ annotator = errant.load(lang if lang == "en" else "en", nlp)
138
+ self._annotators[lang] = annotator
139
+ return annotator
140
+
141
+ def _get_edits(self, annotator, orig_doc, cor_doc):
142
+ """Extract edits between original and corrected documents.
143
+
144
+ Returns a set of (o_start, o_end, o_str, c_str) tuples.
145
+ """
146
+ edits = annotator.annotate(orig_doc, cor_doc)
147
+ edit_set = set()
148
+ for edit in edits:
149
+ # Skip noop edits (no actual change)
150
+ if edit.o_str == edit.c_str:
151
+ continue
152
+ # Use span positions and strings as edit identifier
153
+ edit_set.add((edit.o_start, edit.o_end, edit.o_str, edit.c_str))
154
+ return edit_set
155
+
156
+ def _compute_fscore(self, tp: int, fp: int, fn: int, beta: float = 0.5) -> dict:
157
+ """Compute precision, recall, and F-score."""
158
+ precision = float(tp) / (tp + fp) if (tp + fp) > 0 else 1.0
159
+ recall = float(tp) / (tp + fn) if (tp + fn) > 0 else 1.0
160
+
161
+ if precision + recall > 0:
162
+ f_score = float((1 + beta**2) * precision * recall) / (
163
+ (beta**2 * precision) + recall
164
+ )
165
+ else:
166
+ f_score = 0.0
167
+
168
+ return {
169
+ "precision": precision,
170
+ "recall": recall,
171
+ f"f{beta}": f_score,
172
+ }
173
+
174
+ def _compute(
175
+ self,
176
+ sources: list[str],
177
+ predictions: list[str],
178
+ references: list[str],
179
+ lang: str = "en",
180
+ beta: float = 0.5,
181
+ ) -> dict:
182
+ """Compute ERRANT scores for the given inputs.
183
+
184
+ Args:
185
+ sources: Original (uncorrected) sentences
186
+ predictions: Model's corrected sentences
187
+ references: Gold standard corrected sentences
188
+ lang: Language code for spaCy model
189
+ beta: Beta value for F-score (default 0.5)
190
+
191
+ Returns:
192
+ Dictionary with precision, recall, and f{beta} scores
193
+ """
194
+ if not (len(sources) == len(predictions) == len(references)):
195
+ raise ValueError(
196
+ f"Inputs must have the same length. Got sources={len(sources)}, "
197
+ f"predictions={len(predictions)}, references={len(references)}"
198
+ )
199
+
200
+ annotator = self._get_annotator(lang)
201
+
202
+ total_tp = 0
203
+ total_fp = 0
204
+ total_fn = 0
205
+
206
+ for source, prediction, reference in zip(sources, predictions, references):
207
+ # Parse sentences
208
+ orig_doc = annotator.parse(source)
209
+ hyp_doc = annotator.parse(prediction)
210
+ ref_doc = annotator.parse(reference)
211
+
212
+ # Get edit sets
213
+ hyp_edits = self._get_edits(annotator, orig_doc, hyp_doc)
214
+ ref_edits = self._get_edits(annotator, orig_doc, ref_doc)
215
+
216
+ # Compute TP, FP, FN for this sample
217
+ tp = len(ref_edits & hyp_edits)
218
+ fp = len(hyp_edits - ref_edits)
219
+ fn = len(ref_edits - hyp_edits)
220
+
221
+ total_tp += tp
222
+ total_fp += fp
223
+ total_fn += fn
224
+
225
+ return self._compute_fscore(total_tp, total_fp, total_fn, beta=beta)