|
|
from typing import Any, Optional, Union, Dict |
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning.callbacks import Callback |
|
|
from pytorch_lightning import Trainer, LightningModule |
|
|
import numpy as np |
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT |
|
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
from overrides import overrides |
|
|
import torch |
|
|
from torch import Tensor |
|
|
import torch.nn.functional as F |
|
|
from absl import logging |
|
|
import time |
|
|
import os |
|
|
from torch.optim import Optimizer |
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
class Queue: |
|
|
def __init__(self, max_len=50): |
|
|
self.items = [1] |
|
|
self.max_len = max_len |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.items) |
|
|
|
|
|
def add(self, item): |
|
|
self.items.insert(0, item) |
|
|
if len(self) > self.max_len: |
|
|
self.items.pop() |
|
|
|
|
|
def mean(self): |
|
|
return np.mean(self.items) |
|
|
|
|
|
def std(self): |
|
|
return np.std(self.items) |
|
|
|
|
|
|
|
|
class GradientClip(Callback): |
|
|
def __init__(self, max_grad_norm='Q', Q=Queue(3000)) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.gradnorm_queue = Q |
|
|
if max_grad_norm == 'Q': |
|
|
self.max_grad_norm = max_grad_norm |
|
|
else: |
|
|
self.max_grad_norm = float(max_grad_norm) |
|
|
|
|
|
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.max_grad_norm == 'Q': |
|
|
max_grad_norm = 1.5 * self.gradnorm_queue.mean() + 2 * self.gradnorm_queue.std() |
|
|
max_grad_norm = max_grad_norm.item() |
|
|
else: |
|
|
max_grad_norm = self.max_grad_norm |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
pl_module.parameters(), max_norm=max_grad_norm, norm_type=2.0 |
|
|
) |
|
|
|
|
|
if self.max_grad_norm == 'Q': |
|
|
if float(grad_norm) > max_grad_norm: |
|
|
self.gradnorm_queue.add(float(max_grad_norm)) |
|
|
else: |
|
|
self.gradnorm_queue.add(float(grad_norm)) |
|
|
|
|
|
if float(grad_norm) > max_grad_norm: |
|
|
logging.info( |
|
|
f"Clipped gradient with value {grad_norm:.1f} " |
|
|
f"while allowed {max_grad_norm:.1f}", |
|
|
) |
|
|
pl_module.log_dict( |
|
|
{ |
|
|
"grad_norm": grad_norm.item(), |
|
|
'max_grad_norm': max_grad_norm, |
|
|
}, |
|
|
on_step=True, |
|
|
prog_bar=False, |
|
|
logger=True, |
|
|
batch_size=pl_module.cfg.train.batch_size, |
|
|
) |
|
|
|
|
|
|
|
|
class DebugCallback(Callback): |
|
|
|
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: |
|
|
if not all([torch.isfinite(t.grad).all() for t in pl_module.parameters()]): |
|
|
for t in pl_module.parameters(): |
|
|
if not torch.isfinite(t.grad).all(): |
|
|
print(t.name, t.grad) |
|
|
raise ValueError("gradient is not finite number") |
|
|
|
|
|
def on_train_batch_start( |
|
|
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
super().on_train_batch_start(trainer, pl_module, batch, batch_idx) |
|
|
self._start_time = time.time() |
|
|
|
|
|
def on_before_backward( |
|
|
self, trainer: Trainer, pl_module: LightningModule, loss: Tensor |
|
|
) -> None: |
|
|
super().on_before_backward(trainer, pl_module, loss) |
|
|
_cur_time = time.time() |
|
|
logging.info( |
|
|
f"from trainbatch start to before backward took {_cur_time - self._start_time} secs" |
|
|
) |
|
|
|
|
|
def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
|
super().on_after_backward(trainer, pl_module) |
|
|
_cur_time = time.time() |
|
|
logging.info( |
|
|
f"from trainbatch start to after backward took {_cur_time - self._start_time} secs" |
|
|
) |
|
|
|
|
|
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: |
|
|
super().on_before_optimizer_step(trainer, pl_module, optimizer) |
|
|
_cur_time = time.time() |
|
|
logging.info( |
|
|
f"from trainbatch start to before optimizer step took {_cur_time - self._start_time} secs" |
|
|
) |
|
|
|
|
|
def on_before_zero_grad( |
|
|
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer |
|
|
) -> None: |
|
|
super().on_before_zero_grad(trainer, pl_module, optimizer) |
|
|
_cur_time = time.time() |
|
|
logging.info( |
|
|
f"from trainbatch start to before zero grad took {_cur_time - self._start_time} secs" |
|
|
) |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, |
|
|
trainer: Trainer, |
|
|
pl_module: LightningModule, |
|
|
outputs: STEP_OUTPUT, |
|
|
batch: Any, |
|
|
batch_idx: int, |
|
|
) -> None: |
|
|
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) |
|
|
_cur_time = time.time() |
|
|
logging.info(f"train batch took {_cur_time - self._start_time} secs") |
|
|
|
|
|
|
|
|
class NormalizerCallback(Callback): |
|
|
|
|
|
def __init__(self, normalizer_dict) -> None: |
|
|
super().__init__() |
|
|
self.normalizer_dict = normalizer_dict |
|
|
self.pos_normalizer = torch.tensor(self.normalizer_dict.pos, dtype=torch.float32) |
|
|
self.device = None |
|
|
|
|
|
def quantize(self, pos, h): |
|
|
|
|
|
h = F.one_hot(torch.argmax(h, dim=-1), num_classes=h.shape[-1]) |
|
|
return pos, h |
|
|
|
|
|
def on_train_batch_start( |
|
|
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
super().on_train_batch_start(trainer, pl_module, batch, batch_idx) |
|
|
if self.device is None: |
|
|
self.device = batch.protein_pos.device |
|
|
self.pos_normalizer = self.pos_normalizer.to(self.device) |
|
|
batch.protein_pos = batch.protein_pos / self.pos_normalizer |
|
|
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_validation_batch_start( |
|
|
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
super().on_validation_batch_start(trainer, pl_module, batch, batch_idx) |
|
|
if self.device is None: |
|
|
self.device = batch.protein_pos.device |
|
|
self.pos_normalizer = self.pos_normalizer.to(self.device) |
|
|
batch.protein_pos = batch.protein_pos / self.pos_normalizer |
|
|
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer |
|
|
|
|
|
def on_test_batch_start( |
|
|
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
super().on_test_batch_start(trainer, pl_module, batch, batch_idx) |
|
|
if self.device is None: |
|
|
self.device = batch.protein_pos.device |
|
|
self.pos_normalizer = self.pos_normalizer.to(self.device) |
|
|
batch.protein_pos = batch.protein_pos / self.pos_normalizer |
|
|
batch.ligand_pos = batch.ligand_pos / self.pos_normalizer |
|
|
|
|
|
|
|
|
class RecoverCallback(Callback): |
|
|
def __init__(self, latest_ckpt, recover_trigger_loss=1e3, resume=False) -> None: |
|
|
super().__init__() |
|
|
self.latest_ckpt = latest_ckpt |
|
|
self.recover_trigger_loss = recover_trigger_loss |
|
|
self.resume = resume |
|
|
|
|
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: |
|
|
super().setup(trainer, pl_module, stage) |
|
|
if os.path.exists(self.latest_ckpt) and self.resume: |
|
|
print(f"recover from checkpoint: {self.latest_ckpt}") |
|
|
checkpoint = torch.load(self.latest_ckpt) |
|
|
pl_module.load_state_dict(checkpoint["state_dict"]) |
|
|
|
|
|
elif not os.path.exists(self.latest_ckpt) and self.resume: |
|
|
print( |
|
|
f"checkpoint {self.latest_ckpt} not found, training from scratch" |
|
|
) |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, |
|
|
trainer: Trainer, |
|
|
pl_module: LightningModule, |
|
|
outputs: STEP_OUTPUT, |
|
|
batch: Any, |
|
|
batch_idx: int, |
|
|
) -> None: |
|
|
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) |
|
|
if "loss" not in outputs: |
|
|
return None |
|
|
|
|
|
if outputs["loss"] > self.recover_trigger_loss: |
|
|
logging.warning( |
|
|
f"loss too large: {outputs}\n recovering from checkpoint: {self.latest_ckpt}" |
|
|
) |
|
|
if os.path.exists(self.latest_ckpt): |
|
|
checkpoint = torch.load(self.latest_ckpt) |
|
|
pl_module.load_state_dict(checkpoint["state_dict"]) |
|
|
else: |
|
|
for layer in pl_module.children(): |
|
|
if hasattr(layer, "reset_parameters"): |
|
|
layer.reset_parameters() |
|
|
logging.warning( |
|
|
f"checkpoint {self.latest_ckpt} not found, training from scratch" |
|
|
) |
|
|
|
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
class EMACallback(pl.Callback): |
|
|
"""Implements EMA (exponential moving average) to any kind of model. |
|
|
EMA weights will be used during validation and stored separately from original model weights. |
|
|
How to use EMA: |
|
|
- Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See |
|
|
https://github.com/rwightman/pytorch-image-models/issues/102 |
|
|
- Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See |
|
|
discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and |
|
|
https://github.com/rwightman/pytorch-image-models/issues/224 |
|
|
- For object detection, SWA usually works better. See https://github.com/timgaripov/swa/issues/16 |
|
|
Implementation detail: |
|
|
- See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914 |
|
|
- When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory. |
|
|
This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited |
|
|
resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve |
|
|
performance. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
decay: float = 0.9999, |
|
|
ema_device: Optional[Union[torch.device, str]] = None, |
|
|
pin_memory=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.decay = decay |
|
|
self.ema_device: str = ( |
|
|
f"{ema_device}" if ema_device else None |
|
|
) |
|
|
self.ema_pin_memory = ( |
|
|
pin_memory if torch.cuda.is_available() else False |
|
|
) |
|
|
self.ema_state_dict: Dict[str, torch.Tensor] = {} |
|
|
self.original_state_dict = {} |
|
|
self._ema_state_dict_ready = False |
|
|
|
|
|
@staticmethod |
|
|
def get_state_dict(pl_module: pl.LightningModule): |
|
|
"""Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out. |
|
|
For example, in pl_module has metrics, you don't want to return their parameters. |
|
|
code: |
|
|
# Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached |
|
|
# like losses, metrics, etc. |
|
|
patterns_to_ignore = ("metrics1", "metrics2") |
|
|
return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items())) |
|
|
""" |
|
|
return pl_module.state_dict() |
|
|
|
|
|
@overrides |
|
|
def on_train_start( |
|
|
self, trainer: "pl.Trainer", pl_module: pl.LightningModule |
|
|
) -> None: |
|
|
|
|
|
if not self._ema_state_dict_ready and pl_module.global_rank == 0: |
|
|
self.ema_state_dict = deepcopy(self.get_state_dict(pl_module)) |
|
|
if self.ema_device: |
|
|
self.ema_state_dict = { |
|
|
k: tensor.to(device=self.ema_device) |
|
|
for k, tensor in self.ema_state_dict.items() |
|
|
} |
|
|
|
|
|
if self.ema_device == "cpu" and self.ema_pin_memory: |
|
|
self.ema_state_dict = { |
|
|
k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items() |
|
|
} |
|
|
|
|
|
self._ema_state_dict_ready = True |
|
|
|
|
|
@rank_zero_only |
|
|
def on_train_batch_end( |
|
|
self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs |
|
|
) -> None: |
|
|
|
|
|
with torch.no_grad(): |
|
|
for key, value in self.get_state_dict(pl_module).items(): |
|
|
ema_value = self.ema_state_dict[key] |
|
|
ema_value.copy_( |
|
|
self.decay * ema_value + (1.0 - self.decay) * value, |
|
|
non_blocking=True, |
|
|
) |
|
|
|
|
|
@overrides |
|
|
def on_validation_start( |
|
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule |
|
|
) -> None: |
|
|
if not self._ema_state_dict_ready: |
|
|
return |
|
|
|
|
|
self.original_state_dict = deepcopy(self.get_state_dict(pl_module)) |
|
|
|
|
|
trainer.strategy.broadcast(self.ema_state_dict, 0) |
|
|
|
|
|
assert self.ema_state_dict.keys() == self.original_state_dict.keys(), ( |
|
|
f"There are some keys missing in the ema static dictionary broadcasted. " |
|
|
f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}" |
|
|
) |
|
|
pl_module.load_state_dict(self.ema_state_dict, strict=False) |
|
|
|
|
|
if pl_module.global_rank > 0: |
|
|
|
|
|
self.ema_state_dict = {} |
|
|
|
|
|
@overrides |
|
|
def on_validation_end( |
|
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" |
|
|
) -> None: |
|
|
if not self._ema_state_dict_ready: |
|
|
return |
|
|
|
|
|
|
|
|
pl_module.load_state_dict(self.original_state_dict, strict=False) |
|
|
|
|
|
@overrides |
|
|
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
|
self.on_validation_start(trainer, pl_module) |
|
|
|
|
|
@overrides |
|
|
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: |
|
|
self.on_validation_end(trainer, pl_module) |
|
|
|
|
|
@overrides |
|
|
def on_save_checkpoint( |
|
|
self, |
|
|
trainer: "pl.Trainer", |
|
|
pl_module: "pl.LightningModule", |
|
|
checkpoint: Dict[str, Any], |
|
|
) -> None: |
|
|
checkpoint["ema_state_dict"] = self.ema_state_dict |
|
|
checkpoint["_ema_state_dict_ready"] = self._ema_state_dict_ready |
|
|
|
|
|
|
|
|
@overrides |
|
|
def on_load_checkpoint( |
|
|
self, |
|
|
trainer: "pl.Trainer", |
|
|
pl_module: "pl.LightningModule", |
|
|
checkpoint: Dict[str, Any], |
|
|
) -> None: |
|
|
if checkpoint is None: |
|
|
self._ema_state_dict_ready = False |
|
|
else: |
|
|
self._ema_state_dict_ready = checkpoint["_ema_state_dict_ready"] |
|
|
self.ema_state_dict = checkpoint["ema_state_dict"] |