Transformers documentation
Dynamic weight loading
Dynamic weight loading
Checkpoints are often serialized in a format that does not match what a model expects at runtime. Common scenarios include:
- Fused weights: Checkpoints store separate
gate_projandup_projweights, but the model uses a fusedgate_up_projfor efficiency. - MoE expert consolidation: Individual expert weights (
experts.0.weight,experts.1.weight, …) need to be stacked into a single 3D tensor. - Legacy naming: Old checkpoints use different naming conventions (e.g.,
LayerNorm.gammavsLayerNorm.weight). - Composite models: A vision-language model contains two
PreTrainedModelsub-modules, each with its own checkpoint convention. - Quantization: Weights may be stored in quantized formats that need deserialization.
Dynamic weight loading addresses this by applying scheduled, reversible operations to checkpoint tensors as they are loaded. Transformers exposes this through WeightConverter and WeightRenaming, which describe how one or more checkpoint keys map to one or more model parameters and which composable ConversionOps should run on the matched tensors. This approach adapts to new weight layouts, supports quantized mixture-of-experts (MoEs), and integrates with tensor parallelism.
This guide demonstrates how to use WeightConverter to convert tensors. Conversion mappings live in conversion_mapping.py; a registered mapping is keyed by either a model_type string (e.g. "mixtral") or a class name (e.g. "LlavaModel").
Full loading pipeline
All models go through the dynamic weight loading system. Conversion mapping is an optional step within that system that only activates when the model has entries registered for its class name or model_type.
Checkpoint File → from_pretrained() → convert_and_load_state_dict_in_model()
↓
┌───────────────────────────────────────────────────────────┐
│ For each weight in checkpoint: │
│ 1. Match renamed/processed source key to model parameter │
│ 2. Shard the weight and send to device (async) │
│ 3. Collect tensors with the same source_pattern together │
│ (e.g. MoE experts, gate_up_proj) │
│ 4. Apply dequantization/deserialization (if pre-quant) │
│ 5. Apply conversion (if defined) │
│ 6. Apply quantization (if enabled and step 4 not used) │
│ 7. Set parameter on model │
└───────────────────────────────────────────────────────────┘| Step | When it activates |
|---|---|
| Dynamic loading | Always, for all models |
| Conversion mapping | Only when the model’s class or model_type is registered in _MODEL_TO_CONVERSION_PATTERN |
| TP sharding | Only when tp_plan="auto" and model has base_model_tp_plan |
| Dequantization/deserialization | Only when loading a pre-quantized checkpoint |
| Quantization | Only when a quantization config is provided and weights are not pre-quantized |
Dense models (e.g., Llama)
For most dense models, the checkpoint format matches the model format directly, so no conversion mapping is needed. Some models may still require renaming (e.g., legacy naming conventions). TP sharding still applies when enabled.
Checkpoint: Model:
model.layers.0.self_attn.q_proj.weight → model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight → model.layers.0.self_attn.k_proj.weight
model.layers.0.mlp.gate_proj.weight → model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight → model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight → model.layers.0.mlp.down_proj.weightLegacy checkpoints may use older naming conventions that are handled by built-in renamings applied to all models:
Checkpoint: Model:
LayerNorm.gamma → LayerNorm.weight
LayerNorm.beta → LayerNorm.biasMoE models (e.g., Mixtral)
For MoE models, the checkpoint format differs from the model format. Conversion mapping transforms separate expert weights into fused 3D tensors, and TP sharding applies after conversion.
Checkpoint: Model:
experts.0.w1.weight ─┐
experts.1.w1.weight │ MergeModulelist
... ├───────────────→ experts.gate_up_proj (8, hidden, 2*intermediate)
experts.0.w3.weight │ + Concatenate
experts.1.w3.weight ─┘Composite models (e.g., vision-language)
A PreTrainedModel may contain other PreTrainedModel sub-modules. Each sub-model can have its own conversion mapping, registered against either its class name or its model_type. When the parent model is loaded, get_model_conversion_mapping walks the sub-models in depth-first order, collects their mappings, and automatically scopes each transform to the path of its sub-module via scope_prefix.
Composite model: Per-submodel mappings (auto-scoped):
LlavaForConditionalGeneration
├── vision_model: SiglipVisionModel → SiglipVisionModel mapping (scope="vision_model")
└── language_model: LlamaForCausalLM → LlamaForCausalLM mapping (scope="language_model")scope_prefix is the dotted path of the sub-module ("vision_model", "language_model.model", etc.). A scoped transform only fires on keys that start with f"{scope_prefix}."; the prefix is stripped before pattern matching and re-attached after substitution, so each sub-model’s mapping is written relative to the sub-model, exactly as if it were the root.
Architecture
The system is built around several key components defined in core_model_loading.py:
Phase 1 — Per-key processing (iterates over checkpoint keys):
- Walk the transform list once. Every
WeightRenamingthat matches fires (e.g.block_sparse_moe→mlp), and at most one WeightConverter may claim the key (e.g.experts.*.w1.weight). - Shard (TP) and send to device asynchronously via
ThreadPoolExecutor. - Collect tensors with the same
source_patterntogether (e.g. all MoE expert weights, gate + up projections).
Phase 2 — Per-mapping processing (iterates over collected mappings):
- Dequantize/deserialize (pre-quantized checkpoints only).
- Apply ConversionOps chain:
Chunk,Concatenate,MergeModulelist,Transpose, etc. - Quantize on-the-fly (if not pre-quantized).
- Set parameter on model.
WeightTransform
The base class that handles pattern matching and tensor collection:
- Pattern compilation: Source patterns are full regular expressions matched with
re.search(). The*wildcard matches any indexable component and groups all matches together for batch operations. - Capturing groups & backreferences: Capturing groups in source patterns can be referenced as
\1,\2, … in target patterns to preserve substrings (e.g. layer indices) across the rename. - Scoping:
scope_prefix(set automatically per sub-model byget_model_conversion_mapping) restricts the transform to keys under that path. The prefix is stripped before matching and re-attached after substitution. - Key renaming:
rename_source_key()applies the regex (with scope handling) and returns(renamed_key, source_pattern)so the loader knows which converter, if any, claimed the key. - Tensor collection:
add_tensor()accumulates resolved tensors (orFutures) under theirsource_patternso that all tensors needed by a single conversion (e.g. all MoE expert weights) are batched together before the operation chain runs. - Reversibility:
reverse_transform()swaps source ↔ target patterns and inverts each operation, so the same list reversed drives saving. - Tracking:
was_used()reports whether the transform actually matched any key during loading; this is required so that non-bijective renames (e.g.PrefixChangeadding a prefix that may already be present) can be re-applied symmetrically on save.
WeightRenaming
WeightRenaming is a specialized WeightTransform for pure key renames without tensor operations. Unlike WeightConverter, a WeightRenaming does not claim the key (it does not occupy the “at most one converter per key” slot), so multiple renames may chain freely both before and after a WeightConverter has fired.
# Legacy checkpoint compatibility
WeightRenaming("LayerNorm.gamma", "LayerNorm.weight")
# Module path changes
WeightRenaming(".block_sparse_moe.", ".mlp.")PrefixChange is a higher-level wrapper around WeightRenaming for the common case of stripping or adding an entire path component. The optional model_prefix scopes the operation to keys under that namespace:
# "model.layers.bad_prefix.weight" → "model.layers.weight"
PrefixChange(prefix_to_remove="bad_prefix", model_prefix="model.layers")
# "layers.0.weight" → "model.layers.0.weight"
PrefixChange(prefix_to_add="model")WeightConverter
WeightConverter extends WeightTransform with a chain of ConversionOps that act on the collected tensors. The four supported cardinalities are:
| Cardinality | Source patterns | Target patterns | Typical operation |
|---|---|---|---|
| one-to-one | 1 | 1 | Transpose, PermuteForRope |
| one-to-many | 1 | >1 | Chunk (e.g. unpack qkv_proj) |
| many-to-one | >1 | 1 | Concatenate, MergeModulelist (e.g. fuse experts) |
| many-to-many | >1 | >1 | only with operations that explicitly support it (e.g. ErnieFuseAndSplitTextVisionExperts) |
WeightConverter(
source_patterns=[".experts.*.w1.weight", ".experts.*.w3.weight"],
target_patterns=".experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
)A WeightConverter is also reversible: reverse_transform() swaps source ↔ target and replaces each operation with its reverse_op, so a registered conversion mapping is bidirectional by construction (loading uses the list as-is, saving uses it reversed).
Ordering renames and converters
List WeightRenaming entries before WeightConverter entries, and keep their leaves disjoint: renames normalise the keys converters consume, but never target a leaf a converter produces. The save path relies on this to invert the mapping in two phases (reverse converters, then reverse renames).
The transform list is walked in order, once per checkpoint key. For each key:
- Every
WeightRenamingthat matches fires; multiple renames can chain. - The first
WeightConverterthat matches claims the key. Subsequent converters are skipped — this guarantees the tensor is routed to a single converter for the merge/split step, and is the reason two converters with overlapping intent would be a misconfiguration.
weight_mapping = [
WeightRenaming("^old_prefix", "encoder"), # rename runs always
WeightConverter( # converter claims the key
"attn.qkv_proj.weight",
["attn.q_proj.weight", "attn.k_proj.weight", "attn.v_proj.weight"],
operations=[Chunk(dim=0)],
),
]
# Load: "old_prefix.attn.qkv_proj.weight"
# → WeightRenaming → "encoder.attn.qkv_proj.weight"
# → WeightConverter → "encoder.attn.{q,k,v}_proj.weight"
#
# Save: list reversed, each transform inverted:
# → rev(WeightConverter) repacks QKV → "encoder.attn.qkv_proj.weight"
# → rev(WeightRenaming) fixes prefix → "old_prefix.attn.qkv_proj.weight"Conversion operations
The WeightConverter class has several operations that are executed when from_pretrained() is called for transforming checkpoint source tensors into model target tensors.
Operations are fully reversible. Saving reverses the conversions and returns the original checkpoint so you can easily work across different frameworks.
| Operation | Reverse |
|---|---|
Chunk(dim) | Concatenate(dim) |
Concatenate(dim) | Chunk(dim) |
MergeModulelist(dim) | SplitModulelist(dim) |
SplitModulelist(dim) | MergeModulelist(dim) |
Transpose(d0, d1) | Transpose(d1, d0) |
PermuteForRope() | PermuteForRope() |
Conv3dToLinear(...) | LinearToConv3d(...) |
Chunk
The Chunk operation splits a tensor into equal parts along a dimension. For example, if a model expects Q, K, and V as three separate tensors instead of a single tensor.
WeightConverter(
"self_attn.qkv_proj",
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
operations=[Chunk(dim=0)],
)Concatenate
The Concatenate operation fuses separate tensors into a single tensor. For example, if a model expects Q, K, and V as a single tensor instead of separate tensors.
WeightConverter(
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"self_attn.qkv_proj",
operations=[Concatenate(dim=0)],
)MergeModulelist
MergeModulelist merges a list of 2D tensors into a single 3D tensor. For example, you can compose MergeModulelist with Concatenate to stack the experts in a MoE and pack them into one tensor.
WeightConverter(
["block_sparse_moe.experts.*.w1.weight", "block_sparse_moe.experts.*.w3.weight"],
"mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
)SplitModulelist
SplitModulelist splits a 3D tensor back into a list of 2D tensors. For example, you can split a stack of experts back into individual experts.
WeightConverter(
"mlp.experts.down_proj",
"block_sparse_moe.experts.*.w2.weight",
operations=[SplitModulelist(dim=0)],
)PermuteForRope
PermuteForRope converts weights from the interleaved format to use the sin/cos format. For example, you can compose Chunk with PermuteForRope to split a fused QKV tensor and apply the sin/cos RoPE permutation to Q and K.
WeightConverter(
["model.layers.*.self_attn.qkv_proj.weight"],
[
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
],
operations=[Chunk(dim=0), PermuteForRope()],
)Transpose
Transpose swaps dimensions of a tensor. Useful for converting weight layouts between different conventions.
WeightConverter(
source_patterns="mlp.gate.weight",
target_patterns="mlp.text_moe.gate.weight",
operations=[Transpose(dim0=0, dim1=1)],
)Operation chaining
Operations can be chained to perform complex transformations. The operations execute in order, with each operation’s output becoming the next operation’s input.
Example: Mixtral MoE conversion
WeightConverter(
source_patterns=[
".experts.*.w1.weight", # gate_proj per expert
".experts.*.w3.weight", # up_proj per expert
],
target_patterns=".experts.gate_up_proj",
operations=[
MergeModulelist(dim=0), # Stack all experts: (n_experts, in, out)
Concatenate(dim=1), # Fuse gate+up: (n_experts, in, 2*out)
],
)Data flow:
Input:
".experts.*.w1.weight": [tensor_0, tensor_1, ..., tensor_7] # 8 experts
".experts.*.w3.weight": [tensor_0, tensor_1, ..., tensor_7] # 8 experts
After MergeModulelist(dim=0):
".experts.*.w1.weight": (8, 4096, 14336) # stacked gate
".experts.*.w3.weight": (8, 4096, 14336) # stacked up
After Concatenate(dim=1):
".experts.gate_up_proj": (8, 4096, 28672) # fused gate_upPattern matching rules
- Source patterns are full regexes evaluated with
re.search().^anchors to the start of the scoped key (afterscope_prefixstripping). *is a per-index wildcard that collects all matching tensors under the same source pattern (preserving checkpoint order for correct concatenation).- Capturing groups in source patterns can be referenced as
\1,\2, … in target patterns to keep parts of the original key (e.g. layer indices). - Scoped transforms (
scope_prefixset) only match keys that start withf"{scope_prefix}.".
Registering a conversion mapping
A conversion list is registered against a string key — either a model_type (e.g. "mixtral") or a class name (e.g. "LlavaModel"):
from transformers.conversion_mapping import register_checkpoint_conversion_mapping
register_checkpoint_conversion_mapping(
"my_model_type",
[
WeightRenaming(".old.", ".new."),
WeightConverter(
".experts.*.w1.weight",
".experts.gate_proj",
operations=[MergeModulelist(dim=0)],
),
],
)Lookup rules
When get_model_conversion_mapping processes a PreTrainedModel, every sub-PreTrainedModel is visited in DFS order (nn.Module.named_modules() filtered to PreTrainedModel instances). For each one:
- Class-name lookup is tried first, then
model_type. If both are registered for the same module, the class-name mapping wins and themodel_typeone is ignored for that module — this lets a task head (e.g.LlavaForConditionalGeneration) override the sharedmodel_typebaseline ("llava"). - The selected mapping has
scope_prefixset to the sub-module’s dotted path (""for the root). - Ancestor-based deduplication decides whether to keep the mapping:
- If an ancestor path has already claimed the same identifier (class name or
model_type), the sub-module is skipped — the ancestor’s unscoped or higher-scoped mapping already covers this subtree. - If only a sibling has claimed it, the sub-module is kept with its own
scope_prefix. Each sibling gets its own scoped mapping.
- If an ancestor path has already claimed the same identifier (class name or
The class-name and model_type seen-lists are tracked separately, with one subtlety: when a module is matched via its class name, its model_type is not added to the seen-list. This is so that other modules sharing the same model_type but without a class-specific mapping (e.g. DetrModel under DetrForSegmentation) remain reachable through the model_type lookup.
Class-based mappings vs model_type aliases
Both styles coexist in _MODEL_TO_CONVERSION_PATTERN:
_MODEL_TO_CONVERSION_PATTERN = {
# model_type aliases (lookup by config.model_type)
"minimax": "mixtral",
"qwen3_moe": "qwen2_moe",
"mistral3": "llava",
# class-name aliases (lookup by type(submodule).__name__)
"PaliGemmaModel": "LlavaModel",
"MaskFormerDetrDecoder": "DetrModel",
...
}Class-name keys are preferred when the model_type is shared but a specific class needs different behaviour.
Tensor parallelism integration
The dynamic loading system integrates with tensor parallelism (TP) through the TensorParallelLayer hierarchy defined in src/transformers/integrations/tensor_parallel.py.
When TP is enabled, tensors are sharded during materialization, not after. This means each rank only loads the portion of the tensor it needs.
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, device, dtype):
def _job():
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
return thread_pool.submit(_job)Available parallel styles
| Style | Weight Shard Dim | Description |
|---|---|---|
colwise | -2 | Column-wise: output features sharded |
rowwise | -1 | Row-wise: input features sharded |
packed_colwise | -2 | For fused weights (gate_up_proj) |
packed_rowwise | -1 | For fused weights |
embedding_rowwise | 0 | Vocabulary parallelism |
grouped_gemm | 0 | Expert parallelism for MoE |
sequence_parallel | None | No weight sharding |
Packed weight handling
For fused weights like gate_up_proj, special care is needed to shard correctly:
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
Interleaves gate and up shards correctly.
Packed tensor: [G0 G1 G2 G3 | U0 U1 U2 U3]
With TP=2:
- Rank 0 gets: [G0 G1 | U0 U1]
- Rank 1 gets: [G2 G3 | U2 U3]
"""The TP operation is stored in the WeightTransform and applied after conversion operations:
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]]
mapping.distributed_operation = tp_layer(
device_mesh=device_mesh,
rank=device_mesh.get_local_rank(),
empty_param=empty_param.clone()
)Quantization integration
Quantization hooks into the loading pipeline in two ways, depending on whether the checkpoint is already quantized:
- Pre-quantized checkpoints: The quantizer provides WeightConverter instances (via
get_weight_conversions()) that deserialize quantized tensors. Checkpoint dtypes are preserved to avoid unwanted casts. - On-the-fly quantization: The quantizer provides a quantization operation that is applied after conversion ops, quantizing weights as they are loaded.
The quantizer can also rewrite the entire conversion list at the end of get_model_conversion_mapping via update_weight_conversions(...) — for example, the FP8 dequantizer prepends a Fp8Dequantize op to every existing converter so per-block scales are applied before any expert-merge / concat ops flatten the per-expert structure.
Fast and efficient model loading
Loading a model is faster and uses less memory because the loader knows which tensors are required for operations and schedules their materialization lazily.
The loader scans the checkpoint once to discover pattern matches and collect tensors. It stores them as Future objects and submits them to a thread pool for asynchronous loading without blocking the GIL. A parameter starts loading as soon as a thread becomes available to it.
If your system runs other heavy processes, multiple threads may slow down loading instead of accelerating it. In this case, set the environment variable HF_DEACTIVATE_ASYNC_LOAD=1 to load weights sequentially.
The default is 4 threads for asynchronous parameter loading. This provides the best trade-off across loading scenarios and hardware. The work is mostly I/O bound, but depending on accelerator hardware and the
dtyperequired at loading, it can become CPU/GPU-bound if thedtypediffers from the serialized one (this requires an additional copy operation).
Async vs sync loading
def spawn_materialize(thread_pool, tensor, device, dtype) -> Future | Callable:
def _job():
return _materialize_copy(tensor, device, dtype)
if thread_pool is not None:
return thread_pool.submit(_job) # Async: returns Future
else:
return _job # Sync: returns Callable (deferred execution)Sync loading is used when:
HF_DEACTIVATE_ASYNC_LOAD=1environment variable is set.- Disk offloading is enabled (memory constraints require sequential loading).
- On-the-fly quantization is enabled (avoids worker threads racing ahead of the quantization step).
Materialization flow
1. Checkpoint iteration (Phase 1):
- For each key, walk the transform list once
- Submit materialization job to ThreadPoolExecutor
- Job returns Future (async) or Callable (sync)
- Collect into the matching WeightConverter / WeightRenaming
2. Per-mapping processing (Phase 2, one mapping at a time):
- materialize_tensors() waits for this mapping's Futures only
- Apply conversion operations chain (self.operations)
- Apply quantization operation (if on-the-fly)
- Set parameters on model
- Delete realized tensors immediately
3. Cleanup:
- Thread pool shutdown (with cancel_futures=True for interrupts)Memory efficiency
When converting a weight, the converter waits for all required tensors to materialize if they haven’t loaded yet. For example, the MergeModulelist operation requires all weights in ModuleList to be loaded before merging.
Concatenating tensors requires a temporary copy, so operations like MergeModulelist and Concatenate need 2x the memory of the underlying tensors during conversion. Once merged, only the resulting tensor stays in memory. The theoretical worst-case memory peak is the model size plus the tensors required for the largest MergeModulelist or Concatenate operation.
This worst case only occurs when all other parameters have loaded before the demanding conversion runs. Two scenarios trigger this.
- All parameters loaded asynchronously before entering the demanding conversion (the thread pool was faster than the conversion queue).
- The demanding conversion is the last one.
For example, a MoE model using MergeModulelist for experts on each layer, the theoretical worst-case memory peak is model size plus experts on one layer.
These worst-case scenarios are uncommon. The actual memory peak tends to stay close to the model size.
Reversibility
The system supports saving models with the inverse transformations, enabling round-trip save/load. Saving runs in two phases: reversed converters first (each tensor matches at most one), then reversed renames, per the ordering rule above.
def revert_weight_conversion(model, state_dict):
"""Applies reverse conversions for saving."""
weight_conversions = getattr(model, "_weight_conversions", None)
# Reverse all transforms
reverse_weight_conversion = [
conversion.reverse_transform() for conversion in weight_conversions
]
# Apply in reverse
for first_param_name, reversed_converter in conversion_mapping.items():
realized_value = reversed_converter.convert(first_param_name, model=model)The list of transforms used at load time is cached on the model as _weight_conversions (only entries that actually fired are kept, so non-bijective renames such as PrefixChange are correctly re-applied symmetrically). When the model was instantiated without from_pretrained (and hence has no _weight_conversions), revert_weight_conversion falls back to recomputing the mapping via get_model_conversion_mapping and drops any PrefixChange from it (we cannot tell whether the original checkpoint had the prefix).
Target patterns may contain regex elements that need processing for the reverse direction:
def process_target_pattern(pattern: str) -> tuple[str, str | None]:
"""
- Removes `^` and `$` anchors
- Removes negative lookahead/lookbehind
- Detects capturing groups, replaces with \1
"""Real examples
Mixtral-style MoE
Checkpoint format:
model.layers.0.block_sparse_moe.experts.0.w1.weight # gate per expert
model.layers.0.block_sparse_moe.experts.0.w2.weight # down per expert
model.layers.0.block_sparse_moe.experts.0.w3.weight # up per expert
...
model.layers.0.block_sparse_moe.experts.7.w1.weightModel format:
model.layers.0.mlp.experts.gate_up_proj # (8, 4096, 28672)
model.layers.0.mlp.experts.down_proj # (8, 14336, 4096)Conversion mapping (from conversion_mapping.py):
"mixtral": [
WeightRenaming(".block_sparse_moe.", ".mlp."),
WeightConverter(
source_patterns=[".experts.*.w1.weight", ".experts.*.w3.weight"],
target_patterns=".experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns=[".experts.*.w2.weight"],
target_patterns=".experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],Composite vision-language model (Gemma3)
Gemma3 has the canonical “different prefix logic for the head model than for the base model” setup. Two entries are registered (here via aliases pointing at the shared "llava" / "LlavaModel" lists):
# model_type: applies to Gemma3ForConditionalGeneration / Gemma3ForSequenceClassification
"gemma3": "llava" # → adds "model." in front of language_model / vision_tower / ...
# class: applies to the inner Gemma3Model only
"Gemma3Model": "LlavaModel" # → minimal rename inside the already-prefixed namespaceDFS walk for Gemma3ForConditionalGeneration:
- Root — class lookup misses (
Gemma3ForConditionalGenerationis not registered), model_type lookup hits ("gemma3") → the"llava"prefix-rewriting transforms are added unscoped, puttinglanguage_model,vision_tower,multi_modal_projectorundermodel.*. - Inner
model: Gemma3Model— class lookup hits (Gemma3Model→LlavaModel) → only theLlavaModelmapping is applied, scoped to"model". The much broader"gemma3"(="llava") prefix renames are not re-applied here, which is exactly what you want: the inner model lives in a namespace where the prefix is already correct.
The same pair shape is used by every Llava-family VLM (PaliGemma, InternVL, Mistral3, …). The model_type entry handles the head model’s prefix surgery; the class entry keeps the inner base model’s mapping minimal.
Deeper nesting with a head-specific override (DETR)
DetrForSegmentation shows a head-specific override layered on top of two nested levels:
DetrForSegmentation (class registered: segmentation-only renames)
├── detr: DetrForObjectDetection (no mapping; just walked through)
│ └── model: DetrModel (class registered: shared base transforms)
│ └── backbone, encoder, ...
└── mask_head, bbox_attention (head-specific weights)Two mappings are involved (one per registered key):
"DetrModel": [WeightRenaming("backbone.conv_encoder", "backbone"), ...] # shared base
"DetrForSegmentation": [WeightRenaming("mask_head.lay1", "mask_head.conv1.conv"), ...] # head-specificDFS walk:
- Root
DetrForSegmentationmatches by class → segmentation renames added unscoped. detr: DetrForObjectDetectionis not registered → no transforms added; DFS continues into it.detr.model: DetrModelmatches by class → base transforms added withscope_prefix="detr.model".
This is why DetrForObjectDetection does not need its own mapping: the only registered mapping in its subtree (DetrModel) is automatically scoped to the right path.
Class-keyed aliases reuse the base mapping without extra registration: "MaskFormerDetrDecoder": "DetrModel" makes a MaskFormer decoder pick up the same transforms under its own class name.
Custom operations (ERNIE 4.5 VL MoE)
When the built-in operations aren’t sufficient, you can create a custom ConversionOps subclass. For example, ERNIE 4.5 VL MoE needs to split a shared expert list between text and vision modalities — something no single built-in op handles. The custom ErnieFuseAndSplitTextVisionExperts operation splits and re-stacks experts across two target keys:
"ernie4_5_vl_moe": [
WeightRenaming("vision_model", "vision_tower"),
WeightConverter(
source_patterns=["experts.*.down_proj.weight"],
target_patterns=[
"text_moe.experts.down_proj",
"vision_moe.experts.down_proj",
],
operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
),
],Custom ops must implement convert() and the reverse_op property to support round-trip save/load.
Model type aliases
Many models share conversion patterns:
_MODEL_TO_CONVERSION_PATTERN = {
"mixtral": "mixtral",
"minimax": "mixtral",
"qwen2_moe": "qwen2_moe",
"deepseek_v2": "qwen2_moe",
"deepseek_v3": "qwen2_moe",
"qwen3_moe": "qwen2_moe",
"olmoe": "qwen2_moe",
...
}Reusing the dynamic loading building blocks
Dynamic weight loading is not limited to full model checkpoints. The same building blocks let you load any set of weights as long as you can describe how checkpoint keys map to parameters and ensure the target modules exist.
At a high level, the contract looks like this:
- Prepare the model namespace. Make sure the modules/parameters you want to load are present and named the way your
mapping will target them. For adapters, that means calling
inject_adapter_in_model(...)so adapter modules exist before loading. For custom heads or extra modules, instantiate them on the model first. - Describe how to map weights. Build a conversion/renaming list (for example, in a helper like
_build_peft_weight_mapping(...)) using WeightConverter orWeightRenaming. This is where you express how checkpoint keys should be converted, split, merged, or renamed to match your model namespace. You can do mostly 3 things:- add operations to the list of converters: these will be applied on all weights except for the ones collected in any of the
WeightConverter. These in general should beWeightRenamingoperations - add operations to the list of operations of each converter: this is what happens for
Quantization, where we just add a quantization operation after the list of operations of anyWeightConverter. - replace / map operations to your custom operations: this is what happens with
peft. We replace theConcatenateoperation of saymixtral, to bePeftConcatenate. This way, when the adapter checkpoint is read, the weights to be concatenated are collected, and are properly formatted forpeft
- add operations to the list of converters: these will be applied on all weights except for the ones collected in any of the
- Load + finalize + report. Use the core loader to perform the conversion and populate tensors, then finalize and
log results. Concretely, this flow is:
LoadStateDictConfig(...)+_load_pretrained_model(...)to load and convert._finalize_load_state_dict(...)to move any missing/mismatched tensors offmeta, initialize them, and tie weights.log_state_dict_report(...)to report missing/unexpected/mismatched keys (and conversion errors).
These APIs are exposed to allow you to handle custom code, custom weight formats, but also make sure you benefit from the highest and most efficient weight loading, sharding and good quality of life of transformers API!
Key files reference
| File | Purpose |
|---|---|
src/transformers/core_model_loading.py | Core loading logic, WeightConverter, WeightRenaming, ConversionOps |
src/transformers/conversion_mapping.py | Built-in mappings and per-submodel composition (get_model_conversion_mapping) |
src/transformers/integrations/tensor_parallel.py | TP sharding classes and utilities |
src/transformers/quantizers/base.py | Quantization hooks and base class |