Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Training Overview

This guide covers how to train Chemeleon2 models. The framework implements a three-stage training pipeline where each stage builds upon the previous one.

Training Pipeline

StagePurpose
VAEEncode crystal structures into latent space
LDMLearn to generate in latent space via diffusion
RLFine-tune LDM with reward functions using RL algorithm
PredictorPredict properties from latent vectors (optional) with VAE encoder

Configuration System

Chemeleon2 uses Hydra for configuration management. All configs are in the configs/ directory.

Directory Structure

configs/
├── train_vae.yaml                              # configs/train_vae.yaml
├── train_ldm.yaml                              # configs/train_ldm.yaml
├── train_rl.yaml                               # configs/train_rl.yaml
├── train_predictor.yaml                        # configs/train_predictor.yaml
├── experiment/                                 # Experiment-specific overrides
│   ├── mp_20/
│   │   ├── vae_dng.yaml                        # configs/experiment/mp_20/vae_dng.yaml
│   │   ├── ldm_null.yaml                       # configs/experiment/mp_20/ldm_null.yaml
│   │   └── rl_dng.yaml                         # configs/experiment/mp_20/rl_dng.yaml
│   └── alex_mp_20_bandgap/
│       ├── predictor_dft_band_gap.yaml         # configs/experiment/alex_mp_20_bandgap/predictor_dft_band_gap.yaml
│       └── rl_bandgap.yaml                     # configs/experiment/alex_mp_20_bandgap/rl_bandgap.yaml
├── data/                                       # Dataset configurations
├── vae_module/                                 # VAE architecture configs
├── ldm_module/                                 # LDM architecture configs
├── rl_module/                                  # RL configs
├── trainer/                                    # PyTorch Lightning trainer
├── logger/                                     # WandB logging (configs/logger/wandb.yaml)
└── callbacks/                                  # Training callbacks

View Resolved Configuration

Check the fully resolved config without running training:

python src/train_ldm.py experiment=mp_20/ldm_null --cfg job

Override Syntax

Override any config parameter from the command line:

# Override single parameter
python src/train_vae.py trainer.max_epochs=100

# Override multiple parameters
python src/train_ldm.py data.batch_size=64 trainer.max_epochs=500

# Use experiment config (loads all overrides from file)
python src/train_vae.py experiment=mp_20/vae_dng

Checkpoint Management

Chemeleon2 supports two ways to specify checkpoint paths.

Automatic Download from HuggingFace

Automatically downloads pre-trained checkpoints from HuggingFace:

# In config files
ldm_module:
    vae_ckpt_path: ${hub:mp_20_vae}
    ldm_ckpt_path: ${hub:mp_20_ldm_base}
# In CLI
python src/train_ldm.py ldm_module.vae_ckpt_path='${hub:mp_20_vae}'

Local File Paths

Use existing checkpoint files on your system:

# In config files
ldm_module:
    vae_ckpt_path: ckpts/mp_20/vae/my_checkpoint.ckpt
# In CLI
python src/train_ldm.py ldm_module.vae_ckpt_path=ckpts/my_vae.ckpt

Where Checkpoints Are Saved

During training, checkpoints are automatically saved to:

logs/{task}/runs/{timestamp}/checkpoints/

Examples:

PyTorch Lightning’s ModelCheckpoint callback (configured in configs/callbacks/default.yaml) saves:

Experiment Tracking

Chemeleon2 uses Weights & Biases (wandb) for logging by default.

Setup

# First time: login to wandb
wandb login

Offline Mode

Run without internet connection:

WANDB_MODE=offline python src/train_vae.py experiment=mp_20/vae_dng

Custom Project/Run Names

python src/train_vae.py logger.wandb.project=my_project logger.wandb.name=my_run

Next Steps