The Predictor is an optional module that learns to predict material properties from VAE latent vectors. It enables property-based RL rewards without expensive property calculations during training.
What Predictor Does¶
The Predictor operates in the VAE’s latent space:
Crystal Structure → VAE Encoder → Latent z → Predictor → Property ValueKey benefits (see src/vae_module/predictor_module.py):
Pre-trained VAE encoder: Frozen weights from trained VAE checkpoint
MLP predictor: Trainable multi-layer projection network that maps latent vectors to property values
Surrogate model: Approximates expensive property calculations (e.g., DFT) for fast inference
Prerequisites¶
Predictor training requires:
Trained VAE checkpoint
Dataset with property labels (e.g., band gap, formation energy)
Quick Start¶
# Train band gap predictor (src/train_predictor.py)
python src/train_predictor.py experiment=alex_mp_20_bandgap/predictor_dft_band_gapTraining script: src/train_predictor.py
Example config: configs/experiment/alex_mp_20_bandgap/predictor_dft_band_gap.yaml
Dataset Preparation¶
Required Data Format¶
Your dataset CSV files need:
material_id: Unique identifiercif: Crystal structure in CIF formatProperty column(s): e.g.,
band_gap,formation_energy
data/my_dataset/
├── train.csv
├── val.csv
└── test.csvExample CSV:
material_id,cif,band_gap
mp-1234,"data_...",2.5
mp-5678,"data_...",0.0Compute Normalization Statistics¶
Calculate mean and std for your target property:
import pandas as pd
df = pd.read_csv("data/my_dataset/train.csv")
print(f"band_gap mean: {df['band_gap'].mean():.3f}")
print(f"band_gap std: {df['band_gap'].std():.3f}")These values are needed for the config.
Training Commands¶
Basic Training¶
# Use experiment config
python src/train_predictor.py experiment=alex_mp_20_bandgap/predictor_dft_band_gap
# Override parameters
python src/train_predictor.py experiment=alex_mp_20_bandgap/predictor_dft_band_gap \
data.batch_size=512 \
trainer.max_epochs=500Configuration¶
Example Config¶
Create configs/experiment/my_dataset/predictor_bandgap.yaml:
# @package _global_
# Predictor training for band gap
data:
_target_: src.data.datamodule.DataModule
data_dir: ${paths.data_dir}/my_dataset
batch_size: 256
dataset_type: "my_dataset"
target_condition: band_gap
predictor_module:
vae:
checkpoint_path: ${hub:mp_20_vae}
target_conditions:
band_gap:
mean: 1.5 # From your dataset statistics
std: 1.2 # From your dataset statistics
logger:
wandb:
name: "predictor_bandgap"Key Hyperparameters¶
| Parameter | Default | Description |
|---|---|---|
hidden_dim | Dynamic | Projection network dimensions (input_dim//4, input_dim//2) |
num_layers | 3 | Number of projection layers |
dropout | 0.1 | Dropout rate |
use_encoder_features | True | Concatenate encoder hidden features |
Multiple Properties¶
Train a predictor for multiple properties simultaneously:
predictor_module:
target_conditions:
band_gap:
mean: 1.5
std: 1.2
formation_energy:
mean: -0.5
std: 0.3Training Tips¶
Monitoring¶
Key metrics in WandB:
train/loss: Overall MSE losstrain/band_gap_loss: Per-property lossval/loss: Validation loss (check for overfitting)
Typical Training¶
Duration: Up to 1000 epochs (default), with early stopping after 200 epochs without improvement
Batch size: 256 (default), can be increased to 512 for faster training
Learning rate: 1e-3 (default)
Verifying Quality¶
After training, check prediction quality:
from src.vae_module.predictor_module import PredictorModule
predictor = PredictorModule.load_from_checkpoint(
"ckpts/predictor.ckpt",
map_location="cpu"
)
predictor.eval()
# Check validation MAE
# Should be reasonable for your property rangeAvailable Experiments¶
| Experiment | Dataset | Target | Description |
|---|---|---|---|
alex_mp_20_bandgap/predictor_dft_band_gap | Alex MP-20 | DFT band gap | Band gap prediction |
Using Predictor for RL¶
After training, use the predictor as a reward signal:
# In RL config
reward_fn:
components:
- _target_: src.rl_module.components.PredictorReward
weight: 1.0
predictor:
_target_: src.vae_module.predictor_module.PredictorModule.load_from_checkpoint
checkpoint_path: "ckpts/predictor.ckpt"
map_location: "cpu"
target_name: band_gap
target_value: 3.0 # Optimize toward this valueSee Predictor Reward Tutorial for the complete workflow.
Next Steps¶
Predictor Reward Tutorial - Complete RL workflow with predictor
RL Training - General RL training guide