File size: 3,872 Bytes
1f0c7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import copy
import warnings

import numpy as np
import torch
from torch_geometric.data import Data, Batch

from core.utils.warmup import GradualWarmupScheduler


# customize exp lr scheduler with min lr
class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
    def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
        self.gamma = gamma
        self.min_lr = min_lr
        super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return self.base_lrs
        return [max(group['lr'] * self.gamma, self.min_lr)
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
                for base_lr in self.base_lrs]


def repeat_data(data: Data, num_repeat) -> Batch:
    datas = [copy.deepcopy(data) for i in range(num_repeat)]
    return Batch.from_data_list(datas)


def repeat_batch(batch: Batch, num_repeat) -> Batch:
    datas = batch.to_data_list()
    new_data = []
    for i in range(num_repeat):
        new_data += copy.deepcopy(datas)
    return Batch.from_data_list(new_data)


def inf_iterator(iterable):
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def get_optimizer(cfg, model):
    if cfg.type == 'adam':
        return torch.optim.Adam(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay,
            betas=(cfg.beta1, cfg.beta2,)
        )
    else:
        raise NotImplementedError('Optimizer not supported: %s' % cfg.type)


def get_scheduler(train_cfg, optimizer):
    cfg = train_cfg.scheduler
    if cfg.type == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=cfg.max_iters,
            eta_min=cfg.min_lr,
        )
        def get_last_lr():
            return scheduler.get_last_lr()[0]
        return {
            'scheduler': scheduler,
            'interval': 'step',
        }, get_last_lr
    elif cfg.type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=cfg.factor,
            patience=cfg.patience,
            min_lr=cfg.min_lr
        )

        def get_last_lr():
            return optimizer.param_groups[0]['lr']
        return {
            'scheduler': scheduler,
            'monitor': 'val/recon_loss',
            'interval': 'step',
            'frequency': train_cfg.val_freq,
        }, get_last_lr
    # elif cfg.type == 'warmup_plateau':
    #     return GradualWarmupScheduler(
    #         optimizer,
    #         multiplier=cfg.multiplier,
    #         total_epoch=cfg.total_epoch,
    #         after_scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(
    #             optimizer,
    #             factor=cfg.factor,
    #             patience=cfg.patience,
    #             min_lr=cfg.min_lr
    #         )
    #     )
    # elif cfg.type == 'expmin':
    #     return ExponentialLR_with_minLr(
    #         optimizer,
    #         gamma=cfg.factor,
    #         min_lr=cfg.min_lr,
    #     )
    # elif cfg.type == 'expmin_milestone':
    #     gamma = np.exp(np.log(cfg.factor) / cfg.milestone)
    #     return ExponentialLR_with_minLr(
    #         optimizer,
    #         gamma=gamma,
    #         min_lr=cfg.min_lr,
    #     )
    else:
        raise NotImplementedError('Scheduler not supported: %s' % cfg.type)