Upload app.py
Browse files
app.py
CHANGED
|
@@ -27,6 +27,8 @@ from CLIP_Explainability.vit_cam import (
|
|
| 27 |
|
| 28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
| 29 |
|
|
|
|
|
|
|
| 30 |
MAX_IMG_WIDTH = 500
|
| 31 |
MAX_IMG_HEIGHT = 800
|
| 32 |
|
|
@@ -172,9 +174,10 @@ def init():
|
|
| 172 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
| 173 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
| 180 |
ja_model_name, trust_remote_code=True
|
|
@@ -183,9 +186,10 @@ def init():
|
|
| 183 |
ja_model_name, trust_remote_code=True
|
| 184 |
)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
| 189 |
|
| 190 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
| 191 |
"M-BERT-Base-69"
|
|
@@ -701,11 +705,12 @@ for image_id in batch:
|
|
| 701 |
<div>""",
|
| 702 |
unsafe_allow_html=True,
|
| 703 |
)
|
| 704 |
-
st.
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
|
|
|
| 711 |
col = (col + 1) % row_size
|
|
|
|
| 27 |
|
| 28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
| 29 |
|
| 30 |
+
RUN_LITE = True # Load vision model for CAM viz explainability for M-CLIP only
|
| 31 |
+
|
| 32 |
MAX_IMG_WIDTH = 500
|
| 33 |
MAX_IMG_HEIGHT = 800
|
| 34 |
|
|
|
|
| 174 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
| 175 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
| 176 |
|
| 177 |
+
if not RUN_LITE:
|
| 178 |
+
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
|
| 179 |
+
load(ja_model_path, device=device, jit=False)
|
| 180 |
+
)
|
| 181 |
|
| 182 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
| 183 |
ja_model_name, trust_remote_code=True
|
|
|
|
| 186 |
ja_model_name, trust_remote_code=True
|
| 187 |
)
|
| 188 |
|
| 189 |
+
if not RUN_LITE:
|
| 190 |
+
st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
|
| 191 |
+
clip.load("RN50x4", device=device)
|
| 192 |
+
)
|
| 193 |
|
| 194 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
| 195 |
"M-BERT-Base-69"
|
|
|
|
| 705 |
<div>""",
|
| 706 |
unsafe_allow_html=True,
|
| 707 |
)
|
| 708 |
+
if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
|
| 709 |
+
st.button(
|
| 710 |
+
"Explain this",
|
| 711 |
+
on_click=image_modal,
|
| 712 |
+
args=[image_id],
|
| 713 |
+
use_container_width=True,
|
| 714 |
+
key=image_id,
|
| 715 |
+
)
|
| 716 |
col = (col + 1) % row_size
|