fariasultana's picture
MiniMind Max2 API - Gradio Interface
bd21ba5 verified
"""
MiniMind Max2 Model Configuration
Inspired by MiniMax M2's efficient activated parameters design
"""
from dataclasses import dataclass
from typing import Optional, Dict, Any
@dataclass
class Max2Config:
"""Configuration for MiniMind Max2 models."""
# Model identification
model_name: str = "max2-lite"
model_version: str = "1.0.0"
# Architecture dimensions
hidden_size: int = 1536
intermediate_size: int = 4096
num_hidden_layers: int = 24
num_attention_heads: int = 12
num_key_value_heads: int = 3 # GQA ratio 4:1
# Vocabulary and embeddings
vocab_size: int = 32000
max_position_embeddings: int = 8192
rope_theta: float = 10000.0
# MoE (Mixture of Experts) configuration
use_moe: bool = True
num_experts: int = 8
num_experts_per_tok: int = 2 # Only 25% activation
expert_hidden_size: int = 1024
router_aux_loss_coef: float = 0.01
# Normalization and activation
rms_norm_eps: float = 1e-6
hidden_act: str = "silu"
# Regularization
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
# Special tokens
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
# Initialization
initializer_range: float = 0.02
# Memory optimization
use_cache: bool = True
use_flash_attention: bool = True
gradient_checkpointing: bool = False
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in self.__dict__.items()}
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "Max2Config":
return cls(**{k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__})
# Predefined model configurations
MAX2_CONFIGS = {
"max2-nano": Max2Config(
model_name="max2-nano",
hidden_size=768,
intermediate_size=2048,
num_hidden_layers=12,
num_attention_heads=12,
num_key_value_heads=3,
num_experts=4,
num_experts_per_tok=1,
expert_hidden_size=512,
max_position_embeddings=4096,
),
"max2-lite": Max2Config(
model_name="max2-lite",
hidden_size=1536,
intermediate_size=4096,
num_hidden_layers=24,
num_attention_heads=12,
num_key_value_heads=3,
num_experts=8,
num_experts_per_tok=2,
expert_hidden_size=1024,
max_position_embeddings=8192,
),
"max2-pro": Max2Config(
model_name="max2-pro",
hidden_size=2560,
intermediate_size=6912,
num_hidden_layers=32,
num_attention_heads=20,
num_key_value_heads=4,
num_experts=8,
num_experts_per_tok=2,
expert_hidden_size=1728,
max_position_embeddings=16384,
),
}
# Aliases for backward compatibility
Mind2Config = Max2Config
MIND2_CONFIGS = MAX2_CONFIGS
def get_config(model_name: str) -> Max2Config:
"""Get predefined configuration by name."""
if model_name not in MAX2_CONFIGS:
raise ValueError(f"Unknown model: {model_name}. Available: {list(MAX2_CONFIGS.keys())}")
return MAX2_CONFIGS[model_name]
def estimate_params(config: Max2Config) -> dict:
"""Estimate parameter counts for a configuration."""
embed_params = config.vocab_size * config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads
# Attention parameters per layer (GQA)
q_params = config.hidden_size * config.hidden_size
kv_params = 2 * config.hidden_size * (config.num_key_value_heads * head_dim)
o_params = config.hidden_size * config.hidden_size
attn_params_per_layer = q_params + kv_params + o_params
# MoE FFN parameters per layer
if config.use_moe:
router_params = config.hidden_size * config.num_experts
expert_params = 3 * config.hidden_size * config.expert_hidden_size
ffn_params_per_layer = router_params + (config.num_experts * expert_params)
active_ffn_params = router_params + (config.num_experts_per_tok * expert_params)
else:
ffn_params_per_layer = 3 * config.hidden_size * config.intermediate_size
active_ffn_params = ffn_params_per_layer
norm_params_per_layer = 2 * config.hidden_size
layer_params = attn_params_per_layer + ffn_params_per_layer + norm_params_per_layer
active_layer_params = attn_params_per_layer + active_ffn_params + norm_params_per_layer
total_params = embed_params + (config.num_hidden_layers * layer_params) + embed_params
active_params = embed_params + (config.num_hidden_layers * active_layer_params) + embed_params
return {
"total_params": total_params,
"active_params": active_params,
"activation_ratio": active_params / total_params,
"total_params_b": total_params / 1e9,
"active_params_b": active_params / 1e9,
"estimated_size_fp16_gb": (total_params * 2) / (1024**3),
"estimated_size_int4_gb": (total_params * 0.5) / (1024**3),
}