Spaces:
Runtime error
Runtime error
Filling Mask
Browse files
app.py
CHANGED
|
@@ -54,9 +54,20 @@ st.sidebar.markdown("""
|
|
| 54 |
# -------------------- CACHING FUNCTIONS --------------------
|
| 55 |
@st.cache_resource
|
| 56 |
def load_mask_filling_model():
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@st.cache_resource
|
| 62 |
def load_pos_model():
|
|
@@ -77,6 +88,23 @@ def load_news_classification_model():
|
|
| 77 |
return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
|
| 78 |
|
| 79 |
# -------------------- UTILITY FUNCTIONS --------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def merge_entities(output):
|
| 81 |
"""Merge consecutive entities of the same type"""
|
| 82 |
merged = []
|
|
@@ -166,25 +194,31 @@ tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "
|
|
| 166 |
# -------------------- MASK FILLING TAB --------------------
|
| 167 |
with tab1:
|
| 168 |
st.header("Mask Filling")
|
| 169 |
-
st.write("Fill in the blanks in Setswana sentences using
|
| 170 |
|
| 171 |
mask_examples = [
|
| 172 |
-
"Ke rata go
|
| 173 |
-
"Botswana ke naga e e
|
| 174 |
-
"Bana ba
|
| 175 |
-
"Re tshwanetse go
|
| 176 |
]
|
| 177 |
|
| 178 |
mask_input = get_input_text("mask", mask_examples)
|
| 179 |
|
| 180 |
if st.button("Fill Masks", key="mask_button") and mask_input.strip():
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
else:
|
| 184 |
with st.spinner("Filling masks..."):
|
| 185 |
try:
|
| 186 |
mask_filler = load_mask_filling_model()
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
|
| 189 |
st.subheader("Predictions")
|
| 190 |
for i, result in enumerate(results, 1):
|
|
@@ -193,6 +227,13 @@ with tab1:
|
|
| 193 |
|
| 194 |
except Exception as e:
|
| 195 |
st.error(f"Error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
# -------------------- POS TAGGING TAB --------------------
|
| 198 |
with tab2:
|
|
|
|
| 54 |
# -------------------- CACHING FUNCTIONS --------------------
|
| 55 |
@st.cache_resource
|
| 56 |
def load_mask_filling_model():
|
| 57 |
+
try:
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
|
| 59 |
+
model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
|
| 60 |
+
|
| 61 |
+
# Create pipeline and verify mask token
|
| 62 |
+
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
|
| 63 |
+
|
| 64 |
+
# Debug: print mask token for verification
|
| 65 |
+
print(f"Mask token being used: {tokenizer.mask_token}")
|
| 66 |
+
|
| 67 |
+
return pipe
|
| 68 |
+
except Exception as e:
|
| 69 |
+
st.error(f"Failed to load mask filling model: {str(e)}")
|
| 70 |
+
return None
|
| 71 |
|
| 72 |
@st.cache_resource
|
| 73 |
def load_pos_model():
|
|
|
|
| 88 |
return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
|
| 89 |
|
| 90 |
# -------------------- UTILITY FUNCTIONS --------------------
|
| 91 |
+
|
| 92 |
+
def get_correct_mask_token(text, tokenizer):
|
| 93 |
+
"""Get the correct mask token format for the given tokenizer"""
|
| 94 |
+
mask_token = tokenizer.mask_token
|
| 95 |
+
|
| 96 |
+
# Replace common mask token formats with the correct one
|
| 97 |
+
text = text.replace("[MASK]", mask_token)
|
| 98 |
+
text = text.replace("<mask>", mask_token)
|
| 99 |
+
text = text.replace("<mask>", mask_token)
|
| 100 |
+
|
| 101 |
+
return text
|
| 102 |
+
|
| 103 |
+
# Then in your mask filling section, use:
|
| 104 |
+
# corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
|
| 105 |
+
# results = mask_filler(corrected_input)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
def merge_entities(output):
|
| 109 |
"""Merge consecutive entities of the same type"""
|
| 110 |
merged = []
|
|
|
|
| 194 |
# -------------------- MASK FILLING TAB --------------------
|
| 195 |
with tab1:
|
| 196 |
st.header("Mask Filling")
|
| 197 |
+
st.write("Fill in the blanks in Setswana sentences using `<mask>` token.")
|
| 198 |
|
| 199 |
mask_examples = [
|
| 200 |
+
"Ke rata go <mask> dijo tsa Batswana.",
|
| 201 |
+
"Botswana ke naga e e <mask> mo Afrika Borwa.",
|
| 202 |
+
"Bana ba <mask> sekolo ka Mosupologo.",
|
| 203 |
+
"Re tshwanetse go <mask> tikologo ya rona."
|
| 204 |
]
|
| 205 |
|
| 206 |
mask_input = get_input_text("mask", mask_examples)
|
| 207 |
|
| 208 |
if st.button("Fill Masks", key="mask_button") and mask_input.strip():
|
| 209 |
+
# Check for both mask formats and convert if needed
|
| 210 |
+
if "[MASK]" in mask_input:
|
| 211 |
+
mask_input = mask_input.replace("[MASK]", "<mask>")
|
| 212 |
+
st.info("Converted [MASK] to <mask> format")
|
| 213 |
+
elif "<mask>" not in mask_input:
|
| 214 |
+
st.warning("Please include <mask> token in your text.")
|
| 215 |
else:
|
| 216 |
with st.spinner("Filling masks..."):
|
| 217 |
try:
|
| 218 |
mask_filler = load_mask_filling_model()
|
| 219 |
+
corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
|
| 220 |
+
results = mask_filler(corrected_input)
|
| 221 |
+
# results = mask_filler(mask_input)
|
| 222 |
|
| 223 |
st.subheader("Predictions")
|
| 224 |
for i, result in enumerate(results, 1):
|
|
|
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
st.error(f"Error: {str(e)}")
|
| 230 |
+
# Debug information
|
| 231 |
+
st.info(f"Input text: {mask_input}")
|
| 232 |
+
try:
|
| 233 |
+
mask_filler = load_mask_filling_model()
|
| 234 |
+
st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}")
|
| 235 |
+
except:
|
| 236 |
+
pass
|
| 237 |
|
| 238 |
# -------------------- POS TAGGING TAB --------------------
|
| 239 |
with tab2:
|