Dementia Brain MRI Generation using Conditional DCGAN
Model Description
This is a Conditional Deep Convolutional GAN (DCGAN) trained to generate synthetic brain MRI images for different Alzheimer's dementia severity stages.
Approach: noise + stage_label โ synthetic_brain_MRI
Model evaluation is conducted qualitatively through visual inspection.
Severity Stages (OASIS Dataset)
- Stage 0: Non-Dementia (Normal)
- Stage 1: Very Mild Dementia
- Stage 2: Mild Dementia
- Stage 3: Moderate Dementia
Quick Start
Option 1: Using inference.py script
# Install dependencies
pip install torch torchvision huggingface_hub pillow
# Download model_architecture.py and inference.py from this repo
# Then run:
python inference.py
The script will automatically:
- Download the model from HuggingFace Hub
- Generate samples for all 4 dementia stages
- Save as
generated_stage_0.png,generated_stage_1.png, etc.
Option 2: Python code
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
# Download model architecture first
# (Get model_architecture.py from this repo)
from model_architecture import Generator
# Download generator checkpoint from HuggingFace
model_path = hf_hub_download(
repo_id="Arga23/dementia-cgan-mri",
filename="cDCGAN_generator.pth"
)
# Load checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(model_path, map_location=device)
# Initialize Generator
G = Generator(
z_dim=checkpoint['z_dim'],
num_classes=checkpoint['num_classes'],
img_channels=1
).to(device)
G.load_state_dict(checkpoint['model'])
G.eval()
# Generate image for specific dementia stage
def generate_sample(stage):
with torch.no_grad():
z = torch.randn(1, checkpoint['z_dim'], 1, 1).to(device)
label = torch.tensor([stage]).to(device)
img = G(z, label)
img = (img.squeeze().cpu() + 1) / 2
img = torch.clamp(img, 0, 1)
return Image.fromarray((img.numpy() * 255).astype('uint8'), mode='L')
# Example: Generate Stage 2 (Mild Dementia)
img = generate_sample(stage=2)
img.save('generated_mild_dementia.png')
Files in Repository
cDCGAN_generator.pth: Generator weights (inference-ready)model_architecture.py: PyTorch Generator architectureinference.py: Ready-to-use inference script (downloads from HuggingFace)
Model Architecture
- Input: 100-dim noise vector + stage label (0-3)
- Generator: ConvTranspose2d layers with BatchNorm + ReLU
- Output: 128x128 grayscale MRI image (normalized to [-1, 1])
- Training: Balanced sampling across all 4 stages
Requirements
torch>=2.0.0
torchvision
huggingface_hub
pillow
Disclaimer
โ ๏ธ FOR RESEARCH AND EDUCATIONAL PURPOSES ONLY
Generated images should NOT be used for clinical diagnosis or medical decisions. This model is trained on the OASIS dataset for research purposes only.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support