| import gradio as gr | |
| from comet import download_model, load_from_checkpoint | |
| import os | |
| model_path = os.environ.get("HF_MODEL_PATH", download_model("wasanx/ComeTH")) | |
| model = load_from_checkpoint(model_path) | |
| def score_translation(src_text, mt_text): | |
| translations = [{"src": src_text, "mt": mt_text}] | |
| results = model.predict(translations, batch_size=1, gpus=1) | |
| return results["scores"][0] | |
| good_examples = [ | |
| ["The weather is beautiful today.", "วันนี้อากาศดีมาก"], | |
| ["I need to go to the hospital.", "ฉันต้องไปโรงพยาบาล"], | |
| ["This restaurant serves delicious food.", "ร้านอาหารนี้เสิร์ฟอาหารอร่อย"], | |
| ["Can you help me find the nearest train station?", "คุณช่วยฉันหาสถานีรถไฟที่ใกล้ที่สุดได้ไหม"] | |
| ] | |
| bad_examples = [ | |
| ["The weather is beautiful today.", "วันนี้อากาศแย่สุดๆ"], | |
| ["I need to go to the hospital.", "ฉันอยากกินข้าว"], | |
| ["This restaurant serves delicious food.", "ร้านนี้ไม่อร่อยเลย"], | |
| ["Can you help me find the nearest train station?", "คุณพูดภาษาอังกฤษได้ไหม?"] | |
| ] | |
| font_css = """ | |
| @import url("https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&display=swap"); | |
| * { | |
| font-family: 'JetBrains Mono', monospace !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=font_css) as demo: | |
| gr.Markdown("# ComeTH Translation Quality Estimator") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| src_input = gr.Textbox(label="Source Text (English)", placeholder="Enter English text here...") | |
| mt_input = gr.Textbox(label="Candidate Translation (Thai)", placeholder="Enter Thai translation here...") | |
| score_button = gr.Button("Evaluate Translation", variant="primary") | |
| with gr.Column(scale=1): | |
| score_output = gr.Label(label="Quality Scores") | |
| gr.Markdown("### Higher scores indicate better translation quality across multiple dimensions") | |
| gr.Markdown("## Good Translation Examples") | |
| gr.Examples( | |
| examples=good_examples, | |
| inputs=[src_input, mt_input], | |
| outputs=score_output, | |
| fn=score_translation | |
| ) | |
| gr.Markdown("## Bad Translation Examples") | |
| gr.Examples( | |
| examples=bad_examples, | |
| inputs=[src_input, mt_input], | |
| outputs=score_output, | |
| fn=score_translation | |
| ) | |
| score_button.click(fn=score_translation, inputs=[src_input, mt_input], outputs=score_output) | |
| if __name__ == "__main__": | |
| demo.launch() |