Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
91afacd
1
Parent(s):
6e0ff3e
- app.py +13 -17
- gemma_multiline.py +37 -13
app.py
CHANGED
|
@@ -34,24 +34,20 @@ def process_pdf_gemma(pdf_path, model_name, progress=gr.Progress()):
|
|
| 34 |
return gemma_handler.process_pdf(pdf_path, model_name, progress)
|
| 35 |
|
| 36 |
@spaces.GPU
|
| 37 |
-
def
|
| 38 |
-
return gemma_multiline_handler.
|
| 39 |
|
| 40 |
@spaces.GPU
|
| 41 |
-
def
|
| 42 |
-
|
| 43 |
|
| 44 |
@spaces.GPU
|
| 45 |
-
def
|
| 46 |
-
|
| 47 |
|
| 48 |
@spaces.GPU
|
| 49 |
-
def
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@spaces.GPU
|
| 53 |
-
def process_pdf_multiline_stream(pdf, temp, top_p, repetition_penalty):
|
| 54 |
-
yield from gemma_multiline_handler.process_pdf_stream(pdf, temp, top_p, repetition_penalty)
|
| 55 |
|
| 56 |
# Example images for document-level OCR
|
| 57 |
document_examples = [
|
|
@@ -144,12 +140,12 @@ with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo:
|
|
| 144 |
|
| 145 |
generate_button.click(
|
| 146 |
fn=process_image_multiline,
|
| 147 |
-
inputs=[image_input, temperature_slider, top_p_slider, repetition_penalty_slider],
|
| 148 |
outputs=text_output
|
| 149 |
)
|
| 150 |
|
| 151 |
show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button])
|
| 152 |
-
gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output)
|
| 153 |
gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button])
|
| 154 |
stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event])
|
| 155 |
|
|
@@ -186,16 +182,16 @@ with gr.Blocks(title="Dhivehi Image to Text",css=css) as demo:
|
|
| 186 |
|
| 187 |
pdf_generate_button.click(
|
| 188 |
fn=process_pdf_multiline,
|
| 189 |
-
inputs=[pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider],
|
| 190 |
outputs=pdf_text_output
|
| 191 |
)
|
| 192 |
|
| 193 |
pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
|
| 194 |
-
pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output)
|
| 195 |
pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
|
| 196 |
pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event])
|
| 197 |
|
| 198 |
-
model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown)
|
| 199 |
|
| 200 |
with gr.Tab("PaliGemma"):
|
| 201 |
model_dropdown_paligemma = gr.Dropdown(
|
|
|
|
| 34 |
return gemma_handler.process_pdf(pdf_path, model_name, progress)
|
| 35 |
|
| 36 |
@spaces.GPU
|
| 37 |
+
def process_image_multiline(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
|
| 38 |
+
return gemma_multiline_handler.generate_text_from_image(model_name, image, temp, top_p, repetition_penalty, progress)
|
| 39 |
|
| 40 |
@spaces.GPU
|
| 41 |
+
def process_image_multiline_stream(model_name, image, temp, top_p, repetition_penalty, progress=gr.Progress()):
|
| 42 |
+
yield from gemma_multiline_handler.generate_text_stream(model_name, image, temp, top_p, repetition_penalty, progress)
|
| 43 |
|
| 44 |
@spaces.GPU
|
| 45 |
+
def process_pdf_multiline(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
|
| 46 |
+
return gemma_multiline_handler.process_pdf(model_name, pdf, temp, top_p, repetition_penalty, progress)
|
| 47 |
|
| 48 |
@spaces.GPU
|
| 49 |
+
def process_pdf_multiline_stream(model_name, pdf, temp, top_p, repetition_penalty, progress=gr.Progress()):
|
| 50 |
+
yield from gemma_multiline_handler.process_pdf_stream(model_name, pdf, temp, top_p, repetition_penalty, progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Example images for document-level OCR
|
| 53 |
document_examples = [
|
|
|
|
| 140 |
|
| 141 |
generate_button.click(
|
| 142 |
fn=process_image_multiline,
|
| 143 |
+
inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider],
|
| 144 |
outputs=text_output
|
| 145 |
)
|
| 146 |
|
| 147 |
show_event = stream_button.click(fn=show_stop_button_image, outputs=[stop_button, stream_button, generate_button])
|
| 148 |
+
gen_event = show_event.then(fn=process_image_multiline_stream, inputs=[model_path_dropdown, image_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=text_output)
|
| 149 |
gen_event.then(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button])
|
| 150 |
stop_button.click(fn=hide_stop_button_image, outputs=[stop_button, stream_button, generate_button], cancels=[gen_event])
|
| 151 |
|
|
|
|
| 182 |
|
| 183 |
pdf_generate_button.click(
|
| 184 |
fn=process_pdf_multiline,
|
| 185 |
+
inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider],
|
| 186 |
outputs=pdf_text_output
|
| 187 |
)
|
| 188 |
|
| 189 |
pdf_show_event = pdf_stream_button.click(fn=show_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
|
| 190 |
+
pdf_gen_event = pdf_show_event.then(fn=process_pdf_multiline_stream, inputs=[model_path_dropdown, pdf_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=pdf_text_output)
|
| 191 |
pdf_gen_event.then(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button])
|
| 192 |
pdf_stop_button.click(fn=hide_stop_button_pdf, outputs=[pdf_stop_button, pdf_stream_button, pdf_generate_button], cancels=[pdf_gen_event])
|
| 193 |
|
| 194 |
+
# model_path_dropdown.change(fn=load_model_multiline, inputs=model_path_dropdown)
|
| 195 |
|
| 196 |
with gr.Tab("PaliGemma"):
|
| 197 |
model_dropdown_paligemma = gr.Dropdown(
|
gemma_multiline.py
CHANGED
|
@@ -17,14 +17,14 @@ class GemmaMultilineHandler:
|
|
| 17 |
def __init__(self):
|
| 18 |
self.model = None
|
| 19 |
self.processor = None
|
| 20 |
-
self.
|
| 21 |
self.instruction = 'Extract the dhivehi text from the image'
|
| 22 |
|
| 23 |
def load_model(self, model_name: str):
|
| 24 |
if not model_name:
|
| 25 |
self.model = None
|
| 26 |
self.processor = None
|
| 27 |
-
self.
|
| 28 |
print("Model name is empty. No model loaded.")
|
| 29 |
return
|
| 30 |
|
|
@@ -33,8 +33,8 @@ class GemmaMultilineHandler:
|
|
| 33 |
print(f"Model '{model_name}' not found.")
|
| 34 |
return
|
| 35 |
|
| 36 |
-
if
|
| 37 |
-
print(f"Model
|
| 38 |
return
|
| 39 |
|
| 40 |
try:
|
|
@@ -44,12 +44,12 @@ class GemmaMultilineHandler:
|
|
| 44 |
torch_dtype=torch.bfloat16,
|
| 45 |
)
|
| 46 |
self.processor = AutoProcessor.from_pretrained(model_path)
|
| 47 |
-
self.
|
| 48 |
print(f"Model loaded from {model_path}")
|
| 49 |
except Exception as e:
|
| 50 |
self.model = None
|
| 51 |
self.processor = None
|
| 52 |
-
self.
|
| 53 |
print(f"Failed to load model: {e}")
|
| 54 |
|
| 55 |
def process_vision_info(self, messages: list[dict]) -> list[Image.Image]:
|
|
@@ -65,9 +65,15 @@ class GemmaMultilineHandler:
|
|
| 65 |
image_inputs.append(image.convert("RGB"))
|
| 66 |
return image_inputs
|
| 67 |
|
| 68 |
-
def generate_text_from_image(self, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if self.model is None or self.processor is None:
|
| 70 |
-
return "Model not loaded. Please
|
| 71 |
|
| 72 |
messages = [
|
| 73 |
{
|
|
@@ -114,7 +120,13 @@ class GemmaMultilineHandler:
|
|
| 114 |
)
|
| 115 |
return output_text[0]
|
| 116 |
|
| 117 |
-
def generate_text_stream(self, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
if self.model is None or self.processor is None:
|
| 119 |
yield "Model not loaded. Please provide a model path."
|
| 120 |
return
|
|
@@ -173,7 +185,13 @@ class GemmaMultilineHandler:
|
|
| 173 |
clean_text = generated_text[response_start_index:]
|
| 174 |
yield clean_text.strip()
|
| 175 |
|
| 176 |
-
def process_pdf(self, pdf_path, temperature, top_p, repetition_penalty):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
if self.model is None or self.processor is None:
|
| 178 |
return "Model not loaded. Please load a model first."
|
| 179 |
if pdf_path is None:
|
|
@@ -185,14 +203,20 @@ class GemmaMultilineHandler:
|
|
| 185 |
pix = page.get_pixmap()
|
| 186 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 187 |
doc.close()
|
| 188 |
-
return self.generate_text_from_image(image, temperature, top_p, repetition_penalty)
|
| 189 |
else:
|
| 190 |
doc.close()
|
| 191 |
return "PDF has no pages."
|
| 192 |
except Exception as e:
|
| 193 |
return f"Failed to process PDF: {e}"
|
| 194 |
|
| 195 |
-
def process_pdf_stream(self, pdf_path, temperature, top_p, repetition_penalty):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if self.model is None or self.processor is None:
|
| 197 |
yield "Model not loaded. Please load a model first."
|
| 198 |
return
|
|
@@ -206,7 +230,7 @@ class GemmaMultilineHandler:
|
|
| 206 |
pix = page.get_pixmap()
|
| 207 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 208 |
doc.close()
|
| 209 |
-
yield from self.generate_text_stream(image, temperature, top_p, repetition_penalty)
|
| 210 |
else:
|
| 211 |
doc.close()
|
| 212 |
yield "PDF has no pages."
|
|
|
|
| 17 |
def __init__(self):
|
| 18 |
self.model = None
|
| 19 |
self.processor = None
|
| 20 |
+
self.current_model_name = None
|
| 21 |
self.instruction = 'Extract the dhivehi text from the image'
|
| 22 |
|
| 23 |
def load_model(self, model_name: str):
|
| 24 |
if not model_name:
|
| 25 |
self.model = None
|
| 26 |
self.processor = None
|
| 27 |
+
self.current_model_name = None
|
| 28 |
print("Model name is empty. No model loaded.")
|
| 29 |
return
|
| 30 |
|
|
|
|
| 33 |
print(f"Model '{model_name}' not found.")
|
| 34 |
return
|
| 35 |
|
| 36 |
+
if model_name == self.current_model_name and self.model is not None:
|
| 37 |
+
print(f"Model '{model_name}' is already loaded.")
|
| 38 |
return
|
| 39 |
|
| 40 |
try:
|
|
|
|
| 44 |
torch_dtype=torch.bfloat16,
|
| 45 |
)
|
| 46 |
self.processor = AutoProcessor.from_pretrained(model_path)
|
| 47 |
+
self.current_model_name = model_name
|
| 48 |
print(f"Model loaded from {model_path}")
|
| 49 |
except Exception as e:
|
| 50 |
self.model = None
|
| 51 |
self.processor = None
|
| 52 |
+
self.current_model_name = None
|
| 53 |
print(f"Failed to load model: {e}")
|
| 54 |
|
| 55 |
def process_vision_info(self, messages: list[dict]) -> list[Image.Image]:
|
|
|
|
| 65 |
image_inputs.append(image.convert("RGB"))
|
| 66 |
return image_inputs
|
| 67 |
|
| 68 |
+
def generate_text_from_image(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None) -> str:
|
| 69 |
+
if model_name != self.current_model_name:
|
| 70 |
+
try:
|
| 71 |
+
if progress: progress(0, desc=f"Loading {model_name}...")
|
| 72 |
+
except: pass
|
| 73 |
+
self.load_model(model_name)
|
| 74 |
+
|
| 75 |
if self.model is None or self.processor is None:
|
| 76 |
+
return "Model not loaded. Please select a model."
|
| 77 |
|
| 78 |
messages = [
|
| 79 |
{
|
|
|
|
| 120 |
)
|
| 121 |
return output_text[0]
|
| 122 |
|
| 123 |
+
def generate_text_stream(self, model_name: str, image: Image.Image, temperature: float = 0.8, top_p: float = 1.0, repetition_penalty: float = 1.2, progress=None):
|
| 124 |
+
if model_name != self.current_model_name:
|
| 125 |
+
try:
|
| 126 |
+
if progress: progress(0, desc=f"Loading {model_name}...")
|
| 127 |
+
except: pass
|
| 128 |
+
self.load_model(model_name)
|
| 129 |
+
|
| 130 |
if self.model is None or self.processor is None:
|
| 131 |
yield "Model not loaded. Please provide a model path."
|
| 132 |
return
|
|
|
|
| 185 |
clean_text = generated_text[response_start_index:]
|
| 186 |
yield clean_text.strip()
|
| 187 |
|
| 188 |
+
def process_pdf(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None):
|
| 189 |
+
if model_name != self.current_model_name:
|
| 190 |
+
try:
|
| 191 |
+
if progress: progress(0, desc=f"Loading {model_name}...")
|
| 192 |
+
except: pass
|
| 193 |
+
self.load_model(model_name)
|
| 194 |
+
|
| 195 |
if self.model is None or self.processor is None:
|
| 196 |
return "Model not loaded. Please load a model first."
|
| 197 |
if pdf_path is None:
|
|
|
|
| 203 |
pix = page.get_pixmap()
|
| 204 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 205 |
doc.close()
|
| 206 |
+
return self.generate_text_from_image(model_name, image, temperature, top_p, repetition_penalty, progress)
|
| 207 |
else:
|
| 208 |
doc.close()
|
| 209 |
return "PDF has no pages."
|
| 210 |
except Exception as e:
|
| 211 |
return f"Failed to process PDF: {e}"
|
| 212 |
|
| 213 |
+
def process_pdf_stream(self, model_name: str, pdf_path, temperature, top_p, repetition_penalty, progress=None):
|
| 214 |
+
if model_name != self.current_model_name:
|
| 215 |
+
try:
|
| 216 |
+
if progress: progress(0, desc=f"Loading {model_name}...")
|
| 217 |
+
except: pass
|
| 218 |
+
self.load_model(model_name)
|
| 219 |
+
|
| 220 |
if self.model is None or self.processor is None:
|
| 221 |
yield "Model not loaded. Please load a model first."
|
| 222 |
return
|
|
|
|
| 230 |
pix = page.get_pixmap()
|
| 231 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 232 |
doc.close()
|
| 233 |
+
yield from self.generate_text_stream(model_name, image, temperature, top_p, repetition_penalty, progress)
|
| 234 |
else:
|
| 235 |
doc.close()
|
| 236 |
yield "PDF has no pages."
|