Dementia Brain MRI Generation using Conditional DCGAN

Model Description

This is a Conditional Deep Convolutional GAN (DCGAN) trained to generate synthetic brain MRI images for different Alzheimer's dementia severity stages.

Approach: noise + stage_label โ†’ synthetic_brain_MRI

Model evaluation is conducted qualitatively through visual inspection.

Severity Stages (OASIS Dataset)

  • Stage 0: Non-Dementia (Normal)
  • Stage 1: Very Mild Dementia
  • Stage 2: Mild Dementia
  • Stage 3: Moderate Dementia

Quick Start

Option 1: Using inference.py script

# Install dependencies
pip install torch torchvision huggingface_hub pillow

# Download model_architecture.py and inference.py from this repo
# Then run:
python inference.py

The script will automatically:

  • Download the model from HuggingFace Hub
  • Generate samples for all 4 dementia stages
  • Save as generated_stage_0.png, generated_stage_1.png, etc.

Option 2: Python code

import torch
from huggingface_hub import hf_hub_download
from PIL import Image

# Download model architecture first
# (Get model_architecture.py from this repo)
from model_architecture import Generator

# Download generator checkpoint from HuggingFace
model_path = hf_hub_download(
    repo_id="Arga23/dementia-cgan-mri",
    filename="cDCGAN_generator.pth"
)

# Load checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(model_path, map_location=device)

# Initialize Generator
G = Generator(
    z_dim=checkpoint['z_dim'],
    num_classes=checkpoint['num_classes'],
    img_channels=1
).to(device)

G.load_state_dict(checkpoint['model'])
G.eval()

# Generate image for specific dementia stage
def generate_sample(stage):
    with torch.no_grad():
        z = torch.randn(1, checkpoint['z_dim'], 1, 1).to(device)
        label = torch.tensor([stage]).to(device)
        img = G(z, label)
        img = (img.squeeze().cpu() + 1) / 2
        img = torch.clamp(img, 0, 1)
        return Image.fromarray((img.numpy() * 255).astype('uint8'), mode='L')

# Example: Generate Stage 2 (Mild Dementia)
img = generate_sample(stage=2)
img.save('generated_mild_dementia.png')

Files in Repository

  • cDCGAN_generator.pth: Generator weights (inference-ready)
  • model_architecture.py: PyTorch Generator architecture
  • inference.py: Ready-to-use inference script (downloads from HuggingFace)

Model Architecture

  • Input: 100-dim noise vector + stage label (0-3)
  • Generator: ConvTranspose2d layers with BatchNorm + ReLU
  • Output: 128x128 grayscale MRI image (normalized to [-1, 1])
  • Training: Balanced sampling across all 4 stages

Requirements

torch>=2.0.0
torchvision
huggingface_hub
pillow

Disclaimer

โš ๏ธ FOR RESEARCH AND EDUCATIONAL PURPOSES ONLY

Generated images should NOT be used for clinical diagnosis or medical decisions. This model is trained on the OASIS dataset for research purposes only.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support