Spaces:
Sleeping
Sleeping
Commit
·
d939bae
0
Parent(s):
Deploy V15 Clean (Removed binary files history)
Browse files- .gitignore +31 -0
- Dockerfile +13 -0
- README.md +24 -0
- requirements.txt +21 -0
- src/api.py +366 -0
- src/clean_external_data.py +506 -0
- src/clean_external_data_task1.py +94 -0
- src/explore.py +90 -0
- src/explore_speaking.py +71 -0
- src/pronunciation.py +111 -0
- src/train.py +151 -0
.gitignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File .gitignore
|
| 2 |
+
*.mp3
|
| 3 |
+
*.wav
|
| 4 |
+
*.m4a
|
| 5 |
+
data/
|
| 6 |
+
|
| 7 |
+
# Môi trường ảo
|
| 8 |
+
venv/
|
| 9 |
+
|
| 10 |
+
# Python cache
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.pyc
|
| 13 |
+
.cache/
|
| 14 |
+
|
| 15 |
+
# File hệ điều hành
|
| 16 |
+
.DS_Store
|
| 17 |
+
Thumbs.db
|
| 18 |
+
|
| 19 |
+
# File IDE
|
| 20 |
+
.vscode/
|
| 21 |
+
.idea/
|
| 22 |
+
|
| 23 |
+
# File dữ liệu local và model đã huấn luyện
|
| 24 |
+
# (Chúng ta chỉ push code, không push data/model nặng)
|
| 25 |
+
*.json
|
| 26 |
+
*.csv
|
| 27 |
+
best_model/
|
| 28 |
+
ielts_grader_model/
|
| 29 |
+
|
| 30 |
+
# File .env (chứa API keys)
|
| 31 |
+
.env
|
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10
|
| 2 |
+
WORKDIR /app
|
| 3 |
+
RUN apt-get update && apt-get install -y \
|
| 4 |
+
espeak-ng \
|
| 5 |
+
ffmpeg \
|
| 6 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 7 |
+
COPY requirements.txt .
|
| 8 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
+
COPY . .
|
| 10 |
+
RUN mkdir -p /app/.cache && chmod 777 /app/.cache
|
| 11 |
+
ENV TRANSFORMERS_CACHE=/app/.cache
|
| 12 |
+
EXPOSE 7860
|
| 13 |
+
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: IELTS Grader API
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 7860
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# IELTS Grader AI API
|
| 12 |
+
|
| 13 |
+
Đây là Backend API chấm điểm IELTS Writing Task 1 và Task 2 sử dụng AI.
|
| 14 |
+
|
| 15 |
+
## Các tính năng chính:
|
| 16 |
+
|
| 17 |
+
- **Task 1:** Phân tích biểu đồ (Vision AI) và chấm điểm.
|
| 18 |
+
- **Task 2:** Chấm điểm bài luận nghị luận xã hội.
|
| 19 |
+
- **Feedback:** Cung cấp nhận xét chi tiết và hướng dẫn sửa lỗi.
|
| 20 |
+
|
| 21 |
+
## Cách sử dụng
|
| 22 |
+
|
| 23 |
+
API chạy tại port 7860.
|
| 24 |
+
Endpoint chính: `POST /grade`
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train
|
| 2 |
+
transformers[torch]
|
| 3 |
+
datasets
|
| 4 |
+
scikit-learn
|
| 5 |
+
huggingface_hub
|
| 6 |
+
tqdm
|
| 7 |
+
torch
|
| 8 |
+
pydantic
|
| 9 |
+
|
| 10 |
+
# API
|
| 11 |
+
fastapi
|
| 12 |
+
uvicorn[standard]
|
| 13 |
+
httpx
|
| 14 |
+
python-dotenv
|
| 15 |
+
|
| 16 |
+
# Library
|
| 17 |
+
openai
|
| 18 |
+
librosa
|
| 19 |
+
openai-whisper
|
| 20 |
+
numpy
|
| 21 |
+
python-multipart
|
src/api.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import httpx
|
| 9 |
+
import shutil
|
| 10 |
+
import whisper
|
| 11 |
+
import librosa
|
| 12 |
+
import numpy as np
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
import uuid
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from src.pronunciation import grade_pronunciation_advanced
|
| 19 |
+
except ImportError:
|
| 20 |
+
from pronunciation import grade_pronunciation_advanced
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
SCORER_MODEL_ID_TASK1 = "diminch/ielts-task1-grader-ai-v2"
|
| 25 |
+
SCORER_MODEL_ID_TASK2 = "diminch/ielts-grader-ai-v2"
|
| 26 |
+
|
| 27 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
print(f"API running on: {DEVICE}")
|
| 29 |
+
|
| 30 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 31 |
+
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
|
| 32 |
+
|
| 33 |
+
if not OPENAI_API_KEY:
|
| 34 |
+
print("WARNING: OPENAI_API_KEY not found in .env")
|
| 35 |
+
|
| 36 |
+
print("Loading Whisper...")
|
| 37 |
+
try:
|
| 38 |
+
whisper_model = whisper.load_model("base", device=DEVICE)
|
| 39 |
+
print("Whisper Loaded.")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error loading Whisper: {e}")
|
| 42 |
+
whisper_model = None
|
| 43 |
+
|
| 44 |
+
pipelines = {}
|
| 45 |
+
def load_writing_model(task_name, model_id):
|
| 46 |
+
try:
|
| 47 |
+
print(f"Loading {task_name}: {model_id}...")
|
| 48 |
+
pipelines[task_name] = pipeline(
|
| 49 |
+
"text-classification", model=model_id, tokenizer=model_id,
|
| 50 |
+
device=DEVICE, return_all_scores=True
|
| 51 |
+
)
|
| 52 |
+
print(f"Loaded {task_name}.")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error loading {task_name}: {e}")
|
| 55 |
+
pipelines[task_name] = None
|
| 56 |
+
|
| 57 |
+
load_writing_model("task1", SCORER_MODEL_ID_TASK1)
|
| 58 |
+
load_writing_model("task2", SCORER_MODEL_ID_TASK2)
|
| 59 |
+
|
| 60 |
+
class WritingRequest(BaseModel):
|
| 61 |
+
task_type: int
|
| 62 |
+
prompt: str
|
| 63 |
+
essay: str
|
| 64 |
+
image_url: Optional[str] = None
|
| 65 |
+
|
| 66 |
+
class WritingScores(BaseModel):
|
| 67 |
+
taskResponse: float
|
| 68 |
+
coherenceCohesion: float
|
| 69 |
+
lexicalResource: float
|
| 70 |
+
grammaticalRange: float
|
| 71 |
+
|
| 72 |
+
class ShortFeedbackWriting(BaseModel):
|
| 73 |
+
taskResponse: str
|
| 74 |
+
coherenceCohesion: str
|
| 75 |
+
lexicalResource: str
|
| 76 |
+
grammaticalRange: str
|
| 77 |
+
|
| 78 |
+
class WritingResponse(BaseModel):
|
| 79 |
+
overallScore: float
|
| 80 |
+
imageDescription: Optional[str] = None
|
| 81 |
+
criteriaScores: WritingScores
|
| 82 |
+
shortFeedback: ShortFeedbackWriting
|
| 83 |
+
detailedFeedback: str
|
| 84 |
+
|
| 85 |
+
class SpeakingScores(BaseModel):
|
| 86 |
+
fluencyCoherence: float
|
| 87 |
+
lexicalResource: float
|
| 88 |
+
grammaticalRange: float
|
| 89 |
+
pronunciation: float
|
| 90 |
+
|
| 91 |
+
class PronunciationWord(BaseModel):
|
| 92 |
+
word: str
|
| 93 |
+
score: int
|
| 94 |
+
phonemes_expected: str
|
| 95 |
+
phonemes_actual: str
|
| 96 |
+
is_correct: bool
|
| 97 |
+
error_type: Optional[str] = None
|
| 98 |
+
|
| 99 |
+
class SpeakingResponse(BaseModel):
|
| 100 |
+
overallScore: float
|
| 101 |
+
transcript: str
|
| 102 |
+
refinedTranscript: str
|
| 103 |
+
betterVersion: str
|
| 104 |
+
criteriaScores: SpeakingScores
|
| 105 |
+
shortFeedback: dict
|
| 106 |
+
detailedFeedback: str
|
| 107 |
+
pronunciationBreakdown: List[PronunciationWord]
|
| 108 |
+
|
| 109 |
+
def round_to_half(score: float) -> float:
|
| 110 |
+
return round(score * 2) / 2
|
| 111 |
+
|
| 112 |
+
async def analyze_chart_image(image_url: str, prompt_text: str) -> str:
|
| 113 |
+
"""Vision AI for Task 1"""
|
| 114 |
+
if not image_url: return "No image provided."
|
| 115 |
+
print("Analyzing chart image...")
|
| 116 |
+
|
| 117 |
+
headers = { "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }
|
| 118 |
+
vision_prompt = f"""
|
| 119 |
+
Act as a data analyst. Describe this IELTS Writing Task 1 image in detail.
|
| 120 |
+
Focus strictly on the main trends, comparisons, and specific data points mentioned in the prompt: "{prompt_text}".
|
| 121 |
+
Output a factual description paragraph representing the 'Ground Truth' of the image.
|
| 122 |
+
"""
|
| 123 |
+
payload = {
|
| 124 |
+
"model": "gpt-4o",
|
| 125 |
+
"messages": [{"role": "user", "content": [
|
| 126 |
+
{"type": "text", "text": vision_prompt},
|
| 127 |
+
{"type": "image_url", "image_url": {"url": image_url}}
|
| 128 |
+
]}],
|
| 129 |
+
"max_tokens": 500
|
| 130 |
+
}
|
| 131 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 132 |
+
try:
|
| 133 |
+
resp = await client.post(OPENAI_API_URL, headers=headers, json=payload)
|
| 134 |
+
return resp.json()['choices'][0]['message']['content']
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Vision Error: {e}")
|
| 137 |
+
return ""
|
| 138 |
+
|
| 139 |
+
async def generate_writing_feedback(prompt: str, essay: str, scores: WritingScores, task_type: int, img_desc: str = "") -> dict:
|
| 140 |
+
print("Generating Writing feedback...")
|
| 141 |
+
scores_dict = scores.model_dump()
|
| 142 |
+
|
| 143 |
+
context_info = ""
|
| 144 |
+
criterion_1_name = "Task Response"
|
| 145 |
+
if task_type == 1:
|
| 146 |
+
context_info = f"IMAGE GROUND TRUTH: {img_desc}\n(Check if the student accurately reported this data)"
|
| 147 |
+
criterion_1_name = "Task Achievement"
|
| 148 |
+
|
| 149 |
+
system_prompt = f"""
|
| 150 |
+
You are a strict, expert IELTS Examiner.
|
| 151 |
+
|
| 152 |
+
TASK INFO:
|
| 153 |
+
- Type: Task {task_type}
|
| 154 |
+
- Prompt: "{prompt}"
|
| 155 |
+
{context_info}
|
| 156 |
+
|
| 157 |
+
STUDENT ESSAY:
|
| 158 |
+
"{essay}"
|
| 159 |
+
|
| 160 |
+
SCORES GIVEN (0-9):
|
| 161 |
+
{json.dumps(scores_dict)}
|
| 162 |
+
|
| 163 |
+
YOUR GOAL:
|
| 164 |
+
Provide a deeply analytical and educational feedback JSON.
|
| 165 |
+
|
| 166 |
+
INSTRUCTIONS FOR 'detailedFeedback':
|
| 167 |
+
The 'detailedFeedback' field MUST be a long Markdown string structured as follows:
|
| 168 |
+
|
| 169 |
+
1. **General Overview**: A brief summary of why the essay got this band score.
|
| 170 |
+
2. **Strengths & Weaknesses**: Bullet points highlighting what was done well and what was missing in each criteria (one by one, four criterias in total).
|
| 171 |
+
3. **Specific Corrections (CRITICAL)**:
|
| 172 |
+
- Identify 3-4 specific errors (grammar, vocab, or data accuracy).
|
| 173 |
+
- For each error, show the "Original Text" -> "Correction" -> "Explanation".
|
| 174 |
+
- Example: *Original: "The data shows an increase." -> Better: "The data illustrates a significant upward trend." (Explanation: Use more precise academic vocabulary).*
|
| 175 |
+
4. **Actionable Advice**: Give 2-3 concrete steps the student should take to improve their score next time.
|
| 176 |
+
|
| 177 |
+
Output JSON format:
|
| 178 |
+
{{
|
| 179 |
+
"shortFeedback": {{
|
| 180 |
+
"{criterion_1_name}": "...",
|
| 181 |
+
"Coherence and Cohesion": "...",
|
| 182 |
+
"Lexical Resource": "...",
|
| 183 |
+
"Grammatical Range and Accuracy": "..."
|
| 184 |
+
}},
|
| 185 |
+
"detailedFeedback": "MARKDOWN STRING..."
|
| 186 |
+
}}
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
payload = {
|
| 190 |
+
"model": "gpt-4o-mini",
|
| 191 |
+
"messages": [{"role": "system", "content": system_prompt}],
|
| 192 |
+
"response_format": {"type": "json_object"}
|
| 193 |
+
}
|
| 194 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 195 |
+
resp = await client.post(OPENAI_API_URL, headers={"Authorization": f"Bearer {OPENAI_API_KEY}"}, json=payload)
|
| 196 |
+
return json.loads(resp.json()['choices'][0]['message']['content'])
|
| 197 |
+
|
| 198 |
+
app = FastAPI(title="IELTS Full-Stack AI API (V15.0)")
|
| 199 |
+
|
| 200 |
+
@app.post("/grade/writing", response_model=WritingResponse)
|
| 201 |
+
async def grade_writing(request: WritingRequest):
|
| 202 |
+
model = pipelines.get(f"task{request.task_type}")
|
| 203 |
+
if not model: raise HTTPException(500, "Model not ready.")
|
| 204 |
+
|
| 205 |
+
image_desc = ""
|
| 206 |
+
if request.task_type == 1:
|
| 207 |
+
if not request.image_url: raise HTTPException(400, "Task 1 requires image_url.")
|
| 208 |
+
image_desc = await analyze_chart_image(request.image_url, request.prompt)
|
| 209 |
+
final_input = f"PROMPT: {request.prompt}\n\nIMAGE CONTEXT: {image_desc} [SEP] {request.essay}"
|
| 210 |
+
else:
|
| 211 |
+
final_input = f"{request.prompt} [SEP] {request.essay}"
|
| 212 |
+
|
| 213 |
+
results = model(final_input, truncation=True, max_length=512)[0]
|
| 214 |
+
raw = {item['label']: item['score'] for item in results}
|
| 215 |
+
|
| 216 |
+
def r(x): return round(x * 2) / 2
|
| 217 |
+
|
| 218 |
+
scores = WritingScores(
|
| 219 |
+
taskResponse=r(raw.get('LABEL_0', 1.0)),
|
| 220 |
+
coherenceCohesion=r(raw.get('LABEL_1', 1.0)),
|
| 221 |
+
lexicalResource=r(raw.get('LABEL_2', 1.0)),
|
| 222 |
+
grammaticalRange=r(raw.get('LABEL_3', 1.0))
|
| 223 |
+
)
|
| 224 |
+
overall = r((scores.taskResponse + scores.coherenceCohesion +
|
| 225 |
+
scores.lexicalResource + scores.grammaticalRange) / 4)
|
| 226 |
+
|
| 227 |
+
# Feedback
|
| 228 |
+
fb = await generate_writing_feedback(request.prompt, request.essay, scores, request.task_type, image_desc)
|
| 229 |
+
sf = fb.get("shortFeedback", {})
|
| 230 |
+
|
| 231 |
+
tr_fb = sf.get("Task Response") or sf.get("Task Achievement") or "No feedback"
|
| 232 |
+
|
| 233 |
+
return WritingResponse(
|
| 234 |
+
overallScore=overall,
|
| 235 |
+
imageDescription=image_desc if request.task_type == 1 else None,
|
| 236 |
+
criteriaScores=scores,
|
| 237 |
+
shortFeedback=ShortFeedbackWriting(
|
| 238 |
+
taskResponse=tr_fb,
|
| 239 |
+
coherenceCohesion=sf.get("Coherence and Cohesion", ""),
|
| 240 |
+
lexicalResource=sf.get("Lexical Resource", ""),
|
| 241 |
+
grammaticalRange=sf.get("Grammatical Range and Accuracy", "")
|
| 242 |
+
),
|
| 243 |
+
detailedFeedback=fb.get("detailedFeedback", "")
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
async def grade_speaking_with_gpt(transcript: str, metrics: dict, ipa_data: dict, prompt_text: str) -> dict:
|
| 247 |
+
"""
|
| 248 |
+
Generate Speaking feedback with Pronunciation Breakdown array.
|
| 249 |
+
"""
|
| 250 |
+
print("Generating Speaking feedback...")
|
| 251 |
+
|
| 252 |
+
system_prompt = f"""
|
| 253 |
+
You are an expert IELTS Speaking Examiner and Phonetician.
|
| 254 |
+
|
| 255 |
+
INPUT DATA:
|
| 256 |
+
- Question: "{prompt_text}"
|
| 257 |
+
- Transcript (Whisper): "{transcript}"
|
| 258 |
+
- Raw Audio IPA (Actual): /{ipa_data.get('actual_ipa', '')}/
|
| 259 |
+
- Expected IPA (Standard): /{ipa_data.get('expected_ipa', '')}/
|
| 260 |
+
|
| 261 |
+
METRICS:
|
| 262 |
+
- Speed: {metrics['wpm']:.1f} WPM
|
| 263 |
+
- Pauses: {metrics['pause_ratio']*100:.1f}%
|
| 264 |
+
|
| 265 |
+
YOUR TASK:
|
| 266 |
+
1. Score the 4 criteria (0-9).
|
| 267 |
+
2. **Pronunciation Breakdown**: Map words from Transcript to the IPA. Identify mispronounced words.
|
| 268 |
+
- Compare Actual vs Expected IPA for each word.
|
| 269 |
+
- Assign a score (1-10) for each word's pronunciation.
|
| 270 |
+
- Flag errors (e.g., 'severe_substitution' if user said 'trip' but meant 'subject').
|
| 271 |
+
|
| 272 |
+
OUTPUT JSON FORMAT (This is sample structure, replace with actual data):
|
| 273 |
+
{{
|
| 274 |
+
"scores": {{ "fluencyCoherence": 0.0, "lexicalResource": 0.0, "grammaticalRange": 0.0, "pronunciation": 0.0 }},
|
| 275 |
+
"shortFeedback": {{ "Fluency": "...", "Vocabulary": "...", "Grammar": "...", "Pronunciation": "..." }},
|
| 276 |
+
"detailedFeedback": "MARKDOWN string...",
|
| 277 |
+
"refinedTranscript": "Corrected version...",
|
| 278 |
+
"betterVersion": "Upgraded Band 8 version...",
|
| 279 |
+
"pronunciationBreakdown": [
|
| 280 |
+
{{
|
| 281 |
+
"word": "subject",
|
| 282 |
+
"score": 3,
|
| 283 |
+
"phonemes_expected": "s ʌ b dʒ ɛ k t",
|
| 284 |
+
"phonemes_actual": "t r ɪ p",
|
| 285 |
+
"is_correct": false,
|
| 286 |
+
"error_type": "severe_substitution"
|
| 287 |
+
}},
|
| 288 |
+
... (more words)
|
| 289 |
+
]
|
| 290 |
+
}}
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
payload = {
|
| 294 |
+
"model": "gpt-4o-mini",
|
| 295 |
+
"messages": [{"role": "system", "content": system_prompt}],
|
| 296 |
+
"response_format": {"type": "json_object"}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 300 |
+
resp = await client.post(OPENAI_API_URL, headers={"Authorization": f"Bearer {OPENAI_API_KEY}"}, json=payload)
|
| 301 |
+
return json.loads(resp.json()['choices'][0]['message']['content'])
|
| 302 |
+
|
| 303 |
+
@app.post("/grade/speaking", response_model=SpeakingResponse)
|
| 304 |
+
async def grade_speaking(audio: UploadFile = File(...), prompt: str = Form(...)):
|
| 305 |
+
temp_filename = f"temp_{uuid.uuid4()}.wav"
|
| 306 |
+
try:
|
| 307 |
+
with open(temp_filename, "wb") as buffer:
|
| 308 |
+
shutil.copyfileobj(audio.file, buffer)
|
| 309 |
+
|
| 310 |
+
# 1. Whisper & Acoustic Metrics
|
| 311 |
+
if not whisper_model: raise HTTPException(500, "Whisper missing")
|
| 312 |
+
res = whisper_model.transcribe(temp_filename)
|
| 313 |
+
transcript = res["text"].strip()
|
| 314 |
+
|
| 315 |
+
y, sr = librosa.load(temp_filename)
|
| 316 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
| 317 |
+
word_count = len(transcript.split())
|
| 318 |
+
wpm = (word_count / duration) * 60 if duration > 0 else 0
|
| 319 |
+
non_silent = librosa.effects.split(y, top_db=20)
|
| 320 |
+
silent_time = duration - sum([(e-s)/sr for s,e in non_silent])
|
| 321 |
+
pause_ratio = silent_time / duration if duration > 0 else 0
|
| 322 |
+
|
| 323 |
+
metrics = {"wpm": wpm, "pause_ratio": pause_ratio}
|
| 324 |
+
|
| 325 |
+
# 2. IPA Analysis (Subprocess based)
|
| 326 |
+
ipa_data = grade_pronunciation_advanced(temp_filename, transcript)
|
| 327 |
+
|
| 328 |
+
# 3. GPT Analysis
|
| 329 |
+
gpt_result = await grade_speaking_with_gpt(transcript, metrics, ipa_data, prompt)
|
| 330 |
+
scores = gpt_result.get("scores", {})
|
| 331 |
+
|
| 332 |
+
# 4. Response
|
| 333 |
+
criteria = SpeakingScores(
|
| 334 |
+
fluencyCoherence=round_to_half(scores.get("fluencyCoherence", 0)),
|
| 335 |
+
lexicalResource=round_to_half(scores.get("lexicalResource", 0)),
|
| 336 |
+
grammaticalRange=round_to_half(scores.get("grammaticalRange", 0)),
|
| 337 |
+
pronunciation=round_to_half(scores.get("pronunciation", 0))
|
| 338 |
+
)
|
| 339 |
+
overall = round_to_half((criteria.fluencyCoherence + criteria.lexicalResource +
|
| 340 |
+
criteria.grammaticalRange + criteria.pronunciation) / 4)
|
| 341 |
+
|
| 342 |
+
return SpeakingResponse(
|
| 343 |
+
overallScore=overall,
|
| 344 |
+
transcript=transcript,
|
| 345 |
+
refinedTranscript=gpt_result.get("refinedTranscript", ""),
|
| 346 |
+
betterVersion=gpt_result.get("betterVersion", ""),
|
| 347 |
+
criteriaScores=criteria,
|
| 348 |
+
shortFeedback=gpt_result.get("shortFeedback", {}),
|
| 349 |
+
detailedFeedback=gpt_result.get("detailedFeedback", ""),
|
| 350 |
+
pronunciationBreakdown=gpt_result.get("pronunciationBreakdown", [])
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print(f"Speaking Error: {e}")
|
| 355 |
+
import traceback
|
| 356 |
+
traceback.print_exc()
|
| 357 |
+
raise HTTPException(500, str(e))
|
| 358 |
+
finally:
|
| 359 |
+
if os.path.exists(temp_filename): os.remove(temp_filename)
|
| 360 |
+
|
| 361 |
+
@app.get("/")
|
| 362 |
+
def read_root():
|
| 363 |
+
return {"message": "IELTS API is running."}
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/clean_external_data.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
# Regex để bắt điểm (ví dụ: 7 hoặc 7.5 hoặc 6.0)
|
| 8 |
+
FLOAT_RE = r"(\d+(?:\.\d+)?)"
|
| 9 |
+
|
| 10 |
+
def to_float_safe(x):
|
| 11 |
+
"""Chuyển đổi an toàn sang float, nếu lỗi trả về None"""
|
| 12 |
+
try:
|
| 13 |
+
val = float(x)
|
| 14 |
+
# Kiểm tra điểm hợp lệ (0-9)
|
| 15 |
+
if 0 <= val <= 9:
|
| 16 |
+
return val
|
| 17 |
+
return None
|
| 18 |
+
except Exception:
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_chillies_dataset(dataset):
|
| 23 |
+
"""
|
| 24 |
+
Parser cho 'chillies/IELTS-writing-task-2-evaluation'.
|
| 25 |
+
Format: **Task Achievement: [7]** hoặc **Overall Band Score: [7.5]**
|
| 26 |
+
"""
|
| 27 |
+
print("Đang xử lý dataset 'chillies'...")
|
| 28 |
+
cleaned = []
|
| 29 |
+
bad_examples = 0
|
| 30 |
+
|
| 31 |
+
patterns = {
|
| 32 |
+
"task_response": re.compile(
|
| 33 |
+
r"\*\*Task Achievement:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
|
| 34 |
+
re.I
|
| 35 |
+
),
|
| 36 |
+
"coherence_cohesion": re.compile(
|
| 37 |
+
r"\*\*Coherence and Cohesion:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
|
| 38 |
+
re.I
|
| 39 |
+
),
|
| 40 |
+
"lexical_resource": re.compile(
|
| 41 |
+
r"\*\*Lexical Resource:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
|
| 42 |
+
re.I
|
| 43 |
+
),
|
| 44 |
+
"grammatical_range": re.compile(
|
| 45 |
+
r"\*\*Grammatical Range and Accuracy:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
|
| 46 |
+
re.I
|
| 47 |
+
),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
for item in tqdm(dataset, desc="Parsing chillies"):
|
| 51 |
+
try:
|
| 52 |
+
prompt = item.get('prompt', '').strip()
|
| 53 |
+
essay = item.get('essay', '').strip()
|
| 54 |
+
evaluation_text = item.get('evaluation', '')
|
| 55 |
+
|
| 56 |
+
if not (prompt and essay and evaluation_text and len(essay) > 50):
|
| 57 |
+
bad_examples += 1
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
scores = {}
|
| 61 |
+
for key, pattern in patterns.items():
|
| 62 |
+
match = pattern.search(evaluation_text)
|
| 63 |
+
if match:
|
| 64 |
+
score_str = match.group(1)
|
| 65 |
+
scores[key] = to_float_safe(score_str)
|
| 66 |
+
else:
|
| 67 |
+
scores[key] = None
|
| 68 |
+
|
| 69 |
+
if all(scores.values()):
|
| 70 |
+
standard_scores = {
|
| 71 |
+
"task_response": scores["task_response"],
|
| 72 |
+
"coherence_cohesion": scores["coherence_cohesion"],
|
| 73 |
+
"lexical_resource": scores["lexical_resource"],
|
| 74 |
+
"grammatical_range": scores["grammatical_range"]
|
| 75 |
+
}
|
| 76 |
+
cleaned.append({
|
| 77 |
+
"prompt_text": prompt,
|
| 78 |
+
"essay_text": essay,
|
| 79 |
+
"scores": standard_scores
|
| 80 |
+
})
|
| 81 |
+
else:
|
| 82 |
+
bad_examples += 1
|
| 83 |
+
except Exception:
|
| 84 |
+
bad_examples += 1
|
| 85 |
+
|
| 86 |
+
print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 87 |
+
return cleaned
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def parse_123harr_dataset(dataset):
|
| 91 |
+
"""
|
| 92 |
+
Parser cho '123Harr/IELTS-WT2-LLaMa3-1k'.
|
| 93 |
+
Lấy scores từ 'formatted' field
|
| 94 |
+
"""
|
| 95 |
+
print("Đang xử lý dataset '123Harr'...")
|
| 96 |
+
cleaned = []
|
| 97 |
+
bad_examples = 0
|
| 98 |
+
|
| 99 |
+
prompt_essay_re = re.compile(
|
| 100 |
+
r"<\|start_header_id\|>user<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>",
|
| 101 |
+
re.S
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
score_patterns = {
|
| 105 |
+
"task_response": re.compile(
|
| 106 |
+
r"(?:###|##|\*\*)?Task Achievement(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
|
| 107 |
+
re.I | re.M
|
| 108 |
+
),
|
| 109 |
+
"coherence_cohesion": re.compile(
|
| 110 |
+
r"(?:###|##|\*\*)?Coherence and Cohesion(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
|
| 111 |
+
re.I | re.M
|
| 112 |
+
),
|
| 113 |
+
"lexical_resource": re.compile(
|
| 114 |
+
r"(?:###|##|\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
|
| 115 |
+
re.I | re.M
|
| 116 |
+
),
|
| 117 |
+
"grammatical_range": re.compile(
|
| 118 |
+
r"(?:###|##|\*\*)?Grammatical Range and Accuracy(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
|
| 119 |
+
re.I | re.M
|
| 120 |
+
),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for item in tqdm(dataset, desc="Parsing 123Harr"):
|
| 124 |
+
try:
|
| 125 |
+
formatted_text = item.get('formatted', '')
|
| 126 |
+
|
| 127 |
+
if not formatted_text:
|
| 128 |
+
bad_examples += 1
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
matches = prompt_essay_re.findall(formatted_text)
|
| 132 |
+
|
| 133 |
+
if len(matches) < 2:
|
| 134 |
+
bad_examples += 1
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
prompt = matches[0].strip()
|
| 138 |
+
essay = matches[1].strip()
|
| 139 |
+
|
| 140 |
+
if not prompt or not essay or len(essay) < 50:
|
| 141 |
+
bad_examples += 1
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
scores = {}
|
| 145 |
+
for key, pattern in score_patterns.items():
|
| 146 |
+
match = pattern.search(formatted_text)
|
| 147 |
+
if match:
|
| 148 |
+
score_str = match.group(match.lastindex) if match.lastindex else match.group(1)
|
| 149 |
+
scores[key] = to_float_safe(score_str)
|
| 150 |
+
else:
|
| 151 |
+
scores[key] = None
|
| 152 |
+
|
| 153 |
+
if all(scores.values()):
|
| 154 |
+
standard_scores = {
|
| 155 |
+
"task_response": scores["task_response"],
|
| 156 |
+
"coherence_cohesion": scores["coherence_cohesion"],
|
| 157 |
+
"lexical_resource": scores["lexical_resource"],
|
| 158 |
+
"grammatical_range": scores["grammatical_range"]
|
| 159 |
+
}
|
| 160 |
+
cleaned.append({
|
| 161 |
+
"prompt_text": prompt,
|
| 162 |
+
"essay_text": essay,
|
| 163 |
+
"scores": standard_scores
|
| 164 |
+
})
|
| 165 |
+
else:
|
| 166 |
+
bad_examples += 1
|
| 167 |
+
except Exception:
|
| 168 |
+
bad_examples += 1
|
| 169 |
+
|
| 170 |
+
print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 171 |
+
return cleaned
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def parse_dpo_dataset(dataset):
|
| 175 |
+
"""
|
| 176 |
+
Parser cho 'chillies/DPO_ielts_writing'.
|
| 177 |
+
"""
|
| 178 |
+
print("Đang xử lý dataset 'DPO'...")
|
| 179 |
+
cleaned = []
|
| 180 |
+
bad_examples = 0
|
| 181 |
+
|
| 182 |
+
patterns_primary = {
|
| 183 |
+
"task_response": re.compile(
|
| 184 |
+
r"##\s*Task Achievement:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
|
| 185 |
+
re.I
|
| 186 |
+
),
|
| 187 |
+
"coherence_cohesion": re.compile(
|
| 188 |
+
r"##\s*Coherence and Cohesion:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
|
| 189 |
+
re.I
|
| 190 |
+
),
|
| 191 |
+
"lexical_resource": re.compile(
|
| 192 |
+
r"##\s*Lexical Resource(?:\s*\(Vocabulary\))?:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
|
| 193 |
+
re.I
|
| 194 |
+
),
|
| 195 |
+
"grammatical_range": re.compile(
|
| 196 |
+
r"##\s*Grammatical Range and Accuracy:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
|
| 197 |
+
re.I
|
| 198 |
+
),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
patterns_fallback = {
|
| 202 |
+
"task_response": re.compile(r"(?:\*\*)?Task Achievement(?:\*\*)?:\s*" + FLOAT_RE, re.I),
|
| 203 |
+
"coherence_cohesion": re.compile(r"(?:\*\*)?Coherence and Cohesion(?:\*\*)?:\s*" + FLOAT_RE, re.I),
|
| 204 |
+
"lexical_resource": re.compile(r"(?:\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:\s*" + FLOAT_RE, re.I),
|
| 205 |
+
"grammatical_range": re.compile(r"(?:\*\*)?Grammatical Range and Accuracy(?:\*\*)?:\s*" + FLOAT_RE, re.I),
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
for item in tqdm(dataset, desc="Parsing DPO"):
|
| 209 |
+
try:
|
| 210 |
+
prompt = item.get('prompt', '').strip()
|
| 211 |
+
essay = item.get('essay', '').strip()
|
| 212 |
+
chosen_text = item.get('chosen', '')
|
| 213 |
+
|
| 214 |
+
if not (prompt and essay and chosen_text and len(essay) > 50):
|
| 215 |
+
bad_examples += 1
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
scores = {}
|
| 219 |
+
|
| 220 |
+
for key, pattern in patterns_primary.items():
|
| 221 |
+
match = pattern.search(chosen_text)
|
| 222 |
+
if match:
|
| 223 |
+
scores[key] = to_float_safe(match.group(1))
|
| 224 |
+
else:
|
| 225 |
+
scores[key] = None
|
| 226 |
+
|
| 227 |
+
if not all(scores.values()):
|
| 228 |
+
scores = {}
|
| 229 |
+
for key, pattern in patterns_fallback.items():
|
| 230 |
+
match = pattern.search(chosen_text)
|
| 231 |
+
if match:
|
| 232 |
+
scores[key] = to_float_safe(match.group(1))
|
| 233 |
+
else:
|
| 234 |
+
scores[key] = None
|
| 235 |
+
|
| 236 |
+
if all(scores.values()):
|
| 237 |
+
standard_scores = {
|
| 238 |
+
"task_response": scores["task_response"],
|
| 239 |
+
"coherence_cohesion": scores["coherence_cohesion"],
|
| 240 |
+
"lexical_resource": scores["lexical_resource"],
|
| 241 |
+
"grammatical_range": scores["grammatical_range"]
|
| 242 |
+
}
|
| 243 |
+
cleaned.append({
|
| 244 |
+
"prompt_text": prompt,
|
| 245 |
+
"essay_text": essay,
|
| 246 |
+
"scores": standard_scores
|
| 247 |
+
})
|
| 248 |
+
else:
|
| 249 |
+
bad_examples += 1
|
| 250 |
+
except Exception:
|
| 251 |
+
bad_examples += 1
|
| 252 |
+
|
| 253 |
+
print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 254 |
+
return cleaned
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def parse_hadeel_dataset(dataset):
|
| 258 |
+
"""
|
| 259 |
+
Parser cho 'hadeelbkh/tokenized-IELTS-writing-task-2-evaluation'.
|
| 260 |
+
"""
|
| 261 |
+
print("Đang xử lý dataset 'hadeel'...")
|
| 262 |
+
cleaned = []
|
| 263 |
+
bad_examples = 0
|
| 264 |
+
|
| 265 |
+
patterns = {
|
| 266 |
+
"task_response": re.compile(
|
| 267 |
+
r"(?:\*\*)?task achievement(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
|
| 268 |
+
re.I
|
| 269 |
+
),
|
| 270 |
+
"coherence_cohesion": re.compile(
|
| 271 |
+
r"(?:\*\*)?coherence and cohesion(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
|
| 272 |
+
re.I
|
| 273 |
+
),
|
| 274 |
+
"lexical_resource": re.compile(
|
| 275 |
+
r"(?:\*\*)?lexical resource(?:\s*\(vocabulary\))?(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
|
| 276 |
+
re.I
|
| 277 |
+
),
|
| 278 |
+
"grammatical_range": re.compile(
|
| 279 |
+
r"(?:\*\*)?grammatical range and accuracy(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
|
| 280 |
+
re.I
|
| 281 |
+
),
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
for item in tqdm(dataset, desc="Parsing hadeel"):
|
| 285 |
+
try:
|
| 286 |
+
prompt = item.get('prompt', '').strip()
|
| 287 |
+
essay = item.get('essay', '').strip()
|
| 288 |
+
evaluation_text = item.get('evaluation', '')
|
| 289 |
+
|
| 290 |
+
if not (prompt and essay and evaluation_text and len(essay) > 50):
|
| 291 |
+
bad_examples += 1
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
scores = {}
|
| 295 |
+
for key, pattern in patterns.items():
|
| 296 |
+
match = pattern.search(evaluation_text)
|
| 297 |
+
if match:
|
| 298 |
+
score_str = match.group(1)
|
| 299 |
+
scores[key] = to_float_safe(score_str)
|
| 300 |
+
else:
|
| 301 |
+
scores[key] = None
|
| 302 |
+
|
| 303 |
+
if all(scores.values()):
|
| 304 |
+
standard_scores = {
|
| 305 |
+
"task_response": scores["task_response"],
|
| 306 |
+
"coherence_cohesion": scores["coherence_cohesion"],
|
| 307 |
+
"lexical_resource": scores["lexical_resource"],
|
| 308 |
+
"grammatical_range": scores["grammatical_range"]
|
| 309 |
+
}
|
| 310 |
+
cleaned.append({
|
| 311 |
+
"prompt_text": prompt,
|
| 312 |
+
"essay_text": essay,
|
| 313 |
+
"scores": standard_scores
|
| 314 |
+
})
|
| 315 |
+
else:
|
| 316 |
+
bad_examples += 1
|
| 317 |
+
except Exception:
|
| 318 |
+
bad_examples += 1
|
| 319 |
+
|
| 320 |
+
print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 321 |
+
return cleaned
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def parse_vietanh_dataset(dataset):
|
| 325 |
+
"""
|
| 326 |
+
Parser cho 'vietanh0802/ielts_writing_training_data_prepared'.
|
| 327 |
+
Format: <s>[INST] ... ### Prompt: ... ### Essay: ... [/INST] ...
|
| 328 |
+
"""
|
| 329 |
+
print("Đang xử lý dataset 'vietanh'...")
|
| 330 |
+
cleaned = []
|
| 331 |
+
bad_examples = 0
|
| 332 |
+
|
| 333 |
+
prompt_re = re.compile(r"### Prompt:\s*(.*?)(?=### Essay:|$)", re.S | re.I)
|
| 334 |
+
essay_re = re.compile(r"### Essay:\s*(.*?)(?=\[/INST\]|$)", re.S | re.I)
|
| 335 |
+
|
| 336 |
+
score_patterns = {
|
| 337 |
+
"task_response": re.compile(
|
| 338 |
+
r"(?:\*\*)?Task Achievement(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
|
| 339 |
+
re.I
|
| 340 |
+
),
|
| 341 |
+
"coherence_cohesion": re.compile(
|
| 342 |
+
r"(?:\*\*)?Coherence and Cohesion(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
|
| 343 |
+
re.I
|
| 344 |
+
),
|
| 345 |
+
"lexical_resource": re.compile(
|
| 346 |
+
r"(?:\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
|
| 347 |
+
re.I
|
| 348 |
+
),
|
| 349 |
+
"grammatical_range": re.compile(
|
| 350 |
+
r"(?:\*\*)?Grammatical Range and Accuracy(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
|
| 351 |
+
re.I
|
| 352 |
+
),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
for item in tqdm(dataset, desc="Parsing vietanh"):
|
| 356 |
+
try:
|
| 357 |
+
training_text = item.get('training_text', '')
|
| 358 |
+
|
| 359 |
+
if not training_text:
|
| 360 |
+
bad_examples += 1
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
prompt_match = prompt_re.search(training_text)
|
| 364 |
+
if not prompt_match:
|
| 365 |
+
bad_examples += 1
|
| 366 |
+
continue
|
| 367 |
+
prompt = prompt_match.group(1).strip()
|
| 368 |
+
|
| 369 |
+
essay_match = essay_re.search(training_text)
|
| 370 |
+
if not essay_match:
|
| 371 |
+
bad_examples += 1
|
| 372 |
+
continue
|
| 373 |
+
essay = essay_match.group(1).strip()
|
| 374 |
+
|
| 375 |
+
if not prompt or not essay or len(essay) < 50:
|
| 376 |
+
bad_examples += 1
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
scores = {}
|
| 380 |
+
for key, pattern in score_patterns.items():
|
| 381 |
+
match = pattern.search(training_text)
|
| 382 |
+
if match:
|
| 383 |
+
scores[key] = to_float_safe(match.group(1))
|
| 384 |
+
else:
|
| 385 |
+
scores[key] = None
|
| 386 |
+
|
| 387 |
+
if all(scores.values()):
|
| 388 |
+
standard_scores = {
|
| 389 |
+
"task_response": scores["task_response"],
|
| 390 |
+
"coherence_cohesion": scores["coherence_cohesion"],
|
| 391 |
+
"lexical_resource": scores["lexical_resource"],
|
| 392 |
+
"grammatical_range": scores["grammatical_range"]
|
| 393 |
+
}
|
| 394 |
+
cleaned.append({
|
| 395 |
+
"prompt_text": prompt,
|
| 396 |
+
"essay_text": essay,
|
| 397 |
+
"scores": standard_scores
|
| 398 |
+
})
|
| 399 |
+
else:
|
| 400 |
+
bad_examples += 1
|
| 401 |
+
except Exception:
|
| 402 |
+
bad_examples += 1
|
| 403 |
+
|
| 404 |
+
print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 405 |
+
return cleaned
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def main():
|
| 409 |
+
print("Đang tải các dataset từ Hugging Face...\n")
|
| 410 |
+
cache_dir = "./.cache/huggingface_datasets"
|
| 411 |
+
|
| 412 |
+
all_data = []
|
| 413 |
+
|
| 414 |
+
# Dataset 1: chillies/IELTS-writing-task-2-evaluation
|
| 415 |
+
try:
|
| 416 |
+
ds_chillies = load_dataset(
|
| 417 |
+
"chillies/IELTS-writing-task-2-evaluation",
|
| 418 |
+
split="train",
|
| 419 |
+
cache_dir=cache_dir
|
| 420 |
+
)
|
| 421 |
+
all_data.append(("chillies", parse_chillies_dataset(ds_chillies)))
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"✗ Lỗi tải chillies: {e}\n")
|
| 424 |
+
|
| 425 |
+
# Dataset 2: 123Harr/IELTS-WT2-LLaMa3-1k
|
| 426 |
+
try:
|
| 427 |
+
ds_123harr = load_dataset(
|
| 428 |
+
"123Harr/IELTS-WT2-LLaMa3-1k",
|
| 429 |
+
split="train",
|
| 430 |
+
cache_dir=cache_dir
|
| 431 |
+
)
|
| 432 |
+
all_data.append(("123Harr", parse_123harr_dataset(ds_123harr)))
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"✗ Lỗi tải 123Harr: {e}\n")
|
| 435 |
+
|
| 436 |
+
# Dataset 3: chillies/DPO_ielts_writing
|
| 437 |
+
try:
|
| 438 |
+
ds_chillies_2 = load_dataset(
|
| 439 |
+
"chillies/DPO_ielts_writing",
|
| 440 |
+
split="train",
|
| 441 |
+
cache_dir=cache_dir
|
| 442 |
+
)
|
| 443 |
+
all_data.append(("DPO", parse_dpo_dataset(ds_chillies_2)))
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f"✗ Lỗi tải DPO: {e}\n")
|
| 446 |
+
|
| 447 |
+
# Dataset 4: hadeelbkh/tokenized-IELTS-writing-task-2-evaluation
|
| 448 |
+
try:
|
| 449 |
+
ds_hadeel = load_dataset(
|
| 450 |
+
"hadeelbkh/tokenized-IELTS-writing-task-2-evaluation-DialoGPT-medium",
|
| 451 |
+
split="train",
|
| 452 |
+
cache_dir=cache_dir
|
| 453 |
+
)
|
| 454 |
+
all_data.append(("hadeel", parse_hadeel_dataset(ds_hadeel)))
|
| 455 |
+
except Exception as e:
|
| 456 |
+
print(f"✗ Lỗi tải hadeel: {e}\n")
|
| 457 |
+
|
| 458 |
+
# Dataset 5: vietanh0802/ielts_writing_training_data_prepared
|
| 459 |
+
try:
|
| 460 |
+
ds_vietanh = load_dataset(
|
| 461 |
+
"vietanh0802/ielts_writing_training_data_prepared",
|
| 462 |
+
split="train",
|
| 463 |
+
cache_dir=cache_dir
|
| 464 |
+
)
|
| 465 |
+
all_data.append(("vietanh", parse_vietanh_dataset(ds_vietanh)))
|
| 466 |
+
except Exception as e:
|
| 467 |
+
print(f"✗ Lỗi tải vietanh: {e}\n")
|
| 468 |
+
|
| 469 |
+
# Tính tổng
|
| 470 |
+
print("\n" + "="*60)
|
| 471 |
+
print("--- TỔNG HỢP ---")
|
| 472 |
+
print("="*60)
|
| 473 |
+
total = 0
|
| 474 |
+
for name, data in all_data:
|
| 475 |
+
count = len(data)
|
| 476 |
+
total += count
|
| 477 |
+
print(f"Dataset ({name:15}): {count:5d} mẫu")
|
| 478 |
+
|
| 479 |
+
print("="*60)
|
| 480 |
+
print(f"Tổng cộng mẫu hợp lệ: {total}")
|
| 481 |
+
print("="*60)
|
| 482 |
+
|
| 483 |
+
final_dataset = []
|
| 484 |
+
for name, data in all_data:
|
| 485 |
+
final_dataset.extend(data)
|
| 486 |
+
|
| 487 |
+
if not final_dataset:
|
| 488 |
+
print("✗ Lỗi: Không có dữ liệu nào được chuẩn hóa. Vui lòng kiểm tra lại script.")
|
| 489 |
+
return
|
| 490 |
+
|
| 491 |
+
output_dir = "data"
|
| 492 |
+
output_path = os.path.join(output_dir, "dataset_for_scorer.json")
|
| 493 |
+
|
| 494 |
+
if not os.path.exists(output_dir):
|
| 495 |
+
os.makedirs(output_dir)
|
| 496 |
+
print(f"✓ Đã tạo thư mục {output_dir}")
|
| 497 |
+
|
| 498 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 499 |
+
json.dump(final_dataset, f, ensure_ascii=False, indent=2)
|
| 500 |
+
|
| 501 |
+
print(f"✓ Đã ghi {len(final_dataset)} mẫu vào file '{output_path}'.")
|
| 502 |
+
print("\n✓ Hoàn tất! Bây giờ bạn có thể chạy 'src/train.py' trên Colab!")
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
src/clean_external_data_task1.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File: src/clean_external_data_task1.py
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def to_float_safe(x):
|
| 8 |
+
"""Chuyển đổi string sang float (xử lý cả '9' và '9.0')"""
|
| 9 |
+
try:
|
| 10 |
+
if x is None: return None
|
| 11 |
+
val = float(x)
|
| 12 |
+
if 0 <= val <= 9: return val
|
| 13 |
+
return None
|
| 14 |
+
except ValueError:
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
def parse_hai2131_dataset(dataset):
|
| 18 |
+
"""
|
| 19 |
+
Parser chuẩn cho 'hai2131/IELTS-essays-task-1'.
|
| 20 |
+
Kết hợp: Prompt + Image Description + Essay
|
| 21 |
+
"""
|
| 22 |
+
print("Đang xử lý dataset 'hai2131'...")
|
| 23 |
+
cleaned = []
|
| 24 |
+
bad_examples = 0
|
| 25 |
+
|
| 26 |
+
for item in tqdm(dataset, desc="Parsing hai2131"):
|
| 27 |
+
try:
|
| 28 |
+
# 1. Lấy thông tin đầu vào
|
| 29 |
+
prompt = item.get("subject") or ""
|
| 30 |
+
# QUAN TRỌNG: Lấy mô tả ảnh
|
| 31 |
+
img_desc = item.get("image_description") or ""
|
| 32 |
+
essay = item.get("content") or ""
|
| 33 |
+
|
| 34 |
+
# 2. Tạo input text kết hợp (Prompt + Context + Essay)
|
| 35 |
+
# Model sẽ đọc toàn bộ chuỗi này
|
| 36 |
+
full_prompt_text = f"PROMPT: {prompt}\n\nIMAGE CONTEXT: {img_desc}"
|
| 37 |
+
|
| 38 |
+
# 3. Lấy điểm số (Dataset này để điểm dạng string "9")
|
| 39 |
+
scores = {
|
| 40 |
+
"task_response": to_float_safe(item.get("task_response_score")),
|
| 41 |
+
"coherence_cohesion": to_float_safe(item.get("coherence_cohesion_score")),
|
| 42 |
+
"lexical_resource": to_float_safe(item.get("lexical_resource_score")),
|
| 43 |
+
"grammatical_range": to_float_safe(item.get("grammatical_range_accuracy_score"))
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# 4. Kiểm tra hợp lệ
|
| 47 |
+
if essay and all(scores.values()):
|
| 48 |
+
cleaned.append({
|
| 49 |
+
"prompt_text": full_prompt_text, # Input đặc biệt cho Task 1
|
| 50 |
+
"essay_text": essay,
|
| 51 |
+
"scores": scores
|
| 52 |
+
})
|
| 53 |
+
else:
|
| 54 |
+
bad_examples += 1
|
| 55 |
+
|
| 56 |
+
except Exception:
|
| 57 |
+
bad_examples += 1
|
| 58 |
+
|
| 59 |
+
print(f"hai2131: kept {len(cleaned)} samples, skipped {bad_examples}")
|
| 60 |
+
return cleaned
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
print("🚀 BẮT ĐẦU XỬ LÝ DATASET TASK 1 (hai2131)")
|
| 64 |
+
|
| 65 |
+
# Cache dir để tránh tải lại nhiều lần
|
| 66 |
+
cache_dir = "./.cache/huggingface_datasets_task1"
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Tải dataset (lần này sẽ nhanh vì nó nhỏ gọn)
|
| 70 |
+
dataset = load_dataset("hai2131/IELTS-essays-task-1", split="train", cache_dir=cache_dir)
|
| 71 |
+
|
| 72 |
+
# Xử lý
|
| 73 |
+
final_dataset = parse_hai2131_dataset(dataset)
|
| 74 |
+
|
| 75 |
+
if not final_dataset:
|
| 76 |
+
print("LỖI: Không có dữ liệu.")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
# Lưu file
|
| 80 |
+
output_dir = "data"
|
| 81 |
+
if not os.path.exists(output_dir): os.makedirs(output_dir)
|
| 82 |
+
output_path = os.path.join(output_dir, "dataset_for_scorer_task1.json")
|
| 83 |
+
|
| 84 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 85 |
+
json.dump(final_dataset, f, ensure_ascii=False, indent=2)
|
| 86 |
+
|
| 87 |
+
print(f"\n✅ HOÀN TẤT! Đã lưu {len(final_dataset)} mẫu vào '{output_path}'.")
|
| 88 |
+
print("💡 Lưu ý: 'prompt_text' bây giờ chứa cả Đề bài VÀ Mô tả ảnh.")
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"❌ Lỗi: {e}")
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
src/explore.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File: src/explore.py (SỬA LỖI - CHỈ CHẠY hai2131)
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from itertools import islice
|
| 7 |
+
import traceback
|
| 8 |
+
|
| 9 |
+
DATASET_LIST = {
|
| 10 |
+
# "trantac": "TraTacXiMuoi/Ielts_writing_task1_academic", # Tạm thời tắt do lỗi mạng
|
| 11 |
+
"hai2131": "hai2131/IELTS-essays-task-1"
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
NUM_SAMPLES_TO_VIEW = 2
|
| 15 |
+
SPLIT_NAME = "train"
|
| 16 |
+
|
| 17 |
+
def safe_value_to_string(value):
|
| 18 |
+
"""Chuyển đổi value thành string an toàn"""
|
| 19 |
+
if value is None:
|
| 20 |
+
return None
|
| 21 |
+
if isinstance(value, (str, int, float, bool)):
|
| 22 |
+
return value
|
| 23 |
+
if isinstance(value, dict):
|
| 24 |
+
return value
|
| 25 |
+
if isinstance(value, list):
|
| 26 |
+
return value
|
| 27 |
+
# Đối với các object khác (ảnh, audio, etc)
|
| 28 |
+
return f"<{type(value).__name__}>"
|
| 29 |
+
|
| 30 |
+
def explore_dataset(name: str, path: str, split: str, n: int):
|
| 31 |
+
"""
|
| 32 |
+
Tải N mẫu đầu tiên của một dataset từ Hugging Face và in cấu trúc của nó.
|
| 33 |
+
"""
|
| 34 |
+
print("="*80)
|
| 35 |
+
print(f"🕵️ Đang khám phá dataset: {name}")
|
| 36 |
+
print(f" Path: {path}")
|
| 37 |
+
print(f" Split: {split}")
|
| 38 |
+
print("="*80)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Tải N mẫu đầu tiên (không dùng streaming nữa,
|
| 42 |
+
# vì dataset hai2131 chỉ 8MB, tải luôn cho nhanh)
|
| 43 |
+
dataset = load_dataset(path, split=f"{split}[:{n}]")
|
| 44 |
+
|
| 45 |
+
print(f"\n✅ Tải thành công. Cấu trúc (Features):")
|
| 46 |
+
# In ra các cột và kiểu dữ liệu
|
| 47 |
+
print(dataset.features)
|
| 48 |
+
|
| 49 |
+
print(f"\n--- Đang xem {n} mẫu đầu tiên ---")
|
| 50 |
+
|
| 51 |
+
for i, item in enumerate(dataset):
|
| 52 |
+
print(f"\n--- Mẫu {i+1} ---")
|
| 53 |
+
|
| 54 |
+
printable_item = {}
|
| 55 |
+
for key, value in item.items():
|
| 56 |
+
printable_item[key] = safe_value_to_string(value)
|
| 57 |
+
|
| 58 |
+
print(json.dumps(printable_item, ensure_ascii=False, indent=2))
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"\n❌ LỖI khi tải hoặc đọc dataset '{name}':")
|
| 62 |
+
print(f" {e}")
|
| 63 |
+
traceback.print_exc()
|
| 64 |
+
|
| 65 |
+
def list_available_splits(path):
|
| 66 |
+
# Hàm này không cần thiết nữa nếu chúng ta tải trực tiếp
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
print("🚀 BẮT ĐẦU KHÁM PHÁ IELTS DATASETS (CHỈ hai2131)")
|
| 71 |
+
print("="*80)
|
| 72 |
+
|
| 73 |
+
for name, path in DATASET_LIST.items():
|
| 74 |
+
try:
|
| 75 |
+
explore_dataset(name, path, SPLIT_NAME, NUM_SAMPLES_TO_VIEW)
|
| 76 |
+
except KeyboardInterrupt:
|
| 77 |
+
print("\n⚠️ Bị gián đoạn bởi người dùng")
|
| 78 |
+
break
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"\n❌ Lỗi không mong muốn: {e}")
|
| 81 |
+
traceback.print_exc()
|
| 82 |
+
|
| 83 |
+
print("\n" + "-"*80)
|
| 84 |
+
|
| 85 |
+
print("\n" + "="*80)
|
| 86 |
+
print("✅ KHÁM PHÁ HOÀN TẤT")
|
| 87 |
+
print("="*80)
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
main()
|
src/explore_speaking.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import whisper
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
warnings.filterwarnings("ignore")
|
| 8 |
+
|
| 9 |
+
def analyze_speaking_audio(audio_path):
|
| 10 |
+
print(f"🎤 Đang phân tích file: {audio_path}")
|
| 11 |
+
|
| 12 |
+
# --- 1. Load Model Whisper (ASR) ---
|
| 13 |
+
print("⏳ Đang tải model Whisper (có thể lâu lần đầu)...")
|
| 14 |
+
model = whisper.load_model("base")
|
| 15 |
+
|
| 16 |
+
# --- 2. Transcribe (Chuyển giọng thành chữ) ---
|
| 17 |
+
print("📝 Đang chuyển đổi giọng nói...")
|
| 18 |
+
result = model.transcribe(audio_path, fp16=False)
|
| 19 |
+
transcript = result["text"].strip()
|
| 20 |
+
|
| 21 |
+
print("\n" + "="*40)
|
| 22 |
+
print("TRANSCRIPT:")
|
| 23 |
+
print(f"'{transcript}'")
|
| 24 |
+
print("="*40 + "\n")
|
| 25 |
+
|
| 26 |
+
# --- 3. Phân tích Fluency (Trôi chảy) ---
|
| 27 |
+
# Dùng librosa để phân tích tín hiệu âm thanh
|
| 28 |
+
y, sr = librosa.load(audio_path)
|
| 29 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
| 30 |
+
|
| 31 |
+
# Đếm số từ
|
| 32 |
+
word_count = len(transcript.split())
|
| 33 |
+
|
| 34 |
+
# Tính tốc độ nói (Words Per Minute - WPM)
|
| 35 |
+
wpm = (word_count / duration) * 60
|
| 36 |
+
|
| 37 |
+
# Phát hiện khoảng lặng (Pauses)
|
| 38 |
+
# top_db: ngưỡng decibel để coi là im lặng
|
| 39 |
+
non_silent_intervals = librosa.effects.split(y, top_db=20)
|
| 40 |
+
silent_duration = duration - sum([ (end-start)/sr for start, end in non_silent_intervals ])
|
| 41 |
+
pause_ratio = silent_duration / duration
|
| 42 |
+
|
| 43 |
+
print("ACOUSTIC METRICS:")
|
| 44 |
+
print(f"- Thời lượng (Duration): {duration:.2f} giây")
|
| 45 |
+
print(f"- Số từ (Word Count): {word_count}")
|
| 46 |
+
print(f"- Tốc độ (Speed): {wpm:.2f} WPM (Chuẩn IELTS 6.0+ thường > 100 WPM)")
|
| 47 |
+
print(f"- Thời gian im lặng: {silent_duration:.2f} giây ({pause_ratio*100:.1f}%)")
|
| 48 |
+
|
| 49 |
+
# --- 4. Đánh giá sơ bộ ---
|
| 50 |
+
fluency_score_est = 0
|
| 51 |
+
if wpm > 120: fluency_score_est = 7.0
|
| 52 |
+
elif wpm > 100: fluency_score_est = 6.0
|
| 53 |
+
elif wpm > 80: fluency_score_est = 5.0
|
| 54 |
+
else: fluency_score_est = 4.0
|
| 55 |
+
|
| 56 |
+
print(f"\n💡 Đánh giá sơ bộ Fluency: ~{fluency_score_est}")
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"transcript": transcript,
|
| 60 |
+
"wpm": wpm,
|
| 61 |
+
"pause_ratio": pause_ratio
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
sample_audio = "data/test_speaking.m4a"
|
| 66 |
+
|
| 67 |
+
if os.path.exists(sample_audio):
|
| 68 |
+
analyze_speaking_audio(sample_audio)
|
| 69 |
+
else:
|
| 70 |
+
print(f"Không tìm thấy file '{sample_audio}'.")
|
| 71 |
+
print("Hãy tạo một file ghi âm tiếng Anh, lưu vào đó và chạy lại.")
|
src/pronunciation.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import traceback
|
| 6 |
+
import subprocess
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
from transformers import (
|
| 10 |
+
Wav2Vec2ForCTC,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
Wav2Vec2FeatureExtractor
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
print("Loading Pronunciation module...")
|
| 16 |
+
|
| 17 |
+
MODEL_ID = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
| 18 |
+
model = None
|
| 19 |
+
tokenizer = None
|
| 20 |
+
feature_extractor = None
|
| 21 |
+
|
| 22 |
+
def find_espeak_exe():
|
| 23 |
+
candidates = [
|
| 24 |
+
r"C:\Program Files\eSpeak NG\espeak-ng.exe",
|
| 25 |
+
r"C:\Program Files (x86)\eSpeak NG\espeak-ng.exe",
|
| 26 |
+
r"D:\Program Files\eSpeak NG\espeak-ng.exe"
|
| 27 |
+
]
|
| 28 |
+
path_in_env = shutil.which("espeak-ng")
|
| 29 |
+
if path_in_env: return path_in_env
|
| 30 |
+
|
| 31 |
+
for path in candidates:
|
| 32 |
+
if os.path.exists(path):
|
| 33 |
+
return path
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
ESPEAK_PATH = find_espeak_exe()
|
| 37 |
+
if ESPEAK_PATH:
|
| 38 |
+
print(f"Found eSpeak at: {ESPEAK_PATH}")
|
| 39 |
+
else:
|
| 40 |
+
print("WARNING: eSpeak-ng not found. IPA generation will fail.")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
print("Loading Feature Extractor...")
|
| 44 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
|
| 45 |
+
|
| 46 |
+
print("Loading Tokenizer...")
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 48 |
+
|
| 49 |
+
print("Loading Acoustic Model...")
|
| 50 |
+
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
| 51 |
+
|
| 52 |
+
print("Pronunciation module ready.")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Failed to load AI model: {e}")
|
| 55 |
+
|
| 56 |
+
def get_expected_ipa(text):
|
| 57 |
+
"""Gọi subprocess espeak-ng.exe để lấy IPA chuẩn từ văn bản."""
|
| 58 |
+
if not ESPEAK_PATH:
|
| 59 |
+
return "N/A"
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
cmd = [ESPEAK_PATH, "-v", "en-us", "-q", "--ipa", text]
|
| 63 |
+
|
| 64 |
+
startupinfo = None
|
| 65 |
+
if os.name == 'nt':
|
| 66 |
+
startupinfo = subprocess.STARTUPINFO()
|
| 67 |
+
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
| 68 |
+
|
| 69 |
+
result = subprocess.run(
|
| 70 |
+
cmd,
|
| 71 |
+
capture_output=True,
|
| 72 |
+
text=True,
|
| 73 |
+
encoding='utf-8',
|
| 74 |
+
startupinfo=startupinfo
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if result.returncode == 0:
|
| 78 |
+
return result.stdout.strip().replace('\n', ' ')
|
| 79 |
+
else:
|
| 80 |
+
return "N/A"
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Subprocess error: {e}")
|
| 84 |
+
return "N/A"
|
| 85 |
+
|
| 86 |
+
def grade_pronunciation_advanced(audio_path, reference_text):
|
| 87 |
+
"""
|
| 88 |
+
Trả về chuỗi IPA thực tế (Audio) và IPA chuẩn (Text).
|
| 89 |
+
"""
|
| 90 |
+
actual_ipa = "N/A"
|
| 91 |
+
if model and tokenizer and feature_extractor:
|
| 92 |
+
try:
|
| 93 |
+
y, sr = librosa.load(audio_path, sr=16000)
|
| 94 |
+
|
| 95 |
+
input_values = feature_extractor(y, sampling_rate=16000, return_tensors="pt").input_values
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
logits = model(input_values).logits
|
| 98 |
+
|
| 99 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 100 |
+
actual_ipa = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"AI IPA Error: {e}")
|
| 104 |
+
actual_ipa = "Error"
|
| 105 |
+
|
| 106 |
+
expected_ipa = get_expected_ipa(reference_text)
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"actual_ipa": actual_ipa,
|
| 110 |
+
"expected_ipa": expected_ipa
|
| 111 |
+
}
|
src/train.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
from datasets import Dataset, DatasetDict
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoTokenizer,
|
| 6 |
+
AutoModelForSequenceClassification,
|
| 7 |
+
TrainingArguments,
|
| 8 |
+
Trainer,
|
| 9 |
+
EvalPrediction
|
| 10 |
+
)
|
| 11 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
| 12 |
+
from huggingface_hub import HfFolder, notebook_login
|
| 13 |
+
|
| 14 |
+
MODEL_NAME = "roberta-base"
|
| 15 |
+
DATASET_PATH = "/content/data/dataset_for_scorer.json"
|
| 16 |
+
MODEL_OUTPUT_DIR = "./ielts_grader_model"
|
| 17 |
+
HUB_MODEL_ID = "diminch/ielts-grader-ai"
|
| 18 |
+
|
| 19 |
+
def load_and_prepare_data(dataset_path):
|
| 20 |
+
print(f"Đang tải dữ liệu từ {dataset_path}...")
|
| 21 |
+
with open(dataset_path, "r", encoding="utf-8") as f:
|
| 22 |
+
raw_data = json.load(f)
|
| 23 |
+
|
| 24 |
+
processed_data = []
|
| 25 |
+
for item in raw_data:
|
| 26 |
+
text = item['prompt_text'] + " [SEP] " + item['essay_text']
|
| 27 |
+
|
| 28 |
+
labels = [
|
| 29 |
+
float(item['scores']['task_response']),
|
| 30 |
+
float(item['scores']['coherence_cohesion']),
|
| 31 |
+
float(item['scores']['lexical_resource']),
|
| 32 |
+
float(item['scores']['grammatical_range'])
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
processed_data.append({"text": text, "label": labels})
|
| 36 |
+
|
| 37 |
+
print(f"Tổng cộng {len(processed_data)} mẫu.")
|
| 38 |
+
|
| 39 |
+
dataset = Dataset.from_list(processed_data)
|
| 40 |
+
|
| 41 |
+
train_test_split = dataset.train_test_split(test_size=0.1)
|
| 42 |
+
|
| 43 |
+
dataset_dict = DatasetDict({
|
| 44 |
+
'train': train_test_split['train'],
|
| 45 |
+
'test': train_test_split['test']
|
| 46 |
+
})
|
| 47 |
+
|
| 48 |
+
return dataset_dict
|
| 49 |
+
|
| 50 |
+
def tokenize_data(dataset_dict, tokenizer):
|
| 51 |
+
print("Đang tokenize dữ liệu...")
|
| 52 |
+
def tokenize_function(examples):
|
| 53 |
+
return tokenizer(
|
| 54 |
+
examples['text'],
|
| 55 |
+
padding="max_length",
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=512
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
|
| 61 |
+
return tokenized_datasets
|
| 62 |
+
|
| 63 |
+
def compute_metrics(p: EvalPrediction):
|
| 64 |
+
preds = p.predictions
|
| 65 |
+
labels = p.label_ids
|
| 66 |
+
|
| 67 |
+
rmse_tr = np.sqrt(mean_squared_error(labels[:, 0], preds[:, 0]))
|
| 68 |
+
rmse_cc = np.sqrt(mean_squared_error(labels[:, 1], preds[:, 1]))
|
| 69 |
+
rmse_lr = np.sqrt(mean_squared_error(labels[:, 2], preds[:, 2]))
|
| 70 |
+
rmse_gra = np.sqrt(mean_squared_error(labels[:, 3], preds[:, 3]))
|
| 71 |
+
|
| 72 |
+
mae_tr = mean_absolute_error(labels[:, 0], preds[:, 0])
|
| 73 |
+
mae_cc = mean_absolute_error(labels[:, 1], preds[:, 1])
|
| 74 |
+
mae_lr = mean_absolute_error(labels[:, 2], preds[:, 2])
|
| 75 |
+
mae_gra = mean_absolute_error(labels[:, 3], preds[:, 3])
|
| 76 |
+
|
| 77 |
+
avg_rmse = np.mean([rmse_tr, rmse_cc, rmse_lr, rmse_gra])
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"avg_rmse": avg_rmse,
|
| 81 |
+
"rmse_task_response": rmse_tr,
|
| 82 |
+
"rmse_coherence_cohesion": rmse_cc,
|
| 83 |
+
"rmse_lexical_resource": rmse_lr,
|
| 84 |
+
"rmse_grammatical_range": rmse_gra,
|
| 85 |
+
"mae_task_response": mae_tr,
|
| 86 |
+
"mae_coherence_cohesion": mae_cc,
|
| 87 |
+
# ... có thể thêm các MAE khác
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
print("Vui lòng dán token Hugging Face (quyền 'write') của bạn:")
|
| 92 |
+
# (Nếu chạy trên Colab, nó sẽ hiện ô input)
|
| 93 |
+
# notebook_login()
|
| 94 |
+
# Hoặc nếu chạy local, dùng 'huggingface-cli login' trước
|
| 95 |
+
|
| 96 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 97 |
+
|
| 98 |
+
dataset_dict = load_and_prepare_data(DATASET_PATH)
|
| 99 |
+
tokenized_datasets = tokenize_data(dataset_dict, tokenizer)
|
| 100 |
+
|
| 101 |
+
print("Đang tải mô hình nền tảng...")
|
| 102 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 103 |
+
MODEL_NAME,
|
| 104 |
+
num_labels=4,
|
| 105 |
+
problem_type="regression"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
training_args = TrainingArguments(
|
| 109 |
+
output_dir=MODEL_OUTPUT_DIR,
|
| 110 |
+
learning_rate=2e-5,
|
| 111 |
+
per_device_train_batch_size=8,
|
| 112 |
+
per_device_eval_batch_size=8,
|
| 113 |
+
num_train_epochs=3,
|
| 114 |
+
weight_decay=0.01,
|
| 115 |
+
eval_strategy="epoch", # Changed evaluation_strategy to eval_strategy
|
| 116 |
+
save_strategy="epoch",
|
| 117 |
+
load_best_model_at_end=True,
|
| 118 |
+
metric_for_best_model="avg_rmse",
|
| 119 |
+
greater_is_better=False,
|
| 120 |
+
push_to_hub=True,
|
| 121 |
+
hub_model_id=HUB_MODEL_ID,
|
| 122 |
+
hub_strategy="end",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
trainer = Trainer(
|
| 126 |
+
model=model,
|
| 127 |
+
args=training_args,
|
| 128 |
+
train_dataset=tokenized_datasets["train"],
|
| 129 |
+
eval_dataset=tokenized_datasets["test"],
|
| 130 |
+
compute_metrics=compute_metrics,
|
| 131 |
+
tokenizer=tokenizer,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
print("--- BẮT ĐẦU HUẤN LUYỆN ---")
|
| 135 |
+
trainer.train()
|
| 136 |
+
print("--- HUẤN LUYỆN HOÀN TẤT ---")
|
| 137 |
+
|
| 138 |
+
print("--- ĐÁNH GIÁ TRÊN TẬP TEST ---")
|
| 139 |
+
eval_results = trainer.evaluate()
|
| 140 |
+
print(json.dumps(eval_results, indent=2))
|
| 141 |
+
|
| 142 |
+
print("Đang đẩy model tốt nhất lên Hugging Face Hub...")
|
| 143 |
+
trainer.push_to_hub()
|
| 144 |
+
print(f"Hoàn tất! Model của bạn đã ở trên Hub: https://huggingface.co/{HUB_MODEL_ID}")
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
import os
|
| 148 |
+
if not os.path.exists(DATASET_PATH):
|
| 149 |
+
print(f"LỖI: Không tìm thấy file {DATASET_PATH}.")
|
| 150 |
+
else:
|
| 151 |
+
main()
|