ianalin123 commited on
Commit
db670b9
Β·
1 Parent(s): f32bf64

feat: add Modal eval script for GRPO checkpoints

Browse files

Evaluates base model and LoRA checkpoints on origami folding tasks via Modal cloud.

Files changed (1) hide show
  1. modal_eval.py +174 -0
modal_eval.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal eval script for origami GRPO checkpoints.
2
+
3
+ Run:
4
+ modal run modal_eval.py # latest checkpoint, all tasks
5
+ modal run modal_eval.py --checkpoint checkpoint-20 # specific checkpoint
6
+ modal run modal_eval.py --checkpoint base # base model (no LoRA)
7
+ modal run modal_eval.py --n-samples 20 --tasks quarter_fold,letter_fold
8
+ """
9
+
10
+ import os
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ from pathlib import Path
15
+
16
+ import modal
17
+ from modal_train import OUTPUTS_DIR, app, image, volume
18
+
19
+ ALL_TASKS = ["triangle", "half_fold", "quarter_fold", "letter_fold"]
20
+
21
+
22
+ @app.function(
23
+ image=image,
24
+ gpu="B200",
25
+ timeout=3600,
26
+ volumes={OUTPUTS_DIR: volume},
27
+ )
28
+ def evaluate(
29
+ checkpoint: str = "",
30
+ n_samples: int = 10,
31
+ server_url: str = "",
32
+ tasks: str = "all",
33
+ model_name: str = "unsloth/Qwen3-32B",
34
+ ):
35
+ import torch
36
+ import requests as req
37
+ from training.train_grpo import build_prompt
38
+ from training.reward import extract_fold_json
39
+ from origami_server.models import OrigamiAction
40
+ from client import OrigamiEnv
41
+
42
+ # ── Env server ────────────────────────────────────────────────────────────
43
+ server_proc = None
44
+ if not server_url:
45
+ server_url = "http://localhost:8000"
46
+ server_proc = subprocess.Popen(
47
+ [sys.executable, "-m", "uvicorn", "origami_server.app:app",
48
+ "--host", "0.0.0.0", "--port", "8000"],
49
+ cwd="/app",
50
+ )
51
+ for _ in range(45):
52
+ try:
53
+ if req.get(f"{server_url}/health", timeout=2).status_code == 200:
54
+ break
55
+ except Exception:
56
+ pass
57
+ time.sleep(1)
58
+
59
+ try:
60
+ # ── Resolve checkpoint path ───────────────────────────────────────────
61
+ if checkpoint == "base":
62
+ ckpt_path = None
63
+ print("Evaluating base model (no LoRA)")
64
+ elif checkpoint:
65
+ ckpt_path = str(Path(OUTPUTS_DIR) / checkpoint)
66
+ print(f"Evaluating checkpoint: {checkpoint}")
67
+ else:
68
+ ckpts = sorted(
69
+ Path(OUTPUTS_DIR).glob("checkpoint-*"),
70
+ key=lambda p: int(p.name.split("-")[-1]),
71
+ )
72
+ finals = list(Path(OUTPUTS_DIR).glob("*-lora-final"))
73
+ if ckpts:
74
+ ckpt_path = str(ckpts[-1])
75
+ print(f"Using latest checkpoint: {Path(ckpt_path).name}")
76
+ elif finals:
77
+ ckpt_path = str(finals[-1])
78
+ print(f"Using: {Path(ckpt_path).name}")
79
+ else:
80
+ raise ValueError("No checkpoint found in volume. Pass --checkpoint base to eval base model.")
81
+
82
+ # ── Load model ────────────────────────────────────────────────────────
83
+ from unsloth import FastLanguageModel
84
+ model, tokenizer = FastLanguageModel.from_pretrained(
85
+ model_name=model_name,
86
+ load_in_4bit=False,
87
+ max_seq_length=1024,
88
+ )
89
+ if ckpt_path:
90
+ model.load_adapter(ckpt_path)
91
+ FastLanguageModel.for_inference(model)
92
+
93
+ # ── Evaluate each task ────────────────────────────────────────────────
94
+ task_list = ALL_TASKS if tasks == "all" else [t.strip() for t in tasks.split(",")]
95
+ results = {}
96
+
97
+ for task_name in task_list:
98
+ task_info = req.get(f"{server_url}/tasks/{task_name}").json()
99
+ prompt_text = build_prompt(task_info)
100
+ messages = [
101
+ {"role": "system", "content": "/no_think"},
102
+ {"role": "user", "content": prompt_text},
103
+ ]
104
+ input_ids = tokenizer.apply_chat_template(
105
+ messages,
106
+ tokenize=True,
107
+ add_generation_prompt=True,
108
+ return_tensors="pt",
109
+ ).to("cuda")
110
+ attention_mask = torch.ones_like(input_ids)
111
+
112
+ rewards, valid = [], 0
113
+ for i in range(n_samples):
114
+ with torch.no_grad():
115
+ out = model.generate(
116
+ input_ids,
117
+ attention_mask=attention_mask,
118
+ max_new_tokens=512,
119
+ temperature=0.7,
120
+ do_sample=True,
121
+ pad_token_id=tokenizer.eos_token_id,
122
+ )
123
+ response = tokenizer.decode(
124
+ out[0][input_ids.shape[1]:], skip_special_tokens=True
125
+ )
126
+ fold_data = extract_fold_json(response)
127
+ if fold_data is None:
128
+ print(f" [{task_name}] sample {i+1}: invalid JSON")
129
+ rewards.append(0.0)
130
+ continue
131
+ valid += 1
132
+ try:
133
+ with OrigamiEnv(base_url=server_url) as env:
134
+ env.reset(task_name=task_name)
135
+ result = env.step(OrigamiAction(fold_data=fold_data))
136
+ r = result.reward if result.reward is not None else 0.0
137
+ rewards.append(r)
138
+ print(f" [{task_name}] sample {i+1}: reward={r:.2f}")
139
+ except Exception as e:
140
+ print(f" [{task_name}] sample {i+1}: env error β€” {e}")
141
+ rewards.append(-1.0)
142
+
143
+ mean_r = sum(rewards) / len(rewards)
144
+ std_r = (sum((r - mean_r) ** 2 for r in rewards) / len(rewards)) ** 0.5
145
+ results[task_name] = {"mean": mean_r, "std": std_r, "valid_pct": valid / n_samples * 100}
146
+ print(f" {task_name:15s} reward={mean_r:.2f}Β±{std_r:.2f} valid={valid}/{n_samples}")
147
+
148
+ print("\n=== SUMMARY ===")
149
+ for name, r in results.items():
150
+ bar = "β–ˆ" * int(r["mean"] / 21 * 20)
151
+ print(f" {name:15s} {r['mean']:5.2f}/21 {bar}")
152
+
153
+ return results
154
+
155
+ finally:
156
+ if server_proc:
157
+ server_proc.terminate()
158
+
159
+
160
+ @app.local_entrypoint()
161
+ def eval_main(
162
+ checkpoint: str = "",
163
+ n_samples: int = 10,
164
+ server_url: str = "",
165
+ tasks: str = "all",
166
+ model: str = "unsloth/Qwen3-32B",
167
+ ):
168
+ evaluate.remote(
169
+ checkpoint=checkpoint,
170
+ n_samples=n_samples,
171
+ server_url=server_url,
172
+ tasks=tasks,
173
+ model_name=model,
174
+ )