dhe1raj commited on
Commit
6581479
·
verified ·
1 Parent(s): c67a1ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ # =======================
7
+ # Configuration
8
+ # =======================
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ MODEL_PATH = "cattle_breed_efficientnetb3_pytorch.pth" # Upload this to the Space
11
+ CLASS_NAMES = ["Gir", "Deoni", "Murrah"]
12
+
13
+ # =======================
14
+ # Load Model
15
+ # =======================
16
+ model = models.efficientnet_b3(pretrained=False)
17
+ model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
18
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
19
+ model.to(device)
20
+ model.eval()
21
+
22
+ # =======================
23
+ # Image Preprocessing
24
+ # =======================
25
+ transform = transforms.Compose([
26
+ transforms.Resize((300, 300)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.485, 0.456, 0.406],
29
+ [0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ # =======================
33
+ # Prediction Function
34
+ # =======================
35
+ def predict(image):
36
+ image = image.convert("RGB")
37
+ img_tensor = transform(image).unsqueeze(0).to(device)
38
+ with torch.no_grad():
39
+ output = model(img_tensor)
40
+ pred_idx = torch.argmax(output, dim=1).item()
41
+ return CLASS_NAMES[pred_idx]
42
+
43
+ # =======================
44
+ # Gradio Interface
45
+ # =======================
46
+ iface = gr.Interface(
47
+ fn=predict,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs="text",
50
+ title="Indian Bovine Breed Classifier",
51
+ description="Upload an image of a cow and the model will predict its breed."
52
+ )
53
+
54
+ iface.launch()