sugiv commited on
Commit
24dc13b
·
verified ·
1 Parent(s): 03a7258

Add comprehensive inference example script

Browse files
Files changed (1) hide show
  1. inference_example.py +103 -0
inference_example.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CardVault+ Inference Example
4
+ Simple example showing how to use the CardVault+ model for card extraction
5
+ """
6
+
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForVision2Seq
9
+ from PIL import Image, ImageDraw
10
+ import json
11
+
12
+ def create_sample_card():
13
+ """Create a sample credit card image for testing"""
14
+ # Create card-like image
15
+ img = Image.new('RGB', (400, 250), color='lightblue')
16
+ draw = ImageDraw.Draw(img)
17
+
18
+ # Add card elements
19
+ draw.text((20, 50), "SAMPLE BANK", fill='black')
20
+ draw.text((20, 100), "1234 5678 9012 3456", fill='black')
21
+ draw.text((20, 150), "JOHN DOE", fill='black')
22
+ draw.text((300, 150), "12/25", fill='black')
23
+
24
+ return img
25
+
26
+ def extract_card_info(image_path_or_pil=None):
27
+ """Extract structured information from a card image"""
28
+
29
+ # Load the model
30
+ print("Loading CardVault+ model...")
31
+ model_id = "sugiv/cardvaultplus"
32
+ processor = AutoProcessor.from_pretrained(model_id)
33
+ model = AutoModelForVision2Seq.from_pretrained(
34
+ model_id,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto"
37
+ )
38
+
39
+ # Load image
40
+ if image_path_or_pil is None:
41
+ print("Creating sample card image...")
42
+ image = create_sample_card()
43
+ elif isinstance(image_path_or_pil, str):
44
+ image = Image.open(image_path_or_pil)
45
+ else:
46
+ image = image_path_or_pil
47
+
48
+ # Prepare extraction prompt
49
+ prompt = "<image>Extract structured information from this card/document in JSON format."
50
+
51
+ # Process the image and prompt
52
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
53
+
54
+ # Move to GPU if available
55
+ device = next(model.parameters()).device
56
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
57
+
58
+ # Generate extraction
59
+ print("Extracting information...")
60
+ with torch.no_grad():
61
+ outputs = model.generate(
62
+ **inputs,
63
+ max_new_tokens=150,
64
+ do_sample=False,
65
+ pad_token_id=processor.tokenizer.eos_token_id
66
+ )
67
+
68
+ # Decode response
69
+ response = processor.decode(outputs[0], skip_special_tokens=True)
70
+
71
+ # Extract JSON if present
72
+ extracted_json = None
73
+ if '{' in response and '}' in response:
74
+ try:
75
+ json_start = response.find('{')
76
+ json_end = response.rfind('}') + 1
77
+ json_str = response[json_start:json_end]
78
+ extracted_json = json.loads(json_str)
79
+ except:
80
+ pass
81
+
82
+ return {
83
+ 'full_response': response,
84
+ 'extracted_json': extracted_json,
85
+ 'success': extracted_json is not None
86
+ }
87
+
88
+ if __name__ == "__main__":
89
+ # Example usage
90
+ result = extract_card_info() # Uses sample card
91
+
92
+ print("="*50)
93
+ print("CardVault+ Extraction Results")
94
+ print("="*50)
95
+ print(f"Success: {result['success']}")
96
+ print(f"Full Response: {result['full_response']}")
97
+
98
+ if result['extracted_json']:
99
+ print("Extracted JSON:")
100
+ print(json.dumps(result['extracted_json'], indent=2))
101
+
102
+ # Example with your own image:
103
+ # result = extract_card_info("path/to/your/card.jpg")