Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Custom Rewards Overview

This guide explains how to configure and customize reward functions for RL training in Chemeleon2.

Why Verifiable Rewards?

Generative models for crystal structure generation face a fundamental objective misalignment: likelihood-based sampling inherently favors high-density regions of known compounds, while scientific discovery requires targeted exploration of underexplored regions where novel materials reside.

Reward functions enable the model to optimize for verifiable scientific objectives beyond likelihood maximization:

For implementation details and the GRPO algorithm, see the RL Module architecture guide.

Quick Start

# Run DNG reward training (multi-objective)
python src/train_rl.py custom_reward=rl_dng

# Or with custom hyperparameters
python src/train_rl.py custom_reward=rl_dng \
    rl_module.rl_configs.num_group_samples=128

Quick Decision Guide

Choose the tutorial based on your use case:

Use CaseTutorialDescription
Simple custom logicAtomic DensityModify CustomReward class
Multi-objective (paper)DNG RewardCreativity + stability + diversity
Property optimizationPredictor RewardTrain predictor, use as reward

Built-in Reward Components

All components are in src/rl_module/components.py:

ComponentDescriptionRequired Metrics
CustomRewardUser-defined reward functionNone
CreativityRewardRewards unique and novel structuresunique, novel
EnergyRewardPenalizes high energy above convex hulle_above_hull
StructureDiversityRewardRewards diverse crystal geometries (MMD)structure_diversity
CompositionDiversityRewardRewards diverse chemical compositions (MMD)composition_diversity

RewardComponent Base Class

All reward components inherit from RewardComponent:

class RewardComponent(ABC, torch.nn.Module):
    def __init__(
        self,
        weight: float = 1.0,          # Relative importance
        normalize_fn: str | None = None,  # Normalization strategy
        eps: float = 1e-4,            # Numerical stability
    ):
        ...

    @abstractmethod
    def compute(self, **kwargs) -> torch.Tensor:
        """Compute raw reward values."""
        pass

    def forward(self, **kwargs) -> torch.Tensor:
        """Compute, normalize, and weight the reward."""
        rewards = self.compute(**kwargs)
        if self.normalize_fn:
            rewards = self._normalize(rewards)
        return rewards * self.weight

Available kwargs in compute()

ArgumentTypeDescription
gen_structureslist[Structure]Generated pymatgen Structure objects
batch_genCrystalBatchBatched tensor representation
metrics_objMetricsPre-computed metrics (if required_metrics is set)
devicetorch.deviceCurrent device

Normalization Options

Each component can apply normalization via normalize_fn:

OptionFormulaUse Case
norm(x - min) / (max - min)Scale to [0, 1]
std(x - mean) / stdZero mean, unit variance
subtract_meanx - meanCenter around zero
clipclamp(x, -1, 1)Bound extreme values
nullNo changeAlready normalized (e.g., CreativityReward)

ReinforceReward Aggregation

The ReinforceReward class (see src/rl_module/reward.py) combines multiple components:

reward_fn:
  _target_: src.rl_module.reward.ReinforceReward
  normalize_fn: std           # Global normalization after combining
  eps: 1e-4
  reference_dataset: mp-20    # For novelty/uniqueness metrics
  components:
    - _target_: src.rl_module.components.CreativityReward
      weight: 1.0 # Weight for this component (default 1.0)
      normalize_fn: null # Component normalization
    - _target_: src.rl_module.components.EnergyReward
      weight: 0.5
      normalize_fn: norm # Component normalization

How Rewards Are Combined

  1. Each component computes its reward

  2. Component-level normalization is applied (if specified)

  3. Rewards are multiplied by weights

  4. All weighted rewards are summed

  5. Global normalization is applied (if specified)

Component Details

CustomReward

Placeholder for user-defined logic. Modify directly in src/rl_module/components.py:

class CustomReward(RewardComponent):
    def compute(self, gen_structures: list[Structure], **kwargs) -> torch.Tensor:
        # Your custom logic here
        rewards = [your_function(s) for s in gen_structures]
        return torch.as_tensor(rewards)

CreativityReward

Rewards structures that are both unique (not duplicated in batch) and novel (not in training set):

EnergyReward

Penalizes high energy above the convex hull:

StructureDiversityReward

Encourages diverse crystal geometries using Maximum Mean Discrepancy (MMD):

CompositionDiversityReward

Encourages diverse chemical compositions using MMD:

PredictorReward

Uses a trained predictor as surrogate model:

- _target_: src.rl_module.components.PredictorReward
  weight: 1.0
  predictor:
    _target_: src.vae_module.predictor_module.PredictorModule.load_from_checkpoint
    checkpoint_path: "ckpts/predictor.ckpt"
    map_location: "cpu"
  target_name: dft_band_gap  # Must match predictor's target_conditions key
  target_value: 3.0    # Optional: optimize toward this value
  clip_min: 0.0        # Optional: bound predictions

RL Configuration

Configure RL training behavior via rl_module.rl_configs (see configs/rl_module/rl_module.yaml):

rl_module:
  rl_configs:
    clip_ratio: 0.001
    kl_weight: 1.0
    entropy_weight: 1e-5
    num_group_samples: 64
    group_reward_norm: true
    num_inner_batch: 2

Parameter Details

ParameterDefaultEffect
clip_ratio0.001PPO clipping ratio. ↑ = larger policy updates (faster but unstable), ↓ = conservative updates (stable but slow)
kl_weight1.0KL divergence penalty. ↑ = stays closer to original policy, ↓ = allows more deviation
entropy_weight1e-5Entropy bonus. ↑ = more exploration/diversity, ↓ = more exploitation
num_group_samples1Samples per group for GRPO. ↑ = stable gradients (slow), ↓ = noisy gradients (fast)
group_reward_normfalseGroup reward normalization. true = GRPO (relative ranking), false = REINFORCE (absolute reward)
num_inner_batch2Gradient accumulation steps. ↑ = larger effective batch size

Choosing a Starting Checkpoint

When starting RL training, you can choose between two LDM checkpoint options:

CheckpointDescription
${hub:mp_20_ldm_base}LDM trained on MP-20 dataset without RL fine-tuning
${hub:mp_20_ldm_rl}LDM fine-tuned with DNG reward on MP-20 dataset
${hub:alex_mp_20_ldm_base}LDM trained on Alex-MP-20 dataset without RL fine-tuning
${hub:alex_mp_20_ldm_rl}LDM fine-tuned with DNG reward on Alex-MP-20 dataset

Example configuration:

# Use mp-20 model
rl_module:
  ldm_ckpt_path: ${hub:mp_20_ldm_base}  # or ${hub:mp_20_ldm_rl}
  vae_ckpt_path: ${hub:mp_20_vae}

# Use alex-mp-20 model
rl_module:
  ldm_ckpt_path: ${hub:alex_mp_20_ldm_base} # or ${hub:alex_mp_20_ldm_rl}
  vae_ckpt_path: ${hub:alex_mp_20_vae}

Tutorials