Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import json | |
| import gradio as gr | |
| import requests | |
| # Path to the model file and Hugging Face URL | |
| model_path = 'food_classification_model.pth' | |
| model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth" | |
| # Download the model file if it's not already available | |
| if not os.path.exists(model_path): | |
| print(f"Downloading the model from {model_url}...") | |
| response = requests.get(model_url) | |
| with open(model_path, 'wb') as f: | |
| f.write(response.content) | |
| print("Model downloaded successfully.") | |
| # Load the model with updated weights parameter | |
| model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| model.eval() # Set model to evaluation mode | |
| # Load the model's custom state_dict | |
| try: | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| print("Model loaded successfully.") | |
| except RuntimeError as e: | |
| print("Error loading state_dict:", e) | |
| print("Ensure that the saved model architecture matches ResNet50.") | |
| # Define the image transformations | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| # Load labels | |
| try: | |
| with open("config.json") as f: | |
| labels = json.load(f) | |
| print("Labels loaded successfully.") | |
| except Exception as e: | |
| print("Error loading labels:", e) | |
| # Function to predict image class | |
| def predict(image): | |
| try: | |
| print("Starting prediction...") | |
| # Convert the uploaded file to a PIL image | |
| input_image = image.convert("RGB") | |
| print(f"Image converted to RGB: {input_image.size}") | |
| # Preprocess the image | |
| input_tensor = preprocess(input_image) | |
| input_batch = input_tensor.unsqueeze(0) # Add batch dimension | |
| print(f"Input tensor shape after unsqueeze: {input_batch.shape}") | |
| # Check if a GPU is available and move the input and model to GPU | |
| if torch.cuda.is_available(): | |
| input_batch = input_batch.to('cuda') | |
| model.to('cuda') | |
| print("Using GPU for inference.") | |
| else: | |
| print("GPU not available, using CPU.") | |
| # Perform inference | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| print(f"Inference output shape: {output.shape}") | |
| # Get the predicted class with the highest score | |
| _, predicted_idx = torch.max(output, 1) | |
| predicted_idx = predicted_idx.item() | |
| print(f"Predicted class index: {predicted_idx}") | |
| # Check if the predicted index exists in labels | |
| if str(predicted_idx) in labels: | |
| predicted_class = labels[str(predicted_idx)] | |
| else: | |
| predicted_class = f"Unknown class index: {predicted_idx}. Please check the label mapping." | |
| print(predicted_class) | |
| return f"Predicted class: {predicted_class}" | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| return f"An error occurred during prediction: {e}" | |
| # Set up the Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Food Classification Model", | |
| description="Upload an image of food to classify it." | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |