AndyRaoTHU commited on
Commit
b5f2ae4
·
1 Parent(s): e7f69aa
Files changed (1) hide show
  1. revq/models/revq.py +4 -4
revq/models/revq.py CHANGED
@@ -36,10 +36,10 @@ class Viewer:
36
 
37
  class ReVQ(PyTorchModelHubMixin, nn.Module):
38
  @classmethod
39
- def _from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
40
- print(f"Loading ReVQ model from {pretrained_model_name_or_path}...")
41
- config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="512T_NC=16384.yaml")
42
- ckpt_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="ckpt.pth")
43
 
44
  full_cfg = OmegaConf.load(config_path)
45
  model_cfg = full_cfg.get("model", {})
 
36
 
37
  class ReVQ(PyTorchModelHubMixin, nn.Module):
38
  @classmethod
39
+ def _from_pretrained(cls, model_id: str, **kwargs):
40
+ print(f"Loading ReVQ model from {model_id}...")
41
+ config_path = hf_hub_download(repo_id=model_id, filename="512T_NC=16384.yaml")
42
+ ckpt_path = hf_hub_download(repo_id=model_id, filename="ckpt.pth")
43
 
44
  full_cfg = OmegaConf.load(config_path)
45
  model_cfg = full_cfg.get("model", {})