= commited on
Commit
d5e2b5a
·
1 Parent(s): c2fb0d7

update swinface source code

Browse files
Files changed (2) hide show
  1. source/run_demo.sh +0 -0
  2. source/swin_b_test_lfw.py +5 -5
source/run_demo.sh CHANGED
File without changes
source/swin_b_test_lfw.py CHANGED
@@ -172,10 +172,10 @@ def evaluate_test_vs_template(model, test_loader, template_loader, device):
172
 
173
 
174
  if __name__ == "__main__":
175
- dataset_root = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped'
176
- test_pth = "../../../parameters/classification/swin_face/swin_face.pth"
177
- test_pth_infected = "../../../parametersProcess/swin_face/swin_evilfiles_16.pth"
178
- test_pth_flipped = "../../../parametersProcess/swin_face/swin_flip_16.pth"
179
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180
 
181
  transform = transforms.Compose([
@@ -196,7 +196,7 @@ if __name__ == "__main__":
196
 
197
  print(model) # 输出模型结构
198
 
199
- model.load_state_dict(torch.load(test_pth_flipped, map_location=device))
200
  model.to(device)
201
 
202
  evaluate_test_vs_template(model, test_loader, template_loader, device)
 
172
 
173
 
174
  if __name__ == "__main__":
175
+ dataset_root = '../datasets/LFWPairs/lfw-py/lfw_test_template_50_cropped'
176
+ test_pth = "../checkpoints/swin_face.pth"
177
+ test_pth_infected = "../parametersProcess/swin_face/swin_evilfiles_16.pth"
178
+ test_pth_flipped = "../parametersProcess/swin_face/swin_flip_16.pth"
179
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180
 
181
  transform = transforms.Compose([
 
196
 
197
  print(model) # 输出模型结构
198
 
199
+ model.load_state_dict(torch.load(test_pth, map_location=device))
200
  model.to(device)
201
 
202
  evaluate_test_vs_template(model, test_loader, template_loader, device)