File size: 2,382 Bytes
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------

# Convert a Pytorch model to a Hugging Face model

import torch.nn as nn

from huggingface_hub import PyTorchModelHubMixin

from revq.models.backbone.diffusion import Encoder, Decoder
from revq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn
from revq.losses.aeloss_disc import AELossWithDisc
from revq.models.vqgan import VQModel

class VQModelHF(nn.Module, PyTorchModelHubMixin):
    def __init__(self, 
            encoder: dict = {},
            decoder: dict = {},
            loss: dict = {},
            quantize: dict = {},
            quantize_type: str = "optvq",
            ckpt_path: str = None,
            ignore_keys=[],
            image_key="image",
            colorize_nlabels=None,
            monitor=None,
            use_connector: bool = True,
        ):
        super(VQModelHF, self).__init__()
        encoder = Encoder(**encoder)
        decoder = Decoder(**decoder)
        quantizer = self.setup_quantizer(quantize, quantize_type)
        loss = AELossWithDisc(**loss)

        self.model = VQModel(
            encoder=encoder,
            decoder=decoder,
            loss=loss,
            quantize=quantizer,
            ckpt_path=ckpt_path,
            ignore_keys=ignore_keys,
            image_key=image_key,
            colorize_nlabels=colorize_nlabels,
            monitor=monitor,
            use_connector=use_connector,
        )
    
    def setup_quantizer(self, quantizer_config, quantize_type):
        if quantize_type == "optvq":
            quantizer = VectorQuantizerSinkhorn(**quantizer_config)
        elif quantize_type == "basevq":
            quantizer = VectorQuantizer(**quantizer_config)
        else:
            raise ValueError(f"Unknown quantizer type: {quantize_type}")
        return quantizer
    
    def encode(self, x):
        return self.model.encode(x)

    def decode(self, x):
        return self.model.decode(x)

    def forward(self, x):
        quant, *_ = self.encode(x)
        rec = self.decode(quant)
        return quant, rec