Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

RL Module

The Reinforcement Learning module (src/rl_module/) fine-tunes the LDM using Group Relative Policy Optimization (GRPO).

Algorithm Overview

GRPO optimizes the LDM policy to maximize expected rewards:

LGRPO=E[min(rt(θ)At,clip(rt(θ),1ϵ,1+ϵ)At)]+βDKLγH\mathcal{L}_{GRPO} = -\mathbb{E}[\min(r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)] + \beta D_{KL} - \gamma H

Where:

Key Classes

RLModule

PyTorch Lightning module for RL fine-tuning (src/rl_module/rl_module.py):

from src.rl_module import RLModule

# Load RL module from checkpoint
rl_module = RLModule.load_from_checkpoint(
    "path/to/rl.ckpt",
    weights_only=False,
)

Key Methods:

RewardComponent (Base Class)

Abstract base for all reward components (src/rl_module/components.py):

from src.rl_module.components import RewardComponent

class MyCustomReward(RewardComponent):
    def compute(self, gen_structures, **kwargs):
        # Return tensor of rewards for each structure
        # gen_structures: list[Structure]
        return rewards  # torch.Tensor

Built-in Reward Components

ComponentDescription
CustomRewardUser-defined rewards
PredictorRewardSurrogate model predictions
CreativityRewardUnique + Novel structures
EnergyRewardLow energy above hull
StructureDiversityRewardMMD-based diversity
CompositionDiversityRewardComposition diversity

ReinforceReward

Aggregates multiple reward components (src/rl_module/reward.py):

from src.rl_module.reward import ReinforceReward
from src.rl_module.components import CreativityReward, EnergyReward

reward = ReinforceReward(
    components=[
        CreativityReward(weight=1.0),
        EnergyReward(weight=0.5),
    ],
    normalize_fn="std",
)

Normalization Strategies

StrategyDescription
stdStandardize to zero mean, unit variance
normMin-max normalization to [0, 1]
subtract_meanSubtract mean only
clipClip to specified range

Configuration

See configs/rl_module/ for RL configurations:

# configs/rl_module/rl_module.yaml (default)
_target_: src.rl_module.rl_module.RLModule
reward_fn:
  _target_: src.rl_module.reward.ReinforceReward
  normalize_fn: std
  components:
    - _target_: src.rl_module.components.CreativityReward
      weight: 1.0
    - _target_: src.rl_module.components.EnergyReward
      weight: 1.0
    - _target_: src.rl_module.components.StructureDiversityReward
      weight: 0.1
    - _target_: src.rl_module.components.CompositionDiversityReward
      weight: 1.0
rl_configs:
  clip_ratio: 0.001
  kl_weight: 1.0
  num_group_samples: 1

Training

# De novo generation RL
python src/train_rl.py experiment=mp_20/rl_dng

# Band gap optimization RL
python src/train_rl.py experiment=alex_mp_20_bandgap/rl_bandgap

Training script: src/train_rl.py

See Training Guide and Custom Rewards for more details.