Disable activation viz for RN model, to save memory
Browse files- CLIP_Explainability/app.py +17 -6
CLIP_Explainability/app.py
CHANGED
|
@@ -27,7 +27,7 @@ from CLIP_Explainability.vit_cam import (
|
|
| 27 |
|
| 28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
| 29 |
|
| 30 |
-
RUN_LITE =
|
| 31 |
|
| 32 |
MAX_IMG_WIDTH = 500
|
| 33 |
MAX_IMG_HEIGHT = 800
|
|
@@ -110,6 +110,10 @@ def clip_search(search_query):
|
|
| 110 |
|
| 111 |
|
| 112 |
def string_search():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if "search_field_value" in st.session_state:
|
| 114 |
clip_search(st.session_state.search_field_value)
|
| 115 |
|
|
@@ -179,10 +183,9 @@ def init():
|
|
| 179 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
| 180 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
)
|
| 186 |
|
| 187 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
| 188 |
ja_model_name, trust_remote_code=True
|
|
@@ -216,6 +219,9 @@ def init():
|
|
| 216 |
st.session_state.search_image_ids = []
|
| 217 |
st.session_state.search_image_scores = {}
|
| 218 |
st.session_state.text_table_df = None
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
with st.spinner("Loading models and data, please wait..."):
|
| 221 |
load_image_features()
|
|
@@ -745,6 +751,7 @@ with controls[3]:
|
|
| 745 |
key="uploaded_image",
|
| 746 |
label_visibility="collapsed",
|
| 747 |
on_change=vis_uploaded_image,
|
|
|
|
| 748 |
)
|
| 749 |
|
| 750 |
|
|
@@ -779,7 +786,9 @@ for image_id in batch:
|
|
| 779 |
<div>""",
|
| 780 |
unsafe_allow_html=True,
|
| 781 |
)
|
| 782 |
-
if not
|
|
|
|
|
|
|
| 783 |
st.button(
|
| 784 |
"Explain this",
|
| 785 |
on_click=vis_known_image,
|
|
@@ -787,4 +796,6 @@ for image_id in batch:
|
|
| 787 |
use_container_width=True,
|
| 788 |
key=image_id,
|
| 789 |
)
|
|
|
|
|
|
|
| 790 |
col = (col + 1) % row_size
|
|
|
|
| 27 |
|
| 28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
| 29 |
|
| 30 |
+
RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
|
| 31 |
|
| 32 |
MAX_IMG_WIDTH = 500
|
| 33 |
MAX_IMG_HEIGHT = 800
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
def string_search():
|
| 113 |
+
st.session_state.disable_uploader = (
|
| 114 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
if "search_field_value" in st.session_state:
|
| 118 |
clip_search(st.session_state.search_field_value)
|
| 119 |
|
|
|
|
| 183 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
| 184 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
| 185 |
|
| 186 |
+
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
|
| 187 |
+
ja_model_path, device=device, jit=False
|
| 188 |
+
)
|
|
|
|
| 189 |
|
| 190 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
| 191 |
ja_model_name, trust_remote_code=True
|
|
|
|
| 219 |
st.session_state.search_image_ids = []
|
| 220 |
st.session_state.search_image_scores = {}
|
| 221 |
st.session_state.text_table_df = None
|
| 222 |
+
st.session_state.disable_uploader = (
|
| 223 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
| 224 |
+
)
|
| 225 |
|
| 226 |
with st.spinner("Loading models and data, please wait..."):
|
| 227 |
load_image_features()
|
|
|
|
| 751 |
key="uploaded_image",
|
| 752 |
label_visibility="collapsed",
|
| 753 |
on_change=vis_uploaded_image,
|
| 754 |
+
disabled=st.session_state.disable_uploader,
|
| 755 |
)
|
| 756 |
|
| 757 |
|
|
|
|
| 786 |
<div>""",
|
| 787 |
unsafe_allow_html=True,
|
| 788 |
)
|
| 789 |
+
if not (
|
| 790 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
| 791 |
+
):
|
| 792 |
st.button(
|
| 793 |
"Explain this",
|
| 794 |
on_click=vis_known_image,
|
|
|
|
| 796 |
use_container_width=True,
|
| 797 |
key=image_id,
|
| 798 |
)
|
| 799 |
+
else:
|
| 800 |
+
st.empty()
|
| 801 |
col = (col + 1) % row_size
|