Atomu2014's picture
demo init commit
1f0c7b9
from typing import Any, Optional, Union, Dict
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import Trainer, LightningModule
import numpy as np
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities import rank_zero_only
from overrides import overrides
import torch
from torch import Tensor
import torch.nn.functional as F
from absl import logging
import time
import os
from torch.optim import Optimizer
from copy import deepcopy
class Queue:
def __init__(self, max_len=50):
self.items = [1]
self.max_len = max_len
def __len__(self):
return len(self.items)
def add(self, item):
self.items.insert(0, item)
if len(self) > self.max_len:
self.items.pop()
def mean(self):
return np.mean(self.items)
def std(self):
return np.std(self.items)
class GradientClip(Callback):
def __init__(self, max_grad_norm='Q', Q=Queue(3000)) -> None:
super().__init__()
# self.max_norm = max_norm
self.gradnorm_queue = Q
if max_grad_norm == 'Q':
self.max_grad_norm = max_grad_norm
else:
self.max_grad_norm = float(max_grad_norm)
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
# zero graidents if they are not finite
# if not all([torch.isfinite(t.grad).all() for t in pl_module.parameters()]):
# logging.warning("Gradients are not finite number")
# pl_module.zero_grad()
# return None
if self.max_grad_norm == 'Q':
max_grad_norm = 1.5 * self.gradnorm_queue.mean() + 2 * self.gradnorm_queue.std()
max_grad_norm = max_grad_norm.item()
else:
max_grad_norm = self.max_grad_norm
grad_norm = torch.nn.utils.clip_grad_norm_(
pl_module.parameters(), max_norm=max_grad_norm, norm_type=2.0
)
if self.max_grad_norm == 'Q':
if float(grad_norm) > max_grad_norm:
self.gradnorm_queue.add(float(max_grad_norm))
else:
self.gradnorm_queue.add(float(grad_norm))
if float(grad_norm) > max_grad_norm:
logging.info(
f"Clipped gradient with value {grad_norm:.1f} "
f"while allowed {max_grad_norm:.1f}",
)
pl_module.log_dict(
{
"grad_norm": grad_norm.item(),
'max_grad_norm': max_grad_norm,
},
on_step=True,
prog_bar=False,
logger=True,
batch_size=pl_module.cfg.train.batch_size,
)
class DebugCallback(Callback):
# gradient clupping for
def __init__(self) -> None:
super().__init__()
# self.max_norm = max_norm
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
if not all([torch.isfinite(t.grad).all() for t in pl_module.parameters()]):
for t in pl_module.parameters():
if not torch.isfinite(t.grad).all():
print(t.name, t.grad)
raise ValueError("gradient is not finite number")
def on_train_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
self._start_time = time.time()
def on_before_backward(
self, trainer: Trainer, pl_module: LightningModule, loss: Tensor
) -> None:
super().on_before_backward(trainer, pl_module, loss)
_cur_time = time.time()
logging.info(
f"from trainbatch start to before backward took {_cur_time - self._start_time} secs"
)
def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_after_backward(trainer, pl_module)
_cur_time = time.time()
logging.info(
f"from trainbatch start to after backward took {_cur_time - self._start_time} secs"
)
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
super().on_before_optimizer_step(trainer, pl_module, optimizer)
_cur_time = time.time()
logging.info(
f"from trainbatch start to before optimizer step took {_cur_time - self._start_time} secs"
)
def on_before_zero_grad(
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
) -> None:
super().on_before_zero_grad(trainer, pl_module, optimizer)
_cur_time = time.time()
logging.info(
f"from trainbatch start to before zero grad took {_cur_time - self._start_time} secs"
)
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
_cur_time = time.time()
logging.info(f"train batch took {_cur_time - self._start_time} secs")
class NormalizerCallback(Callback):
# for data inputs we need to normalize the data, before the data outputs we
def __init__(self, normalizer_dict) -> None:
super().__init__()
self.normalizer_dict = normalizer_dict
self.pos_normalizer = torch.tensor(self.normalizer_dict.pos, dtype=torch.float32)
self.device = None
def quantize(self, pos, h):
# quantize the latent space
h = F.one_hot(torch.argmax(h, dim=-1), num_classes=h.shape[-1])
return pos, h
def on_train_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
if self.device is None:
self.device = batch.protein_pos.device
self.pos_normalizer = self.pos_normalizer.to(self.device)
batch.protein_pos = batch.protein_pos / self.pos_normalizer
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer
# batch.x = batch.x / self.normalizer_dict.one_hot
# #batch.charges = batch.charges / self.normalizer_dict.charges - 1
# # print(batch.charges)
# batch.charges = (2*batch.charges - 1)/self.normalizer_dict.charges - 1 #normalizer as k_c
# print(batch.charges)
def on_validation_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
super().on_validation_batch_start(trainer, pl_module, batch, batch_idx)
if self.device is None:
self.device = batch.protein_pos.device
self.pos_normalizer = self.pos_normalizer.to(self.device)
batch.protein_pos = batch.protein_pos / self.pos_normalizer
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer
def on_test_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
super().on_test_batch_start(trainer, pl_module, batch, batch_idx)
if self.device is None:
self.device = batch.protein_pos.device
self.pos_normalizer = self.pos_normalizer.to(self.device)
batch.protein_pos = batch.protein_pos / self.pos_normalizer
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer
class RecoverCallback(Callback):
def __init__(self, latest_ckpt, recover_trigger_loss=1e3, resume=False) -> None:
super().__init__()
self.latest_ckpt = latest_ckpt
self.recover_trigger_loss = recover_trigger_loss
self.resume = resume
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
super().setup(trainer, pl_module, stage)
if os.path.exists(self.latest_ckpt) and self.resume:
print(f"recover from checkpoint: {self.latest_ckpt}")
checkpoint = torch.load(self.latest_ckpt)
pl_module.load_state_dict(checkpoint["state_dict"])
# pl_module.load_from_checkpoint(self.latest_ckpt)
elif not os.path.exists(self.latest_ckpt) and self.resume:
print(
f"checkpoint {self.latest_ckpt} not found, training from scratch"
)
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
if "loss" not in outputs:
return None
if outputs["loss"] > self.recover_trigger_loss:
logging.warning(
f"loss too large: {outputs}\n recovering from checkpoint: {self.latest_ckpt}"
)
if os.path.exists(self.latest_ckpt):
checkpoint = torch.load(self.latest_ckpt)
pl_module.load_state_dict(checkpoint["state_dict"])
else:
for layer in pl_module.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
logging.warning(
f"checkpoint {self.latest_ckpt} not found, training from scratch"
)
else:
pass
class EMACallback(pl.Callback):
"""Implements EMA (exponential moving average) to any kind of model.
EMA weights will be used during validation and stored separately from original model weights.
How to use EMA:
- Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See
https://github.com/rwightman/pytorch-image-models/issues/102
- Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and
https://github.com/rwightman/pytorch-image-models/issues/224
- For object detection, SWA usually works better. See https://github.com/timgaripov/swa/issues/16
Implementation detail:
- See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
- When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
performance.
"""
def __init__(
self,
decay: float = 0.9999,
ema_device: Optional[Union[torch.device, str]] = None,
pin_memory=True,
):
super().__init__()
self.decay = decay
self.ema_device: str = (
f"{ema_device}" if ema_device else None
) # perform ema on different device from the model
self.ema_pin_memory = (
pin_memory if torch.cuda.is_available() else False
) # Only works if CUDA is available
self.ema_state_dict: Dict[str, torch.Tensor] = {}
self.original_state_dict = {}
self._ema_state_dict_ready = False
@staticmethod
def get_state_dict(pl_module: pl.LightningModule):
"""Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
For example, in pl_module has metrics, you don't want to return their parameters.
code:
# Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
# like losses, metrics, etc.
patterns_to_ignore = ("metrics1", "metrics2")
return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
"""
return pl_module.state_dict()
@overrides
def on_train_start(
self, trainer: "pl.Trainer", pl_module: pl.LightningModule
) -> None:
# Only keep track of EMA weights in rank zero.
if not self._ema_state_dict_ready and pl_module.global_rank == 0:
self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
if self.ema_device:
self.ema_state_dict = {
k: tensor.to(device=self.ema_device)
for k, tensor in self.ema_state_dict.items()
}
if self.ema_device == "cpu" and self.ema_pin_memory:
self.ema_state_dict = {
k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()
}
self._ema_state_dict_ready = True
@rank_zero_only
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs
) -> None:
# Update EMA weights
with torch.no_grad():
for key, value in self.get_state_dict(pl_module).items():
ema_value = self.ema_state_dict[key]
ema_value.copy_(
self.decay * ema_value + (1.0 - self.decay) * value,
non_blocking=True,
)
@overrides
def on_validation_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
trainer.strategy.broadcast(self.ema_state_dict, 0)
assert self.ema_state_dict.keys() == self.original_state_dict.keys(), (
f"There are some keys missing in the ema static dictionary broadcasted. "
f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
)
pl_module.load_state_dict(self.ema_state_dict, strict=False)
if pl_module.global_rank > 0:
# Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
self.ema_state_dict = {}
@overrides
def on_validation_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=False)
@overrides
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.on_validation_start(trainer, pl_module)
@overrides
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.on_validation_end(trainer, pl_module)
@overrides
def on_save_checkpoint(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any],
) -> None:
checkpoint["ema_state_dict"] = self.ema_state_dict
checkpoint["_ema_state_dict_ready"] = self._ema_state_dict_ready
# return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}
@overrides
def on_load_checkpoint(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any],
) -> None:
if checkpoint is None:
self._ema_state_dict_ready = False
else:
self._ema_state_dict_ready = checkpoint["_ema_state_dict_ready"]
self.ema_state_dict = checkpoint["ema_state_dict"]