Inference Engines¶
Local backend only
This page applies to the local backend only. When using the Tinker backend, the Tinker service handles all inference internally -- the [inference] config section is ignored. See Backends for Tinker setup.
retrain separates inference (sampling completions) from training (gradient updates). The inference engine controls how completions are generated, while PyTorch/PEFT always handles LoRA training.
Architecture¶
retrain
└── LocalTrainHelper
├── InferenceEngine (ABC)
│ ├── PyTorchEngine ← same model, shared VRAM
│ ├── MAXLocalEngine ← in-process MAX pipeline
│ ├── MAXServeEngine ← HTTP to max serve
│ └── OpenAIEngine ← HTTP to vLLM / SGLang / MLX-LM / any server
└── PyTorch/PEFT training (unchanged)
All engines implement the same interface: generate() returns token IDs + per-token logprobs. The training side never knows which engine produced the samples.
Engine options¶
| Engine | TOML value | What it does |
|---|---|---|
| PyTorch | pytorch |
Shares the training model for inference. 1x VRAM. Default |
| MAX (auto) | max |
In-process if no URL, HTTP to max serve if url set |
| vLLM | vllm |
HTTP client to a vLLM server |
| SGLang | sglang |
HTTP client to a SGLang server |
| MLX-LM | mlx |
HTTP client to a local mlx_lm.server endpoint |
| OpenAI | openai |
HTTP client to any OpenAI-compatible endpoint |
For mlx, retrain sends the active LoRA adapter path in each completion request using the MLX-LM adapters field.
Why PyTorch is best on 1 GPU¶
With LoRA training, only the adapter weights change -- the base model is frozen. The PyTorch engine exploits this: the same model object serves both training and inference. There is no weight duplication.
Every other engine loads a separate copy of the base model -- either in a different framework (MAX) or a different process (vLLM, SGLang, MLX-LM). On 1 GPU, that means 2x base model VRAM for no benefit.
PyTorch (1 GPU): [base model + LoRA] ← shared, 1x VRAM
MAX (1 GPU): [base model + LoRA] + [base model (MAX)] ← 2x base VRAM
vLLM (1 GPU): [base model + LoRA] + [base model (vLLM)] ← 2x base VRAM
Multi-GPU: when to use MAX / vLLM¶
With multiple GPUs, inference and training run on separate devices. Base model duplication is expected and desirable -- each device has its own copy.
8x H100 example:
GPUs 0-6: max serve (tensor parallel inference, continuous batching)
GPU 7: PyTorch/PEFT training
Here MAX or vLLM provide real benefits: tensor parallelism across inference GPUs, continuous batching for high throughput, and optimized kernels.
Quick start¶
# 1 GPU -- PyTorch (default, no extra setup)
retrain --devices gpu:0
# 1 GPU -- explicit PyTorch
retrain --devices gpu:0 --inference-engine pytorch
# 8 GPUs -- MAX serve on GPUs 0-6, training on GPU 7
max serve --model Qwen/Qwen3-4B-Instruct-2507 # manages its own GPUs
retrain --devices gpu:7 \
--inference-engine max --inference-url http://localhost:8000
# 8 GPUs -- vLLM server
vllm serve Qwen/Qwen3-4B-Instruct-2507 --tensor-parallel-size 7
retrain --devices gpu:7 \
--inference-engine vllm --inference-url http://localhost:8000
# Apple Silicon -- MLX-LM local server
pip install -e ".[mlx]"
python -m mlx_lm.server --model mlx-community/Qwen2.5-3B-Instruct-4bit
retrain --devices cpu \
--inference-engine mlx --inference-url http://localhost:8080
LoRA weight sync¶
After each training step, updated LoRA weights must reach the inference engine:
| Engine | Sync mechanism | Latency |
|---|---|---|
| PyTorch (1 GPU) | Same model object, no sync needed | 0 |
| PyTorch (split mode) | In-memory sync_from_state_dict() via snapshot |
~1ms |
| MAX / vLLM / SGLang / MLX-LM | save_pretrained() to disk, then reload_weights() |
~1-2s |
The _weights_dirty flag avoids redundant saves. In split mode, a weight snapshot is taken after each optimizer step for safe cross-thread access.
Device allocation¶
| Config | Training | Inference |
|---|---|---|
engine = "pytorch", devices = "gpu:0" |
GPU 0 | GPU 0 (same model) |
engine = "pytorch", devices = "gpu:0,gpu:1" |
GPU 1 | GPU 0 (split mode) |
engine = "max", devices = "gpu:7" |
GPU 7 | MAX-managed |
engine = "vllm", devices = "gpu:7" |
GPU 7 | Server-managed |
engine = "mlx", devices = "cpu" |
CPU | MLX-LM server-managed |
With external engines, devices controls only the training GPU. The engine manages its own GPU allocation independently.
TOML configuration¶
[inference]
engine = "pytorch" # pytorch | max | vllm | sglang | mlx | openai
url = "" # server URL for non-PyTorch engines
attention_kernel = "default"
dtype = "auto"
kv_cache_dtype = "auto"
prefix_caching = true
InferenceEngine ABC¶
All engines implement three methods:
class InferenceEngine(ABC):
def generate(self, prompt_ids_list, num_samples, max_tokens,
temperature, top_p) -> list[list[SampleResult]]:
"""Return [num_prompts][num_samples] of (token_ids, logprobs)."""
def reload_weights(self, adapter_path: str) -> None:
"""Reload LoRA adapter from disk."""
def shutdown(self) -> None:
"""Release resources."""
SampleResult is a dataclass with token_ids: list[int] and logprobs: list[float].
PyTorchEngine adds sync_from_state_dict(lora_dict) for fast in-memory weight sync in split mode.
Files¶
| File | Role |
|---|---|
retrain/inference_engine/__init__.py |
Exports + create_engine() factory |
retrain/inference_engine/base.py |
InferenceEngine ABC + SampleResult dataclass |
retrain/inference_engine/pytorch_engine.py |
Local PyTorch engine |
retrain/inference_engine/max_engine.py |
MAX engine (in-process vs serve) |
retrain/inference_engine/openai_engine.py |
HTTP client for vLLM / SGLang / MLX-LM / any server |
retrain/local_train_helper.py |
Orchestrates engine + training, weight sync |