|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_path = "/workspace/distillation/gpt-oss-distilled/results/checkpoint-171" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": "Explain quantum computing in simple terms"} |
|
|
] |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, 'apply_chat_template'): |
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
else: |
|
|
text = "User: Explain quantum computing in simple terms\nAssistant:" |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
repetition_penalty=1.1, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
num_beams=1 |
|
|
|
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
print("Response:", response) |
|
|
|