Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

RL Training

Reinforcement Learning (RL) is the third stage of the Chemeleon2 pipeline. It fine-tunes the LDM to generate crystal structures that maximize user-defined reward functions.

What RL Does

The RL module is the third stage of Chemeleon2 that fine-tunes the LDM to generate crystal structures optimized for specific material properties. For architectural details, see RL Module.

Key concepts (see src/rl_module/rl_module.py):

Prerequisites

RL training requires both trained LDM and VAE checkpoints. The LDM is fine-tuned with reward signals, while the VAE decodes latent vectors to structures for reward computation.

# In config files
rl_module:
  ldm_ckpt_path: ${hub:mp_20_ldm_base}  # Or use local path
  vae_ckpt_path: ${hub:mp_20_vae}
# In CLI
python src/train_rl.py \
  rl_module.ldm_ckpt_path='${hub:mp_20_ldm_base}' \
  rl_module.vae_ckpt_path='${hub:mp_20_vae}'

See Checkpoint Management for available checkpoints.

Quick Start

# Fine-tune with de novo generation reward (src/train_rl.py)
python src/train_rl.py experiment=mp_20/rl_dng

Training script: src/train_rl.py Example config: configs/experiment/mp_20/rl_dng.yaml

Training Commands

Basic Training

# Use custom reward config
python src/train_rl.py custom_reward=rl_dng

# Override checkpoint paths (e.g., use alex_mp_20 model)
python src/train_rl.py custom_reward=rl_dng \
    rl_module.ldm_ckpt_path='${hub:alex_mp_20_ldm_base}' \
    rl_module.vae_ckpt_path='${hub:alex_mp_20_vae}'

# Override RL hyperparameters
python src/train_rl.py custom_reward=rl_dng \
    rl_module.rl_configs.num_group_samples=128 \
    data.batch_size=8

GRPO Algorithm

Chemeleon2 uses Group Relative Policy Optimization (GRPO) for efficient RL training:

  1. Sample Groups: Generate multiple structures per batch

  2. Compute Rewards: Evaluate all structures in the group

  3. Relative Ranking: Compare rewards within each group

  4. Policy Update: Reinforce high-reward structures relative to group

Key GRPO Hyperparameters

ParameterDefaultDescription
num_group_samples64Structures per group
group_reward_normtrueNormalize rewards within group (required for GRPO)
num_inner_batch2Number of inner batches for gradient accumulation
clip_ratio0.001PPO-style clipping ratio
kl_weight1.0KL divergence penalty weight
entropy_weight1e-5Entropy regularization weight
# Example: adjust group size
python src/train_rl.py custom_reward=rl_dng \
    rl_module.rl_configs.num_group_samples=128

Reward Configuration

Rewards are defined in the reward_fn section of the config (see configs/train_rl.yaml for defaults):

rl_module:
  reward_fn:
    _target_: src.rl_module.reward.ReinforceReward
    normalize_fn: std           # Global normalization
    eps: 1e-4
    reference_dataset: mp-20    # For novelty/uniqueness metrics
    components:
      - _target_: src.rl_module.components.CreativityReward
        weight: 1.0
        normalize_fn: null
      - _target_: src.rl_module.components.EnergyReward
        weight: 1.0
        normalize_fn: norm
      - _target_: src.rl_module.components.StructureDiversityReward
        weight: 0.1
        normalize_fn: norm
      - _target_: src.rl_module.components.CompositionDiversityReward
        weight: 1.0
        normalize_fn: norm

See Custom Rewards Guide for detailed component documentation (src/rl_module/components.py).

Available Experiments

Custom Reward ConfigDatasetRewardDescription
atomic_densityAlex MP-20CustomExample: atomic density optimization (see Custom Reward tutorial)
rl_dngMP-20DNG (multi-objective)Paper’s de novo generation (see DNG Reward tutorial)
rl_bandgapAlex MP-20Predictor-basedBand gap optimization (see Predictor Reward tutorial)

Training Tips

Monitoring

Key metrics to watch in WandB:

Hyperparameter Tuning

IssueSolution
Unstable trainingIncrease num_group_samples, enable group_reward_norm
Mode collapseIncrease kl_weight, add diversity rewards
Slow convergenceDecrease kl_weight, increase reward weights
Poor structure qualityAdd EnergyReward component

Typical Training

Next Steps