Spaces:
Running
Running
| import gradio as gr | |
| from ultralytics import YOLO | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForMaskedLM | |
| from PIL import Image | |
| import numpy as np | |
| import pandas as pd | |
| import tempfile | |
| from nltk.translate import bleu_score | |
| from nltk.translate.bleu_score import SmoothingFunction | |
| import torch | |
| yolo_weights_path = "final_wts.pt" | |
| device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') | |
| trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device) | |
| trocr_model.config.num_beams = 2 | |
| yolo_model = YOLO(yolo_weights_path).to(device) | |
| roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device) | |
| print(f'TrOCR, YOLO and Roberta Models loaded on {device}') | |
| CONFIDENCE_THRESHOLD = 0.72 | |
| BLEU_THRESHOLD = 0.6 | |
| CONFIDENCE_THRESHOLD = 0.72 | |
| BLEU_THRESHOLD = 0.6 | |
| def inference(image_path, debug=False, return_texts='final'): | |
| def get_cropped_images(image_path): | |
| results = yolo_model(image_path, save=True) | |
| patches = [] | |
| ys = [] | |
| for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]): | |
| image = Image.open(image_path).convert("RGB") | |
| x_center, y_center, w, h = box.xywh[0].cpu().numpy() | |
| x, y = x_center - w / 2, y_center - h / 2 | |
| cropped_image = image.crop((x, y, x + w, y + h)) | |
| patches.append(cropped_image) | |
| ys.append(y) | |
| bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg' | |
| return patches, ys, bounding_box_path | |
| def get_model_output(images): | |
| pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device) | |
| output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30) | |
| generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True) | |
| generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences] | |
| stacked_logits = torch.stack(output.logits, dim=1) | |
| return generated_texts, stacked_logits, generated_tokens | |
| def get_scores(logits): | |
| scores = logits.softmax(-1).max(-1).values.mean(-1) | |
| return scores | |
| def post_process_texts(generated_texts): | |
| for i in range(len(generated_texts)): | |
| if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ': | |
| generated_texts[i] = generated_texts[i][2:] | |
| if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #': | |
| generated_texts[i] = generated_texts[i][:-2] | |
| return generated_texts | |
| def get_qualified_texts(generated_texts, scores, y, logits, tokens): | |
| qualified_texts = [] | |
| for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens): | |
| if score > CONFIDENCE_THRESHOLD: | |
| qualified_texts.append({ | |
| 'text': text, | |
| 'score': score, | |
| 'y': y_i, | |
| 'logits': logits_i, | |
| 'tokens': tokens_i | |
| }) | |
| return qualified_texts | |
| def get_adjacent_bleu_scores(qualified_texts): | |
| def get_bleu_score(hypothesis, references): | |
| weights = [0.5, 0.5] | |
| smoothing = SmoothingFunction() | |
| return bleu_score.sentence_bleu(references, hypothesis, weights=weights, | |
| smoothing_function=smoothing.method1) | |
| for i in range(len(qualified_texts)): | |
| hyp = qualified_texts[i]['text'].split() | |
| bleu = 0 | |
| if i < len(qualified_texts) - 1: | |
| ref = qualified_texts[i + 1]['text'].split() | |
| bleu = get_bleu_score(hyp, [ref]) | |
| qualified_texts[i]['bleu'] = bleu | |
| return qualified_texts | |
| def remove_overlapping_texts(qualified_texts): | |
| final_texts = [] | |
| new = True | |
| for i in range(len(qualified_texts)): | |
| if new: | |
| final_texts.append(qualified_texts[i]) | |
| else: | |
| if final_texts[-1]['score'] < qualified_texts[i]['score']: | |
| final_texts[-1] = qualified_texts[i] | |
| new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD | |
| return final_texts | |
| def get_lm_logits(ocr_tokens, confidence): | |
| tokens = ocr_tokens.clone() | |
| indices = torch.where(confidence < 0.5) | |
| for i, j in zip(indices[0], indices[1]): | |
| if i != 6: | |
| continue | |
| tokens[i, j] = torch.tensor(50264) | |
| inputs = tokens.reshape(1, -1) | |
| with torch.no_grad(): | |
| outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device)) | |
| lm_logits = outputs.logits | |
| return lm_logits.reshape(ocr_tokens.shape[0], ocr_tokens.shape[1], -1), indices | |
| cropped_images, y, bounding_box_path = get_cropped_images(image_path) | |
| if debug: | |
| print('Number of cropped images:', len(cropped_images)) | |
| generated_texts, logits, gen_tokens = get_model_output(cropped_images) | |
| normalised_scores = get_scores(logits) | |
| generated_df = pd.DataFrame({ | |
| 'text': generated_texts, | |
| }) | |
| if return_texts == 'generated': | |
| return pd.DataFrame({ | |
| 'text': generated_texts, | |
| 'score': normalised_scores, | |
| 'y': y, | |
| }) | |
| generated_texts = post_process_texts(generated_texts) | |
| if return_texts == 'post_processed': | |
| return pd.DataFrame({ | |
| 'text': generated_texts, | |
| 'score': normalised_scores, | |
| 'y': y | |
| }) | |
| qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens) | |
| if return_texts == 'qualified': | |
| return pd.DataFrame(qualified_texts) | |
| qualified_texts = get_adjacent_bleu_scores(qualified_texts) | |
| if return_texts == 'qualified_with_bleu': | |
| return pd.DataFrame(qualified_texts) | |
| final_texts = remove_overlapping_texts(qualified_texts) | |
| final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y']) | |
| final_logits = [text['logits'] for text in final_texts] | |
| logits = torch.stack([logit for logit in final_logits], dim=0) | |
| tokens = logits.argmax(-1) | |
| confidence = logits.softmax(-1).max(-1).values | |
| if return_texts == 'final': | |
| return final_texts_df | |
| lm_logits, indices = get_lm_logits(tokens, confidence) | |
| combined_logits = logits.clone() | |
| for i, j in zip(indices[0], indices[1]): | |
| combined_logits[i, j] = logits[i, j] * 0.9 + lm_logits[i, j] * 0.1 | |
| return final_texts_df, bounding_box_path, tokens, combined_logits, confidence, generated_df | |
| def process_image(image): | |
| text, bounding_path = "", "" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_image: | |
| image.save(temp_image.name) | |
| image_path = temp_image.name | |
| df, bounding_path, tokens, logits, confidence, generated_df = inference(image_path, debug=False, return_texts='final_v2') | |
| text = df['text'].str.cat(sep='\n') | |
| before_text = generated_df['text'].str.cat(sep='\n') | |
| bounding_img = Image.open(bounding_path) | |
| return bounding_img, before_text, text | |
| interface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Bounding Box Image"), | |
| gr.Textbox(label="Extracted Text (Custom trained YOLO Object Detection + TrOCR Vision Transformer)"), | |
| gr.Textbox(label="Post Processed Text (BLEU score based filtering + Roberta contextual understanding)"), | |
| ], | |
| title="OCR Pipeline with YOLO, TrOCR and Roberta", | |
| description="Upload an image to detect text regions with YOLO, merge bounding boxes, and extract text using TrOCR which is then preprocessed with Roberta for contextual understanding.", | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(share=True) |