diminch commited on
Commit
d939bae
·
0 Parent(s):

Deploy V15 Clean (Removed binary files history)

Browse files
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File .gitignore
2
+ *.mp3
3
+ *.wav
4
+ *.m4a
5
+ data/
6
+
7
+ # Môi trường ảo
8
+ venv/
9
+
10
+ # Python cache
11
+ __pycache__/
12
+ *.pyc
13
+ .cache/
14
+
15
+ # File hệ điều hành
16
+ .DS_Store
17
+ Thumbs.db
18
+
19
+ # File IDE
20
+ .vscode/
21
+ .idea/
22
+
23
+ # File dữ liệu local và model đã huấn luyện
24
+ # (Chúng ta chỉ push code, không push data/model nặng)
25
+ *.json
26
+ *.csv
27
+ best_model/
28
+ ielts_grader_model/
29
+
30
+ # File .env (chứa API keys)
31
+ .env
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+ WORKDIR /app
3
+ RUN apt-get update && apt-get install -y \
4
+ espeak-ng \
5
+ ffmpeg \
6
+ && rm -rf /var/lib/apt/lists/*
7
+ COPY requirements.txt .
8
+ RUN pip install --no-cache-dir -r requirements.txt
9
+ COPY . .
10
+ RUN mkdir -p /app/.cache && chmod 777 /app/.cache
11
+ ENV TRANSFORMERS_CACHE=/app/.cache
12
+ EXPOSE 7860
13
+ CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: IELTS Grader API
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 7860
9
+ ---
10
+
11
+ # IELTS Grader AI API
12
+
13
+ Đây là Backend API chấm điểm IELTS Writing Task 1 và Task 2 sử dụng AI.
14
+
15
+ ## Các tính năng chính:
16
+
17
+ - **Task 1:** Phân tích biểu đồ (Vision AI) và chấm điểm.
18
+ - **Task 2:** Chấm điểm bài luận nghị luận xã hội.
19
+ - **Feedback:** Cung cấp nhận xét chi tiết và hướng dẫn sửa lỗi.
20
+
21
+ ## Cách sử dụng
22
+
23
+ API chạy tại port 7860.
24
+ Endpoint chính: `POST /grade`
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ transformers[torch]
3
+ datasets
4
+ scikit-learn
5
+ huggingface_hub
6
+ tqdm
7
+ torch
8
+ pydantic
9
+
10
+ # API
11
+ fastapi
12
+ uvicorn[standard]
13
+ httpx
14
+ python-dotenv
15
+
16
+ # Library
17
+ openai
18
+ librosa
19
+ openai-whisper
20
+ numpy
21
+ python-multipart
src/api.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
3
+ from pydantic import BaseModel, Field
4
+ from transformers import pipeline
5
+ import torch
6
+ import os
7
+ import json
8
+ import httpx
9
+ import shutil
10
+ import whisper
11
+ import librosa
12
+ import numpy as np
13
+ from dotenv import load_dotenv
14
+ from typing import Optional, List
15
+ import uuid
16
+
17
+ try:
18
+ from src.pronunciation import grade_pronunciation_advanced
19
+ except ImportError:
20
+ from pronunciation import grade_pronunciation_advanced
21
+
22
+ load_dotenv()
23
+
24
+ SCORER_MODEL_ID_TASK1 = "diminch/ielts-task1-grader-ai-v2"
25
+ SCORER_MODEL_ID_TASK2 = "diminch/ielts-grader-ai-v2"
26
+
27
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
28
+ print(f"API running on: {DEVICE}")
29
+
30
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
31
+ OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
32
+
33
+ if not OPENAI_API_KEY:
34
+ print("WARNING: OPENAI_API_KEY not found in .env")
35
+
36
+ print("Loading Whisper...")
37
+ try:
38
+ whisper_model = whisper.load_model("base", device=DEVICE)
39
+ print("Whisper Loaded.")
40
+ except Exception as e:
41
+ print(f"Error loading Whisper: {e}")
42
+ whisper_model = None
43
+
44
+ pipelines = {}
45
+ def load_writing_model(task_name, model_id):
46
+ try:
47
+ print(f"Loading {task_name}: {model_id}...")
48
+ pipelines[task_name] = pipeline(
49
+ "text-classification", model=model_id, tokenizer=model_id,
50
+ device=DEVICE, return_all_scores=True
51
+ )
52
+ print(f"Loaded {task_name}.")
53
+ except Exception as e:
54
+ print(f"Error loading {task_name}: {e}")
55
+ pipelines[task_name] = None
56
+
57
+ load_writing_model("task1", SCORER_MODEL_ID_TASK1)
58
+ load_writing_model("task2", SCORER_MODEL_ID_TASK2)
59
+
60
+ class WritingRequest(BaseModel):
61
+ task_type: int
62
+ prompt: str
63
+ essay: str
64
+ image_url: Optional[str] = None
65
+
66
+ class WritingScores(BaseModel):
67
+ taskResponse: float
68
+ coherenceCohesion: float
69
+ lexicalResource: float
70
+ grammaticalRange: float
71
+
72
+ class ShortFeedbackWriting(BaseModel):
73
+ taskResponse: str
74
+ coherenceCohesion: str
75
+ lexicalResource: str
76
+ grammaticalRange: str
77
+
78
+ class WritingResponse(BaseModel):
79
+ overallScore: float
80
+ imageDescription: Optional[str] = None
81
+ criteriaScores: WritingScores
82
+ shortFeedback: ShortFeedbackWriting
83
+ detailedFeedback: str
84
+
85
+ class SpeakingScores(BaseModel):
86
+ fluencyCoherence: float
87
+ lexicalResource: float
88
+ grammaticalRange: float
89
+ pronunciation: float
90
+
91
+ class PronunciationWord(BaseModel):
92
+ word: str
93
+ score: int
94
+ phonemes_expected: str
95
+ phonemes_actual: str
96
+ is_correct: bool
97
+ error_type: Optional[str] = None
98
+
99
+ class SpeakingResponse(BaseModel):
100
+ overallScore: float
101
+ transcript: str
102
+ refinedTranscript: str
103
+ betterVersion: str
104
+ criteriaScores: SpeakingScores
105
+ shortFeedback: dict
106
+ detailedFeedback: str
107
+ pronunciationBreakdown: List[PronunciationWord]
108
+
109
+ def round_to_half(score: float) -> float:
110
+ return round(score * 2) / 2
111
+
112
+ async def analyze_chart_image(image_url: str, prompt_text: str) -> str:
113
+ """Vision AI for Task 1"""
114
+ if not image_url: return "No image provided."
115
+ print("Analyzing chart image...")
116
+
117
+ headers = { "Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json" }
118
+ vision_prompt = f"""
119
+ Act as a data analyst. Describe this IELTS Writing Task 1 image in detail.
120
+ Focus strictly on the main trends, comparisons, and specific data points mentioned in the prompt: "{prompt_text}".
121
+ Output a factual description paragraph representing the 'Ground Truth' of the image.
122
+ """
123
+ payload = {
124
+ "model": "gpt-4o",
125
+ "messages": [{"role": "user", "content": [
126
+ {"type": "text", "text": vision_prompt},
127
+ {"type": "image_url", "image_url": {"url": image_url}}
128
+ ]}],
129
+ "max_tokens": 500
130
+ }
131
+ async with httpx.AsyncClient(timeout=60.0) as client:
132
+ try:
133
+ resp = await client.post(OPENAI_API_URL, headers=headers, json=payload)
134
+ return resp.json()['choices'][0]['message']['content']
135
+ except Exception as e:
136
+ print(f"Vision Error: {e}")
137
+ return ""
138
+
139
+ async def generate_writing_feedback(prompt: str, essay: str, scores: WritingScores, task_type: int, img_desc: str = "") -> dict:
140
+ print("Generating Writing feedback...")
141
+ scores_dict = scores.model_dump()
142
+
143
+ context_info = ""
144
+ criterion_1_name = "Task Response"
145
+ if task_type == 1:
146
+ context_info = f"IMAGE GROUND TRUTH: {img_desc}\n(Check if the student accurately reported this data)"
147
+ criterion_1_name = "Task Achievement"
148
+
149
+ system_prompt = f"""
150
+ You are a strict, expert IELTS Examiner.
151
+
152
+ TASK INFO:
153
+ - Type: Task {task_type}
154
+ - Prompt: "{prompt}"
155
+ {context_info}
156
+
157
+ STUDENT ESSAY:
158
+ "{essay}"
159
+
160
+ SCORES GIVEN (0-9):
161
+ {json.dumps(scores_dict)}
162
+
163
+ YOUR GOAL:
164
+ Provide a deeply analytical and educational feedback JSON.
165
+
166
+ INSTRUCTIONS FOR 'detailedFeedback':
167
+ The 'detailedFeedback' field MUST be a long Markdown string structured as follows:
168
+
169
+ 1. **General Overview**: A brief summary of why the essay got this band score.
170
+ 2. **Strengths & Weaknesses**: Bullet points highlighting what was done well and what was missing in each criteria (one by one, four criterias in total).
171
+ 3. **Specific Corrections (CRITICAL)**:
172
+ - Identify 3-4 specific errors (grammar, vocab, or data accuracy).
173
+ - For each error, show the "Original Text" -> "Correction" -> "Explanation".
174
+ - Example: *Original: "The data shows an increase." -> Better: "The data illustrates a significant upward trend." (Explanation: Use more precise academic vocabulary).*
175
+ 4. **Actionable Advice**: Give 2-3 concrete steps the student should take to improve their score next time.
176
+
177
+ Output JSON format:
178
+ {{
179
+ "shortFeedback": {{
180
+ "{criterion_1_name}": "...",
181
+ "Coherence and Cohesion": "...",
182
+ "Lexical Resource": "...",
183
+ "Grammatical Range and Accuracy": "..."
184
+ }},
185
+ "detailedFeedback": "MARKDOWN STRING..."
186
+ }}
187
+ """
188
+
189
+ payload = {
190
+ "model": "gpt-4o-mini",
191
+ "messages": [{"role": "system", "content": system_prompt}],
192
+ "response_format": {"type": "json_object"}
193
+ }
194
+ async with httpx.AsyncClient(timeout=60.0) as client:
195
+ resp = await client.post(OPENAI_API_URL, headers={"Authorization": f"Bearer {OPENAI_API_KEY}"}, json=payload)
196
+ return json.loads(resp.json()['choices'][0]['message']['content'])
197
+
198
+ app = FastAPI(title="IELTS Full-Stack AI API (V15.0)")
199
+
200
+ @app.post("/grade/writing", response_model=WritingResponse)
201
+ async def grade_writing(request: WritingRequest):
202
+ model = pipelines.get(f"task{request.task_type}")
203
+ if not model: raise HTTPException(500, "Model not ready.")
204
+
205
+ image_desc = ""
206
+ if request.task_type == 1:
207
+ if not request.image_url: raise HTTPException(400, "Task 1 requires image_url.")
208
+ image_desc = await analyze_chart_image(request.image_url, request.prompt)
209
+ final_input = f"PROMPT: {request.prompt}\n\nIMAGE CONTEXT: {image_desc} [SEP] {request.essay}"
210
+ else:
211
+ final_input = f"{request.prompt} [SEP] {request.essay}"
212
+
213
+ results = model(final_input, truncation=True, max_length=512)[0]
214
+ raw = {item['label']: item['score'] for item in results}
215
+
216
+ def r(x): return round(x * 2) / 2
217
+
218
+ scores = WritingScores(
219
+ taskResponse=r(raw.get('LABEL_0', 1.0)),
220
+ coherenceCohesion=r(raw.get('LABEL_1', 1.0)),
221
+ lexicalResource=r(raw.get('LABEL_2', 1.0)),
222
+ grammaticalRange=r(raw.get('LABEL_3', 1.0))
223
+ )
224
+ overall = r((scores.taskResponse + scores.coherenceCohesion +
225
+ scores.lexicalResource + scores.grammaticalRange) / 4)
226
+
227
+ # Feedback
228
+ fb = await generate_writing_feedback(request.prompt, request.essay, scores, request.task_type, image_desc)
229
+ sf = fb.get("shortFeedback", {})
230
+
231
+ tr_fb = sf.get("Task Response") or sf.get("Task Achievement") or "No feedback"
232
+
233
+ return WritingResponse(
234
+ overallScore=overall,
235
+ imageDescription=image_desc if request.task_type == 1 else None,
236
+ criteriaScores=scores,
237
+ shortFeedback=ShortFeedbackWriting(
238
+ taskResponse=tr_fb,
239
+ coherenceCohesion=sf.get("Coherence and Cohesion", ""),
240
+ lexicalResource=sf.get("Lexical Resource", ""),
241
+ grammaticalRange=sf.get("Grammatical Range and Accuracy", "")
242
+ ),
243
+ detailedFeedback=fb.get("detailedFeedback", "")
244
+ )
245
+
246
+ async def grade_speaking_with_gpt(transcript: str, metrics: dict, ipa_data: dict, prompt_text: str) -> dict:
247
+ """
248
+ Generate Speaking feedback with Pronunciation Breakdown array.
249
+ """
250
+ print("Generating Speaking feedback...")
251
+
252
+ system_prompt = f"""
253
+ You are an expert IELTS Speaking Examiner and Phonetician.
254
+
255
+ INPUT DATA:
256
+ - Question: "{prompt_text}"
257
+ - Transcript (Whisper): "{transcript}"
258
+ - Raw Audio IPA (Actual): /{ipa_data.get('actual_ipa', '')}/
259
+ - Expected IPA (Standard): /{ipa_data.get('expected_ipa', '')}/
260
+
261
+ METRICS:
262
+ - Speed: {metrics['wpm']:.1f} WPM
263
+ - Pauses: {metrics['pause_ratio']*100:.1f}%
264
+
265
+ YOUR TASK:
266
+ 1. Score the 4 criteria (0-9).
267
+ 2. **Pronunciation Breakdown**: Map words from Transcript to the IPA. Identify mispronounced words.
268
+ - Compare Actual vs Expected IPA for each word.
269
+ - Assign a score (1-10) for each word's pronunciation.
270
+ - Flag errors (e.g., 'severe_substitution' if user said 'trip' but meant 'subject').
271
+
272
+ OUTPUT JSON FORMAT (This is sample structure, replace with actual data):
273
+ {{
274
+ "scores": {{ "fluencyCoherence": 0.0, "lexicalResource": 0.0, "grammaticalRange": 0.0, "pronunciation": 0.0 }},
275
+ "shortFeedback": {{ "Fluency": "...", "Vocabulary": "...", "Grammar": "...", "Pronunciation": "..." }},
276
+ "detailedFeedback": "MARKDOWN string...",
277
+ "refinedTranscript": "Corrected version...",
278
+ "betterVersion": "Upgraded Band 8 version...",
279
+ "pronunciationBreakdown": [
280
+ {{
281
+ "word": "subject",
282
+ "score": 3,
283
+ "phonemes_expected": "s ʌ b dʒ ɛ k t",
284
+ "phonemes_actual": "t r ɪ p",
285
+ "is_correct": false,
286
+ "error_type": "severe_substitution"
287
+ }},
288
+ ... (more words)
289
+ ]
290
+ }}
291
+ """
292
+
293
+ payload = {
294
+ "model": "gpt-4o-mini",
295
+ "messages": [{"role": "system", "content": system_prompt}],
296
+ "response_format": {"type": "json_object"}
297
+ }
298
+
299
+ async with httpx.AsyncClient(timeout=60.0) as client:
300
+ resp = await client.post(OPENAI_API_URL, headers={"Authorization": f"Bearer {OPENAI_API_KEY}"}, json=payload)
301
+ return json.loads(resp.json()['choices'][0]['message']['content'])
302
+
303
+ @app.post("/grade/speaking", response_model=SpeakingResponse)
304
+ async def grade_speaking(audio: UploadFile = File(...), prompt: str = Form(...)):
305
+ temp_filename = f"temp_{uuid.uuid4()}.wav"
306
+ try:
307
+ with open(temp_filename, "wb") as buffer:
308
+ shutil.copyfileobj(audio.file, buffer)
309
+
310
+ # 1. Whisper & Acoustic Metrics
311
+ if not whisper_model: raise HTTPException(500, "Whisper missing")
312
+ res = whisper_model.transcribe(temp_filename)
313
+ transcript = res["text"].strip()
314
+
315
+ y, sr = librosa.load(temp_filename)
316
+ duration = librosa.get_duration(y=y, sr=sr)
317
+ word_count = len(transcript.split())
318
+ wpm = (word_count / duration) * 60 if duration > 0 else 0
319
+ non_silent = librosa.effects.split(y, top_db=20)
320
+ silent_time = duration - sum([(e-s)/sr for s,e in non_silent])
321
+ pause_ratio = silent_time / duration if duration > 0 else 0
322
+
323
+ metrics = {"wpm": wpm, "pause_ratio": pause_ratio}
324
+
325
+ # 2. IPA Analysis (Subprocess based)
326
+ ipa_data = grade_pronunciation_advanced(temp_filename, transcript)
327
+
328
+ # 3. GPT Analysis
329
+ gpt_result = await grade_speaking_with_gpt(transcript, metrics, ipa_data, prompt)
330
+ scores = gpt_result.get("scores", {})
331
+
332
+ # 4. Response
333
+ criteria = SpeakingScores(
334
+ fluencyCoherence=round_to_half(scores.get("fluencyCoherence", 0)),
335
+ lexicalResource=round_to_half(scores.get("lexicalResource", 0)),
336
+ grammaticalRange=round_to_half(scores.get("grammaticalRange", 0)),
337
+ pronunciation=round_to_half(scores.get("pronunciation", 0))
338
+ )
339
+ overall = round_to_half((criteria.fluencyCoherence + criteria.lexicalResource +
340
+ criteria.grammaticalRange + criteria.pronunciation) / 4)
341
+
342
+ return SpeakingResponse(
343
+ overallScore=overall,
344
+ transcript=transcript,
345
+ refinedTranscript=gpt_result.get("refinedTranscript", ""),
346
+ betterVersion=gpt_result.get("betterVersion", ""),
347
+ criteriaScores=criteria,
348
+ shortFeedback=gpt_result.get("shortFeedback", {}),
349
+ detailedFeedback=gpt_result.get("detailedFeedback", ""),
350
+ pronunciationBreakdown=gpt_result.get("pronunciationBreakdown", [])
351
+ )
352
+
353
+ except Exception as e:
354
+ print(f"Speaking Error: {e}")
355
+ import traceback
356
+ traceback.print_exc()
357
+ raise HTTPException(500, str(e))
358
+ finally:
359
+ if os.path.exists(temp_filename): os.remove(temp_filename)
360
+
361
+ @app.get("/")
362
+ def read_root():
363
+ return {"message": "IELTS API is running."}
364
+
365
+ if __name__ == "__main__":
366
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/clean_external_data.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ # Regex để bắt điểm (ví dụ: 7 hoặc 7.5 hoặc 6.0)
8
+ FLOAT_RE = r"(\d+(?:\.\d+)?)"
9
+
10
+ def to_float_safe(x):
11
+ """Chuyển đổi an toàn sang float, nếu lỗi trả về None"""
12
+ try:
13
+ val = float(x)
14
+ # Kiểm tra điểm hợp lệ (0-9)
15
+ if 0 <= val <= 9:
16
+ return val
17
+ return None
18
+ except Exception:
19
+ return None
20
+
21
+
22
+ def parse_chillies_dataset(dataset):
23
+ """
24
+ Parser cho 'chillies/IELTS-writing-task-2-evaluation'.
25
+ Format: **Task Achievement: [7]** hoặc **Overall Band Score: [7.5]**
26
+ """
27
+ print("Đang xử lý dataset 'chillies'...")
28
+ cleaned = []
29
+ bad_examples = 0
30
+
31
+ patterns = {
32
+ "task_response": re.compile(
33
+ r"\*\*Task Achievement:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
34
+ re.I
35
+ ),
36
+ "coherence_cohesion": re.compile(
37
+ r"\*\*Coherence and Cohesion:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
38
+ re.I
39
+ ),
40
+ "lexical_resource": re.compile(
41
+ r"\*\*Lexical Resource:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
42
+ re.I
43
+ ),
44
+ "grammatical_range": re.compile(
45
+ r"\*\*Grammatical Range and Accuracy:\s*\[?(" + FLOAT_RE + r")\]?\*\*",
46
+ re.I
47
+ ),
48
+ }
49
+
50
+ for item in tqdm(dataset, desc="Parsing chillies"):
51
+ try:
52
+ prompt = item.get('prompt', '').strip()
53
+ essay = item.get('essay', '').strip()
54
+ evaluation_text = item.get('evaluation', '')
55
+
56
+ if not (prompt and essay and evaluation_text and len(essay) > 50):
57
+ bad_examples += 1
58
+ continue
59
+
60
+ scores = {}
61
+ for key, pattern in patterns.items():
62
+ match = pattern.search(evaluation_text)
63
+ if match:
64
+ score_str = match.group(1)
65
+ scores[key] = to_float_safe(score_str)
66
+ else:
67
+ scores[key] = None
68
+
69
+ if all(scores.values()):
70
+ standard_scores = {
71
+ "task_response": scores["task_response"],
72
+ "coherence_cohesion": scores["coherence_cohesion"],
73
+ "lexical_resource": scores["lexical_resource"],
74
+ "grammatical_range": scores["grammatical_range"]
75
+ }
76
+ cleaned.append({
77
+ "prompt_text": prompt,
78
+ "essay_text": essay,
79
+ "scores": standard_scores
80
+ })
81
+ else:
82
+ bad_examples += 1
83
+ except Exception:
84
+ bad_examples += 1
85
+
86
+ print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
87
+ return cleaned
88
+
89
+
90
+ def parse_123harr_dataset(dataset):
91
+ """
92
+ Parser cho '123Harr/IELTS-WT2-LLaMa3-1k'.
93
+ Lấy scores từ 'formatted' field
94
+ """
95
+ print("Đang xử lý dataset '123Harr'...")
96
+ cleaned = []
97
+ bad_examples = 0
98
+
99
+ prompt_essay_re = re.compile(
100
+ r"<\|start_header_id\|>user<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>",
101
+ re.S
102
+ )
103
+
104
+ score_patterns = {
105
+ "task_response": re.compile(
106
+ r"(?:###|##|\*\*)?Task Achievement(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
107
+ re.I | re.M
108
+ ),
109
+ "coherence_cohesion": re.compile(
110
+ r"(?:###|##|\*\*)?Coherence and Cohesion(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
111
+ re.I | re.M
112
+ ),
113
+ "lexical_resource": re.compile(
114
+ r"(?:###|##|\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
115
+ re.I | re.M
116
+ ),
117
+ "grammatical_range": re.compile(
118
+ r"(?:###|##|\*\*)?Grammatical Range and Accuracy(?:\*\*)?:[\s\S]*?(?:Suggested Band Score|Band Score)?[\s\S]*?" + FLOAT_RE + r"(?:\s|$)",
119
+ re.I | re.M
120
+ ),
121
+ }
122
+
123
+ for item in tqdm(dataset, desc="Parsing 123Harr"):
124
+ try:
125
+ formatted_text = item.get('formatted', '')
126
+
127
+ if not formatted_text:
128
+ bad_examples += 1
129
+ continue
130
+
131
+ matches = prompt_essay_re.findall(formatted_text)
132
+
133
+ if len(matches) < 2:
134
+ bad_examples += 1
135
+ continue
136
+
137
+ prompt = matches[0].strip()
138
+ essay = matches[1].strip()
139
+
140
+ if not prompt or not essay or len(essay) < 50:
141
+ bad_examples += 1
142
+ continue
143
+
144
+ scores = {}
145
+ for key, pattern in score_patterns.items():
146
+ match = pattern.search(formatted_text)
147
+ if match:
148
+ score_str = match.group(match.lastindex) if match.lastindex else match.group(1)
149
+ scores[key] = to_float_safe(score_str)
150
+ else:
151
+ scores[key] = None
152
+
153
+ if all(scores.values()):
154
+ standard_scores = {
155
+ "task_response": scores["task_response"],
156
+ "coherence_cohesion": scores["coherence_cohesion"],
157
+ "lexical_resource": scores["lexical_resource"],
158
+ "grammatical_range": scores["grammatical_range"]
159
+ }
160
+ cleaned.append({
161
+ "prompt_text": prompt,
162
+ "essay_text": essay,
163
+ "scores": standard_scores
164
+ })
165
+ else:
166
+ bad_examples += 1
167
+ except Exception:
168
+ bad_examples += 1
169
+
170
+ print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
171
+ return cleaned
172
+
173
+
174
+ def parse_dpo_dataset(dataset):
175
+ """
176
+ Parser cho 'chillies/DPO_ielts_writing'.
177
+ """
178
+ print("Đang xử lý dataset 'DPO'...")
179
+ cleaned = []
180
+ bad_examples = 0
181
+
182
+ patterns_primary = {
183
+ "task_response": re.compile(
184
+ r"##\s*Task Achievement:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
185
+ re.I
186
+ ),
187
+ "coherence_cohesion": re.compile(
188
+ r"##\s*Coherence and Cohesion:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
189
+ re.I
190
+ ),
191
+ "lexical_resource": re.compile(
192
+ r"##\s*Lexical Resource(?:\s*\(Vocabulary\))?:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
193
+ re.I
194
+ ),
195
+ "grammatical_range": re.compile(
196
+ r"##\s*Grammatical Range and Accuracy:[\s\S]*?Suggested Band Score:\s*" + FLOAT_RE,
197
+ re.I
198
+ ),
199
+ }
200
+
201
+ patterns_fallback = {
202
+ "task_response": re.compile(r"(?:\*\*)?Task Achievement(?:\*\*)?:\s*" + FLOAT_RE, re.I),
203
+ "coherence_cohesion": re.compile(r"(?:\*\*)?Coherence and Cohesion(?:\*\*)?:\s*" + FLOAT_RE, re.I),
204
+ "lexical_resource": re.compile(r"(?:\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:\s*" + FLOAT_RE, re.I),
205
+ "grammatical_range": re.compile(r"(?:\*\*)?Grammatical Range and Accuracy(?:\*\*)?:\s*" + FLOAT_RE, re.I),
206
+ }
207
+
208
+ for item in tqdm(dataset, desc="Parsing DPO"):
209
+ try:
210
+ prompt = item.get('prompt', '').strip()
211
+ essay = item.get('essay', '').strip()
212
+ chosen_text = item.get('chosen', '')
213
+
214
+ if not (prompt and essay and chosen_text and len(essay) > 50):
215
+ bad_examples += 1
216
+ continue
217
+
218
+ scores = {}
219
+
220
+ for key, pattern in patterns_primary.items():
221
+ match = pattern.search(chosen_text)
222
+ if match:
223
+ scores[key] = to_float_safe(match.group(1))
224
+ else:
225
+ scores[key] = None
226
+
227
+ if not all(scores.values()):
228
+ scores = {}
229
+ for key, pattern in patterns_fallback.items():
230
+ match = pattern.search(chosen_text)
231
+ if match:
232
+ scores[key] = to_float_safe(match.group(1))
233
+ else:
234
+ scores[key] = None
235
+
236
+ if all(scores.values()):
237
+ standard_scores = {
238
+ "task_response": scores["task_response"],
239
+ "coherence_cohesion": scores["coherence_cohesion"],
240
+ "lexical_resource": scores["lexical_resource"],
241
+ "grammatical_range": scores["grammatical_range"]
242
+ }
243
+ cleaned.append({
244
+ "prompt_text": prompt,
245
+ "essay_text": essay,
246
+ "scores": standard_scores
247
+ })
248
+ else:
249
+ bad_examples += 1
250
+ except Exception:
251
+ bad_examples += 1
252
+
253
+ print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
254
+ return cleaned
255
+
256
+
257
+ def parse_hadeel_dataset(dataset):
258
+ """
259
+ Parser cho 'hadeelbkh/tokenized-IELTS-writing-task-2-evaluation'.
260
+ """
261
+ print("Đang xử lý dataset 'hadeel'...")
262
+ cleaned = []
263
+ bad_examples = 0
264
+
265
+ patterns = {
266
+ "task_response": re.compile(
267
+ r"(?:\*\*)?task achievement(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
268
+ re.I
269
+ ),
270
+ "coherence_cohesion": re.compile(
271
+ r"(?:\*\*)?coherence and cohesion(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
272
+ re.I
273
+ ),
274
+ "lexical_resource": re.compile(
275
+ r"(?:\*\*)?lexical resource(?:\s*\(vocabulary\))?(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
276
+ re.I
277
+ ),
278
+ "grammatical_range": re.compile(
279
+ r"(?:\*\*)?grammatical range and accuracy(?:\*\*)?:\s*-?\s*(" + FLOAT_RE + r")",
280
+ re.I
281
+ ),
282
+ }
283
+
284
+ for item in tqdm(dataset, desc="Parsing hadeel"):
285
+ try:
286
+ prompt = item.get('prompt', '').strip()
287
+ essay = item.get('essay', '').strip()
288
+ evaluation_text = item.get('evaluation', '')
289
+
290
+ if not (prompt and essay and evaluation_text and len(essay) > 50):
291
+ bad_examples += 1
292
+ continue
293
+
294
+ scores = {}
295
+ for key, pattern in patterns.items():
296
+ match = pattern.search(evaluation_text)
297
+ if match:
298
+ score_str = match.group(1)
299
+ scores[key] = to_float_safe(score_str)
300
+ else:
301
+ scores[key] = None
302
+
303
+ if all(scores.values()):
304
+ standard_scores = {
305
+ "task_response": scores["task_response"],
306
+ "coherence_cohesion": scores["coherence_cohesion"],
307
+ "lexical_resource": scores["lexical_resource"],
308
+ "grammatical_range": scores["grammatical_range"]
309
+ }
310
+ cleaned.append({
311
+ "prompt_text": prompt,
312
+ "essay_text": essay,
313
+ "scores": standard_scores
314
+ })
315
+ else:
316
+ bad_examples += 1
317
+ except Exception:
318
+ bad_examples += 1
319
+
320
+ print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
321
+ return cleaned
322
+
323
+
324
+ def parse_vietanh_dataset(dataset):
325
+ """
326
+ Parser cho 'vietanh0802/ielts_writing_training_data_prepared'.
327
+ Format: <s>[INST] ... ### Prompt: ... ### Essay: ... [/INST] ...
328
+ """
329
+ print("Đang xử lý dataset 'vietanh'...")
330
+ cleaned = []
331
+ bad_examples = 0
332
+
333
+ prompt_re = re.compile(r"### Prompt:\s*(.*?)(?=### Essay:|$)", re.S | re.I)
334
+ essay_re = re.compile(r"### Essay:\s*(.*?)(?=\[/INST\]|$)", re.S | re.I)
335
+
336
+ score_patterns = {
337
+ "task_response": re.compile(
338
+ r"(?:\*\*)?Task Achievement(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
339
+ re.I
340
+ ),
341
+ "coherence_cohesion": re.compile(
342
+ r"(?:\*\*)?Coherence and Cohesion(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
343
+ re.I
344
+ ),
345
+ "lexical_resource": re.compile(
346
+ r"(?:\*\*)?Lexical Resource(?:\s*\(Vocabulary\))?(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
347
+ re.I
348
+ ),
349
+ "grammatical_range": re.compile(
350
+ r"(?:\*\*)?Grammatical Range and Accuracy(?:\*\*)?:\s*\[?(" + FLOAT_RE + r")\]?",
351
+ re.I
352
+ ),
353
+ }
354
+
355
+ for item in tqdm(dataset, desc="Parsing vietanh"):
356
+ try:
357
+ training_text = item.get('training_text', '')
358
+
359
+ if not training_text:
360
+ bad_examples += 1
361
+ continue
362
+
363
+ prompt_match = prompt_re.search(training_text)
364
+ if not prompt_match:
365
+ bad_examples += 1
366
+ continue
367
+ prompt = prompt_match.group(1).strip()
368
+
369
+ essay_match = essay_re.search(training_text)
370
+ if not essay_match:
371
+ bad_examples += 1
372
+ continue
373
+ essay = essay_match.group(1).strip()
374
+
375
+ if not prompt or not essay or len(essay) < 50:
376
+ bad_examples += 1
377
+ continue
378
+
379
+ scores = {}
380
+ for key, pattern in score_patterns.items():
381
+ match = pattern.search(training_text)
382
+ if match:
383
+ scores[key] = to_float_safe(match.group(1))
384
+ else:
385
+ scores[key] = None
386
+
387
+ if all(scores.values()):
388
+ standard_scores = {
389
+ "task_response": scores["task_response"],
390
+ "coherence_cohesion": scores["coherence_cohesion"],
391
+ "lexical_resource": scores["lexical_resource"],
392
+ "grammatical_range": scores["grammatical_range"]
393
+ }
394
+ cleaned.append({
395
+ "prompt_text": prompt,
396
+ "essay_text": essay,
397
+ "scores": standard_scores
398
+ })
399
+ else:
400
+ bad_examples += 1
401
+ except Exception:
402
+ bad_examples += 1
403
+
404
+ print(f" ✓ kept {len(cleaned)} samples, skipped {bad_examples}")
405
+ return cleaned
406
+
407
+
408
+ def main():
409
+ print("Đang tải các dataset từ Hugging Face...\n")
410
+ cache_dir = "./.cache/huggingface_datasets"
411
+
412
+ all_data = []
413
+
414
+ # Dataset 1: chillies/IELTS-writing-task-2-evaluation
415
+ try:
416
+ ds_chillies = load_dataset(
417
+ "chillies/IELTS-writing-task-2-evaluation",
418
+ split="train",
419
+ cache_dir=cache_dir
420
+ )
421
+ all_data.append(("chillies", parse_chillies_dataset(ds_chillies)))
422
+ except Exception as e:
423
+ print(f"✗ Lỗi tải chillies: {e}\n")
424
+
425
+ # Dataset 2: 123Harr/IELTS-WT2-LLaMa3-1k
426
+ try:
427
+ ds_123harr = load_dataset(
428
+ "123Harr/IELTS-WT2-LLaMa3-1k",
429
+ split="train",
430
+ cache_dir=cache_dir
431
+ )
432
+ all_data.append(("123Harr", parse_123harr_dataset(ds_123harr)))
433
+ except Exception as e:
434
+ print(f"✗ Lỗi tải 123Harr: {e}\n")
435
+
436
+ # Dataset 3: chillies/DPO_ielts_writing
437
+ try:
438
+ ds_chillies_2 = load_dataset(
439
+ "chillies/DPO_ielts_writing",
440
+ split="train",
441
+ cache_dir=cache_dir
442
+ )
443
+ all_data.append(("DPO", parse_dpo_dataset(ds_chillies_2)))
444
+ except Exception as e:
445
+ print(f"✗ Lỗi tải DPO: {e}\n")
446
+
447
+ # Dataset 4: hadeelbkh/tokenized-IELTS-writing-task-2-evaluation
448
+ try:
449
+ ds_hadeel = load_dataset(
450
+ "hadeelbkh/tokenized-IELTS-writing-task-2-evaluation-DialoGPT-medium",
451
+ split="train",
452
+ cache_dir=cache_dir
453
+ )
454
+ all_data.append(("hadeel", parse_hadeel_dataset(ds_hadeel)))
455
+ except Exception as e:
456
+ print(f"✗ Lỗi tải hadeel: {e}\n")
457
+
458
+ # Dataset 5: vietanh0802/ielts_writing_training_data_prepared
459
+ try:
460
+ ds_vietanh = load_dataset(
461
+ "vietanh0802/ielts_writing_training_data_prepared",
462
+ split="train",
463
+ cache_dir=cache_dir
464
+ )
465
+ all_data.append(("vietanh", parse_vietanh_dataset(ds_vietanh)))
466
+ except Exception as e:
467
+ print(f"✗ Lỗi tải vietanh: {e}\n")
468
+
469
+ # Tính tổng
470
+ print("\n" + "="*60)
471
+ print("--- TỔNG HỢP ---")
472
+ print("="*60)
473
+ total = 0
474
+ for name, data in all_data:
475
+ count = len(data)
476
+ total += count
477
+ print(f"Dataset ({name:15}): {count:5d} mẫu")
478
+
479
+ print("="*60)
480
+ print(f"Tổng cộng mẫu hợp lệ: {total}")
481
+ print("="*60)
482
+
483
+ final_dataset = []
484
+ for name, data in all_data:
485
+ final_dataset.extend(data)
486
+
487
+ if not final_dataset:
488
+ print("✗ Lỗi: Không có dữ liệu nào được chuẩn hóa. Vui lòng kiểm tra lại script.")
489
+ return
490
+
491
+ output_dir = "data"
492
+ output_path = os.path.join(output_dir, "dataset_for_scorer.json")
493
+
494
+ if not os.path.exists(output_dir):
495
+ os.makedirs(output_dir)
496
+ print(f"✓ Đã tạo thư mục {output_dir}")
497
+
498
+ with open(output_path, "w", encoding="utf-8") as f:
499
+ json.dump(final_dataset, f, ensure_ascii=False, indent=2)
500
+
501
+ print(f"✓ Đã ghi {len(final_dataset)} mẫu vào file '{output_path}'.")
502
+ print("\n✓ Hoàn tất! Bây giờ bạn có thể chạy 'src/train.py' trên Colab!")
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
src/clean_external_data_task1.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: src/clean_external_data_task1.py
2
+ import json
3
+ import os
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ def to_float_safe(x):
8
+ """Chuyển đổi string sang float (xử lý cả '9' và '9.0')"""
9
+ try:
10
+ if x is None: return None
11
+ val = float(x)
12
+ if 0 <= val <= 9: return val
13
+ return None
14
+ except ValueError:
15
+ return None
16
+
17
+ def parse_hai2131_dataset(dataset):
18
+ """
19
+ Parser chuẩn cho 'hai2131/IELTS-essays-task-1'.
20
+ Kết hợp: Prompt + Image Description + Essay
21
+ """
22
+ print("Đang xử lý dataset 'hai2131'...")
23
+ cleaned = []
24
+ bad_examples = 0
25
+
26
+ for item in tqdm(dataset, desc="Parsing hai2131"):
27
+ try:
28
+ # 1. Lấy thông tin đầu vào
29
+ prompt = item.get("subject") or ""
30
+ # QUAN TRỌNG: Lấy mô tả ảnh
31
+ img_desc = item.get("image_description") or ""
32
+ essay = item.get("content") or ""
33
+
34
+ # 2. Tạo input text kết hợp (Prompt + Context + Essay)
35
+ # Model sẽ đọc toàn bộ chuỗi này
36
+ full_prompt_text = f"PROMPT: {prompt}\n\nIMAGE CONTEXT: {img_desc}"
37
+
38
+ # 3. Lấy điểm số (Dataset này để điểm dạng string "9")
39
+ scores = {
40
+ "task_response": to_float_safe(item.get("task_response_score")),
41
+ "coherence_cohesion": to_float_safe(item.get("coherence_cohesion_score")),
42
+ "lexical_resource": to_float_safe(item.get("lexical_resource_score")),
43
+ "grammatical_range": to_float_safe(item.get("grammatical_range_accuracy_score"))
44
+ }
45
+
46
+ # 4. Kiểm tra hợp lệ
47
+ if essay and all(scores.values()):
48
+ cleaned.append({
49
+ "prompt_text": full_prompt_text, # Input đặc biệt cho Task 1
50
+ "essay_text": essay,
51
+ "scores": scores
52
+ })
53
+ else:
54
+ bad_examples += 1
55
+
56
+ except Exception:
57
+ bad_examples += 1
58
+
59
+ print(f"hai2131: kept {len(cleaned)} samples, skipped {bad_examples}")
60
+ return cleaned
61
+
62
+ def main():
63
+ print("🚀 BẮT ĐẦU XỬ LÝ DATASET TASK 1 (hai2131)")
64
+
65
+ # Cache dir để tránh tải lại nhiều lần
66
+ cache_dir = "./.cache/huggingface_datasets_task1"
67
+
68
+ try:
69
+ # Tải dataset (lần này sẽ nhanh vì nó nhỏ gọn)
70
+ dataset = load_dataset("hai2131/IELTS-essays-task-1", split="train", cache_dir=cache_dir)
71
+
72
+ # Xử lý
73
+ final_dataset = parse_hai2131_dataset(dataset)
74
+
75
+ if not final_dataset:
76
+ print("LỖI: Không có dữ liệu.")
77
+ return
78
+
79
+ # Lưu file
80
+ output_dir = "data"
81
+ if not os.path.exists(output_dir): os.makedirs(output_dir)
82
+ output_path = os.path.join(output_dir, "dataset_for_scorer_task1.json")
83
+
84
+ with open(output_path, "w", encoding="utf-8") as f:
85
+ json.dump(final_dataset, f, ensure_ascii=False, indent=2)
86
+
87
+ print(f"\n✅ HOÀN TẤT! Đã lưu {len(final_dataset)} mẫu vào '{output_path}'.")
88
+ print("💡 Lưu ý: 'prompt_text' bây giờ chứa cả Đề bài VÀ Mô tả ảnh.")
89
+
90
+ except Exception as e:
91
+ print(f"❌ Lỗi: {e}")
92
+
93
+ if __name__ == "__main__":
94
+ main()
src/explore.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: src/explore.py (SỬA LỖI - CHỈ CHẠY hai2131)
2
+ import json
3
+ import os
4
+ import sys
5
+ from datasets import load_dataset
6
+ from itertools import islice
7
+ import traceback
8
+
9
+ DATASET_LIST = {
10
+ # "trantac": "TraTacXiMuoi/Ielts_writing_task1_academic", # Tạm thời tắt do lỗi mạng
11
+ "hai2131": "hai2131/IELTS-essays-task-1"
12
+ }
13
+
14
+ NUM_SAMPLES_TO_VIEW = 2
15
+ SPLIT_NAME = "train"
16
+
17
+ def safe_value_to_string(value):
18
+ """Chuyển đổi value thành string an toàn"""
19
+ if value is None:
20
+ return None
21
+ if isinstance(value, (str, int, float, bool)):
22
+ return value
23
+ if isinstance(value, dict):
24
+ return value
25
+ if isinstance(value, list):
26
+ return value
27
+ # Đối với các object khác (ảnh, audio, etc)
28
+ return f"<{type(value).__name__}>"
29
+
30
+ def explore_dataset(name: str, path: str, split: str, n: int):
31
+ """
32
+ Tải N mẫu đầu tiên của một dataset từ Hugging Face và in cấu trúc của nó.
33
+ """
34
+ print("="*80)
35
+ print(f"🕵️ Đang khám phá dataset: {name}")
36
+ print(f" Path: {path}")
37
+ print(f" Split: {split}")
38
+ print("="*80)
39
+
40
+ try:
41
+ # Tải N mẫu đầu tiên (không dùng streaming nữa,
42
+ # vì dataset hai2131 chỉ 8MB, tải luôn cho nhanh)
43
+ dataset = load_dataset(path, split=f"{split}[:{n}]")
44
+
45
+ print(f"\n✅ Tải thành công. Cấu trúc (Features):")
46
+ # In ra các cột và kiểu dữ liệu
47
+ print(dataset.features)
48
+
49
+ print(f"\n--- Đang xem {n} mẫu đầu tiên ---")
50
+
51
+ for i, item in enumerate(dataset):
52
+ print(f"\n--- Mẫu {i+1} ---")
53
+
54
+ printable_item = {}
55
+ for key, value in item.items():
56
+ printable_item[key] = safe_value_to_string(value)
57
+
58
+ print(json.dumps(printable_item, ensure_ascii=False, indent=2))
59
+
60
+ except Exception as e:
61
+ print(f"\n❌ LỖI khi tải hoặc đọc dataset '{name}':")
62
+ print(f" {e}")
63
+ traceback.print_exc()
64
+
65
+ def list_available_splits(path):
66
+ # Hàm này không cần thiết nữa nếu chúng ta tải trực tiếp
67
+ pass
68
+
69
+ def main():
70
+ print("🚀 BẮT ĐẦU KHÁM PHÁ IELTS DATASETS (CHỈ hai2131)")
71
+ print("="*80)
72
+
73
+ for name, path in DATASET_LIST.items():
74
+ try:
75
+ explore_dataset(name, path, SPLIT_NAME, NUM_SAMPLES_TO_VIEW)
76
+ except KeyboardInterrupt:
77
+ print("\n⚠️ Bị gián đoạn bởi người dùng")
78
+ break
79
+ except Exception as e:
80
+ print(f"\n❌ Lỗi không mong muốn: {e}")
81
+ traceback.print_exc()
82
+
83
+ print("\n" + "-"*80)
84
+
85
+ print("\n" + "="*80)
86
+ print("✅ KHÁM PHÁ HOÀN TẤT")
87
+ print("="*80)
88
+
89
+ if __name__ == "__main__":
90
+ main()
src/explore_speaking.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import librosa
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+
7
+ warnings.filterwarnings("ignore")
8
+
9
+ def analyze_speaking_audio(audio_path):
10
+ print(f"🎤 Đang phân tích file: {audio_path}")
11
+
12
+ # --- 1. Load Model Whisper (ASR) ---
13
+ print("⏳ Đang tải model Whisper (có thể lâu lần đầu)...")
14
+ model = whisper.load_model("base")
15
+
16
+ # --- 2. Transcribe (Chuyển giọng thành chữ) ---
17
+ print("📝 Đang chuyển đổi giọng nói...")
18
+ result = model.transcribe(audio_path, fp16=False)
19
+ transcript = result["text"].strip()
20
+
21
+ print("\n" + "="*40)
22
+ print("TRANSCRIPT:")
23
+ print(f"'{transcript}'")
24
+ print("="*40 + "\n")
25
+
26
+ # --- 3. Phân tích Fluency (Trôi chảy) ---
27
+ # Dùng librosa để phân tích tín hiệu âm thanh
28
+ y, sr = librosa.load(audio_path)
29
+ duration = librosa.get_duration(y=y, sr=sr)
30
+
31
+ # Đếm số từ
32
+ word_count = len(transcript.split())
33
+
34
+ # Tính tốc độ nói (Words Per Minute - WPM)
35
+ wpm = (word_count / duration) * 60
36
+
37
+ # Phát hiện khoảng lặng (Pauses)
38
+ # top_db: ngưỡng decibel để coi là im lặng
39
+ non_silent_intervals = librosa.effects.split(y, top_db=20)
40
+ silent_duration = duration - sum([ (end-start)/sr for start, end in non_silent_intervals ])
41
+ pause_ratio = silent_duration / duration
42
+
43
+ print("ACOUSTIC METRICS:")
44
+ print(f"- Thời lượng (Duration): {duration:.2f} giây")
45
+ print(f"- Số từ (Word Count): {word_count}")
46
+ print(f"- Tốc độ (Speed): {wpm:.2f} WPM (Chuẩn IELTS 6.0+ thường > 100 WPM)")
47
+ print(f"- Thời gian im lặng: {silent_duration:.2f} giây ({pause_ratio*100:.1f}%)")
48
+
49
+ # --- 4. Đánh giá sơ bộ ---
50
+ fluency_score_est = 0
51
+ if wpm > 120: fluency_score_est = 7.0
52
+ elif wpm > 100: fluency_score_est = 6.0
53
+ elif wpm > 80: fluency_score_est = 5.0
54
+ else: fluency_score_est = 4.0
55
+
56
+ print(f"\n💡 Đánh giá sơ bộ Fluency: ~{fluency_score_est}")
57
+
58
+ return {
59
+ "transcript": transcript,
60
+ "wpm": wpm,
61
+ "pause_ratio": pause_ratio
62
+ }
63
+
64
+ if __name__ == "__main__":
65
+ sample_audio = "data/test_speaking.m4a"
66
+
67
+ if os.path.exists(sample_audio):
68
+ analyze_speaking_audio(sample_audio)
69
+ else:
70
+ print(f"Không tìm thấy file '{sample_audio}'.")
71
+ print("Hãy tạo một file ghi âm tiếng Anh, lưu vào đó và chạy lại.")
src/pronunciation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ import os
5
+ import traceback
6
+ import subprocess
7
+ import shutil
8
+
9
+ from transformers import (
10
+ Wav2Vec2ForCTC,
11
+ AutoTokenizer,
12
+ Wav2Vec2FeatureExtractor
13
+ )
14
+
15
+ print("Loading Pronunciation module...")
16
+
17
+ MODEL_ID = "facebook/wav2vec2-lv-60-espeak-cv-ft"
18
+ model = None
19
+ tokenizer = None
20
+ feature_extractor = None
21
+
22
+ def find_espeak_exe():
23
+ candidates = [
24
+ r"C:\Program Files\eSpeak NG\espeak-ng.exe",
25
+ r"C:\Program Files (x86)\eSpeak NG\espeak-ng.exe",
26
+ r"D:\Program Files\eSpeak NG\espeak-ng.exe"
27
+ ]
28
+ path_in_env = shutil.which("espeak-ng")
29
+ if path_in_env: return path_in_env
30
+
31
+ for path in candidates:
32
+ if os.path.exists(path):
33
+ return path
34
+ return None
35
+
36
+ ESPEAK_PATH = find_espeak_exe()
37
+ if ESPEAK_PATH:
38
+ print(f"Found eSpeak at: {ESPEAK_PATH}")
39
+ else:
40
+ print("WARNING: eSpeak-ng not found. IPA generation will fail.")
41
+
42
+ try:
43
+ print("Loading Feature Extractor...")
44
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
45
+
46
+ print("Loading Tokenizer...")
47
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
48
+
49
+ print("Loading Acoustic Model...")
50
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
51
+
52
+ print("Pronunciation module ready.")
53
+ except Exception as e:
54
+ print(f"Failed to load AI model: {e}")
55
+
56
+ def get_expected_ipa(text):
57
+ """Gọi subprocess espeak-ng.exe để lấy IPA chuẩn từ văn bản."""
58
+ if not ESPEAK_PATH:
59
+ return "N/A"
60
+
61
+ try:
62
+ cmd = [ESPEAK_PATH, "-v", "en-us", "-q", "--ipa", text]
63
+
64
+ startupinfo = None
65
+ if os.name == 'nt':
66
+ startupinfo = subprocess.STARTUPINFO()
67
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
68
+
69
+ result = subprocess.run(
70
+ cmd,
71
+ capture_output=True,
72
+ text=True,
73
+ encoding='utf-8',
74
+ startupinfo=startupinfo
75
+ )
76
+
77
+ if result.returncode == 0:
78
+ return result.stdout.strip().replace('\n', ' ')
79
+ else:
80
+ return "N/A"
81
+
82
+ except Exception as e:
83
+ print(f"Subprocess error: {e}")
84
+ return "N/A"
85
+
86
+ def grade_pronunciation_advanced(audio_path, reference_text):
87
+ """
88
+ Trả về chuỗi IPA thực tế (Audio) và IPA chuẩn (Text).
89
+ """
90
+ actual_ipa = "N/A"
91
+ if model and tokenizer and feature_extractor:
92
+ try:
93
+ y, sr = librosa.load(audio_path, sr=16000)
94
+
95
+ input_values = feature_extractor(y, sampling_rate=16000, return_tensors="pt").input_values
96
+ with torch.no_grad():
97
+ logits = model(input_values).logits
98
+
99
+ predicted_ids = torch.argmax(logits, dim=-1)
100
+ actual_ipa = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
101
+
102
+ except Exception as e:
103
+ print(f"AI IPA Error: {e}")
104
+ actual_ipa = "Error"
105
+
106
+ expected_ipa = get_expected_ipa(reference_text)
107
+
108
+ return {
109
+ "actual_ipa": actual_ipa,
110
+ "expected_ipa": expected_ipa
111
+ }
src/train.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from datasets import Dataset, DatasetDict
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ TrainingArguments,
8
+ Trainer,
9
+ EvalPrediction
10
+ )
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
12
+ from huggingface_hub import HfFolder, notebook_login
13
+
14
+ MODEL_NAME = "roberta-base"
15
+ DATASET_PATH = "/content/data/dataset_for_scorer.json"
16
+ MODEL_OUTPUT_DIR = "./ielts_grader_model"
17
+ HUB_MODEL_ID = "diminch/ielts-grader-ai"
18
+
19
+ def load_and_prepare_data(dataset_path):
20
+ print(f"Đang tải dữ liệu từ {dataset_path}...")
21
+ with open(dataset_path, "r", encoding="utf-8") as f:
22
+ raw_data = json.load(f)
23
+
24
+ processed_data = []
25
+ for item in raw_data:
26
+ text = item['prompt_text'] + " [SEP] " + item['essay_text']
27
+
28
+ labels = [
29
+ float(item['scores']['task_response']),
30
+ float(item['scores']['coherence_cohesion']),
31
+ float(item['scores']['lexical_resource']),
32
+ float(item['scores']['grammatical_range'])
33
+ ]
34
+
35
+ processed_data.append({"text": text, "label": labels})
36
+
37
+ print(f"Tổng cộng {len(processed_data)} mẫu.")
38
+
39
+ dataset = Dataset.from_list(processed_data)
40
+
41
+ train_test_split = dataset.train_test_split(test_size=0.1)
42
+
43
+ dataset_dict = DatasetDict({
44
+ 'train': train_test_split['train'],
45
+ 'test': train_test_split['test']
46
+ })
47
+
48
+ return dataset_dict
49
+
50
+ def tokenize_data(dataset_dict, tokenizer):
51
+ print("Đang tokenize dữ liệu...")
52
+ def tokenize_function(examples):
53
+ return tokenizer(
54
+ examples['text'],
55
+ padding="max_length",
56
+ truncation=True,
57
+ max_length=512
58
+ )
59
+
60
+ tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
61
+ return tokenized_datasets
62
+
63
+ def compute_metrics(p: EvalPrediction):
64
+ preds = p.predictions
65
+ labels = p.label_ids
66
+
67
+ rmse_tr = np.sqrt(mean_squared_error(labels[:, 0], preds[:, 0]))
68
+ rmse_cc = np.sqrt(mean_squared_error(labels[:, 1], preds[:, 1]))
69
+ rmse_lr = np.sqrt(mean_squared_error(labels[:, 2], preds[:, 2]))
70
+ rmse_gra = np.sqrt(mean_squared_error(labels[:, 3], preds[:, 3]))
71
+
72
+ mae_tr = mean_absolute_error(labels[:, 0], preds[:, 0])
73
+ mae_cc = mean_absolute_error(labels[:, 1], preds[:, 1])
74
+ mae_lr = mean_absolute_error(labels[:, 2], preds[:, 2])
75
+ mae_gra = mean_absolute_error(labels[:, 3], preds[:, 3])
76
+
77
+ avg_rmse = np.mean([rmse_tr, rmse_cc, rmse_lr, rmse_gra])
78
+
79
+ return {
80
+ "avg_rmse": avg_rmse,
81
+ "rmse_task_response": rmse_tr,
82
+ "rmse_coherence_cohesion": rmse_cc,
83
+ "rmse_lexical_resource": rmse_lr,
84
+ "rmse_grammatical_range": rmse_gra,
85
+ "mae_task_response": mae_tr,
86
+ "mae_coherence_cohesion": mae_cc,
87
+ # ... có thể thêm các MAE khác
88
+ }
89
+
90
+ def main():
91
+ print("Vui lòng dán token Hugging Face (quyền 'write') của bạn:")
92
+ # (Nếu chạy trên Colab, nó sẽ hiện ô input)
93
+ # notebook_login()
94
+ # Hoặc nếu chạy local, dùng 'huggingface-cli login' trước
95
+
96
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
97
+
98
+ dataset_dict = load_and_prepare_data(DATASET_PATH)
99
+ tokenized_datasets = tokenize_data(dataset_dict, tokenizer)
100
+
101
+ print("Đang tải mô hình nền tảng...")
102
+ model = AutoModelForSequenceClassification.from_pretrained(
103
+ MODEL_NAME,
104
+ num_labels=4,
105
+ problem_type="regression"
106
+ )
107
+
108
+ training_args = TrainingArguments(
109
+ output_dir=MODEL_OUTPUT_DIR,
110
+ learning_rate=2e-5,
111
+ per_device_train_batch_size=8,
112
+ per_device_eval_batch_size=8,
113
+ num_train_epochs=3,
114
+ weight_decay=0.01,
115
+ eval_strategy="epoch", # Changed evaluation_strategy to eval_strategy
116
+ save_strategy="epoch",
117
+ load_best_model_at_end=True,
118
+ metric_for_best_model="avg_rmse",
119
+ greater_is_better=False,
120
+ push_to_hub=True,
121
+ hub_model_id=HUB_MODEL_ID,
122
+ hub_strategy="end",
123
+ )
124
+
125
+ trainer = Trainer(
126
+ model=model,
127
+ args=training_args,
128
+ train_dataset=tokenized_datasets["train"],
129
+ eval_dataset=tokenized_datasets["test"],
130
+ compute_metrics=compute_metrics,
131
+ tokenizer=tokenizer,
132
+ )
133
+
134
+ print("--- BẮT ĐẦU HUẤN LUYỆN ---")
135
+ trainer.train()
136
+ print("--- HUẤN LUYỆN HOÀN TẤT ---")
137
+
138
+ print("--- ĐÁNH GIÁ TRÊN TẬP TEST ---")
139
+ eval_results = trainer.evaluate()
140
+ print(json.dumps(eval_results, indent=2))
141
+
142
+ print("Đang đẩy model tốt nhất lên Hugging Face Hub...")
143
+ trainer.push_to_hub()
144
+ print(f"Hoàn tất! Model của bạn đã ở trên Hub: https://huggingface.co/{HUB_MODEL_ID}")
145
+
146
+ if __name__ == "__main__":
147
+ import os
148
+ if not os.path.exists(DATASET_PATH):
149
+ print(f"LỖI: Không tìm thấy file {DATASET_PATH}.")
150
+ else:
151
+ main()