=
commited on
Commit
·
d5e2b5a
1
Parent(s):
c2fb0d7
update swinface source code
Browse files- source/run_demo.sh +0 -0
- source/swin_b_test_lfw.py +5 -5
source/run_demo.sh
CHANGED
|
File without changes
|
source/swin_b_test_lfw.py
CHANGED
|
@@ -172,10 +172,10 @@ def evaluate_test_vs_template(model, test_loader, template_loader, device):
|
|
| 172 |
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
-
dataset_root = '
|
| 176 |
-
test_pth = "
|
| 177 |
-
test_pth_infected = "
|
| 178 |
-
test_pth_flipped = "
|
| 179 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 180 |
|
| 181 |
transform = transforms.Compose([
|
|
@@ -196,7 +196,7 @@ if __name__ == "__main__":
|
|
| 196 |
|
| 197 |
print(model) # 输出模型结构
|
| 198 |
|
| 199 |
-
model.load_state_dict(torch.load(
|
| 200 |
model.to(device)
|
| 201 |
|
| 202 |
evaluate_test_vs_template(model, test_loader, template_loader, device)
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
+
dataset_root = '../datasets/LFWPairs/lfw-py/lfw_test_template_50_cropped'
|
| 176 |
+
test_pth = "../checkpoints/swin_face.pth"
|
| 177 |
+
test_pth_infected = "../parametersProcess/swin_face/swin_evilfiles_16.pth"
|
| 178 |
+
test_pth_flipped = "../parametersProcess/swin_face/swin_flip_16.pth"
|
| 179 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 180 |
|
| 181 |
transform = transforms.Compose([
|
|
|
|
| 196 |
|
| 197 |
print(model) # 输出模型结构
|
| 198 |
|
| 199 |
+
model.load_state_dict(torch.load(test_pth, map_location=device))
|
| 200 |
model.to(device)
|
| 201 |
|
| 202 |
evaluate_test_vs_template(model, test_loader, template_loader, device)
|