Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

VAE Module

The Variational Autoencoder module (src/vae_module/) encodes crystal structures into continuous latent representations and decodes them back.

Architecture

Key Classes

VAEModule

The main PyTorch Lightning module implementing the VAE (src/vae_module/vae_module.py):

from src.vae_module import VAEModule

# Load pre-trained VAE
vae = VAEModule.load_from_checkpoint("path/to/checkpoint.ckpt", weights_only=False)

# Encode crystal batch to latent distribution
encoded = vae.encode(batch)
posterior = encoded["posterior"]
z = posterior.sample()

# Decode latent vectors to crystal properties
encoded["x"] = z
decoder_out = vae.decode(encoded)

# Reconstruct crystal structures
batch_recon = vae.reconstruct(decoder_out, batch)

Key Methods:

DiagonalGaussianDistribution

Represents the latent distribution with diagonal covariance:

Loss Functions

The VAE training minimizes:

L=λ1Lrecon+λ2LKL\mathcal{L} = \lambda_1 \mathcal{L}_{recon} + \lambda_2 \mathcal{L}_{KL}

Where:

Configuration

See configs/vae_module/ for VAE configurations:

# configs/vae_module/vae_module.yaml (default)
_target_: src.vae_module.vae_module.VAEModule
encoder:
  _target_: src.vae_module.encoders.transformer.TransformerEncoder
  d_model: 512
  nhead: 8
  num_layers: 8
decoder:
  _target_: src.vae_module.decoders.transformer.TransformerDecoder
  d_model: 512
  nhead: 8
  num_layers: 8
latent_dim: 8

Training

python src/train_vae.py experiment=mp_20/vae_dng

Training script: src/train_vae.py

See Training Guide for more details.