from transformers import AutoTokenizer, AutoModelForSeq2SeqLM class GrammarCorrector: def __init__(self, model_name="vennify/t5-base-grammar-correction"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def correct(self, text): input_text = "gec: " + text # gec: grammar error correction task input_ids = self.tokenizer.encode(input_text, return_tensors="pt") outputs = self.model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) corrected_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected_text