Spaces:
Sleeping
Sleeping
File size: 2,382 Bytes
199c8cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------
# Convert a Pytorch model to a Hugging Face model
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from revq.models.backbone.diffusion import Encoder, Decoder
from revq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn
from revq.losses.aeloss_disc import AELossWithDisc
from revq.models.vqgan import VQModel
class VQModelHF(nn.Module, PyTorchModelHubMixin):
def __init__(self,
encoder: dict = {},
decoder: dict = {},
loss: dict = {},
quantize: dict = {},
quantize_type: str = "optvq",
ckpt_path: str = None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
use_connector: bool = True,
):
super(VQModelHF, self).__init__()
encoder = Encoder(**encoder)
decoder = Decoder(**decoder)
quantizer = self.setup_quantizer(quantize, quantize_type)
loss = AELossWithDisc(**loss)
self.model = VQModel(
encoder=encoder,
decoder=decoder,
loss=loss,
quantize=quantizer,
ckpt_path=ckpt_path,
ignore_keys=ignore_keys,
image_key=image_key,
colorize_nlabels=colorize_nlabels,
monitor=monitor,
use_connector=use_connector,
)
def setup_quantizer(self, quantizer_config, quantize_type):
if quantize_type == "optvq":
quantizer = VectorQuantizerSinkhorn(**quantizer_config)
elif quantize_type == "basevq":
quantizer = VectorQuantizer(**quantizer_config)
else:
raise ValueError(f"Unknown quantizer type: {quantize_type}")
return quantizer
def encode(self, x):
return self.model.encode(x)
def decode(self, x):
return self.model.decode(x)
def forward(self, x):
quant, *_ = self.encode(x)
rec = self.decode(quant)
return quant, rec |