| import math |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .config import InitFnType, ModelConfig |
| from .util import StrEnum |
|
|
| __all__ = ["init_weights", "ModuleType"] |
|
|
|
|
| class ModuleType(StrEnum): |
| in_module = "in" |
| out_module = "out" |
| emb = "emb" |
| final_out = "final_out" |
|
|
|
|
| def init_weights( |
| config: ModelConfig, |
| module: Union[nn.Linear, nn.Embedding], |
| d: Optional[int] = None, |
| layer_id: Optional[int] = None, |
| std_factor: float = 1.0, |
| type_of_module: Optional[ModuleType] = None, |
| ) -> None: |
| """ |
| Initialize weights of a linear or embedding module. |
| |
| :param config: The model config. |
| :param module: The linear or embedding submodule to initialize. |
| :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions |
| for fused layers. |
| :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by |
| ``1 / sqrt(2 * (layer_id + 1))``. |
| """ |
| d = d if d is not None else config.d_model |
| if config.init_fn == InitFnType.normal: |
| std = config.init_std * std_factor |
| if config.init_cutoff_factor is not None: |
| cutoff_value = config.init_cutoff_factor * std |
| nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) |
| else: |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| elif config.init_fn == InitFnType.mitchell: |
| std = std_factor / math.sqrt(d) |
| if layer_id is not None: |
| std = std / math.sqrt(2 * (layer_id + 1)) |
| nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) |
| elif config.init_fn == InitFnType.kaiming_normal: |
| nn.init.kaiming_normal_(module.weight, nonlinearity="relu") |
| elif config.init_fn == InitFnType.fan_in: |
| std = std_factor / math.sqrt(d) |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| elif config.init_fn == InitFnType.full_megatron: |
| if type_of_module is None: |
| raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") |
|
|
| cutoff_factor = config.init_cutoff_factor |
| if cutoff_factor is None: |
| cutoff_factor = 3 |
|
|
| if type_of_module == ModuleType.in_module: |
| |
| std = config.init_std |
| elif type_of_module == ModuleType.out_module: |
| |
| std = config.init_std / math.sqrt(2.0 * config.n_layers) |
| elif type_of_module == ModuleType.emb: |
| |
| |
| std = config.init_std |
| elif type_of_module == ModuleType.final_out: |
| |
| std = config.d_model**-0.5 |
| else: |
| raise RuntimeError(f"Unknown module type '{type_of_module}'") |
| nn.init.trunc_normal_( |
| module.weight, |
| mean=0.0, |
| std=std, |
| a=-cutoff_factor * std, |
| b=cutoff_factor * std, |
| ) |
| else: |
| raise NotImplementedError(config.init_fn) |
|
|
| if isinstance(module, nn.Linear): |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): |
| with torch.no_grad(): |
| module.weight.div_(math.sqrt(2 * config.n_layers)) |
|
|