File size: 4,965 Bytes
bd21ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
MiniMind Max2 Model Configuration
Inspired by MiniMax M2's efficient activated parameters design
"""

from dataclasses import dataclass
from typing import Optional, Dict, Any


@dataclass
class Max2Config:
    """Configuration for MiniMind Max2 models."""

    # Model identification
    model_name: str = "max2-lite"
    model_version: str = "1.0.0"

    # Architecture dimensions
    hidden_size: int = 1536
    intermediate_size: int = 4096
    num_hidden_layers: int = 24
    num_attention_heads: int = 12
    num_key_value_heads: int = 3  # GQA ratio 4:1

    # Vocabulary and embeddings
    vocab_size: int = 32000
    max_position_embeddings: int = 8192
    rope_theta: float = 10000.0

    # MoE (Mixture of Experts) configuration
    use_moe: bool = True
    num_experts: int = 8
    num_experts_per_tok: int = 2  # Only 25% activation
    expert_hidden_size: int = 1024
    router_aux_loss_coef: float = 0.01

    # Normalization and activation
    rms_norm_eps: float = 1e-6
    hidden_act: str = "silu"

    # Regularization
    hidden_dropout: float = 0.0
    attention_dropout: float = 0.0

    # Special tokens
    pad_token_id: int = 0
    bos_token_id: int = 1
    eos_token_id: int = 2

    # Initialization
    initializer_range: float = 0.02

    # Memory optimization
    use_cache: bool = True
    use_flash_attention: bool = True
    gradient_checkpointing: bool = False

    def to_dict(self) -> Dict[str, Any]:
        return {k: v for k, v in self.__dict__.items()}

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "Max2Config":
        return cls(**{k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__})


# Predefined model configurations
MAX2_CONFIGS = {
    "max2-nano": Max2Config(
        model_name="max2-nano",
        hidden_size=768,
        intermediate_size=2048,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_key_value_heads=3,
        num_experts=4,
        num_experts_per_tok=1,
        expert_hidden_size=512,
        max_position_embeddings=4096,
    ),
    "max2-lite": Max2Config(
        model_name="max2-lite",
        hidden_size=1536,
        intermediate_size=4096,
        num_hidden_layers=24,
        num_attention_heads=12,
        num_key_value_heads=3,
        num_experts=8,
        num_experts_per_tok=2,
        expert_hidden_size=1024,
        max_position_embeddings=8192,
    ),
    "max2-pro": Max2Config(
        model_name="max2-pro",
        hidden_size=2560,
        intermediate_size=6912,
        num_hidden_layers=32,
        num_attention_heads=20,
        num_key_value_heads=4,
        num_experts=8,
        num_experts_per_tok=2,
        expert_hidden_size=1728,
        max_position_embeddings=16384,
    ),
}

# Aliases for backward compatibility
Mind2Config = Max2Config
MIND2_CONFIGS = MAX2_CONFIGS


def get_config(model_name: str) -> Max2Config:
    """Get predefined configuration by name."""
    if model_name not in MAX2_CONFIGS:
        raise ValueError(f"Unknown model: {model_name}. Available: {list(MAX2_CONFIGS.keys())}")
    return MAX2_CONFIGS[model_name]


def estimate_params(config: Max2Config) -> dict:
    """Estimate parameter counts for a configuration."""
    embed_params = config.vocab_size * config.hidden_size
    head_dim = config.hidden_size // config.num_attention_heads

    # Attention parameters per layer (GQA)
    q_params = config.hidden_size * config.hidden_size
    kv_params = 2 * config.hidden_size * (config.num_key_value_heads * head_dim)
    o_params = config.hidden_size * config.hidden_size
    attn_params_per_layer = q_params + kv_params + o_params

    # MoE FFN parameters per layer
    if config.use_moe:
        router_params = config.hidden_size * config.num_experts
        expert_params = 3 * config.hidden_size * config.expert_hidden_size
        ffn_params_per_layer = router_params + (config.num_experts * expert_params)
        active_ffn_params = router_params + (config.num_experts_per_tok * expert_params)
    else:
        ffn_params_per_layer = 3 * config.hidden_size * config.intermediate_size
        active_ffn_params = ffn_params_per_layer

    norm_params_per_layer = 2 * config.hidden_size
    layer_params = attn_params_per_layer + ffn_params_per_layer + norm_params_per_layer
    active_layer_params = attn_params_per_layer + active_ffn_params + norm_params_per_layer

    total_params = embed_params + (config.num_hidden_layers * layer_params) + embed_params
    active_params = embed_params + (config.num_hidden_layers * active_layer_params) + embed_params

    return {
        "total_params": total_params,
        "active_params": active_params,
        "activation_ratio": active_params / total_params,
        "total_params_b": total_params / 1e9,
        "active_params_b": active_params / 1e9,
        "estimated_size_fp16_gb": (total_params * 2) / (1024**3),
        "estimated_size_int4_gb": (total_params * 0.5) / (1024**3),
    }