This guide covers how to train Chemeleon2 models. The framework implements a three-stage training pipeline where each stage builds upon the previous one.
Training Pipeline¶
| Stage | Purpose |
|---|---|
| VAE | Encode crystal structures into latent space |
| LDM | Learn to generate in latent space via diffusion |
| RL | Fine-tune LDM with reward functions using RL algorithm |
| Predictor | Predict properties from latent vectors (optional) with VAE encoder |
Configuration System¶
Chemeleon2 uses Hydra for configuration management. All configs are in the configs/ directory.
Directory Structure¶
configs/
├── train_vae.yaml # configs/train_vae.yaml
├── train_ldm.yaml # configs/train_ldm.yaml
├── train_rl.yaml # configs/train_rl.yaml
├── train_predictor.yaml # configs/train_predictor.yaml
├── experiment/ # Experiment-specific overrides
│ ├── mp_20/
│ │ ├── vae_dng.yaml # configs/experiment/mp_20/vae_dng.yaml
│ │ ├── ldm_null.yaml # configs/experiment/mp_20/ldm_null.yaml
│ │ └── rl_dng.yaml # configs/experiment/mp_20/rl_dng.yaml
│ └── alex_mp_20_bandgap/
│ ├── predictor_dft_band_gap.yaml # configs/experiment/alex_mp_20_bandgap/predictor_dft_band_gap.yaml
│ └── rl_bandgap.yaml # configs/experiment/alex_mp_20_bandgap/rl_bandgap.yaml
├── data/ # Dataset configurations
├── vae_module/ # VAE architecture configs
├── ldm_module/ # LDM architecture configs
├── rl_module/ # RL configs
├── trainer/ # PyTorch Lightning trainer
├── logger/ # WandB logging (configs/logger/wandb.yaml)
└── callbacks/ # Training callbacksView Resolved Configuration¶
Check the fully resolved config without running training:
python src/train_ldm.py experiment=mp_20/ldm_null --cfg jobOverride Syntax¶
Override any config parameter from the command line:
# Override single parameter
python src/train_vae.py trainer.max_epochs=100
# Override multiple parameters
python src/train_ldm.py data.batch_size=64 trainer.max_epochs=500
# Use experiment config (loads all overrides from file)
python src/train_vae.py experiment=mp_20/vae_dngCheckpoint Management¶
Chemeleon2 supports two ways to specify checkpoint paths.
Automatic Download from HuggingFace¶
Automatically downloads pre-trained checkpoints from HuggingFace:
# In config files
ldm_module:
vae_ckpt_path: ${hub:mp_20_vae}
ldm_ckpt_path: ${hub:mp_20_ldm_base}# In CLI
python src/train_ldm.py ldm_module.vae_ckpt_path='${hub:mp_20_vae}'Local File Paths¶
Use existing checkpoint files on your system:
# In config files
ldm_module:
vae_ckpt_path: ckpts/mp_20/vae/my_checkpoint.ckpt# In CLI
python src/train_ldm.py ldm_module.vae_ckpt_path=ckpts/my_vae.ckptWhere Checkpoints Are Saved¶
During training, checkpoints are automatically saved to:
logs/{task}/runs/{timestamp}/checkpoints/Examples:
logs/train_vae/runs/2025-11-02_09-35-59/checkpoints/logs/train_ldm/runs/2025-11-05_14-22-31/checkpoints/logs/train_rl/runs/2025-11-10_08-15-42/checkpoints/
PyTorch Lightning’s ModelCheckpoint callback (configured in configs/callbacks/default.yaml) saves:
last.ckpt: Most recent (or last) epochepoch_*.ckpt: Best checkpoints based on validation metrics
Experiment Tracking¶
Chemeleon2 uses Weights & Biases (wandb) for logging by default.
Setup¶
# First time: login to wandb
wandb loginOffline Mode¶
Run without internet connection:
WANDB_MODE=offline python src/train_vae.py experiment=mp_20/vae_dngCustom Project/Run Names¶
python src/train_vae.py logger.wandb.project=my_project logger.wandb.name=my_runNext Steps¶
VAE Training - First stage: encode crystals to latent space
LDM Training - Second stage: diffusion model in latent space
RL Training - Third stage: fine-tune with rewards
Predictor Training - Optional: property prediction