World Models - Atari Agent

This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.

Model Description

The World Models architecture consists of three main components:

  1. VAE (Variational Autoencoder): Compresses 64x64 RGB images into a 64-dimensional latent space
  2. RNN (Memory-Augmented Recurrent Neural Network - MDRNN): Predicts the next latent representation given current latent state and action
  3. Controller: A linear controller optimized using CMA-ES to maximize cumulative reward

Model Details

  • Latent Size: 64
  • Hidden Size: 256 (MDRNN)
  • Action Space: 4 (Atari discrete actions)
  • Architecture: Convolutional encoder/decoder for VAE, LSTM-based RNN
  • Optimization: CMA-ES for controller training

Usage

import torch
from pathlib import Path

# Load checkpoint
checkpoint = torch.load('pytorch_model.bin')

# Access components
vae_state = checkpoint['vae_state_dict']
rnn_state = checkpoint['rnn_state_dict']
controller_state = checkpoint['controller_state_dict']

# Reconstruct models (see auto_train.py for architecture definitions)
# and load states into them

Training Details

  • Environment: Atari Breakout
  • Harvest Episodes: 100
  • VAE Epochs: 20
  • RNN Epochs: 20
  • CMA-ES Generations: 25
  • Population Size: 32
  • Framework: PyTorch
  • Training Script: auto_train.py

References

License

MIT License

Downloads last month
7
Video Preview
loading