Advantage Functions¶
retrain uses a composable advantage pipeline: an episode-level advantage function produces per-completion scores, then optional token-level transforms redistribute credit across individual tokens.
The 5 conditions tested in campaigns correspond to a progressive ablation of this pipeline: GRPO baseline, MaxRL, MaxRL+GTPO, MaxRL+GTPO+HICRA, and MaxRL+GTPO+SEPA.
Pipeline¶
Rewards (per completion)
│
▼
Episode-level advantage (GRPO or MaxRL)
│
▼
Token-level expansion (GTPO entropy weighting)
│
▼
Optional transform (HICRA or SEPA)
│
▼
Token-level advantages (fed to training loss)
Episode-level advantages¶
GRPO¶
Group Relative Policy Optimization. Centers rewards around the group mean:
Simple and effective. Positive reward completions get positive advantage, negative get negative. No normalization.
MaxRL¶
Inverse success-rate reweighting. Normalizes by the group mean reward:
When the model is mostly wrong (low mean reward), the denominator is small, amplifying the signal from rare correct completions. When the model is mostly right, advantages shrink -- the model learns less from easy problems.
Returns zero if the group mean is near zero (all wrong).
Custom episode-level advantage¶
Set advantage_mode to a dotted path. The target can be a plain function:
def hipa_like_advantages(rewards):
if not rewards:
return []
mean_r = sum(rewards) / len(rewards)
return [2.0 * (r - mean_r) for r in rewards]
If your function needs extra knobs, accept a second params argument and pass
advantage_params when calling compute_composable_advantages(...) in Python.
Token-level transforms¶
Custom transform (context-style)¶
Set transform_mode to a dotted path pointing to a function that accepts a
TransformContext and returns TransformOutput.
from retrain import TransformOutput
def my_transform(ctx):
scale = float(ctx.params.get("scale", 1.0))
token_advs = []
for i, logprobs in enumerate(ctx.logprobs_G):
token_advs.append([ctx.episode_advantages[i] * scale for _ in logprobs])
return TransformOutput(token_advs=token_advs)
[algorithm]
transform_mode = "plugins.my_transform.my_transform"
[algorithm.transform_params]
scale = 2.0
GTPO¶
Group-relative Token-level Policy Optimization. Weights token advantages by normalized per-token entropy:
High-entropy tokens (where the model is uncertain) get amplified. Low-entropy tokens (confident predictions) get dampened. This focuses learning on the tokens where the model's decision actually matters.
The uncertainty_kind config controls which entropy proxy is used:
Built-in uncertainty signals:
- surprisal (default) — sampled-token -logprob. Requires only logprobs. Noisy for tail samples where the model was confident but the sampler drew unluckily.
- predictive_variance — Bernoulli variance p * (1 - p) where p = exp(logprob). Free from existing logprobs, peaks at genuine uncertainty (p ≈ 0.5), decays for both confident and tail-sample tokens. Aliases: pred_var, bernoulli_variance.
- shannon_entropy — true per-position entropy H(t) = −Σ pᵢ log pᵢ computed from the full ~150k-dimensional vocabulary distribution on GPU. Unlike surprisal (a function of the single sampled token's logprob), this captures the model's true distributional uncertainty at each position. Requires inference_engine = "pytorch" and backend = "local".
Custom uncertainty signals can be provided via dotted plugin paths (e.g. my_module.my_uncertainty).
Real entropy vs surprisal¶
Surprisal (-log p) is a single scalar — the negative log-probability of whichever token was sampled. It's a noisy proxy for uncertainty: a low-probability sample from a confident distribution gives high surprisal even though the model was sure.
Shannon entropy H(t) = -Σ pᵢ log pᵢ uses the full softmax distribution over the vocabulary. It captures genuine distributional uncertainty regardless of which token was sampled. The predictive variance experiment confirmed that no function of a single logprob can approximate this — the information lives in the ~150k-dimensional distribution, not the scalar.
The PyTorch engine computes H(t) on GPU alongside logprobs (-(probs * log_probs).sum(dim=-1)) and passes one float per token to the advantage pipeline. No full distribution is transferred to CPU.
[algorithm]
uncertainty_kind = "shannon_entropy"
[inference]
engine = "pytorch"
[backend]
backend = "local"
Controlled by beta:
HICRA¶
Hierarchical Credit Assignment. Amplifies advantages for planning tokens:
Where mask(t) = 1 for tokens identified as planning (thinking, self-correction, strategy) and 0 for execution tokens. The amplification is proportional to the magnitude of the existing advantage, so it preserves the sign.
SEPA¶
Selective Entropy Pooling of Attention. Pulls execution-token entropies toward their mean before GTPO weighting:
H_pooled(t) = lambda * mean(H_exec) + (1 - lambda) * H(t) if execution token
H_pooled(t) = H(t) if planning token
This reduces entropy variance among execution tokens, letting GTPO focus its differentiation on planning tokens. Lambda ramps from 0 to 1 over training. See SEPA for scheduling details.
Valid combinations¶
advantage_mode |
transform_mode |
What it does |
|---|---|---|
grpo |
none |
Baseline GRPO |
maxrl |
none |
MaxRL without token-level transforms |
maxrl |
gtpo |
MaxRL + entropy-weighted credit assignment |
maxrl |
gtpo_hicra |
MaxRL + GTPO + planning token amplification |
maxrl |
gtpo_sepa |
MaxRL + GTPO + selective entropy pooling (recommended) |
These are the 5 conditions used in campaign sweeps. See Campaigns.
Note
GRPO can also be combined with gtpo, gtpo_hicra, or gtpo_sepa, but the standard conditions use MaxRL for the non-baseline transforms.
Full algorithm override¶
Use algorithm_mode when you want to replace the full pipeline in one plugin:
When algorithm_mode is set, it takes precedence over advantage_mode and
transform_mode.
Planning tokens¶
HICRA and SEPA both rely on identifying which tokens are "planning" (thinking, self-correction) vs "execution" (direct computation). retrain detects planning tokens via strategic gram matching -- a sliding window over token text that checks for patterns like:
- "wait let me", "let me think", "on second thought"
- "let me check", "let me verify", "is this right"
- "double check", "try another approach", "go back and"
- "that's not right", "that doesn't work"
- "the key is", "the key insight", "notice that"
The full default list has 18 grams. You can override them via the strategic_grams config field:
Or as a comma-separated string:
Uninformative groups¶
Groups where all completions have the same reward (all correct or all wrong) are skipped -- they produce zero advantage and waste a training step. The trainer logs these as "skipped (all correct)" or "skipped (all wrong)".