Spaces:
Sleeping
Sleeping
| """ | |
| Simple training loop; Boilerplate that could apply to any arbitrary neural network, | |
| so nothing in this file really has anything to do with GPT specifically. | |
| """ | |
| import time | |
| from collections import defaultdict | |
| import torch | |
| from torch.utils.data.dataloader import DataLoader | |
| from mingpt.utils import CfgNode as CN | |
| class Trainer: | |
| def get_default_config(): | |
| C = CN() | |
| # device to train on | |
| C.device = 'auto' | |
| # dataloder parameters | |
| C.num_workers = 4 | |
| # optimizer parameters | |
| C.max_iters = None | |
| C.batch_size = 64 | |
| C.learning_rate = 3e-4 | |
| C.betas = (0.9, 0.95) | |
| C.weight_decay = 0.1 # only applied on matmul weights | |
| C.grad_norm_clip = 1.0 | |
| return C | |
| def __init__(self, config, model, train_dataset): | |
| self.config = config | |
| self.model = model | |
| self.optimizer = None | |
| self.train_dataset = train_dataset | |
| self.callbacks = defaultdict(list) | |
| # determine the device we'll train on | |
| if config.device == 'auto': | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| else: | |
| self.device = config.device | |
| self.model = self.model.to(self.device) | |
| print("running on device", self.device) | |
| # variables that will be assigned to trainer class later for logging and etc | |
| self.iter_num = 0 | |
| self.iter_time = 0.0 | |
| self.iter_dt = 0.0 | |
| def add_callback(self, onevent: str, callback): | |
| self.callbacks[onevent].append(callback) | |
| def set_callback(self, onevent: str, callback): | |
| self.callbacks[onevent] = [callback] | |
| def trigger_callbacks(self, onevent: str): | |
| for callback in self.callbacks.get(onevent, []): | |
| callback(self) | |
| def run(self): | |
| model, config = self.model, self.config | |
| # setup the optimizer | |
| self.optimizer = model.configure_optimizers(config) | |
| # setup the dataloader | |
| train_loader = DataLoader( | |
| self.train_dataset, | |
| sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)), | |
| shuffle=False, | |
| pin_memory=True, | |
| batch_size=config.batch_size, | |
| num_workers=config.num_workers, | |
| ) | |
| model.train() | |
| self.iter_num = 0 | |
| self.iter_time = time.time() | |
| data_iter = iter(train_loader) | |
| while True: | |
| # fetch the next batch (x, y) and re-init iterator if needed | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(train_loader) | |
| batch = next(data_iter) | |
| batch = [t.to(self.device) for t in batch] | |
| x, y = batch | |
| # forward the model | |
| logits, self.loss = model(x, y) | |
| # backprop and update the parameters | |
| model.zero_grad(set_to_none=True) | |
| self.loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) | |
| self.optimizer.step() | |
| self.trigger_callbacks('on_batch_end') | |
| self.iter_num += 1 | |
| tnow = time.time() | |
| self.iter_dt = tnow - self.iter_time | |
| self.iter_time = tnow | |
| # termination conditions | |
| if config.max_iters is not None and self.iter_num >= config.max_iters: | |
| break | |