chegde commited on
Commit
063be31
·
verified ·
1 Parent(s): bbbee58

Fix model registration for AutoModel compatibility

Browse files
Files changed (3) hide show
  1. __init__.py +31 -0
  2. configuration_nanogpt.py +5 -0
  3. modeling_nanogpt.py +41 -4
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NanoGPT HuggingFace Integration"""
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ # Import our classes
6
+ try:
7
+ from .configuration_nanogpt import NanoGPTConfig
8
+ from .modeling_nanogpt import (
9
+ NanoGPTModel,
10
+ NanoGPTForCausalLM,
11
+ NanoGPTPreTrainedModel
12
+ )
13
+ except ImportError:
14
+ from configuration_nanogpt import NanoGPTConfig
15
+ from modeling_nanogpt import (
16
+ NanoGPTModel,
17
+ NanoGPTForCausalLM,
18
+ NanoGPTPreTrainedModel
19
+ )
20
+
21
+ # Register the model with Auto* classes
22
+ AutoConfig.register("nanogpt", NanoGPTConfig)
23
+ AutoModel.register(NanoGPTConfig, NanoGPTModel)
24
+ AutoModelForCausalLM.register(NanoGPTConfig, NanoGPTForCausalLM)
25
+
26
+ __all__ = [
27
+ "NanoGPTConfig",
28
+ "NanoGPTModel",
29
+ "NanoGPTForCausalLM",
30
+ "NanoGPTPreTrainedModel"
31
+ ]
configuration_nanogpt.py CHANGED
@@ -1,8 +1,13 @@
1
  """NanoGPT model configuration"""
2
 
3
  from transformers import PretrainedConfig
 
 
 
4
 
5
  class NanoGPTConfig(PretrainedConfig):
 
 
6
  model_type = "nanogpt"
7
 
8
  def __init__(
 
1
  """NanoGPT model configuration"""
2
 
3
  from transformers import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
 
8
  class NanoGPTConfig(PretrainedConfig):
9
+ """Configuration for NanoGPT model"""
10
+
11
  model_type = "nanogpt"
12
 
13
  def __init__(
modeling_nanogpt.py CHANGED
@@ -1,4 +1,4 @@
1
- """NanoGPT model implementation"""
2
 
3
  import torch
4
  import torch.nn as nn
@@ -6,7 +6,15 @@ import torch.nn.functional as F
6
  import math
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
- from .configuration_nanogpt import NanoGPTConfig
 
 
 
 
 
 
 
 
10
 
11
  class ExactNanoGPTAttention(nn.Module):
12
  def __init__(self, config):
@@ -80,8 +88,26 @@ class ExactNanoGPTBlock(nn.Module):
80
  x = x + self.mlp(self.ln_2(x))
81
  return x
82
 
83
- class NanoGPTModel(PreTrainedModel):
 
84
  config_class = NanoGPTConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def __init__(self, config):
87
  super().__init__(config)
@@ -96,8 +122,15 @@ class NanoGPTModel(PreTrainedModel):
96
  ))
97
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
98
 
 
99
  self.post_init()
100
 
 
 
 
 
 
 
101
  def forward(self, input_ids, attention_mask=None, **kwargs):
102
  device = input_ids.device
103
  b, t = input_ids.size()
@@ -115,7 +148,8 @@ class NanoGPTModel(PreTrainedModel):
115
 
116
  return CausalLMOutputWithCrossAttentions(logits=logits)
117
 
118
- def generate(self, input_ids, max_length=None, max_new_tokens=None, temperature=1.0, top_k=None, do_sample=True, top_p=None, **kwargs):
 
119
  if max_new_tokens is None:
120
  max_new_tokens = max_length - input_ids.shape[1] if max_length else 50
121
 
@@ -153,3 +187,6 @@ class NanoGPTModel(PreTrainedModel):
153
  input_ids = torch.cat((input_ids, idx_next), dim=1)
154
 
155
  return input_ids
 
 
 
 
1
+ """NanoGPT model implementation for HuggingFace"""
2
 
3
  import torch
4
  import torch.nn as nn
 
6
  import math
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
+ from transformers.utils import logging
10
+
11
+ # Import configuration
12
+ try:
13
+ from .configuration_nanogpt import NanoGPTConfig
14
+ except ImportError:
15
+ from configuration_nanogpt import NanoGPTConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
 
19
  class ExactNanoGPTAttention(nn.Module):
20
  def __init__(self, config):
 
88
  x = x + self.mlp(self.ln_2(x))
89
  return x
90
 
91
+ class NanoGPTPreTrainedModel(PreTrainedModel):
92
+ """Base class for NanoGPT models"""
93
  config_class = NanoGPTConfig
94
+ base_model_prefix = "transformer"
95
+ supports_gradient_checkpointing = False
96
+ _no_split_modules = ["ExactNanoGPTBlock"]
97
+
98
+ def _init_weights(self, module):
99
+ if isinstance(module, nn.Linear):
100
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
101
+ if module.bias is not None:
102
+ torch.nn.init.zeros_(module.bias)
103
+ elif isinstance(module, nn.Embedding):
104
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
105
+ elif isinstance(module, nn.LayerNorm):
106
+ torch.nn.init.zeros_(module.bias)
107
+ torch.nn.init.ones_(module.weight)
108
+
109
+ class NanoGPTModel(NanoGPTPreTrainedModel):
110
+ """The main NanoGPT model"""
111
 
112
  def __init__(self, config):
113
  super().__init__(config)
 
122
  ))
123
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
124
 
125
+ # Initialize weights
126
  self.post_init()
127
 
128
+ def get_input_embeddings(self):
129
+ return self.transformer.wte
130
+
131
+ def set_input_embeddings(self, new_embeddings):
132
+ self.transformer.wte = new_embeddings
133
+
134
  def forward(self, input_ids, attention_mask=None, **kwargs):
135
  device = input_ids.device
136
  b, t = input_ids.size()
 
148
 
149
  return CausalLMOutputWithCrossAttentions(logits=logits)
150
 
151
+ def generate(self, input_ids, max_length=None, max_new_tokens=None, temperature=1.0,
152
+ top_k=None, do_sample=True, top_p=None, pad_token_id=None, eos_token_id=None, **kwargs):
153
  if max_new_tokens is None:
154
  max_new_tokens = max_length - input_ids.shape[1] if max_length else 50
155
 
 
187
  input_ids = torch.cat((input_ids, idx_next), dim=1)
188
 
189
  return input_ids
190
+
191
+ # For backward compatibility
192
+ NanoGPTForCausalLM = NanoGPTModel