|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
from lib.core.config import cfg, update_config |
|
|
from lib.models.model import HACO |
|
|
from lib.utils.contact_utils import get_contact_thres |
|
|
from lib.utils.train_utils import worker_init_fn, set_seed |
|
|
from lib.utils.eval_utils import evaluation |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Test HACO') |
|
|
parser.add_argument('--backbone', type=str, default='hamer', choices=['hamer', 'vit-l-16', 'vit-b-16', 'vit-s-16', 'handoccnet', 'hrnet-w48', 'hrnet-w32', 'resnet-152', 'resnet-101', 'resnet-50', 'resnet-34', 'resnet-18'], help='backbone model') |
|
|
parser.add_argument('--checkpoint', type=str, default='', help='model path for evaluation') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
exec(f'from data.{cfg.DATASET.test_name}.dataset import {cfg.DATASET.test_name}') |
|
|
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
torch.set_num_threads(cfg.DATASET.workers) |
|
|
os.environ["OMP_NUM_THREADS"] = "4" |
|
|
os.environ["MKL_NUM_THREADS"] = "4" |
|
|
|
|
|
|
|
|
|
|
|
experiment_dir = f'experiments_test_{cfg.DATASET.test_name.lower()}' |
|
|
checkpoint_dir = os.path.join(experiment_dir, 'full', 'checkpoints') |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
update_config(backbone_type=args.backbone, exp_dir=experiment_dir) |
|
|
|
|
|
|
|
|
|
|
|
from lib.core.config import logger |
|
|
set_seed(cfg.MODEL.seed) |
|
|
logger.info(f"Using random seed: {cfg.MODEL.seed}") |
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.ToTensor() |
|
|
test_dataset = eval(f'{cfg.DATASET.test_name}')(transform, 'test') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST.batch, shuffle=False, num_workers=cfg.DATASET.workers, pin_memory=True, drop_last=False, worker_init_fn=worker_init_fn) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"# of test samples: {len(test_dataset)}") |
|
|
|
|
|
|
|
|
|
|
|
model = HACO().to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.checkpoint: |
|
|
checkpoint = torch.load(args.checkpoint, map_location=device) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
|
|
|
|
|
|
|
|
|
eval_result = { |
|
|
'cont_pre': [None for _ in range(len(test_dataset))], |
|
|
'cont_rec': [None for _ in range(len(test_dataset))], |
|
|
'cont_f1': [None for _ in range(len(test_dataset))], |
|
|
} |
|
|
|
|
|
test_iterator = tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
for idx, data in test_iterator: |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model({'input': data['input_data'], 'target': data['targets_data'], 'meta_info': data['meta_info']}, mode="test") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_thres = get_contact_thres(args.backbone) |
|
|
eval_out = evaluation(outputs, data['targets_data'], data['meta_info'], mode='test', thres=eval_thres) |
|
|
for key in [*eval_out]: |
|
|
eval_result[key][idx] = eval_out[key] |
|
|
|
|
|
|
|
|
total_cont_pre = np.mean([x if x is not None else 0.0 for x in eval_result['cont_pre'][:idx+1]]) |
|
|
total_cont_rec = np.mean([x if x is not None else 0.0 for x in eval_result['cont_rec'][:idx+1]]) |
|
|
total_cont_f1 = np.mean([x if x is not None else 0.0 for x in eval_result['cont_f1'][:idx+1]]) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"C-Pre: {total_cont_pre:.3f} | C-Rec: {total_cont_rec:.3f} | C-F1: {total_cont_f1:.3f}") |
|
|
|
|
|
|
|
|
|
|
|
logger.info('Test finished!!!!') |
|
|
logger.info(f"Final Results --- C-Pre: {total_cont_pre:.3f} | C-Rec: {total_cont_rec:.3f} | C-F1: {total_cont_f1:.3f}") |