""" MiniMind Max2 Model Configuration Inspired by MiniMax M2's efficient activated parameters design """ from dataclasses import dataclass from typing import Optional, Dict, Any @dataclass class Max2Config: """Configuration for MiniMind Max2 models.""" # Model identification model_name: str = "max2-lite" model_version: str = "1.0.0" # Architecture dimensions hidden_size: int = 1536 intermediate_size: int = 4096 num_hidden_layers: int = 24 num_attention_heads: int = 12 num_key_value_heads: int = 3 # GQA ratio 4:1 # Vocabulary and embeddings vocab_size: int = 32000 max_position_embeddings: int = 8192 rope_theta: float = 10000.0 # MoE (Mixture of Experts) configuration use_moe: bool = True num_experts: int = 8 num_experts_per_tok: int = 2 # Only 25% activation expert_hidden_size: int = 1024 router_aux_loss_coef: float = 0.01 # Normalization and activation rms_norm_eps: float = 1e-6 hidden_act: str = "silu" # Regularization hidden_dropout: float = 0.0 attention_dropout: float = 0.0 # Special tokens pad_token_id: int = 0 bos_token_id: int = 1 eos_token_id: int = 2 # Initialization initializer_range: float = 0.02 # Memory optimization use_cache: bool = True use_flash_attention: bool = True gradient_checkpointing: bool = False def to_dict(self) -> Dict[str, Any]: return {k: v for k, v in self.__dict__.items()} @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "Max2Config": return cls(**{k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__}) # Predefined model configurations MAX2_CONFIGS = { "max2-nano": Max2Config( model_name="max2-nano", hidden_size=768, intermediate_size=2048, num_hidden_layers=12, num_attention_heads=12, num_key_value_heads=3, num_experts=4, num_experts_per_tok=1, expert_hidden_size=512, max_position_embeddings=4096, ), "max2-lite": Max2Config( model_name="max2-lite", hidden_size=1536, intermediate_size=4096, num_hidden_layers=24, num_attention_heads=12, num_key_value_heads=3, num_experts=8, num_experts_per_tok=2, expert_hidden_size=1024, max_position_embeddings=8192, ), "max2-pro": Max2Config( model_name="max2-pro", hidden_size=2560, intermediate_size=6912, num_hidden_layers=32, num_attention_heads=20, num_key_value_heads=4, num_experts=8, num_experts_per_tok=2, expert_hidden_size=1728, max_position_embeddings=16384, ), } # Aliases for backward compatibility Mind2Config = Max2Config MIND2_CONFIGS = MAX2_CONFIGS def get_config(model_name: str) -> Max2Config: """Get predefined configuration by name.""" if model_name not in MAX2_CONFIGS: raise ValueError(f"Unknown model: {model_name}. Available: {list(MAX2_CONFIGS.keys())}") return MAX2_CONFIGS[model_name] def estimate_params(config: Max2Config) -> dict: """Estimate parameter counts for a configuration.""" embed_params = config.vocab_size * config.hidden_size head_dim = config.hidden_size // config.num_attention_heads # Attention parameters per layer (GQA) q_params = config.hidden_size * config.hidden_size kv_params = 2 * config.hidden_size * (config.num_key_value_heads * head_dim) o_params = config.hidden_size * config.hidden_size attn_params_per_layer = q_params + kv_params + o_params # MoE FFN parameters per layer if config.use_moe: router_params = config.hidden_size * config.num_experts expert_params = 3 * config.hidden_size * config.expert_hidden_size ffn_params_per_layer = router_params + (config.num_experts * expert_params) active_ffn_params = router_params + (config.num_experts_per_tok * expert_params) else: ffn_params_per_layer = 3 * config.hidden_size * config.intermediate_size active_ffn_params = ffn_params_per_layer norm_params_per_layer = 2 * config.hidden_size layer_params = attn_params_per_layer + ffn_params_per_layer + norm_params_per_layer active_layer_params = attn_params_per_layer + active_ffn_params + norm_params_per_layer total_params = embed_params + (config.num_hidden_layers * layer_params) + embed_params active_params = embed_params + (config.num_hidden_layers * active_layer_params) + embed_params return { "total_params": total_params, "active_params": active_params, "activation_ratio": active_params / total_params, "total_params_b": total_params / 1e9, "active_params_b": active_params / 1e9, "estimated_size_fp16_gb": (total_params * 2) / (1024**3), "estimated_size_int4_gb": (total_params * 0.5) / (1024**3), }