Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,43 +6,74 @@ import torch
|
|
| 6 |
|
| 7 |
from Model import TRCaptionNet, clip_transform
|
| 8 |
|
| 9 |
-
model_ckpt = "./checkpoints/TRCaptionNet_L14_berturk_tasviret.pth"
|
| 10 |
|
| 11 |
-
|
| 12 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
preprocess = clip_transform(224)
|
| 15 |
model = TRCaptionNet({
|
| 16 |
"max_length": 35,
|
| 17 |
"clip": "ViT-L/14",
|
| 18 |
-
"bert": "bert
|
| 19 |
"proj": True,
|
| 20 |
"proj_num_head": 16
|
| 21 |
})
|
|
|
|
| 22 |
model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
|
| 23 |
model = model.to(device)
|
| 24 |
model.eval()
|
| 25 |
|
| 26 |
|
|
|
|
| 27 |
def inference(raw_image, min_length, repetition_penalty):
|
|
|
|
|
|
|
|
|
|
| 28 |
batch = preprocess(raw_image).unsqueeze(0).to(device)
|
| 29 |
caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
inputs = [gr.Image(type='pil', interactive=True,),
|
| 34 |
-
gr.Slider(minimum=
|
| 35 |
gr.Slider(minimum=1, maximum=2, value=1.6, label="REPETITION PENALTY")]
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
paper_link = ""
|
| 39 |
github_link = "https://github.com/serdaryildiz/TRCaptionNet"
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
examples = [
|
| 42 |
["images/test1.jpg"],
|
| 43 |
["images/test2.jpg"],
|
| 44 |
["images/test3.jpg"],
|
| 45 |
-
["images/test4.jpg"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
]
|
| 47 |
article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
|
| 48 |
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
|
|
@@ -56,4 +87,3 @@ iface = gr.Interface(fn=inference,
|
|
| 56 |
article=article,
|
| 57 |
css=css)
|
| 58 |
iface.launch()
|
| 59 |
-
|
|
|
|
| 6 |
|
| 7 |
from Model import TRCaptionNet, clip_transform
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
+
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
+
# device = "cpu"
|
| 13 |
+
|
| 14 |
+
preprocess_tasviret = clip_transform(336)
|
| 15 |
+
model_tasviret = TRCaptionNet({
|
| 16 |
+
"max_length": 35,
|
| 17 |
+
"clip": "ViT-L/14@336px",
|
| 18 |
+
"bert": "dbmdz/bert-base-turkish-cased",
|
| 19 |
+
"proj": True,
|
| 20 |
+
"proj_num_head": 16
|
| 21 |
+
})
|
| 22 |
+
model_ckpt = "./checkpoints/TRCaptionNet-TasvirEt_L14_334_berturk.pth"
|
| 23 |
+
model_tasviret.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
|
| 24 |
+
model_tasviret = model_tasviret.to(device)
|
| 25 |
+
model_tasviret.eval()
|
| 26 |
|
| 27 |
preprocess = clip_transform(224)
|
| 28 |
model = TRCaptionNet({
|
| 29 |
"max_length": 35,
|
| 30 |
"clip": "ViT-L/14",
|
| 31 |
+
"bert": "dbmdz/bert-base-turkish-cased",
|
| 32 |
"proj": True,
|
| 33 |
"proj_num_head": 16
|
| 34 |
})
|
| 35 |
+
model_ckpt = "./checkpoints/TRCaptionNet_L14_berturk.pth"
|
| 36 |
model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
|
| 37 |
model = model.to(device)
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
|
| 41 |
+
|
| 42 |
def inference(raw_image, min_length, repetition_penalty):
|
| 43 |
+
batch = preprocess_tasviret(raw_image).unsqueeze(0).to(device)
|
| 44 |
+
caption_tasviret = model_tasviret.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
|
| 45 |
+
|
| 46 |
batch = preprocess(raw_image).unsqueeze(0).to(device)
|
| 47 |
caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
|
| 48 |
+
|
| 49 |
+
return [caption, caption_tasviret]
|
| 50 |
|
| 51 |
|
| 52 |
inputs = [gr.Image(type='pil', interactive=True,),
|
| 53 |
+
gr.Slider(minimum=4, maximum=22, value=8, label="MINIMUM CAPTION LENGTH", step=1),
|
| 54 |
gr.Slider(minimum=1, maximum=2, value=1.6, label="REPETITION PENALTY")]
|
| 55 |
+
|
| 56 |
+
outputs = [gr.components.Textbox(label="Caption"), gr.components.Textbox(label="Caption-TasvirEt")]
|
| 57 |
+
title = "TRCaptionNet-TasvirEt"
|
| 58 |
paper_link = ""
|
| 59 |
github_link = "https://github.com/serdaryildiz/TRCaptionNet"
|
| 60 |
+
IEEE_link = "https://github.com/serdaryildiz/TRCaptionNet"
|
| 61 |
+
|
| 62 |
+
description = f"<p style='text-align: center'><a href='{IEEE_link}' target='_blank'> SIU2024: Turkish Image Captioning with Vision Transformer Based Encoders and Text Decoders</a> "
|
| 63 |
+
description += f"<p style='text-align: center'><a href='{github_link}' target='_blank'>TRCaptionNet</a> : A novel and accurate deep Turkish image captioning model with vision transformer based image encoders and deep linguistic text decoders"
|
| 64 |
+
|
| 65 |
examples = [
|
| 66 |
["images/test1.jpg"],
|
| 67 |
["images/test2.jpg"],
|
| 68 |
["images/test3.jpg"],
|
| 69 |
+
["images/test4.jpg"],
|
| 70 |
+
["images/test5.jpg"],
|
| 71 |
+
["images/test6.jpg"],
|
| 72 |
+
["images/test7.jpg"],
|
| 73 |
+
["images/test8.jpg"],
|
| 74 |
+
["images/test9.jpg"],
|
| 75 |
+
["images/test10.jpg"],
|
| 76 |
+
["images/test11.jpg"],
|
| 77 |
]
|
| 78 |
article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
|
| 79 |
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
|
|
|
|
| 87 |
article=article,
|
| 88 |
css=css)
|
| 89 |
iface.launch()
|
|
|