diff --git a/rfdetr/__init__.py b/rfdetr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef87e732b34cf26576fbab0b98996b7001df0555 --- /dev/null +++ b/rfdetr/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + + +import os +if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None: + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRNano, RFDETRSmall, RFDETRMedium diff --git a/rfdetr/__pycache__/__init__.cpython-313.pyc b/rfdetr/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eddea0c4390ecf3f2bf33e3c130a3feca95ba525 Binary files /dev/null and b/rfdetr/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/__pycache__/config.cpython-313.pyc b/rfdetr/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20e4b4cdaccfb98d36220e09e338663d26e0433f Binary files /dev/null and b/rfdetr/__pycache__/config.cpython-313.pyc differ diff --git a/rfdetr/__pycache__/detr.cpython-313.pyc b/rfdetr/__pycache__/detr.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb8c5e6bc868bde684938d84dd7ba98e5c80a6d Binary files /dev/null and b/rfdetr/__pycache__/detr.cpython-313.pyc differ diff --git a/rfdetr/__pycache__/engine.cpython-313.pyc b/rfdetr/__pycache__/engine.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acf2995e1ea3d0377b714fc0554afea6d9c6f2e3 Binary files /dev/null and b/rfdetr/__pycache__/engine.cpython-313.pyc differ diff --git a/rfdetr/__pycache__/main.cpython-313.pyc b/rfdetr/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ba9fbd3f4bc97d8dd4103287e05413e4d595316 Binary files /dev/null and b/rfdetr/__pycache__/main.cpython-313.pyc differ diff --git a/rfdetr/cli/__pycache__/main.cpython-313.pyc b/rfdetr/cli/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f35654445da27d471071a46589d5023644ee99 Binary files /dev/null and b/rfdetr/cli/__pycache__/main.cpython-313.pyc differ diff --git a/rfdetr/cli/main.py b/rfdetr/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8527ba322fa7397c9001d83ed0c9917b4701fc5f --- /dev/null +++ b/rfdetr/cli/main.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +import argparse +from rf100vl import get_rf100vl_projects +import roboflow +from rfdetr import RFDETRBase +import torch +import os + +def download_dataset(rf_project: roboflow.Project, dataset_version: int): + versions = rf_project.versions() + if dataset_version is not None: + versions = [v for v in versions if v.version == str(dataset_version)] + if len(versions) == 0: + raise ValueError(f"Dataset version {dataset_version} not found") + version = versions[0] + else: + version = max(versions, key=lambda v: v.id) + location = os.path.join("datasets/", rf_project.name + "_v" + version.version) + if not os.path.exists(location): + location = version.download( + model_format="coco", location=location, overwrite=False + ).location + + return location + + +def train_from_rf_project(rf_project: roboflow.Project, dataset_version: int): + location = download_dataset(rf_project, dataset_version) + print(location) + rf_detr = RFDETRBase() + device_supports_cuda = torch.cuda.is_available() + rf_detr.train( + dataset_dir=location, + epochs=1, + device="cuda" if device_supports_cuda else "cpu", + ) + + +def train_from_coco_dir(coco_dir: str): + rf_detr = RFDETRBase() + rf_detr.train( + dataset_dir=coco_dir, + epochs=1, + device="cuda" if device_supports_cuda else "cpu", + ) + + +def trainer(): + parser = argparse.ArgumentParser() + parser.add_argument("--coco_dir", type=str, required=False) + parser.add_argument("--api_key", type=str, required=False) + parser.add_argument("--workspace", type=str, required=False, default=None) + parser.add_argument("--project_name", type=str, required=False, default=None) + parser.add_argument("--dataset_version", type=int, required=False, default=None) + args = parser.parse_args() + + if args.coco_dir is not None: + train_from_coco_dir(args.coco_dir) + return + + if (args.workspace is None and args.project_name is not None) or ( + args.workspace is not None and args.project_name is None + ): + raise ValueError( + "Either both workspace and project_name must be provided or none of them" + ) + + if args.workspace is not None: + rf = roboflow.Roboflow(api_key=args.api_key) + project = rf.workspace(args.workspace).project(args.project_name) + else: + projects = get_rf100vl_projects(api_key=args.api_key) + project = projects[0].rf_project + + train_from_rf_project(project, args.dataset_version) + + +if __name__ == "__main__": + trainer() diff --git a/rfdetr/config.py b/rfdetr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..12f7fd7b5bc16e5731f819b04d98f5c3dfa33288 --- /dev/null +++ b/rfdetr/config.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + + +from pydantic import BaseModel +from typing import List, Optional, Literal, Type +import torch +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + +class ModelConfig(BaseModel): + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] + out_feature_indexes: List[int] + dec_layers: int + two_stage: bool = True + projector_scale: List[Literal["P3", "P4", "P5"]] + hidden_dim: int + patch_size: int + num_windows: int + sa_nheads: int + ca_nheads: int + dec_n_points: int + bbox_reparam: bool = True + lite_refpoint_refine: bool = True + layer_norm: bool = True + amp: bool = True + num_classes: int = 90 + pretrain_weights: Optional[str] = None + device: Literal["cpu", "cuda", "mps"] = DEVICE + resolution: int + group_detr: int = 13 + gradient_checkpointing: bool = False + positional_encoding_size: int + +class RFDETRBaseConfig(ModelConfig): + """ + The configuration for an RF-DETR Base model. + """ + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small" + hidden_dim: int = 256 + patch_size: int = 14 + num_windows: int = 4 + dec_layers: int = 3 + sa_nheads: int = 8 + ca_nheads: int = 16 + dec_n_points: int = 2 + num_queries: int = 300 + num_select: int = 300 + projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"] + out_feature_indexes: List[int] = [2, 5, 8, 11] + pretrain_weights: Optional[str] = "rf-detr-base.pth" + resolution: int = 560 + positional_encoding_size: int = 37 + +class RFDETRLargeConfig(RFDETRBaseConfig): + """ + The configuration for an RF-DETR Large model. + """ + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base" + hidden_dim: int = 384 + sa_nheads: int = 12 + ca_nheads: int = 24 + dec_n_points: int = 4 + projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"] + pretrain_weights: Optional[str] = "rf-detr-large.pth" + +class RFDETRNanoConfig(RFDETRBaseConfig): + """ + The configuration for an RF-DETR Nano model. + """ + out_feature_indexes: List[int] = [3, 6, 9, 12] + num_windows: int = 2 + dec_layers: int = 2 + patch_size: int = 16 + resolution: int = 384 + positional_encoding_size: int = 24 + pretrain_weights: Optional[str] = "rf-detr-nano.pth" + +class RFDETRSmallConfig(RFDETRBaseConfig): + """ + The configuration for an RF-DETR Small model. + """ + out_feature_indexes: List[int] = [3, 6, 9, 12] + num_windows: int = 2 + dec_layers: int = 3 + patch_size: int = 16 + resolution: int = 512 + positional_encoding_size: int = 32 + pretrain_weights: Optional[str] = "rf-detr-small.pth" + +class RFDETRMediumConfig(RFDETRBaseConfig): + """ + The configuration for an RF-DETR Medium model. + """ + out_feature_indexes: List[int] = [3, 6, 9, 12] + num_windows: int = 2 + dec_layers: int = 4 + patch_size: int = 16 + resolution: int = 576 + positional_encoding_size: int = 36 + pretrain_weights: Optional[str] = "rf-detr-medium.pth" + +class TrainConfig(BaseModel): + lr: float = 1e-4 + lr_encoder: float = 1.5e-4 + batch_size: int = 4 + grad_accum_steps: int = 4 + epochs: int = 100 + ema_decay: float = 0.993 + ema_tau: int = 100 + lr_drop: int = 100 + checkpoint_interval: int = 10 + warmup_epochs: int = 0 + lr_vit_layer_decay: float = 0.8 + lr_component_decay: float = 0.7 + drop_path: float = 0.0 + group_detr: int = 13 + ia_bce_loss: bool = True + cls_loss_coef: float = 1.0 + num_select: int = 300 + dataset_file: Literal["coco", "o365", "roboflow"] = "roboflow" + square_resize_div_64: bool = True + dataset_dir: str + output_dir: str = "output" + multi_scale: bool = True + expanded_scales: bool = True + do_random_resize_via_padding: bool = False + use_ema: bool = True + num_workers: int = 2 + weight_decay: float = 1e-4 + early_stopping: bool = False + early_stopping_patience: int = 10 + early_stopping_min_delta: float = 0.001 + early_stopping_use_ema: bool = False + tensorboard: bool = True + wandb: bool = False + project: Optional[str] = None + run: Optional[str] = None + class_names: List[str] = None + run_test: bool = True diff --git a/rfdetr/datasets/__init__.py b/rfdetr/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d6443839a5837e4e699afd1e6a9f9a8bc68fbd --- /dev/null +++ b/rfdetr/datasets/__init__.py @@ -0,0 +1,36 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +import torch.utils.data +import torchvision + +from .coco import build as build_coco +from .o365 import build_o365 +from .coco import build_roboflow + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + + +def build_dataset(image_set, args, resolution): + if args.dataset_file == 'coco': + return build_coco(image_set, args, resolution) + if args.dataset_file == 'o365': + return build_o365(image_set, args, resolution) + if args.dataset_file == 'roboflow': + return build_roboflow(image_set, args, resolution) + raise ValueError(f'dataset {args.dataset_file} not supported') diff --git a/rfdetr/datasets/__pycache__/__init__.cpython-313.pyc b/rfdetr/datasets/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82f2584b84f71b7e3da01a544853a98bec497fad Binary files /dev/null and b/rfdetr/datasets/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/datasets/__pycache__/coco.cpython-313.pyc b/rfdetr/datasets/__pycache__/coco.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d007c1cbbe8bb72ffcea0a3329397e51cfb819c1 Binary files /dev/null and b/rfdetr/datasets/__pycache__/coco.cpython-313.pyc differ diff --git a/rfdetr/datasets/__pycache__/coco_eval.cpython-313.pyc b/rfdetr/datasets/__pycache__/coco_eval.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb5f79274937b31d78a166c08e557d48ea3b5b72 Binary files /dev/null and b/rfdetr/datasets/__pycache__/coco_eval.cpython-313.pyc differ diff --git a/rfdetr/datasets/__pycache__/o365.cpython-313.pyc b/rfdetr/datasets/__pycache__/o365.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9742348d692eca78829ab598b309bbfc76022810 Binary files /dev/null and b/rfdetr/datasets/__pycache__/o365.cpython-313.pyc differ diff --git a/rfdetr/datasets/__pycache__/transforms.cpython-313.pyc b/rfdetr/datasets/__pycache__/transforms.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed7cb1951065ecca5b5a6c8b5ef62d776835084f Binary files /dev/null and b/rfdetr/datasets/__pycache__/transforms.cpython-313.pyc differ diff --git a/rfdetr/datasets/coco.py b/rfdetr/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ef47a4bd96e8dd8eb8613a5f24a0cd3e290bc650 --- /dev/null +++ b/rfdetr/datasets/coco.py @@ -0,0 +1,280 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +COCO dataset which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" +from pathlib import Path + +import torch +import torch.utils.data +import torchvision + +import rfdetr.datasets.transforms as T + + +def compute_multi_scale_scales(resolution, expanded_scales=False, patch_size=16, num_windows=4): + # round to the nearest multiple of 4*patch_size to enable both patching and windowing + base_num_patches_per_window = resolution // (patch_size * num_windows) + offsets = [-3, -2, -1, 0, 1, 2, 3, 4] if not expanded_scales else [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] + scales = [base_num_patches_per_window + offset for offset in offsets] + proposed_scales = [scale * patch_size * num_windows for scale in scales] + proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * num_windows * 2] # ensure minimum image size + return proposed_scales + + +class CocoDetection(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCoco() + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = {'image_id': image_id, 'annotations': target} + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +class ConvertCoco(object): + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["image_id"] = image_id + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + + +def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4): + + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + scales = [resolution] + if multi_scale: + # scales = [448, 512, 576, 640, 704, 768, 832, 896] + scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows) + if skip_random_resize: + scales = [scales[-1]] + print(scales) + + if image_set == 'train': + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=1333), + T.Compose([ + T.RandomResize([400, 500, 600]), + T.RandomSizeCrop(384, 600), + T.RandomResize(scales, max_size=1333), + ]) + ), + normalize, + ]) + + if image_set == 'val': + return T.Compose([ + T.RandomResize([resolution], max_size=1333), + normalize, + ]) + if image_set == 'val_speed': + return T.Compose([ + T.SquareResize([resolution]), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') + + +def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4): + """ + """ + + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + + scales = [resolution] + if multi_scale: + # scales = [448, 512, 576, 640, 704, 768, 832, 896] + scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows) + if skip_random_resize: + scales = [scales[-1]] + print(scales) + + if image_set == 'train': + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.SquareResize(scales), + T.Compose([ + T.RandomResize([400, 500, 600]), + T.RandomSizeCrop(384, 600), + T.SquareResize(scales), + ]), + ), + normalize, + ]) + + if image_set == 'val': + return T.Compose([ + T.SquareResize([resolution]), + normalize, + ]) + if image_set == 'test': + return T.Compose([ + T.SquareResize([resolution]), + normalize, + ]) + if image_set == 'val_speed': + return T.Compose([ + T.SquareResize([resolution]), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') + +def build(image_set, args, resolution): + root = Path(args.coco_path) + assert root.exists(), f'provided COCO path {root} does not exist' + mode = 'instances' + PATHS = { + "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), + "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), + "test": (root / "test2017", root / "annotations" / f'image_info_test-dev2017.json'), + } + + img_folder, ann_file = PATHS[image_set.split("_")[0]] + + try: + square_resize = args.square_resize + except: + square_resize = False + + try: + square_resize_div_64 = args.square_resize_div_64 + except: + square_resize_div_64 = False + + + if square_resize_div_64: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64( + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + skip_random_resize=not args.do_random_resize_via_padding, + patch_size=args.patch_size, + num_windows=args.num_windows + )) + else: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms( + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + skip_random_resize=not args.do_random_resize_via_padding, + patch_size=args.patch_size, + num_windows=args.num_windows + )) + return dataset + +def build_roboflow(image_set, args, resolution): + root = Path(args.dataset_dir) + assert root.exists(), f'provided Roboflow path {root} does not exist' + mode = 'instances' + PATHS = { + "train": (root / "train", root / "train" / "_annotations.coco.json"), + "val": (root / "valid", root / "valid" / "_annotations.coco.json"), + "test": (root / "test", root / "test" / "_annotations.coco.json"), + } + + img_folder, ann_file = PATHS[image_set.split("_")[0]] + + try: + square_resize = args.square_resize + except: + square_resize = False + + try: + square_resize_div_64 = args.square_resize_div_64 + except: + square_resize_div_64 = False + + + if square_resize_div_64: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64( + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + skip_random_resize=not args.do_random_resize_via_padding, + patch_size=args.patch_size, + num_windows=args.num_windows + )) + else: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms( + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + skip_random_resize=not args.do_random_resize_via_padding, + patch_size=args.patch_size, + num_windows=args.num_windows + )) + return dataset diff --git a/rfdetr/datasets/coco_eval.py b/rfdetr/datasets/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd00a5af700beb43e72e16024e6361001d29e89 --- /dev/null +++ b/rfdetr/datasets/coco_eval.py @@ -0,0 +1,271 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from rfdetr.util.misc import all_gather + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/rfdetr/datasets/o365.py b/rfdetr/datasets/o365.py new file mode 100644 index 0000000000000000000000000000000000000000..376da739cc5537a8eeea5c964104816f68f6e759 --- /dev/null +++ b/rfdetr/datasets/o365.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +"""Dataset file for Object365.""" +from pathlib import Path + +from .coco import ( + CocoDetection, make_coco_transforms, make_coco_transforms_square_div_64 +) + +from PIL import Image +Image.MAX_IMAGE_PIXELS = None + + +def build_o365_raw(image_set, args, resolution): + root = Path(args.coco_path) + PATHS = { + "train": (root, root / 'zhiyuan_objv2_train_val_wo_5k.json'), + "val": (root, root / 'zhiyuan_objv2_minival5k.json'), + } + img_folder, ann_file = PATHS[image_set] + + try: + square_resize = args.square_resize + except: + square_resize = False + + try: + square_resize_div_64 = args.square_resize_div_64 + except: + square_resize_div_64 = False + + if square_resize_div_64: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales)) + else: + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales)) + return dataset + + +def build_o365(image_set, args, resolution): + if image_set == 'train': + train_ds = build_o365_raw('train', args, resolution=resolution) + return train_ds + if image_set == 'val': + val_ds = build_o365_raw('val', args, resolution=resolution) + return val_ds + raise ValueError('Unknown image_set: {}'.format(image_set)) \ No newline at end of file diff --git a/rfdetr/datasets/transforms.py b/rfdetr/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b5da93c187c57fc6a2c47001ca48f968d9826bf9 --- /dev/null +++ b/rfdetr/datasets/transforms.py @@ -0,0 +1,475 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Transforms and data augmentation for both image + bbox. +""" +import random + +import PIL +import numpy as np +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence +from numbers import Number +import torch +import torchvision.transforms as T +# from detectron2.data import transforms as DT +import torchvision.transforms.functional as F + +from rfdetr.util.box_ops import box_xyxy_to_cxcywh +from rfdetr.util.misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class SquareResize(object): + def __init__(self, sizes): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + rescaled_img=F.resize(img, (size, size)) + w, h = rescaled_img.size + if target is None: + return rescaled_img, None + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + target["size"] = torch.tensor([h, w]) + + return rescaled_img, target + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class PILtoNdArray(object): + + def __call__(self, img, target): + return np.asarray(img), target + + +class NdArraytoPIL(object): + + def __call__(self, img, target): + return F.to_pil_image(img.astype('uint8')), target + + +class Pad(object): + def __init__(self, + size=None, + size_divisor=32, + pad_mode=0, + offsets=None, + fill_value=(127.5, 127.5, 127.5)): + """ + Pad image to a specified size or multiple of size_divisor. + Args: + size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None + size_divisor (int): size divisor, default 32 + pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets + if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top + offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1 + fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5) + """ + + if not isinstance(size, (int, Sequence)): + raise TypeError( + "Type of target_size is invalid when random_size is True. \ + Must be List, now is {}".format(type(size))) + + if isinstance(size, int): + size = [size, size] + + assert pad_mode in [ + -1, 0, 1, 2 + ], 'currently only supports four modes [-1, 0, 1, 2]' + if pad_mode == -1: + assert offsets, 'if pad_mode is -1, offsets should not be None' + + self.size = size + self.size_divisor = size_divisor + self.pad_mode = pad_mode + self.fill_value = fill_value + self.offsets = offsets + + def apply_bbox(self, bbox, offsets): + return bbox + np.array(offsets * 2, dtype=np.float32) + + def apply_image(self, image, offsets, im_size, size): + x, y = offsets + im_h, im_w = im_size + h, w = size + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32) + return canvas + + def __call__(self, im, target): + im_h, im_w = im.shape[:2] + if self.size: + h, w = self.size + assert ( + im_h <= h and im_w <= w + ), '(h, w) of target size should be greater than (im_h, im_w)' + else: + h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor) + w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor) + + if h == im_h and w == im_w: + return im.astype(np.float32), target + + if self.pad_mode == -1: + offset_x, offset_y = self.offsets + elif self.pad_mode == 0: + offset_y, offset_x = 0, 0 + elif self.pad_mode == 1: + offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2 + else: + offset_y, offset_x = h - im_h, w - im_w + + offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w] + + im = self.apply_image(im, offsets, im_size, size) + + if self.pad_mode == 0: + target["size"] = torch.tensor([h, w]) + return im, target + if 'boxes' in target and len(target['boxes']) > 0: + boxes = np.asarray(target["boxes"]) + target["boxes"] = torch.from_numpy(self.apply_bbox(boxes, offsets)) + target["size"] = torch.tensor([h, w]) + + return im, target + + +class RandomExpand(object): + """Random expand the canvas. + Args: + ratio (float): maximum expansion ratio. + prob (float): probability to expand. + fill_value (list): color value used to fill the canvas. in RGB order. + """ + + def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)): + assert ratio > 1.01, "expand ratio must be larger than 1.01" + self.ratio = ratio + self.prob = prob + assert isinstance(fill_value, (Number, Sequence)), \ + "fill value must be either float or sequence" + if isinstance(fill_value, Number): + fill_value = (fill_value, ) * 3 + if not isinstance(fill_value, tuple): + fill_value = tuple(fill_value) + self.fill_value = fill_value + + def __call__(self, img, target): + if np.random.uniform(0., 1.) < self.prob: + return img, target + + height, width = img.shape[:2] + ratio = np.random.uniform(1., self.ratio) + h = int(height * ratio) + w = int(width * ratio) + if not h > height or not w > width: + return img, target + y = np.random.randint(0, h - height) + x = np.random.randint(0, w - width) + offsets, size = [x, y], [h, w] + + pad = Pad(size, + pad_mode=-1, + offsets=offsets, + fill_value=self.fill_value) + + return pad(img, target) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string \ No newline at end of file diff --git a/rfdetr/deploy/__init__.py b/rfdetr/deploy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdetr/deploy/__pycache__/__init__.cpython-313.pyc b/rfdetr/deploy/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d2e30cf1ba7d9129c143b73077c92ee80e18e46 Binary files /dev/null and b/rfdetr/deploy/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/deploy/__pycache__/benchmark.cpython-313.pyc b/rfdetr/deploy/__pycache__/benchmark.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3acec4cd60fe73de3d5c67fb6ff0d4f9d94070e7 Binary files /dev/null and b/rfdetr/deploy/__pycache__/benchmark.cpython-313.pyc differ diff --git a/rfdetr/deploy/__pycache__/export.cpython-313.pyc b/rfdetr/deploy/__pycache__/export.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b7c42e336b3f44c51e2a399caa942ec976de54e Binary files /dev/null and b/rfdetr/deploy/__pycache__/export.cpython-313.pyc differ diff --git a/rfdetr/deploy/_onnx/__init__.py b/rfdetr/deploy/_onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc86979645338a451d1200268629d6e236471221 --- /dev/null +++ b/rfdetr/deploy/_onnx/__init__.py @@ -0,0 +1,13 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +""" +onnx optimizer and symbolic registry +""" +from . import optimizer +from . import symbolic + +from .optimizer import OnnxOptimizer +from .symbolic import CustomOpSymbolicRegistry \ No newline at end of file diff --git a/rfdetr/deploy/_onnx/__pycache__/__init__.cpython-313.pyc b/rfdetr/deploy/_onnx/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1431770651f49a1f7043f47ae682894d87ce73ca Binary files /dev/null and b/rfdetr/deploy/_onnx/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/deploy/_onnx/__pycache__/optimizer.cpython-313.pyc b/rfdetr/deploy/_onnx/__pycache__/optimizer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b6bcdc608106138d29ef4be06aa3918dc187c7 Binary files /dev/null and b/rfdetr/deploy/_onnx/__pycache__/optimizer.cpython-313.pyc differ diff --git a/rfdetr/deploy/_onnx/__pycache__/symbolic.cpython-313.pyc b/rfdetr/deploy/_onnx/__pycache__/symbolic.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bade6ef5eeb4955ee8d8ed6a1e959b2e1d6a273 Binary files /dev/null and b/rfdetr/deploy/_onnx/__pycache__/symbolic.cpython-313.pyc differ diff --git a/rfdetr/deploy/_onnx/optimizer.py b/rfdetr/deploy/_onnx/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..47831eefc1a104830957249811d18e75b10ef2e1 --- /dev/null +++ b/rfdetr/deploy/_onnx/optimizer.py @@ -0,0 +1,579 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +OnnxOptimizer +""" +import os +from collections import OrderedDict +from copy import deepcopy + +import numpy as np +import onnx +import torch +from onnx import shape_inference +import onnx_graphsurgeon as gs +from polygraphy.backend.onnx.loader import fold_constants +from onnx_graphsurgeon.logger.logger import G_LOGGER + +from .symbolic import CustomOpSymbolicRegistry + + +class OnnxOptimizer(): + def __init__( + self, + input, + severity=G_LOGGER.INFO + ): + if isinstance(input, str): + onnx_graph = self.load_onnx(input) + else: + onnx_graph = input + self.graph = gs.import_onnx(onnx_graph) + self.severity = severity + self.set_severity(severity) + + def set_severity(self, severity): + G_LOGGER.severity = severity + + def load_onnx(self, onnx_path:str): + """Load onnx from file + """ + assert os.path.isfile(onnx_path), f"not found onnx file: {onnx_path}" + onnx_graph = onnx.load(onnx_path) + G_LOGGER.info(f"load onnx file: {onnx_path}") + return onnx_graph + + def save_onnx(self, onnx_path:str): + onnx_graph = gs.export_onnx(self.graph) + G_LOGGER.info(f"save onnx file: {onnx_path}") + onnx.save(onnx_graph, onnx_path) + + def info(self, prefix=''): + G_LOGGER.verbose(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs") + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def find_node_input(self, node, name:str=None, value=None) -> int: + for i, inp in enumerate(node.inputs): + if isinstance(name, str) and inp.name == name: + index = i + elif inp == value: + index = i + assert index >= 0, f"not found {name}({value}) in node.inputs" + return index + + def find_node_output(self, node, name:str=None, value=None) -> int: + for i, inp in enumerate(node.outputs): + if isinstance(name, str) and inp.name == name: + index = i + elif inp == value: + index = i + assert index >= 0, f"not found {name}({value}) in node.outputs" + return index + + def common_opt(self, return_onnx=False): + for fn in CustomOpSymbolicRegistry._OPTIMIZER: + fn(self) + self.cleanup() + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=False) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + self.graph = gs.import_onnx(onnx_graph) + self.cleanup() + if return_onnx: + return onnx_graph + + def resize_fix(self): + ''' + This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs). + It substitutes found Resize with Resize that takes the size of the output tensor instead of scales. + It adds Shape->Slice->Concat + Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor. + This fix is required for the dynamic shape support. + ''' + mResizeNodes = 0 + for node in self.graph.nodes: + if node.op == "Resize" and len(node.inputs) == 3: + name = node.name + "/" + + add_node = node.o().o().i(1) + div_node = node.i() + + shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4]) + shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out]) + + const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64)) + const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64)) + const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64)) + + slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2]) + slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out]) + + shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2]) + shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out]) + + slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2]) + slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out]) + + concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4]) + concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out]) + + none_var = gs.Variable.empty() + + resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]]) + + self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw]) + + node.inputs = [] + node.outputs = [] + + mResizeNodes += 1 + + self.cleanup() + return mResizeNodes + + def adjustAddNode(self): + nAdjustAddNode = 0 + for node in self.graph.nodes: + # Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT. + if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant): + tensor = node.inputs[1] + bias = node.inputs[0] + node.inputs = [tensor, bias] + nAdjustAddNode += 1 + + self.cleanup() + return nAdjustAddNode + + def decompose_instancenorms(self): + nRemoveInstanceNorm = 0 + for node in self.graph.nodes: + if node.op == "InstanceNormalization": + name = node.name + "/" + input_tensor = node.inputs[0] + output_tensor = node.outputs[0] + mean_out = gs.Variable(name=name + "mean_out") + mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out]) + sub_out = gs.Variable(name=name + "sub_out") + sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out]) + pow_out = gs.Variable(name=name + "pow_out") + pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32)) + pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out]) + mean2_out = gs.Variable(name=name + "mean2_out") + mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out]) + epsilon_out = gs.Variable(name=name + "epsilon_out") + epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32)) + epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out]) + sqrt_out = gs.Variable(name=name + "sqrt_out") + sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out]) + div_out = gs.Variable(name=name + "div_out") + div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out]) + constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1))) + constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1))) + mul_out = gs.Variable(name=name + "mul_out") + mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out]) + add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor]) + self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node]) + node.inputs = [] + node.outputs = [] + nRemoveInstanceNorm += 1 + + self.cleanup() + return nRemoveInstanceNorm + + def insert_groupnorm_plugin(self): + nGroupNormPlugin = 0 + for node in self.graph.nodes: + if node.op == "Reshape" and node.outputs != [] and \ + node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \ + node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \ + node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \ + len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3: + # "node.outputs != []" is added for VAE + + inputTensor = node.inputs[0] + + gammaNode = node.o().o().o().o().o().o().o().o().o().o().o() + index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True) + gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32) + constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!! + + betaNode = gammaNode.o() + index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True) + beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32) + constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1))) + + epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0] + + if betaNode.o().op == "Sigmoid": # need Swish + bSwish = True + lastNode = betaNode.o().o() # Mul node of Swish + else: + bSwish = False + lastNode = betaNode # Cast node after Group Norm + + if lastNode.o().op == "Cast": + lastNode = lastNode.o() + inputList = [inputTensor, constantGamma, constantBeta] + groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape) + groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))])) + self.graph.nodes.append(groupNormN) + + for subNode in self.graph.nodes: + if lastNode.outputs[0] in subNode.inputs: + index = subNode.inputs.index(lastNode.outputs[0]) + subNode.inputs[index] = groupNormV + node.inputs = [] + lastNode.outputs = [] + nGroupNormPlugin += 1 + + self.cleanup() + return nGroupNormPlugin + + def insert_layernorm_plugin(self): + nLayerNormPlugin = 0 + for node in self.graph.nodes: + if node.op == 'ReduceMean' and \ + node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \ + node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \ + node.o().o(0).o().op == 'ReduceMean' and \ + node.o().o(0).o().o().op == 'Add' and \ + node.o().o(0).o().o().o().op == 'Sqrt' and \ + node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \ + node.o().o(0).o().o().o().o().o().op == 'Mul' and \ + node.o().o(0).o().o().o().o().o().o().op == 'Add' and \ + len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1: + + if node.i().op == "Add": + inputTensor = node.inputs[0] # CLIP + else: + inputTensor = node.i().inputs[0] # UNet and VAE + + gammaNode = node.o().o().o().o().o().o().o() + index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True) + gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32) + constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!! + + betaNode = gammaNode.o() + index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True) + beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32) + constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1))) + + inputList = [inputTensor, constantGamma, constantBeta] + layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape) + layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV]) + self.graph.nodes.append(layerNormN) + nLayerNormPlugin += 1 + + if betaNode.outputs[0] in self.graph.outputs: + index = self.graph.outputs.index(betaNode.outputs[0]) + self.graph.outputs[index] = layerNormV + else: + if betaNode.o().op == "Cast": + lastNode = betaNode.o() + else: + lastNode = betaNode + for subNode in self.graph.nodes: + if lastNode.outputs[0] in subNode.inputs: + index = subNode.inputs.index(lastNode.outputs[0]) + subNode.inputs[index] = layerNormV + lastNode.outputs = [] + + self.cleanup() + return nLayerNormPlugin + + def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0): + # Get weights of K + weights_k = node_k.inputs[1].values + # Get weights of V + weights_v = node_v.inputs[1].values + # Input number of channels to K and V + C = weights_k.shape[0] + # Number of heads + H = heads + # Dimension per head + D = weights_k.shape[1] // H + + # Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape + weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D) + + # K and V have the same input + input_tensor = node_k.inputs[0] + # K and V must have the same output which we feed into fmha plugin + output_tensor_k = node_k.outputs[0] + # Create tensor + constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv)) + + # Create fused KV node + fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k]) + self.graph.nodes.append(fused_kv_node) + + # Connect the output of fused node to the inputs of the nodes after K and V + node_v.o(num_dynamic).inputs[0] = output_tensor_k + node_k.o(num_dynamic).inputs[0] = output_tensor_k + for i in range(0,num_dynamic): + node_v.o().inputs.clear() + node_k.o().inputs.clear() + + # Clear inputs and outputs of K and V to ge these nodes cleared + node_k.outputs.clear() + node_v.outputs.clear() + node_k.inputs.clear() + node_v.inputs.clear() + + self.cleanup() + return fused_kv_node + + def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0): + # Get inputs and outputs for the fMHCA plugin + # We take an output of reshape that follows the Q GEMM + output_q = node_q.o(num_dynamic).o().inputs[0] + output_kv = node_kv.o().inputs[0] + output_final_tranpose = final_tranpose.outputs[0] + + # Clear the inputs of the nodes that follow the Q and KV GEMM + # to delete these subgraphs (it will be substituted by fMHCA plugin) + node_kv.outputs[0].outputs[0].inputs.clear() + node_kv.outputs[0].outputs[0].inputs.clear() + node_q.o(num_dynamic).o().inputs.clear() + for i in range(0,num_dynamic): + node_q.o(i).o().o(1).inputs.clear() + + weights_kv = node_kv.inputs[1].values + dims_per_head = weights_kv.shape[1] // (heads * 2) + + # Reshape dims + shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64))) + + # Reshape output tensor + output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None) + # Create fMHA plugin + reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape]) + # Insert node + self.graph.nodes.append(reshape) + + # Create fMHCA plugin + fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose]) + # Insert node + self.graph.nodes.append(fmhca) + + # Connect input of fMHCA to output of Q GEMM + node_q.o(num_dynamic).outputs[0] = output_q + + if num_dynamic > 0: + reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None) + reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out]) + self.graph.nodes.append(reshape2_input1_shape) + final_tranpose.o().inputs[1] = reshape2_input1_out + + # Clear outputs of transpose to get this subgraph cleared + final_tranpose.outputs.clear() + + self.cleanup() + + def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0): + # Get weights of Q + weights_q = node_q.inputs[1].values + # Get weights of K + weights_k = node_k.inputs[1].values + # Get weights of V + weights_v = node_v.inputs[1].values + + # Input number of channels to Q, K and V + C = weights_k.shape[0] + # Number of heads + H = heads + # Hidden dimension per head + D = weights_k.shape[1] // H + + # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape + weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D) + + input_tensor = node_k.inputs[0] # K and V have the same input + # Q, K and V must have the same output which we feed into fmha plugin + output_tensor_k = node_k.outputs[0] + # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape + constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv)) + + # Created a fused node + fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k]) + self.graph.nodes.append(fused_qkv_node) + + # Connect the output of the fused node to the inputs of the nodes after Q, K and V + node_q.o(num_dynamic).inputs[0] = output_tensor_k + node_k.o(num_dynamic).inputs[0] = output_tensor_k + node_v.o(num_dynamic).inputs[0] = output_tensor_k + for i in range(0,num_dynamic): + node_q.o().inputs.clear() + node_k.o().inputs.clear() + node_v.o().inputs.clear() + + # Clear inputs and outputs of Q, K and V to ge these nodes cleared + node_q.outputs.clear() + node_k.outputs.clear() + node_v.outputs.clear() + + node_q.inputs.clear() + node_k.inputs.clear() + node_v.inputs.clear() + + self.cleanup() + return fused_qkv_node + + def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0): + # Get inputs and outputs for the fMHA plugin + output_qkv = node_qkv.o().inputs[0] + output_final_tranpose = final_tranpose.outputs[0] + + # Clear the inputs of the nodes that follow the QKV GEMM + # to delete these subgraphs (it will be substituted by fMHA plugin) + node_qkv.outputs[0].outputs[2].inputs.clear() + node_qkv.outputs[0].outputs[1].inputs.clear() + node_qkv.outputs[0].outputs[0].inputs.clear() + + weights_qkv = node_qkv.inputs[1].values + dims_per_head = weights_qkv.shape[1] // (heads * 3) + + # Reshape dims + shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64))) + + # Reshape output tensor + output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None) + # Create fMHA plugin + reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape]) + # Insert node + self.graph.nodes.append(reshape) + + # Create fMHA plugin + fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose]) + # Insert node + self.graph.nodes.append(fmha) + + if num_dynamic > 0: + reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None) + reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out]) + self.graph.nodes.append(reshape2_input1_shape) + final_tranpose.o().inputs[1] = reshape2_input1_out + + # Clear outputs of transpose to get this subgraph cleared + final_tranpose.outputs.clear() + + self.cleanup() + + def mha_mhca_detected(self, node, mha): + # Go from V GEMM down to the S*V MatMul and all way up to K GEMM + # If we are looking for MHCA inputs of two matmuls (K and V) must be equal. + # If we are looking for MHA inputs (K and V) must be not equal. + if node.op == "MatMul" and len(node.outputs) == 1 and \ + ((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \ + (not mha and len(node.inputs[0].inputs) == 0)): + + if node.o().op == 'Shape': + if node.o(1).op == 'Shape': + num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2 + else: + num_dynamic_kv = 1 + # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well + num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1 + else: + num_dynamic_kv = 0 + num_dynamic_q = 0 + + o = node.o(num_dynamic_kv) + if o.op == "Reshape" and \ + o.o().op == "Transpose" and \ + o.o().o().op == "Reshape" and \ + o.o().o().o().op == "MatMul" and \ + o.o().o().o().i(0).op == "Softmax" and \ + o.o().o().o().i(1).op == "Reshape" and \ + o.o().o().o().i(0).i().op == "Mul" and \ + o.o().o().o().i(0).i().i().op == "MatMul" and \ + o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \ + o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \ + o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \ + o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \ + o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \ + o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \ + node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name: + # "len(node.outputs) == 1" to make sure we are not in the already fused node + node_q = o.o().o().o().i(0).i().i().i(0).i().i().i() + node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i() + node_v = node + final_tranpose = o.o().o().o().o(num_dynamic_q).o() + # Sanity check to make sure that the graph looks like expected + if node_q.op == "MatMul" and final_tranpose.op == "Transpose": + return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose + return False, 0, 0, None, None, None, None + + def fuse_kv_insert_fmhca(self, heads, mhca_index, sm): + nodes = self.graph.nodes + # Iterate over graph and search for MHCA pattern + for idx, _ in enumerate(nodes): + # fMHCA can't be at the 2 last layers of the network. It is a guard from OOB + if idx + 1 > len(nodes) or idx + 2 > len(nodes): + continue + + # Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected + detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \ + self.mha_mhca_detected(nodes[idx], mha=False) + if detected: + assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1 + # Skip the FMHCA plugin for SM75 except for when the dim per head is 40. + if sm == 75 and node_q.inputs[1].shape[1] // heads == 160: + continue + # Fuse K and V GEMMS + node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv) + # Insert fMHCA plugin + self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q) + return True + return False + + def fuse_qkv_insert_fmha(self, heads, mha_index): + nodes = self.graph.nodes + # Iterate over graph and search for MHA pattern + for idx, _ in enumerate(nodes): + # fMHA can't be at the 2 last layers of the network. It is a guard from OOB + if idx + 1 > len(nodes) or idx + 2 > len(nodes): + continue + + # Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected + detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \ + self.mha_mhca_detected(nodes[idx], mha=True) + if detected: + assert num_dynamic_q == num_dynamic_kv + # Fuse Q, K and V GEMMS + node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv) + # Insert fMHA plugin + self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv) + return True + return False + + def insert_fmhca_plugin(self, num_heads, sm): + mhca_index = 0 + while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm): + mhca_index += 1 + return mhca_index + + def insert_fmha_plugin(self, num_heads): + mha_index = 0 + while self.fuse_qkv_insert_fmha(num_heads, mha_index): + mha_index += 1 + return mha_index diff --git a/rfdetr/deploy/_onnx/symbolic.py b/rfdetr/deploy/_onnx/symbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa8c6f4086af07f32d56ba4373c7d7b44671686 --- /dev/null +++ b/rfdetr/deploy/_onnx/symbolic.py @@ -0,0 +1,37 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +""" +CustomOpSymbolicRegistry class +""" +from copy import deepcopy + +import onnx +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.onnx import register_custom_op_symbolic +from torch.onnx.symbolic_helper import parse_args +from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes +from torch.autograd import Function + + +class CustomOpSymbolicRegistry: + # _SYMBOLICS = {} + _OPTIMIZER = [] + + @classmethod + def optimizer(cls, fn): + cls._OPTIMIZER.append(fn) + + +def register_optimizer(): + def optimizer_wrapper(fn): + CustomOpSymbolicRegistry.optimizer(fn) + return fn + return optimizer_wrapper \ No newline at end of file diff --git a/rfdetr/deploy/benchmark.py b/rfdetr/deploy/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bc27be8ec278028d0db4d40fe8ba122cd7f6e78d --- /dev/null +++ b/rfdetr/deploy/benchmark.py @@ -0,0 +1,590 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +This tool provides performance benchmarks by using ONNX Runtime and TensorRT +to run inference on a given model with the COCO validation set. It offers +reliable measurements of inference latency using ONNX Runtime or TensorRT +on the device. +""" +import argparse +import copy +import contextlib +import datetime +import json +import os +import os.path as osp +import random +import time +import ast +from pathlib import Path +from collections import namedtuple, OrderedDict + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +import numpy as np +from PIL import Image +import torch +from torch.utils.data import DataLoader, DistributedSampler +import torchvision.transforms as T +import torchvision.transforms.functional as F +import tqdm + +import pycuda.driver as cuda +import pycuda.autoinit +import onnxruntime as nxrun +import tensorrt as trt + + +def parser_args(): + parser = argparse.ArgumentParser('performance benchmark tool for onnx/trt model') + parser.add_argument('--path', type=str, help='engine file path') + parser.add_argument('--coco_path', type=str, default="data/coco", help='coco dataset path') + parser.add_argument('--device', default=0, type=int) + parser.add_argument('--run_benchmark', action='store_true', help='repeat the inference to benchmark the latency') + parser.add_argument('--disable_eval', action='store_true', help='disable evaluation') + return parser.parse_args() + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = COCO(coco_gt) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # Running per image evaluation... + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + return p.imgIds, evalImgs + +def convert_to_xywh(boxes): + boxes[:, 2:] -= boxes[:, :2] + return boxes + + +def get_image_list(ann_file): + with open(ann_file, 'r') as fin: + data = json.load(fin) + return data['images'] + + +def load_image(file_path): + return Image.open(file_path).convert("RGB") + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class SquareResize(object): + def __init__(self, sizes): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + rescaled_img=F.resize(img, (size, size)) + w, h = rescaled_img.size + if target is None: + return rescaled_img, None + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_img.size, img.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + target["size"] = torch.tensor([h, w]) + + return rescaled_img, target + + +def infer_transforms(): + normalize = Compose([ + ToTensor(), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + return Compose([ + SquareResize([640]), + normalize, + ]) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)), + (x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))] + return torch.stack(b, dim=-1) + + +def post_process(outputs, target_sizes): + out_logits, out_bbox = outputs['labels'], outputs['dets'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +def infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device, repeats=1): + time_list = [] + for img_dict in tqdm.tqdm(img_list): + image = load_image(os.path.join(prefix, img_dict['file_name'])) + width, height = image.size + orig_target_sizes = torch.Tensor([height, width]) + image_tensor, _ = infer_transforms()(image, None) # target is None + + samples = image_tensor[None].numpy() + + time_profile.reset() + with time_profile: + for _ in range(repeats): + res = sess.run(None, {"input": samples}) + time_list.append(time_profile.total / repeats) + outputs = {} + outputs['labels'] = torch.Tensor(res[1]).to(device) + outputs['dets'] = torch.Tensor(res[0]).to(device) + + orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device) + results = post_process(outputs, orig_target_sizes) + res = {img_dict['id']: results[0]} + if coco_evaluator is not None: + coco_evaluator.update(res) + + print("Model latency with ONNX Runtime: {}ms".format(1000 * sum(time_list) / len(img_list))) + + # accumulate predictions from all images + stats = {} + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + coco_evaluator.accumulate() + coco_evaluator.summarize() + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + print(stats) + + +def infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device, repeats=1): + time_list = [] + for img_dict in tqdm.tqdm(img_list): + image = load_image(os.path.join(prefix, img_dict['file_name'])) + width, height = image.size + orig_target_sizes = torch.Tensor([height, width]) + image_tensor, _ = infer_transforms()(image, None) # target is None + + samples = image_tensor[None].to(device) + _, _, h, w = samples.shape + im_shape = torch.Tensor(np.array([h, w]).reshape((1, 2)).astype(np.float32)).to(device) + scale_factor = torch.Tensor(np.array([h / height, w / width]).reshape((1, 2)).astype(np.float32)).to(device) + + time_profile.reset() + with time_profile: + for _ in range(repeats): + outputs = model({"input": samples}) + + time_list.append(time_profile.total / repeats) + orig_target_sizes = torch.stack([orig_target_sizes], dim=0).to(device) + if coco_evaluator is not None: + results = post_process(outputs, orig_target_sizes) + res = {img_dict['id']: results[0]} + coco_evaluator.update(res) + + print("Model latency with TensorRT: {}ms".format(1000 * sum(time_list) / len(img_list))) + + # accumulate predictions from all images + stats = {} + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + coco_evaluator.accumulate() + coco_evaluator.summarize() + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + print(stats) + + +class TRTInference(object): + """TensorRT inference engine + """ + def __init__(self, engine_path='dino.engine', device='cuda:0', sync_mode:bool=False, max_batch_size=32, verbose=False): + self.engine_path = engine_path + self.device = device + self.sync_mode = sync_mode + self.max_batch_size = max_batch_size + + self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO) + + self.engine = self.load_engine(engine_path) + + self.context = self.engine.create_execution_context() + + self.bindings = self.get_bindings(self.engine, self.context, self.max_batch_size, self.device) + self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items()) + + self.input_names = self.get_input_names() + self.output_names = self.get_output_names() + + if not self.sync_mode: + self.stream = cuda.Stream() + + # self.time_profile = TimeProfiler() + self.time_profile = None + + def get_dummy_input(self, batch_size:int): + blob = {} + for name, binding in self.bindings.items(): + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + print(f"make dummy input {name} with shape {binding.shape}") + blob[name] = torch.rand(batch_size, *binding.shape[1:]).float().to('cuda:0') + return blob + + def load_engine(self, path): + '''load engine + ''' + trt.init_libnvinfer_plugins(self.logger, '') + with open(path, 'rb') as f, trt.Runtime(self.logger) as runtime: + return runtime.deserialize_cuda_engine(f.read()) + + def get_input_names(self, ): + names = [] + for _, name in enumerate(self.engine): + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + names.append(name) + return names + + def get_output_names(self, ): + names = [] + for _, name in enumerate(self.engine): + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: + names.append(name) + return names + + def get_bindings(self, engine, context, max_batch_size=32, device=None): + '''build binddings + ''' + Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + bindings = OrderedDict() + + for i, name in enumerate(engine): + shape = engine.get_tensor_shape(name) + dtype = trt.nptype(engine.get_tensor_dtype(name)) + + if shape[0] == -1: + raise NotImplementedError + + if False: + if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + data = np.random.randn(*shape).astype(dtype) + ptr = cuda.mem_alloc(data.nbytes) + bindings[name] = Binding(name, dtype, shape, data, ptr) + else: + data = cuda.pagelocked_empty(trt.volume(shape), dtype) + ptr = cuda.mem_alloc(data.nbytes) + bindings[name] = Binding(name, dtype, shape, data, ptr) + + else: + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, data, data.data_ptr()) + + return bindings + + def run_sync(self, blob): + self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names}) + self.context.execute_v2(list(self.bindings_addr.values())) + outputs = {n: self.bindings[n].data for n in self.output_names} + return outputs + + def run_async(self, blob): + self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names}) + bindings_addr = [int(v) for _, v in self.bindings_addr.items()] + self.context.execute_async_v2(bindings=bindings_addr, stream_handle=self.stream.handle) + outputs = {n: self.bindings[n].data for n in self.output_names} + self.stream.synchronize() + return outputs + + def __call__(self, blob): + if self.sync_mode: + return self.run_sync(blob) + else: + return self.run_async(blob) + + def synchronize(self, ): + if not self.sync_mode and torch.cuda.is_available(): + torch.cuda.synchronize() + elif self.sync_mode: + self.stream.synchronize() + + def speed(self, blob, n): + self.time_profile.reset() + with self.time_profile: + for _ in range(n): + _ = self(blob) + return self.time_profile.total / n + + + def build_engine(self, onnx_file_path, engine_file_path, max_batch_size=32): + '''Takes an ONNX file and creates a TensorRT engine to run inference with + http://gitlab.baidu.com/paddle-inference/benchmark/blob/main/backend_trt.py#L57 + ''' + EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + with trt.Builder(self.logger) as builder, \ + builder.create_network(EXPLICIT_BATCH) as network, \ + trt.OnnxParser(network, self.logger) as parser, \ + builder.create_builder_config() as config: + + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1024 MiB + config.set_flag(trt.BuilderFlag.FP16) + + with open(onnx_file_path, 'rb') as model: + if not parser.parse(model.read()): + print('ERROR: Failed to parse the ONNX file.') + for error in range(parser.num_errors): + print(parser.get_error(error)) + return None + + serialized_engine = builder.build_serialized_network(network, config) + with open(engine_file_path, 'wb') as f: + f.write(serialized_engine) + + return serialized_engine + + +class TimeProfiler(contextlib.ContextDecorator): + def __init__(self, ): + self.total = 0 + + def __enter__(self, ): + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): + self.total += self.time() - self.start + + def reset(self, ): + self.total = 0 + + def time(self, ): + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.perf_counter() + + +def main(args): + print(args) + + coco_gt = osp.join(args.coco_path, 'annotations/instances_val2017.json') + img_list = get_image_list(coco_gt) + prefix = osp.join(args.coco_path, 'val2017') + if args.run_benchmark: + repeats = 10 + print('Inference for each image will be repeated 10 times to obtain ' + 'a reliable measurement of inference latency.') + else: + repeats = 1 + + if args.disable_eval: + coco_evaluator = None + else: + coco_evaluator = CocoEvaluator(coco_gt, ('bbox',)) + + time_profile = TimeProfiler() + + if args.path.endswith(".onnx"): + sess = nxrun.InferenceSession(args.path, providers=['CUDAExecutionProvider']) + infer_onnx(sess, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats) + elif args.path.endswith(".engine"): + model = TRTInference(args.path, sync_mode=True, device=f'cuda:{args.device}') + infer_engine(model, coco_evaluator, time_profile, prefix, img_list, device=f'cuda:{args.device}', repeats=repeats) + else: + raise NotImplementedError('Only model file names ending with ".onnx" and ".engine" are supported.') + + +if __name__ == '__main__': + args = parser_args() + main(args) diff --git a/rfdetr/deploy/export.py b/rfdetr/deploy/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9a501887e95a3a9c9c428c1960285e542268a1a9 --- /dev/null +++ b/rfdetr/deploy/export.py @@ -0,0 +1,276 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +export ONNX model and TensorRT engine for deployment +""" +import os +import ast +import random +import argparse +import subprocess +import torch.nn as nn +from pathlib import Path +import time +from collections import defaultdict + +import onnx +import torch +import onnxsim +import numpy as np +from PIL import Image + +import rfdetr.util.misc as utils +import rfdetr.datasets.transforms as T +from rfdetr.models import build_model +from rfdetr.deploy._onnx import OnnxOptimizer +import re +import sys + + +def run_command_shell(command, dry_run:bool = False) -> int: + if dry_run: + print("") + print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} {command}") + print("") + try: + result = subprocess.run(command, shell=True, capture_output=True, text=True) + return result + except subprocess.CalledProcessError as e: + print(f"Command failed with exit code {e.returncode}") + print(f"Error output:\n{e.stderr.decode('utf-8')}") + raise + + +def make_infer_image(infer_dir, shape, batch_size, device="cuda"): + if infer_dir is None: + dummy = np.random.randint(0, 256, (shape[0], shape[1], 3), dtype=np.uint8) + image = Image.fromarray(dummy, mode="RGB") + else: + image = Image.open(infer_dir).convert("RGB") + + transforms = T.Compose([ + T.SquareResize([shape[0]]), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + inps, _ = transforms(image, None) + inps = inps.to(device) + # inps = utils.nested_tensor_from_tensor_list([inps for _ in range(args.batch_size)]) + inps = torch.stack([inps for _ in range(batch_size)]) + return inps + +def export_onnx(output_dir, model, input_names, input_tensors, output_names, dynamic_axes, backbone_only=False, verbose=True, opset_version=17): + export_name = "backbone_model" if backbone_only else "inference_model" + output_file = os.path.join(output_dir, f"{export_name}.onnx") + + # Prepare model for export + if hasattr(model, "export"): + model.export() + + torch.onnx.export( + model, + input_tensors, + output_file, + input_names=input_names, + output_names=output_names, + export_params=True, + keep_initializers_as_inputs=False, + do_constant_folding=True, + verbose=verbose, + opset_version=opset_version, + dynamic_axes=dynamic_axes) + + print(f'\nSuccessfully exported ONNX model: {output_file}') + return output_file + + +def onnx_simplify(onnx_dir:str, input_names, input_tensors, force=False): + sim_onnx_dir = onnx_dir.replace(".onnx", ".sim.onnx") + if os.path.isfile(sim_onnx_dir) and not force: + return sim_onnx_dir + + if isinstance(input_tensors, torch.Tensor): + input_tensors = [input_tensors] + + print(f'start simplify ONNX model: {onnx_dir}') + opt = OnnxOptimizer(onnx_dir) + opt.info('Model: original') + opt.common_opt() + opt.info('Model: optimized') + opt.save_onnx(sim_onnx_dir) + input_dict = {name: tensor.detach().cpu().numpy() for name, tensor in zip(input_names, input_tensors)} + model_opt, check_ok = onnxsim.simplify( + onnx_dir, + check_n = 3, + input_data=input_dict, + dynamic_input_shape=False) + if check_ok: + onnx.save(model_opt, sim_onnx_dir) + else: + raise RuntimeError("Failed to simplify ONNX model.") + print(f'Successfully simplified ONNX model: {sim_onnx_dir}') + return sim_onnx_dir + + +def trtexec(onnx_dir:str, args) -> None: + engine_dir = onnx_dir.replace(".onnx", f".engine") + + # Base trtexec command + trt_command = " ".join([ + "trtexec", + f"--onnx={onnx_dir}", + f"--saveEngine={engine_dir}", + f"--memPoolSize=workspace:4096 --fp16", + f"--useCudaGraph --useSpinWait --warmUp=500 --avgRuns=1000 --duration=10", + f"{'--verbose' if args.verbose else ''}"]) + + if args.profile: + profile_dir = onnx_dir.replace(".onnx", f".nsys-rep") + # Wrap with nsys profile command + command = " ".join([ + "nsys profile", + f"--output={profile_dir}", + "--trace=cuda,nvtx", + "--force-overwrite true", + trt_command + ]) + print(f'Profile data will be saved to: {profile_dir}') + else: + command = trt_command + + output = run_command_shell(command, args.dry_run) + stats = parse_trtexec_output(output.stdout) + +def parse_trtexec_output(output_text): + print(output_text) + # Common patterns in trtexec output + gpu_compute_pattern = r"GPU Compute Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms, median = (\d+\.\d+) ms" + h2d_pattern = r"Host to Device Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms" + d2h_pattern = r"Device to Host Transfer Time: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms" + latency_pattern = r"Latency: min = (\d+\.\d+) ms, max = (\d+\.\d+) ms, mean = (\d+\.\d+) ms" + throughput_pattern = r"Throughput: (\d+\.\d+) qps" + + stats = {} + + # Extract compute times + if match := re.search(gpu_compute_pattern, output_text): + stats.update({ + 'compute_min_ms': float(match.group(1)), + 'compute_max_ms': float(match.group(2)), + 'compute_mean_ms': float(match.group(3)), + 'compute_median_ms': float(match.group(4)) + }) + + # Extract H2D times + if match := re.search(h2d_pattern, output_text): + stats.update({ + 'h2d_min_ms': float(match.group(1)), + 'h2d_max_ms': float(match.group(2)), + 'h2d_mean_ms': float(match.group(3)) + }) + + # Extract D2H times + if match := re.search(d2h_pattern, output_text): + stats.update({ + 'd2h_min_ms': float(match.group(1)), + 'd2h_max_ms': float(match.group(2)), + 'd2h_mean_ms': float(match.group(3)) + }) + + if match := re.search(latency_pattern, output_text): + stats.update({ + 'latency_min_ms': float(match.group(1)), + 'latency_max_ms': float(match.group(2)), + 'latency_mean_ms': float(match.group(3)) + }) + + # Extract throughput + if match := re.search(throughput_pattern, output_text): + stats['throughput_qps'] = float(match.group(1)) + + return stats + +def no_batch_norm(model): + for module in model.modules(): + if isinstance(module, nn.BatchNorm2d): + raise ValueError("BatchNorm2d found in the model. Please remove it.") + +def main(args): + print("git:\n {}\n".format(utils.get_sha())) + print(args) + # convert device to device_id + if args.device == 'cuda': + device_id = "0" + elif args.device == 'cpu': + device_id = "" + else: + device_id = str(int(args.device)) + args.device = f"cuda:{device_id}" + + # device for export onnx + # TODO: export onnx with cuda failed with onnx error + device = torch.device("cpu") + os.environ["CUDA_VISIBLE_DEVICES"] = device_id + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + model, criterion, postprocessors = build_model(args) + n_parameters = sum(p.numel() for p in model.parameters()) + print(f"number of parameters: {n_parameters}") + n_backbone_parameters = sum(p.numel() for p in model.backbone.parameters()) + print(f"number of backbone parameters: {n_backbone_parameters}") + n_projector_parameters = sum(p.numel() for p in model.backbone[0].projector.parameters()) + print(f"number of projector parameters: {n_projector_parameters}") + n_backbone_encoder_parameters = sum(p.numel() for p in model.backbone[0].encoder.parameters()) + print(f"number of backbone encoder parameters: {n_backbone_encoder_parameters}") + n_transformer_parameters = sum(p.numel() for p in model.transformer.parameters()) + print(f"number of transformer parameters: {n_transformer_parameters}") + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint['model'], strict=True) + print(f"load checkpoints {args.resume}") + + if args.layer_norm: + no_batch_norm(model) + + model.to(device) + + input_tensors = make_infer_image(args, device) + input_names = ['input'] + output_names = ['features'] if args.backbone_only else ['dets', 'labels'] + dynamic_axes = None + # Run model inference in pytorch mode + model.eval().to("cuda") + input_tensors = input_tensors.to("cuda") + with torch.no_grad(): + if args.backbone_only: + features = model(input_tensors) + print(f"PyTorch inference output shape: {features.shape}") + else: + outputs = model(input_tensors) + dets = outputs['pred_boxes'] + labels = outputs['pred_logits'] + print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}") + model.cpu() + input_tensors = input_tensors.cpu() + + + output_file = export_onnx(model, args, input_names, input_tensors, output_names, dynamic_axes) + + if args.simplify: + output_file = onnx_simplify(output_file, input_names, input_tensors, args) + + if args.tensorrt: + output_file = trtexec(output_file, args) diff --git a/rfdetr/detr.py b/rfdetr/detr.py new file mode 100644 index 0000000000000000000000000000000000000000..55c514c782973c3dc7525e8a63f99edc64fda4e6 --- /dev/null +++ b/rfdetr/detr.py @@ -0,0 +1,451 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + + +import json +import os +from collections import defaultdict +from logging import getLogger +from typing import Union, List +from copy import deepcopy + +import numpy as np +import supervision as sv +import torch +import torchvision.transforms.functional as F +from PIL import Image + +try: + torch.set_float32_matmul_precision('high') +except: + pass + +from rfdetr.config import ( + RFDETRBaseConfig, + RFDETRLargeConfig, + RFDETRNanoConfig, + RFDETRSmallConfig, + RFDETRMediumConfig, + TrainConfig, + ModelConfig +) +from rfdetr.main import Model, download_pretrain_weights +from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink +from rfdetr.util.coco_classes import COCO_CLASSES + +logger = getLogger(__name__) +class RFDETR: + """ + The base RF-DETR class implements the core methods for training RF-DETR models, + running inference on the models, optimising models, and uploading trained + models for deployment. + """ + means = [0.485, 0.456, 0.406] + stds = [0.229, 0.224, 0.225] + size = None + + def __init__(self, **kwargs): + self.model_config = self.get_model_config(**kwargs) + self.maybe_download_pretrain_weights() + self.model = self.get_model(self.model_config) + self.callbacks = defaultdict(list) + + self.model.inference_model = None + self._is_optimized_for_inference = False + self._has_warned_about_not_being_optimized_for_inference = False + self._optimized_has_been_compiled = False + self._optimized_batch_size = None + self._optimized_resolution = None + self._optimized_dtype = None + + def maybe_download_pretrain_weights(self): + """ + Download pre-trained weights if they are not already downloaded. + """ + download_pretrain_weights(self.model_config.pretrain_weights) + + def get_model_config(self, **kwargs): + """ + Retrieve the configuration parameters used by the model. + """ + return ModelConfig(**kwargs) + + def train(self, **kwargs): + """ + Train an RF-DETR model. + """ + config = self.get_train_config(**kwargs) + self.train_from_config(config, **kwargs) + + def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32): + self.remove_optimized_model() + + self.model.inference_model = deepcopy(self.model.model) + self.model.inference_model.eval() + self.model.inference_model.export() + + self._optimized_resolution = self.model.resolution + self._is_optimized_for_inference = True + + self.model.inference_model = self.model.inference_model.to(dtype=dtype) + self._optimized_dtype = dtype + + if compile: + self.model.inference_model = torch.jit.trace( + self.model.inference_model, + torch.randn( + batch_size, 3, self.model.resolution, self.model.resolution, + device=self.model.device, + dtype=dtype + ) + ) + self._optimized_has_been_compiled = True + self._optimized_batch_size = batch_size + + def remove_optimized_model(self): + self.model.inference_model = None + self._is_optimized_for_inference = False + self._optimized_has_been_compiled = False + self._optimized_batch_size = None + self._optimized_resolution = None + self._optimized_half = False + + def export(self, **kwargs): + """ + Export your model to an ONNX file. + + See [the ONNX export documentation](https://rfdetr.roboflow.com/learn/train/#onnx-export) for more information. + """ + self.model.export(**kwargs) + + def train_from_config(self, config: TrainConfig, **kwargs): + with open( + os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r" + ) as f: + anns = json.load(f) + num_classes = len(anns["categories"]) + class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"] + self.model.class_names = class_names + + if self.model_config.num_classes != num_classes: + logger.warning( + f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n" + f"reinitializing your detection head with {num_classes} classes." + ) + self.model.reinitialize_detection_head(num_classes) + + + train_config = config.dict() + model_config = self.model_config.dict() + model_config.pop("num_classes") + if "class_names" in model_config: + model_config.pop("class_names") + + if "class_names" in train_config and train_config["class_names"] is None: + train_config["class_names"] = class_names + + for k, v in train_config.items(): + if k in model_config: + model_config.pop(k) + if k in kwargs: + kwargs.pop(k) + + all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes} + + metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir) + self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update) + self.callbacks["on_train_end"].append(metrics_plot_sink.save) + + if config.tensorboard: + metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=config.output_dir) + self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update) + self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close) + + if config.wandb: + metrics_wandb_sink = MetricsWandBSink( + output_dir=config.output_dir, + project=config.project, + run=config.run, + config=config.model_dump() + ) + self.callbacks["on_fit_epoch_end"].append(metrics_wandb_sink.update) + self.callbacks["on_train_end"].append(metrics_wandb_sink.close) + + if config.early_stopping: + from rfdetr.util.early_stopping import EarlyStoppingCallback + early_stopping_callback = EarlyStoppingCallback( + model=self.model, + patience=config.early_stopping_patience, + min_delta=config.early_stopping_min_delta, + use_ema=config.early_stopping_use_ema + ) + self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update) + + self.model.train( + **all_kwargs, + callbacks=self.callbacks, + ) + + def get_train_config(self, **kwargs): + """ + Retrieve the configuration parameters that will be used for training. + """ + return TrainConfig(**kwargs) + + def get_model(self, config: ModelConfig): + """ + Retrieve a model instance based on the provided configuration. + """ + return Model(**config.dict()) + + # Get class_names from the model + @property + def class_names(self): + """ + Retrieve the class names supported by the loaded model. + + Returns: + dict: A dictionary mapping class IDs to class names. The keys are integers starting from + """ + if hasattr(self.model, 'class_names') and self.model.class_names: + return {i+1: name for i, name in enumerate(self.model.class_names)} + + return COCO_CLASSES + + def predict( + self, + images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]], + threshold: float = 0.5, + **kwargs, + ) -> Union[sv.Detections, List[sv.Detections]]: + """Performs object detection on the input images and returns bounding box + predictions. + + This method accepts a single image or a list of images in various formats + (file path, PIL Image, NumPy array, or torch.Tensor). The images should be in + RGB channel order. If a torch.Tensor is provided, it must already be normalized + to values in the [0, 1] range and have the shape (C, H, W). + + Args: + images (Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]]): + A single image or a list of images to process. Images can be provided + as file paths, PIL Images, NumPy arrays, or torch.Tensors. + threshold (float, optional): + The minimum confidence score needed to consider a detected bounding box valid. + **kwargs: + Additional keyword arguments. + + Returns: + Union[sv.Detections, List[sv.Detections]]: A single or multiple Detections + objects, each containing bounding box coordinates, confidence scores, + and class IDs. + """ + if not self._is_optimized_for_inference and not self._has_warned_about_not_being_optimized_for_inference: + logger.warning( + "Model is not optimized for inference. " + "Latency may be higher than expected. " + "You can optimize the model for inference by calling model.optimize_for_inference()." + ) + self._has_warned_about_not_being_optimized_for_inference = True + + self.model.model.eval() + + if not isinstance(images, list): + images = [images] + + orig_sizes = [] + processed_images = [] + + for img in images: + + if isinstance(img, str): + img = Image.open(img) + + if not isinstance(img, torch.Tensor): + img = F.to_tensor(img) + + if (img > 1).any(): + raise ValueError( + "Image has pixel values above 1. Please ensure the image is " + "normalized (scaled to [0, 1])." + ) + if img.shape[0] != 3: + raise ValueError( + f"Invalid image shape. Expected 3 channels (RGB), but got " + f"{img.shape[0]} channels." + ) + img_tensor = img + + h, w = img_tensor.shape[1:] + orig_sizes.append((h, w)) + + img_tensor = img_tensor.to(self.model.device) + img_tensor = F.normalize(img_tensor, self.means, self.stds) + img_tensor = F.resize(img_tensor, (self.model.resolution, self.model.resolution)) + + processed_images.append(img_tensor) + + batch_tensor = torch.stack(processed_images) + + if self._is_optimized_for_inference: + if self._optimized_resolution != batch_tensor.shape[2]: + # this could happen if someone manually changes self.model.resolution after optimizing the model + raise ValueError(f"Resolution mismatch. " + f"Model was optimized for resolution {self._optimized_resolution}, " + f"but got {batch_tensor.shape[2]}. " + "You can explicitly remove the optimized model by calling model.remove_optimized_model().") + if self._optimized_has_been_compiled: + if self._optimized_batch_size != batch_tensor.shape[0]: + raise ValueError(f"Batch size mismatch. " + f"Optimized model was compiled for batch size {self._optimized_batch_size}, " + f"but got {batch_tensor.shape[0]}. " + "You can explicitly remove the optimized model by calling model.remove_optimized_model(). " + "Alternatively, you can recompile the optimized model for a different batch size " + "by calling model.optimize_for_inference(batch_size=).") + + with torch.inference_mode(): + if self._is_optimized_for_inference: + predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype)) + else: + predictions = self.model.model(batch_tensor) + if isinstance(predictions, tuple): + predictions = { + "pred_logits": predictions[1], + "pred_boxes": predictions[0] + } + target_sizes = torch.tensor(orig_sizes, device=self.model.device) + results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes) + + detections_list = [] + for result in results: + scores = result["scores"] + labels = result["labels"] + boxes = result["boxes"] + + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + + detections = sv.Detections( + xyxy=boxes.float().cpu().numpy(), + confidence=scores.float().cpu().numpy(), + class_id=labels.cpu().numpy(), + ) + detections_list.append(detections) + + return detections_list if len(detections_list) > 1 else detections_list[0] + + def deploy_to_roboflow(self, workspace: str, project_id: str, version: str, api_key: str = None, size: str = None): + """ + Deploy the trained RF-DETR model to Roboflow. + + Deploying with Roboflow will create a Serverless API to which you can make requests. + + You can also download weights into a Roboflow Inference deployment for use in Roboflow Workflows and on-device deployment. + + Args: + workspace (str): The name of the Roboflow workspace to deploy to. + project_ids (List[str]): A list of project IDs to which the model will be deployed + api_key (str, optional): Your Roboflow API key. If not provided, + it will be read from the environment variable `ROBOFLOW_API_KEY`. + size (str, optional): The size of the model to deploy. If not provided, + it will default to the size of the model being trained (e.g., "rfdetr-base", "rfdetr-large", etc.). + model_name (str, optional): The name you want to give the uploaded model. + If not provided, it will default to "-uploaded". + Raises: + ValueError: If the `api_key` is not provided and not found in the environment + variable `ROBOFLOW_API_KEY`, or if the `size` is not set for custom architectures. + """ + from roboflow import Roboflow + import shutil + if api_key is None: + api_key = os.getenv("ROBOFLOW_API_KEY") + if api_key is None: + raise ValueError("Set api_key= in deploy_to_roboflow or export ROBOFLOW_API_KEY=") + + + rf = Roboflow(api_key=api_key) + workspace = rf.workspace(workspace) + + if self.size is None and size is None: + raise ValueError("Must set size for custom architectures") + + size = self.size or size + tmp_out_dir = ".roboflow_temp_upload" + os.makedirs(tmp_out_dir, exist_ok=True) + outpath = os.path.join(tmp_out_dir, "weights.pt") + torch.save( + { + "model": self.model.model.state_dict(), + "args": self.model.args + }, outpath + ) + project = workspace.project(project_id) + version = project.version(version) + version.deploy( + model_type=size, + model_path=tmp_out_dir, + filename="weights.pt" + ) + shutil.rmtree(tmp_out_dir) + + + +class RFDETRBase(RFDETR): + """ + Train an RF-DETR Base model (29M parameters). + """ + size = "rfdetr-base" + def get_model_config(self, **kwargs): + return RFDETRBaseConfig(**kwargs) + + def get_train_config(self, **kwargs): + return TrainConfig(**kwargs) + +class RFDETRLarge(RFDETR): + """ + Train an RF-DETR Large model. + """ + size = "rfdetr-large" + def get_model_config(self, **kwargs): + return RFDETRLargeConfig(**kwargs) + + def get_train_config(self, **kwargs): + return TrainConfig(**kwargs) + +class RFDETRNano(RFDETR): + """ + Train an RF-DETR Nano model. + """ + size = "rfdetr-nano" + def get_model_config(self, **kwargs): + return RFDETRNanoConfig(**kwargs) + + def get_train_config(self, **kwargs): + return TrainConfig(**kwargs) + +class RFDETRSmall(RFDETR): + """ + Train an RF-DETR Small model. + """ + size = "rfdetr-small" + def get_model_config(self, **kwargs): + return RFDETRSmallConfig(**kwargs) + + def get_train_config(self, **kwargs): + return TrainConfig(**kwargs) + +class RFDETRMedium(RFDETR): + """ + Train an RF-DETR Medium model. + """ + size = "rfdetr-medium" + def get_model_config(self, **kwargs): + return RFDETRMediumConfig(**kwargs) + + def get_train_config(self, **kwargs): + return TrainConfig(**kwargs) \ No newline at end of file diff --git a/rfdetr/engine.py b/rfdetr/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..31e68cae3d5802f6d2a6e01a9480cd3153c02e2b --- /dev/null +++ b/rfdetr/engine.py @@ -0,0 +1,340 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Train and eval functions used in main.py +""" +import math +import sys +from typing import Iterable +import random + +import torch +import torch.nn.functional as F + +import rfdetr.util.misc as utils +from rfdetr.datasets.coco_eval import CocoEvaluator +from rfdetr.datasets.coco import compute_multi_scale_scales + +try: + from torch.amp import autocast, GradScaler + DEPRECATED_AMP = False +except ImportError: + from torch.cuda.amp import autocast, GradScaler + DEPRECATED_AMP = True +from typing import DefaultDict, List, Callable +from rfdetr.util.misc import NestedTensor +import numpy as np + +def get_autocast_args(args): + if DEPRECATED_AMP: + return {'enabled': args.amp, 'dtype': torch.bfloat16} + else: + return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16} + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + batch_size: int, + max_norm: float = 0, + ema_m: torch.nn.Module = None, + schedules: dict = {}, + num_training_steps_per_epoch=None, + vit_encoder_num_layers=None, + args=None, + callbacks: DefaultDict[str, List[Callable]] = None, +): + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter( + "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") + ) + header = "Epoch: [{}]".format(epoch) + print_freq = 10 + start_steps = epoch * num_training_steps_per_epoch + + print("Grad accum steps: ", args.grad_accum_steps) + print("Total batch size: ", batch_size * utils.get_world_size()) + + # Add gradient scaler for AMP + if DEPRECATED_AMP: + scaler = GradScaler(enabled=args.amp) + else: + scaler = GradScaler('cuda', enabled=args.amp) + + optimizer.zero_grad() + assert batch_size % args.grad_accum_steps == 0 + sub_batch_size = batch_size // args.grad_accum_steps + print("LENGTH OF DATA LOADER:", len(data_loader)) + for data_iter_step, (samples, targets) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + it = start_steps + data_iter_step + callback_dict = { + "step": it, + "model": model, + "epoch": epoch, + } + for callback in callbacks["on_train_batch_start"]: + callback(callback_dict) + if "dp" in schedules: + if args.distributed: + model.module.update_drop_path( + schedules["dp"][it], vit_encoder_num_layers + ) + else: + model.update_drop_path(schedules["dp"][it], vit_encoder_num_layers) + if "do" in schedules: + if args.distributed: + model.module.update_dropout(schedules["do"][it]) + else: + model.update_dropout(schedules["do"][it]) + + if args.multi_scale and not args.do_random_resize_via_padding: + scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows) + random.seed(it) + scale = random.choice(scales) + with torch.inference_mode(): + samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False) + samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool() + + for i in range(args.grad_accum_steps): + start_idx = i * sub_batch_size + final_idx = start_idx + sub_batch_size + new_samples_tensors = samples.tensors[start_idx:final_idx] + new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx]) + new_samples = new_samples.to(device) + new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]] + + with autocast(**get_autocast_args(args)): + outputs = model(new_samples, new_targets) + loss_dict = criterion(outputs, new_targets) + weight_dict = criterion.weight_dict + losses = sum( + (1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k] + for k in loss_dict.keys() + if k in weight_dict + ) + + + scaler.scale(losses).backward() + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = { + f"{k}_unscaled": v for k, v in loss_dict_reduced.items() + } + loss_dict_reduced_scaled = { + k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() + if k in weight_dict + } + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print(loss_dict_reduced) + raise ValueError("Loss is {}, stopping training".format(loss_value)) + + if max_norm > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() + if ema_m is not None: + if epoch >= 0: + ema_m.update(model) + metric_logger.update( + loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def coco_extended_metrics(coco_eval): + """ + Safe version: ignores the –1 sentinel entries so precision/F1 never explode. + """ + + iou_thrs, rec_thrs = coco_eval.params.iouThrs, coco_eval.params.recThrs + iou50_idx, area_idx, maxdet_idx = ( + int(np.argwhere(np.isclose(iou_thrs, 0.50))), 0, 2) + + P = coco_eval.eval["precision"] + S = coco_eval.eval["scores"] + + prec_raw = P[iou50_idx, :, :, area_idx, maxdet_idx] + + prec = prec_raw.copy().astype(float) + prec[prec < 0] = np.nan + + f1_cls = 2 * prec * rec_thrs[:, None] / (prec + rec_thrs[:, None]) + f1_macro = np.nanmean(f1_cls, axis=1) + + best_j = int(f1_macro.argmax()) + + macro_precision = float(np.nanmean(prec[best_j])) + macro_recall = float(rec_thrs[best_j]) + macro_f1 = float(f1_macro[best_j]) + + score_vec = S[iou50_idx, best_j, :, area_idx, maxdet_idx].astype(float) + score_vec[prec_raw[best_j] < 0] = np.nan + score_thr = float(np.nanmean(score_vec)) + + map_50_95, map_50 = float(coco_eval.stats[0]), float(coco_eval.stats[1]) + + per_class = [] + cat_ids = coco_eval.params.catIds + cat_id_to_name = {c["id"]: c["name"] for c in coco_eval.cocoGt.loadCats(cat_ids)} + for k, cid in enumerate(cat_ids): + p_slice = P[:, :, k, area_idx, maxdet_idx] + valid = p_slice > -1 + ap_50_95 = float(p_slice[valid].mean()) if valid.any() else float("nan") + ap_50 = float(p_slice[iou50_idx][p_slice[iou50_idx] > -1].mean()) if (p_slice[iou50_idx] > -1).any() else float("nan") + + pc = float(prec[best_j, k]) if prec_raw[best_j, k] > -1 else float("nan") + rc = macro_recall + + #Doing to this to filter out dataset class + if np.isnan(ap_50_95) or np.isnan(ap_50) or np.isnan(pc) or np.isnan(rc): + continue + + per_class.append({ + "class" : cat_id_to_name[int(cid)], + "map@50:95" : ap_50_95, + "map@50" : ap_50, + "precision" : pc, + "recall" : rc, + }) + + per_class.append({ + "class" : "all", + "map@50:95" : map_50_95, + "map@50" : map_50, + "precision" : macro_precision, + "recall" : macro_recall, + }) + + return { + "class_map": per_class, + "map" : map_50, + "precision": macro_precision, + "recall" : macro_recall + } + +def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None): + model.eval() + if args.fp16_eval: + model.half() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter( + "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") + ) + header = "Test:" + + iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + + for samples, targets in metric_logger.log_every(data_loader, 10, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + if args.fp16_eval: + samples.tensors = samples.tensors.half() + + # Add autocast for evaluation + with autocast(**get_autocast_args(args)): + outputs = model(samples) + + if args.fp16_eval: + for key in outputs.keys(): + if key == "enc_outputs": + for sub_key in outputs[key].keys(): + outputs[key][sub_key] = outputs[key][sub_key].float() + elif key == "aux_outputs": + for idx in range(len(outputs[key])): + for sub_key in outputs[key][idx].keys(): + outputs[key][idx][sub_key] = outputs[key][idx][ + sub_key + ].float() + else: + outputs[key] = outputs[key].float() + + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = { + k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() + if k in weight_dict + } + loss_dict_reduced_unscaled = { + f"{k}_unscaled": v for k, v in loss_dict_reduced.items() + } + metric_logger.update( + loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled, + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors["bbox"](outputs, orig_target_sizes) + res = { + target["image_id"].item(): output + for target, output in zip(targets, results) + } + if coco_evaluator is not None: + coco_evaluator.update(res) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + results_json = coco_extended_metrics(coco_evaluator.coco_eval["bbox"]) + stats["results_json"] = results_json + if "bbox" in postprocessors.keys(): + stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() + + if "segm" in postprocessors.keys(): + stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() + return stats, coco_evaluator \ No newline at end of file diff --git a/rfdetr/main.py b/rfdetr/main.py new file mode 100644 index 0000000000000000000000000000000000000000..14fe89cd52ccd9989e945fe436b9da7d33c019aa --- /dev/null +++ b/rfdetr/main.py @@ -0,0 +1,1062 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +cleaned main file +""" +import argparse +import ast +import copy +import datetime +import json +import math +import os +import random +import shutil +import time +from copy import deepcopy +from logging import getLogger +from pathlib import Path +from typing import DefaultDict, List, Callable + +import numpy as np +import torch +from peft import LoraConfig, get_peft_model +from torch.utils.data import DataLoader, DistributedSampler + +import rfdetr.util.misc as utils +from rfdetr.datasets import build_dataset, get_coco_api_from_dataset +from rfdetr.engine import evaluate, train_one_epoch +from rfdetr.models import build_model, build_criterion_and_postprocessors +from rfdetr.util.benchmark import benchmark +from rfdetr.util.drop_scheduler import drop_scheduler +from rfdetr.util.files import download_file +from rfdetr.util.get_param_dicts import get_param_dict +from rfdetr.util.utils import ModelEma, BestMetricHolder, clean_state_dict + +if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]: + import torch.multiprocessing + torch.multiprocessing.set_sharing_strategy('file_system') + +logger = getLogger(__name__) + +HOSTED_MODELS = { + "rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth", + # below is a less converged model that may be better for finetuning but worse for inference + "rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth", + "rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth", + "rf-detr-nano.pth": "https://storage.googleapis.com/rfdetr/nano_coco/checkpoint_best_regular.pth", + "rf-detr-small.pth": "https://storage.googleapis.com/rfdetr/small_coco/checkpoint_best_regular.pth", + "rf-detr-medium.pth": "https://storage.googleapis.com/rfdetr/medium_coco/checkpoint_best_regular.pth", +} + +def download_pretrain_weights(pretrain_weights: str, redownload=False): + if pretrain_weights in HOSTED_MODELS: + if redownload or not os.path.exists(pretrain_weights): + logger.info( + f"Downloading pretrained weights for {pretrain_weights}" + ) + download_file( + HOSTED_MODELS[pretrain_weights], + pretrain_weights, + ) + +class Model: + def __init__(self, **kwargs): + args = populate_args(**kwargs) + self.args = args + self.resolution = args.resolution + self.model = build_model(args) + self.device = torch.device(args.device) + if args.pretrain_weights is not None: + print("Loading pretrain weights") + try: + checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False) + except Exception as e: + print(f"Failed to load pretrain weights: {e}") + # re-download weights if they are corrupted + print("Failed to load pretrain weights, re-downloading") + download_pretrain_weights(args.pretrain_weights, redownload=True) + checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False) + + # Extract class_names from checkpoint if available + if 'args' in checkpoint and hasattr(checkpoint['args'], 'class_names'): + self.args.class_names = checkpoint['args'].class_names + self.class_names = checkpoint['args'].class_names + + checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0] + if checkpoint_num_classes != args.num_classes + 1: + logger.warning( + f"num_classes mismatch: pretrain weights has {checkpoint_num_classes - 1} classes, but your model has {args.num_classes} classes\n" + f"reinitializing detection head with {checkpoint_num_classes - 1} classes" + ) + self.reinitialize_detection_head(checkpoint_num_classes) + # add support to exclude_keys + # e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]` + if args.pretrain_exclude_keys is not None: + assert isinstance(args.pretrain_exclude_keys, list) + for exclude_key in args.pretrain_exclude_keys: + checkpoint['model'].pop(exclude_key) + if args.pretrain_keys_modify_to_load is not None: + from util.obj365_to_coco_model import get_coco_pretrain_from_obj365 + assert isinstance(args.pretrain_keys_modify_to_load, list) + for modify_key_to_load in args.pretrain_keys_modify_to_load: + try: + checkpoint['model'][modify_key_to_load] = get_coco_pretrain_from_obj365( + model_without_ddp.state_dict()[modify_key_to_load], + checkpoint['model'][modify_key_to_load] + ) + except: + print(f"Failed to load {modify_key_to_load}, deleting from checkpoint") + checkpoint['model'].pop(modify_key_to_load) + + # we may want to resume training with a smaller number of groups for group detr + num_desired_queries = args.num_queries * args.group_detr + query_param_names = ["refpoint_embed.weight", "query_feat.weight"] + for name, state in checkpoint['model'].items(): + if any(name.endswith(x) for x in query_param_names): + checkpoint['model'][name] = state[:num_desired_queries] + + self.model.load_state_dict(checkpoint['model'], strict=False) + + if args.backbone_lora: + print("Applying LORA to backbone") + lora_config = LoraConfig( + r=16, + lora_alpha=16, + use_dora=True, + target_modules=[ + "q_proj", "v_proj", "k_proj", # covers OWL-ViT + "qkv", # covers open_clip ie Siglip2 + "query", "key", "value", "cls_token", "register_tokens", # covers Dinov2 with windowed attn + ] + ) + self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config) + self.model = self.model.to(self.device) + self.criterion, self.postprocessors = build_criterion_and_postprocessors(args) + self.stop_early = False + + def reinitialize_detection_head(self, num_classes): + self.model.reinitialize_detection_head(num_classes) + + def request_early_stop(self): + self.stop_early = True + print("Early stopping requested, will complete current epoch and stop") + + def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs): + currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"] + for key in callbacks.keys(): + if key not in currently_supported_callbacks: + raise ValueError( + f"Callback {key} is not currently supported, please file an issue if you need it!\n" + f"Currently supported callbacks: {currently_supported_callbacks}" + ) + args = populate_args(**kwargs) + if getattr(args, 'class_names') is not None: + self.args.class_names = args.class_names + self.args.num_classes = args.num_classes + + utils.init_distributed_mode(args) + print("git:\n {}\n".format(utils.get_sha())) + print(args) + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + criterion, postprocessors = build_criterion_and_postprocessors(args) + model = self.model + model.to(device) + + model_without_ddp = model + if args.distributed: + if args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + param_dicts = get_param_dict(args, model_without_ddp) + + param_dicts = [p for p in param_dicts if p['params'].requires_grad] + + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + # Choose the learning rate scheduler based on the new argument + + dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution) + dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution) + dataset_test = build_dataset(image_set='test', args=args, resolution=args.resolution) + + # for cosine annealing, calculate total training steps and warmup steps + total_batch_size_for_lr = args.batch_size * utils.get_world_size() * args.grad_accum_steps + num_training_steps_per_epoch_lr = (len(dataset_train) + total_batch_size_for_lr - 1) // total_batch_size_for_lr + total_training_steps_lr = num_training_steps_per_epoch_lr * args.epochs + warmup_steps_lr = num_training_steps_per_epoch_lr * args.warmup_epochs + def lr_lambda(current_step: int): + if current_step < warmup_steps_lr: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps_lr)) + else: + # Cosine annealing from multiplier 1.0 down to lr_min_factor + if args.lr_scheduler == 'cosine': + progress = float(current_step - warmup_steps_lr) / float(max(1, total_training_steps_lr - warmup_steps_lr)) + return args.lr_min_factor + (1 - args.lr_min_factor) * 0.5 * (1 + math.cos(math.pi * progress)) + elif args.lr_scheduler == 'step': + if current_step < args.lr_drop * num_training_steps_per_epoch_lr: + return 1.0 + else: + return 0.1 + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + if args.distributed: + sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) + sampler_test = DistributedSampler(dataset_test, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + effective_batch_size = args.batch_size * args.grad_accum_steps + min_batches = kwargs.get('min_batches', 5) + if len(dataset_train) < effective_batch_size * min_batches: + logger.info( + f"Training with uniform sampler because dataset is too small: {len(dataset_train)} < {effective_batch_size * min_batches}" + ) + sampler = torch.utils.data.RandomSampler( + dataset_train, + replacement=True, + num_samples=effective_batch_size * min_batches, + ) + data_loader_train = DataLoader( + dataset_train, + batch_size=effective_batch_size, + collate_fn=utils.collate_fn, + num_workers=args.num_workers, + sampler=sampler, + ) + else: + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, effective_batch_size, drop_last=True) + data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, + num_workers=args.num_workers + ) + + data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, + drop_last=False, collate_fn=utils.collate_fn, + num_workers=args.num_workers) + data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test, + drop_last=False, collate_fn=utils.collate_fn, + num_workers=args.num_workers) + + base_ds = get_coco_api_from_dataset(dataset_val) + base_ds_test = get_coco_api_from_dataset(dataset_test) + if args.use_ema: + self.ema_m = ModelEma(model_without_ddp, decay=args.ema_decay, tau=args.ema_tau) + else: + self.ema_m = None + + + output_dir = Path(args.output_dir) + + if utils.is_main_process(): + print("Get benchmark") + if args.do_benchmark: + benchmark_model = copy.deepcopy(model_without_ddp) + bm = benchmark(benchmark_model.float(), dataset_val, output_dir) + print(json.dumps(bm, indent=2)) + del benchmark_model + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False) + model_without_ddp.load_state_dict(checkpoint['model'], strict=True) + if args.use_ema: + if 'ema_model' in checkpoint: + self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model'])) + else: + del self.ema_m + self.ema_m = ModelEma(model, decay=args.ema_decay, tau=args.ema_tau) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.eval: + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args) + if args.output_dir: + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + # for drop + total_batch_size = effective_batch_size * utils.get_world_size() + num_training_steps_per_epoch = (len(dataset_train) + total_batch_size - 1) // total_batch_size + schedules = {} + if args.dropout > 0: + schedules['do'] = drop_scheduler( + args.dropout, args.epochs, num_training_steps_per_epoch, + args.cutoff_epoch, args.drop_mode, args.drop_schedule) + print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do']))) + + if args.drop_path > 0: + schedules['dp'] = drop_scheduler( + args.drop_path, args.epochs, num_training_steps_per_epoch, + args.cutoff_epoch, args.drop_mode, args.drop_schedule) + print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp']))) + + print("Start training") + start_time = time.time() + best_map_holder = BestMetricHolder(use_ema=args.use_ema) + best_map_5095 = 0 + best_map_50 = 0 + best_map_ema_5095 = 0 + best_map_ema_50 = 0 + for epoch in range(args.start_epoch, args.epochs): + epoch_start_time = time.time() + if args.distributed: + sampler_train.set_epoch(epoch) + + model.train() + criterion.train() + train_stats = train_one_epoch( + model, criterion, lr_scheduler, data_loader_train, optimizer, device, epoch, + effective_batch_size, args.clip_max_norm, ema_m=self.ema_m, schedules=schedules, + num_training_steps_per_epoch=num_training_steps_per_epoch, + vit_encoder_num_layers=args.vit_encoder_num_layers, args=args, callbacks=callbacks) + train_epoch_time = time.time() - epoch_start_time + train_epoch_time_str = str(datetime.timedelta(seconds=int(train_epoch_time))) + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + # extra checkpoint before LR drop and every `checkpoint_interval` epochs + if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % args.checkpoint_interval == 0: + checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') + for checkpoint_path in checkpoint_paths: + weights = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args, + } + if args.use_ema: + weights.update({ + 'ema_model': self.ema_m.module.state_dict(), + }) + if not args.dont_save_weights: + # create checkpoint dir + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + + utils.save_on_master(weights, checkpoint_path) + + with torch.inference_mode(): + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args=args + ) + map_regular = test_stats["coco_eval_bbox"][0] + _isbest = best_map_holder.update(map_regular, epoch, is_ema=False) + if _isbest: + best_map_5095 = max(best_map_5095, map_regular) + best_map_50 = max(best_map_50, test_stats["coco_eval_bbox"][1]) + checkpoint_path = output_dir / 'checkpoint_best_regular.pth' + if not args.dont_save_weights: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args, + }, checkpoint_path) + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + if args.use_ema: + ema_test_stats, _ = evaluate( + self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args + ) + log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()}) + map_ema = ema_test_stats["coco_eval_bbox"][0] + best_map_ema_5095 = max(best_map_ema_5095, map_ema) + _isbest = best_map_holder.update(map_ema, epoch, is_ema=True) + if _isbest: + best_map_ema_50 = max(best_map_ema_50, ema_test_stats["coco_eval_bbox"][1]) + checkpoint_path = output_dir / 'checkpoint_best_ema.pth' + if not args.dont_save_weights: + utils.save_on_master({ + 'model': self.ema_m.module.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args, + }, checkpoint_path) + log_stats.update(best_map_holder.summary()) + + # epoch parameters + ep_paras = { + 'epoch': epoch, + 'n_parameters': n_parameters + } + log_stats.update(ep_paras) + try: + log_stats.update({'now_time': str(datetime.datetime.now())}) + except: + pass + log_stats['train_epoch_time'] = train_epoch_time_str + epoch_time = time.time() - epoch_start_time + epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) + log_stats['epoch_time'] = epoch_time_str + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (output_dir / 'eval').mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ['latest.pth'] + if epoch % 50 == 0: + filenames.append(f'{epoch:03}.pth') + for name in filenames: + torch.save(coco_evaluator.coco_eval["bbox"].eval, + output_dir / "eval" / name) + + for callback in callbacks["on_fit_epoch_end"]: + callback(log_stats) + + if self.stop_early: + print(f"Early stopping requested, stopping at epoch {epoch}") + break + + best_is_ema = best_map_ema_5095 > best_map_5095 + + if utils.is_main_process(): + if best_is_ema: + shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth') + else: + shutil.copy2(output_dir / 'checkpoint_best_regular.pth', output_dir / 'checkpoint_best_total.pth') + + utils.strip_checkpoint(output_dir / 'checkpoint_best_total.pth') + + best_map_5095 = max(best_map_5095, best_map_ema_5095) + if best_is_ema: + results = ema_test_stats["results_json"] + else: + results = test_stats["results_json"] + + class_map = results["class_map"] + results["class_map"] = {"valid": class_map} + with open(output_dir / "results.json", "w") as f: + json.dump(results, f) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + print('Results saved to {}'.format(output_dir / "results.json")) + + + if best_is_ema: + self.model = self.ema_m.module + self.model.eval() + + + if args.run_test: + best_state_dict = torch.load(output_dir / 'checkpoint_best_total.pth', map_location='cpu', weights_only=False)['model'] + model.load_state_dict(best_state_dict) + model.eval() + + test_stats, _ = evaluate( + model, criterion, postprocessors, data_loader_test, base_ds_test, device, args=args + ) + print(f"Test results: {test_stats}") + with open(output_dir / "results.json", "r") as f: + results = json.load(f) + test_metrics = test_stats["results_json"]["class_map"] + results["class_map"]["test"] = test_metrics + with open(output_dir / "results.json", "w") as f: + json.dump(results, f) + + for callback in callbacks["on_train_end"]: + callback() + + def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs): + """Export the trained model to ONNX format""" + print(f"Exporting model to ONNX format") + try: + from rfdetr.deploy.export import export_onnx, onnx_simplify, make_infer_image + except ImportError: + print("It seems some dependencies for ONNX export are missing. Please run `pip install rfdetr[onnxexport]` and try again.") + raise + + + device = self.device + model = deepcopy(self.model.to("cpu")) + model.to(device) + + os.makedirs(output_dir, exist_ok=True) + output_dir = Path(output_dir) + if shape is None: + shape = (self.resolution, self.resolution) + else: + if shape[0] % 14 != 0 or shape[1] % 14 != 0: + raise ValueError("Shape must be divisible by 14") + + input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device) + input_names = ['input'] + output_names = ['features'] if backbone_only else ['dets', 'labels'] + dynamic_axes = None + self.model.eval() + with torch.no_grad(): + if backbone_only: + features = model(input_tensors) + print(f"PyTorch inference output shape: {features.shape}") + else: + outputs = model(input_tensors) + dets = outputs['pred_boxes'] + labels = outputs['pred_logits'] + print(f"PyTorch inference output shapes - Boxes: {dets.shape}, Labels: {labels.shape}") + model.cpu() + input_tensors = input_tensors.cpu() + + # Export to ONNX + output_file = export_onnx( + output_dir=output_dir, + model=model, + input_names=input_names, + input_tensors=input_tensors, + output_names=output_names, + dynamic_axes=dynamic_axes, + backbone_only=backbone_only, + verbose=verbose, + opset_version=opset_version + ) + + print(f"Successfully exported ONNX model to: {output_file}") + + if simplify: + sim_output_file = onnx_simplify( + onnx_dir=output_file, + input_names=input_names, + input_tensors=input_tensors, + force=force + ) + print(f"Successfully simplified ONNX model to: {sim_output_file}") + + print("ONNX export completed successfully") + self.model = self.model.to(device) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('LWDETR training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + config = vars(args) # Convert Namespace to dictionary + + if args.subcommand == 'distill': + distill(**config) + elif args.subcommand is None: + main(**config) + elif args.subcommand == 'export_model': + filter_keys = [ + "num_classes", + "grad_accum_steps", + "lr", + "lr_encoder", + "weight_decay", + "epochs", + "lr_drop", + "clip_max_norm", + "lr_vit_layer_decay", + "lr_component_decay", + "dropout", + "drop_path", + "drop_mode", + "drop_schedule", + "cutoff_epoch", + "pretrained_encoder", + "pretrain_weights", + "pretrain_exclude_keys", + "pretrain_keys_modify_to_load", + "freeze_florence", + "freeze_aimv2", + "decoder_norm", + "set_cost_class", + "set_cost_bbox", + "set_cost_giou", + "cls_loss_coef", + "bbox_loss_coef", + "giou_loss_coef", + "focal_alpha", + "aux_loss", + "sum_group_losses", + "use_varifocal_loss", + "use_position_supervised_loss", + "ia_bce_loss", + "dataset_file", + "coco_path", + "dataset_dir", + "square_resize_div_64", + "output_dir", + "checkpoint_interval", + "seed", + "resume", + "start_epoch", + "eval", + "use_ema", + "ema_decay", + "ema_tau", + "num_workers", + "device", + "world_size", + "dist_url", + "sync_bn", + "fp16_eval", + "infer_dir", + "verbose", + "opset_version", + "dry_run", + "shape", + ] + for key in filter_keys: + config.pop(key, None) # Use pop with None to avoid KeyError + + from deploy.export import main as export_main + if args.batch_size != 1: + config['batch_size'] = 1 + print(f"Only batch_size 1 is supported for onnx export, \ + but got batchsize = {args.batch_size}. batch_size is forcibly set to 1.") + export_main(**config) + +def get_args_parser(): + parser = argparse.ArgumentParser('Set transformer detector', add_help=False) + parser.add_argument('--num_classes', default=2, type=int) + parser.add_argument('--grad_accum_steps', default=1, type=int) + parser.add_argument('--amp', default=False, type=bool) + parser.add_argument('--lr', default=1e-4, type=float) + parser.add_argument('--lr_encoder', default=1.5e-4, type=float) + parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--epochs', default=12, type=int) + parser.add_argument('--lr_drop', default=11, type=int) + parser.add_argument('--clip_max_norm', default=0.1, type=float, + help='gradient clipping max norm') + parser.add_argument('--lr_vit_layer_decay', default=0.8, type=float) + parser.add_argument('--lr_component_decay', default=1.0, type=float) + parser.add_argument('--do_benchmark', action='store_true', help='benchmark the model') + + # drop args + # dropout and stochastic depth drop rate; set at most one to non-zero + parser.add_argument('--dropout', type=float, default=0, + help='Drop path rate (default: 0.0)') + parser.add_argument('--drop_path', type=float, default=0, + help='Drop path rate (default: 0.0)') + + # early / late dropout and stochastic depth settings + parser.add_argument('--drop_mode', type=str, default='standard', + choices=['standard', 'early', 'late'], help='drop mode') + parser.add_argument('--drop_schedule', type=str, default='constant', + choices=['constant', 'linear'], + help='drop schedule for early dropout / s.d. only') + parser.add_argument('--cutoff_epoch', type=int, default=0, + help='if drop_mode is early / late, this is the epoch where dropout ends / starts') + + # Model parameters + parser.add_argument('--pretrained_encoder', type=str, default=None, + help="Path to the pretrained encoder.") + parser.add_argument('--pretrain_weights', type=str, default=None, + help="Path to the pretrained model.") + parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+', + help="Keys you do not want to load.") + parser.add_argument('--pretrain_keys_modify_to_load', type=str, default=None, nargs='+', + help="Keys you want to modify to load. Only used when loading objects365 pre-trained weights.") + + # * Backbone + parser.add_argument('--encoder', default='vit_tiny', type=str, + help="Name of the transformer or convolutional encoder to use") + parser.add_argument('--vit_encoder_num_layers', default=12, type=int, + help="Number of layers used in ViT encoder") + parser.add_argument('--window_block_indexes', default=None, type=int, nargs='+') + parser.add_argument('--position_embedding', default='sine', type=str, + choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + parser.add_argument('--out_feature_indexes', default=[-1], type=int, nargs='+', help='only for vit now') + parser.add_argument("--freeze_encoder", action="store_true", dest="freeze_encoder") + parser.add_argument("--layer_norm", action="store_true", dest="layer_norm") + parser.add_argument("--rms_norm", action="store_true", dest="rms_norm") + parser.add_argument("--backbone_lora", action="store_true", dest="backbone_lora") + parser.add_argument("--force_no_pretrain", action="store_true", dest="force_no_pretrain") + + # * Transformer + parser.add_argument('--dec_layers', default=3, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=2048, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--sa_nheads', default=8, type=int, + help="Number of attention heads inside the transformer's self-attentions") + parser.add_argument('--ca_nheads', default=8, type=int, + help="Number of attention heads inside the transformer's cross-attentions") + parser.add_argument('--num_queries', default=300, type=int, + help="Number of query slots") + parser.add_argument('--group_detr', default=13, type=int, + help="Number of groups to speed up detr training") + parser.add_argument('--two_stage', action='store_true') + parser.add_argument('--projector_scale', default='P4', type=str, nargs='+', choices=('P3', 'P4', 'P5', 'P6')) + parser.add_argument('--lite_refpoint_refine', action='store_true', help='lite refpoint refine mode for speed-up') + parser.add_argument('--num_select', default=100, type=int, + help='the number of predictions selected for evaluation') + parser.add_argument('--dec_n_points', default=4, type=int, + help='the number of sampling points') + parser.add_argument('--decoder_norm', default='LN', type=str) + parser.add_argument('--bbox_reparam', action='store_true') + parser.add_argument('--freeze_batch_norm', action='store_true') + # * Matcher + parser.add_argument('--set_cost_class', default=2, type=float, + help="Class coefficient in the matching cost") + parser.add_argument('--set_cost_bbox', default=5, type=float, + help="L1 box coefficient in the matching cost") + parser.add_argument('--set_cost_giou', default=2, type=float, + help="giou box coefficient in the matching cost") + + # * Loss coefficients + parser.add_argument('--cls_loss_coef', default=2, type=float) + parser.add_argument('--bbox_loss_coef', default=5, type=float) + parser.add_argument('--giou_loss_coef', default=2, type=float) + parser.add_argument('--focal_alpha', default=0.25, type=float) + + # Loss + parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', + help="Disables auxiliary decoding losses (loss at each layer)") + parser.add_argument('--sum_group_losses', action='store_true', + help="To sum losses across groups or mean losses.") + parser.add_argument('--use_varifocal_loss', action='store_true') + parser.add_argument('--use_position_supervised_loss', action='store_true') + parser.add_argument('--ia_bce_loss', action='store_true') + + # dataset parameters + parser.add_argument('--dataset_file', default='coco') + parser.add_argument('--coco_path', type=str) + parser.add_argument('--dataset_dir', type=str) + parser.add_argument('--square_resize_div_64', action='store_true') + + parser.add_argument('--output_dir', default='output', + help='path where to save, empty for no saving') + parser.add_argument('--dont_save_weights', action='store_true') + parser.add_argument('--checkpoint_interval', default=10, type=int, + help='epoch interval to save checkpoint') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true') + parser.add_argument('--use_ema', action='store_true') + parser.add_argument('--ema_decay', default=0.9997, type=float) + parser.add_argument('--ema_tau', default=0, type=float) + + parser.add_argument('--num_workers', default=2, type=int) + + # distributed training parameters + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + parser.add_argument('--sync_bn', default=True, type=bool, + help='setup synchronized BatchNorm for distributed training') + + # fp16 + parser.add_argument('--fp16_eval', default=False, action='store_true', + help='evaluate in fp16 precision.') + + # custom args + parser.add_argument('--encoder_only', action='store_true', help='Export and benchmark encoder only') + parser.add_argument('--backbone_only', action='store_true', help='Export and benchmark backbone only') + parser.add_argument('--resolution', type=int, default=640, help="input resolution") + parser.add_argument('--use_cls_token', action='store_true', help='use cls token') + parser.add_argument('--multi_scale', action='store_true', help='use multi scale') + parser.add_argument('--expanded_scales', action='store_true', help='use expanded scales') + parser.add_argument('--do_random_resize_via_padding', action='store_true', help='use random resize via padding') + parser.add_argument('--warmup_epochs', default=1, type=float, + help='Number of warmup epochs for linear warmup before cosine annealing') + # Add scheduler type argument: 'step' or 'cosine' + parser.add_argument( + '--lr_scheduler', + default='step', + choices=['step', 'cosine'], + help="Type of learning rate scheduler to use: 'step' (default) or 'cosine'" + ) + parser.add_argument('--lr_min_factor', default=0.0, type=float, + help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing') + # Early stopping parameters + parser.add_argument('--early_stopping', action='store_true', + help='Enable early stopping based on mAP improvement') + parser.add_argument('--early_stopping_patience', default=10, type=int, + help='Number of epochs with no improvement after which training will be stopped') + parser.add_argument('--early_stopping_min_delta', default=0.001, type=float, + help='Minimum change in mAP to qualify as an improvement') + parser.add_argument('--early_stopping_use_ema', action='store_true', + help='Use EMA model metrics for early stopping') + # subparsers + subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand', + description='valid subcommands', help='additional help') + + # subparser for export model + parser_export = subparsers.add_parser('export_model', help='LWDETR model export') + parser_export.add_argument('--infer_dir', type=str, default=None) + parser_export.add_argument('--verbose', type=ast.literal_eval, default=False, nargs="?", const=True) + parser_export.add_argument('--opset_version', type=int, default=17) + parser_export.add_argument('--simplify', action='store_true', help="Simplify onnx model") + parser_export.add_argument('--tensorrt', '--trtexec', '--trt', action='store_true', + help="build tensorrt engine") + parser_export.add_argument('--dry-run', '--test', '-t', action='store_true', help="just print command") + parser_export.add_argument('--profile', action='store_true', help='Run nsys profiling during TensorRT export') + parser_export.add_argument('--shape', type=int, nargs=2, default=(640, 640), help="input shape (width, height)") + return parser + +def populate_args( + # Basic training parameters + num_classes=2, + grad_accum_steps=1, + amp=False, + lr=1e-4, + lr_encoder=1.5e-4, + batch_size=2, + weight_decay=1e-4, + epochs=12, + lr_drop=11, + clip_max_norm=0.1, + lr_vit_layer_decay=0.8, + lr_component_decay=1.0, + do_benchmark=False, + + # Drop parameters + dropout=0, + drop_path=0, + drop_mode='standard', + drop_schedule='constant', + cutoff_epoch=0, + + # Model parameters + pretrained_encoder=None, + pretrain_weights=None, + pretrain_exclude_keys=None, + pretrain_keys_modify_to_load=None, + pretrained_distiller=None, + + # Backbone parameters + encoder='vit_tiny', + vit_encoder_num_layers=12, + window_block_indexes=None, + position_embedding='sine', + out_feature_indexes=[-1], + freeze_encoder=False, + layer_norm=False, + rms_norm=False, + backbone_lora=False, + force_no_pretrain=False, + + # Transformer parameters + dec_layers=3, + dim_feedforward=2048, + hidden_dim=256, + sa_nheads=8, + ca_nheads=8, + num_queries=300, + group_detr=13, + two_stage=False, + projector_scale='P4', + lite_refpoint_refine=False, + num_select=100, + dec_n_points=4, + decoder_norm='LN', + bbox_reparam=False, + freeze_batch_norm=False, + + # Matcher parameters + set_cost_class=2, + set_cost_bbox=5, + set_cost_giou=2, + + # Loss coefficients + cls_loss_coef=2, + bbox_loss_coef=5, + giou_loss_coef=2, + focal_alpha=0.25, + aux_loss=True, + sum_group_losses=False, + use_varifocal_loss=False, + use_position_supervised_loss=False, + ia_bce_loss=False, + + # Dataset parameters + dataset_file='coco', + coco_path=None, + dataset_dir=None, + square_resize_div_64=False, + + # Output parameters + output_dir='output', + dont_save_weights=False, + checkpoint_interval=10, + seed=42, + resume='', + start_epoch=0, + eval=False, + use_ema=False, + ema_decay=0.9997, + ema_tau=0, + num_workers=2, + + # Distributed training parameters + device='cuda', + world_size=1, + dist_url='env://', + sync_bn=True, + + # FP16 + fp16_eval=False, + + # Custom args + encoder_only=False, + backbone_only=False, + resolution=640, + use_cls_token=False, + multi_scale=False, + expanded_scales=False, + do_random_resize_via_padding=False, + warmup_epochs=1, + lr_scheduler='step', + lr_min_factor=0.0, + # Early stopping parameters + early_stopping=True, + early_stopping_patience=10, + early_stopping_min_delta=0.001, + early_stopping_use_ema=False, + gradient_checkpointing=False, + # Additional + subcommand=None, + **extra_kwargs # To handle any unexpected arguments +): + args = argparse.Namespace( + num_classes=num_classes, + grad_accum_steps=grad_accum_steps, + amp=amp, + lr=lr, + lr_encoder=lr_encoder, + batch_size=batch_size, + weight_decay=weight_decay, + epochs=epochs, + lr_drop=lr_drop, + clip_max_norm=clip_max_norm, + lr_vit_layer_decay=lr_vit_layer_decay, + lr_component_decay=lr_component_decay, + do_benchmark=do_benchmark, + dropout=dropout, + drop_path=drop_path, + drop_mode=drop_mode, + drop_schedule=drop_schedule, + cutoff_epoch=cutoff_epoch, + pretrained_encoder=pretrained_encoder, + pretrain_weights=pretrain_weights, + pretrain_exclude_keys=pretrain_exclude_keys, + pretrain_keys_modify_to_load=pretrain_keys_modify_to_load, + pretrained_distiller=pretrained_distiller, + encoder=encoder, + vit_encoder_num_layers=vit_encoder_num_layers, + window_block_indexes=window_block_indexes, + position_embedding=position_embedding, + out_feature_indexes=out_feature_indexes, + freeze_encoder=freeze_encoder, + layer_norm=layer_norm, + rms_norm=rms_norm, + backbone_lora=backbone_lora, + force_no_pretrain=force_no_pretrain, + dec_layers=dec_layers, + dim_feedforward=dim_feedforward, + hidden_dim=hidden_dim, + sa_nheads=sa_nheads, + ca_nheads=ca_nheads, + num_queries=num_queries, + group_detr=group_detr, + two_stage=two_stage, + projector_scale=projector_scale, + lite_refpoint_refine=lite_refpoint_refine, + num_select=num_select, + dec_n_points=dec_n_points, + decoder_norm=decoder_norm, + bbox_reparam=bbox_reparam, + freeze_batch_norm=freeze_batch_norm, + set_cost_class=set_cost_class, + set_cost_bbox=set_cost_bbox, + set_cost_giou=set_cost_giou, + cls_loss_coef=cls_loss_coef, + bbox_loss_coef=bbox_loss_coef, + giou_loss_coef=giou_loss_coef, + focal_alpha=focal_alpha, + aux_loss=aux_loss, + sum_group_losses=sum_group_losses, + use_varifocal_loss=use_varifocal_loss, + use_position_supervised_loss=use_position_supervised_loss, + ia_bce_loss=ia_bce_loss, + dataset_file=dataset_file, + coco_path=coco_path, + dataset_dir=dataset_dir, + square_resize_div_64=square_resize_div_64, + output_dir=output_dir, + dont_save_weights=dont_save_weights, + checkpoint_interval=checkpoint_interval, + seed=seed, + resume=resume, + start_epoch=start_epoch, + eval=eval, + use_ema=use_ema, + ema_decay=ema_decay, + ema_tau=ema_tau, + num_workers=num_workers, + device=device, + world_size=world_size, + dist_url=dist_url, + sync_bn=sync_bn, + fp16_eval=fp16_eval, + encoder_only=encoder_only, + backbone_only=backbone_only, + resolution=resolution, + use_cls_token=use_cls_token, + multi_scale=multi_scale, + expanded_scales=expanded_scales, + do_random_resize_via_padding=do_random_resize_via_padding, + warmup_epochs=warmup_epochs, + lr_scheduler=lr_scheduler, + lr_min_factor=lr_min_factor, + early_stopping=early_stopping, + early_stopping_patience=early_stopping_patience, + early_stopping_min_delta=early_stopping_min_delta, + early_stopping_use_ema=early_stopping_use_ema, + gradient_checkpointing=gradient_checkpointing, + **extra_kwargs + ) + return args \ No newline at end of file diff --git a/rfdetr/models/__init__.py b/rfdetr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba018eb63e1d04e9d83187dea9f07c07c507ba04 --- /dev/null +++ b/rfdetr/models/__init__.py @@ -0,0 +1,16 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +from .lwdetr import build_model, build_criterion_and_postprocessors diff --git a/rfdetr/models/__pycache__/__init__.cpython-313.pyc b/rfdetr/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18bc260bf0a37920e6281399c00932da82719410 Binary files /dev/null and b/rfdetr/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/models/__pycache__/lwdetr.cpython-313.pyc b/rfdetr/models/__pycache__/lwdetr.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81c05a1d7ae0baa1e9c033def00405c180e0836c Binary files /dev/null and b/rfdetr/models/__pycache__/lwdetr.cpython-313.pyc differ diff --git a/rfdetr/models/__pycache__/matcher.cpython-313.pyc b/rfdetr/models/__pycache__/matcher.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77a8f4a5feed3db9812a1a3af9a822626c3864fa Binary files /dev/null and b/rfdetr/models/__pycache__/matcher.cpython-313.pyc differ diff --git a/rfdetr/models/__pycache__/position_encoding.cpython-313.pyc b/rfdetr/models/__pycache__/position_encoding.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a20a08260195d0ca083fbbc97b12999aa6ecb8a Binary files /dev/null and b/rfdetr/models/__pycache__/position_encoding.cpython-313.pyc differ diff --git a/rfdetr/models/__pycache__/transformer.cpython-313.pyc b/rfdetr/models/__pycache__/transformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b334f3d56f44739a16a6c2cf35ef4e8e1437523 Binary files /dev/null and b/rfdetr/models/__pycache__/transformer.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__init__.py b/rfdetr/models/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01f7fd8fdebb3ef05070eff806c0882498bb70d4 --- /dev/null +++ b/rfdetr/models/backbone/__init__.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +from typing import Dict, List + +import torch +from torch import nn + +from rfdetr.util.misc import NestedTensor +from rfdetr.models.position_encoding import build_position_encoding +from rfdetr.models.backbone.backbone import * +from typing import Callable + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self._export = False + + def forward(self, tensor_list: NestedTensor): + """ """ + x = self[0](tensor_list) + pos = [] + for x_ in x: + pos.append(self[1](x_, align_dim_orders=False).to(x_.tensors.dtype)) + return x, pos + + def export(self): + self._export = True + self._forward_origin = self.forward + self.forward = self.forward_export + for name, m in self.named_modules(): + if ( + hasattr(m, "export") + and isinstance(m.export, Callable) + and hasattr(m, "_export") + and not m._export + ): + m.export() + + def forward_export(self, inputs: torch.Tensor): + feats, masks = self[0](inputs) + poss = [] + for feat, mask in zip(feats, masks): + poss.append(self[1](mask, align_dim_orders=False).to(feat.dtype)) + return feats, None, poss + + +def build_backbone( + encoder, + vit_encoder_num_layers, + pretrained_encoder, + window_block_indexes, + drop_path, + out_channels, + out_feature_indexes, + projector_scale, + use_cls_token, + hidden_dim, + position_embedding, + freeze_encoder, + layer_norm, + target_shape, + rms_norm, + backbone_lora, + force_no_pretrain, + gradient_checkpointing, + load_dinov2_weights, + patch_size, + num_windows, + positional_encoding_size, +): + """ + Useful args: + - encoder: encoder name + - lr_encoder: + - dilation + - use_checkpoint: for swin only for now + + """ + position_embedding = build_position_encoding(hidden_dim, position_embedding) + + backbone = Backbone( + encoder, + pretrained_encoder, + window_block_indexes=window_block_indexes, + drop_path=drop_path, + out_channels=out_channels, + out_feature_indexes=out_feature_indexes, + projector_scale=projector_scale, + use_cls_token=use_cls_token, + layer_norm=layer_norm, + freeze_encoder=freeze_encoder, + target_shape=target_shape, + rms_norm=rms_norm, + backbone_lora=backbone_lora, + gradient_checkpointing=gradient_checkpointing, + load_dinov2_weights=load_dinov2_weights, + patch_size=patch_size, + num_windows=num_windows, + positional_encoding_size=positional_encoding_size, + ) + + model = Joiner(backbone, position_embedding) + return model diff --git a/rfdetr/models/backbone/__pycache__/__init__.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae7250c104e4d52cb5fe66de163f0eed1ce8dc1 Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__pycache__/backbone.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/backbone.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..514ed669a31ce38acc876a1a0704088fb8b9b97a Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/backbone.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__pycache__/base.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62d9dfe13273acf1a9e2a2e357f62b72aef21f64 Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/base.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__pycache__/dinov2.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/dinov2.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78dabda4d48f134ac1ee99f62ac81062756971d2 Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/dinov2.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__pycache__/dinov2_with_windowed_attn.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/dinov2_with_windowed_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c99b769289df404702c1d7b959730a58df927536 Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/dinov2_with_windowed_attn.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/__pycache__/projector.cpython-313.pyc b/rfdetr/models/backbone/__pycache__/projector.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b9d57fcbe122d141de13e45c80e3e3517faae0b Binary files /dev/null and b/rfdetr/models/backbone/__pycache__/projector.cpython-313.pyc differ diff --git a/rfdetr/models/backbone/backbone.py b/rfdetr/models/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..94b1f774c4945608fe1bfc63b0ce1a206f6538c7 --- /dev/null +++ b/rfdetr/models/backbone/backbone.py @@ -0,0 +1,205 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from functools import partial +import torch +import torch.nn.functional as F +from torch import nn + +from transformers import AutoModel, AutoProcessor, AutoModelForCausalLM, AutoConfig, AutoBackbone +from peft import LoraConfig, get_peft_model, PeftModel + +from rfdetr.util.misc import NestedTensor, is_main_process + +from rfdetr.models.backbone.base import BackboneBase +from rfdetr.models.backbone.projector import MultiScaleProjector +from rfdetr.models.backbone.dinov2 import DinoV2 + +__all__ = ["Backbone"] + + +class Backbone(BackboneBase): + """backbone.""" + def __init__(self, + name: str, + pretrained_encoder: str=None, + window_block_indexes: list=None, + drop_path=0.0, + out_channels=256, + out_feature_indexes: list=None, + projector_scale: list=None, + use_cls_token: bool = False, + freeze_encoder: bool = False, + layer_norm: bool = False, + target_shape: tuple[int, int] = (640, 640), + rms_norm: bool = False, + backbone_lora: bool = False, + gradient_checkpointing: bool = False, + load_dinov2_weights: bool = True, + patch_size: int = 14, + num_windows: int = 4, + positional_encoding_size: bool = False, + ): + super().__init__() + # an example name here would be "dinov2_base" or "dinov2_registers_windowed_base" + # if "registers" is in the name, then use_registers is set to True, otherwise it is set to False + # similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False + # the last part of the name should be the size + # and the start should be dinov2 + name_parts = name.split("_") + assert name_parts[0] == "dinov2" + size = name_parts[-1] + use_registers = False + if "registers" in name_parts: + use_registers = True + name_parts.remove("registers") + use_windowed_attn = False + if "windowed" in name_parts: + use_windowed_attn = True + name_parts.remove("windowed") + assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size" + self.encoder = DinoV2( + size=name_parts[-1], + out_feature_indexes=out_feature_indexes, + shape=target_shape, + use_registers=use_registers, + use_windowed_attn=use_windowed_attn, + gradient_checkpointing=gradient_checkpointing, + load_dinov2_weights=load_dinov2_weights, + patch_size=patch_size, + num_windows=num_windows, + positional_encoding_size=positional_encoding_size, + ) + # build encoder + projector as backbone module + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + self.projector_scale = projector_scale + assert len(self.projector_scale) > 0 + # x[0] + assert ( + sorted(self.projector_scale) == self.projector_scale + ), "only support projector scale P3/P4/P5/P6 in ascending order." + level2scalefactor = dict(P3=2.0, P4=1.0, P5=0.5, P6=0.25) + scale_factors = [level2scalefactor[lvl] for lvl in self.projector_scale] + + self.projector = MultiScaleProjector( + in_channels=self.encoder._out_feature_channels, + out_channels=out_channels, + scale_factors=scale_factors, + layer_norm=layer_norm, + rms_norm=rms_norm, + ) + + self._export = False + + def export(self): + self._export = True + self._forward_origin = self.forward + self.forward = self.forward_export + + if isinstance(self.encoder, PeftModel): + print("Merging and unloading LoRA weights") + self.encoder.merge_and_unload() + + def forward(self, tensor_list: NestedTensor): + """ """ + # (H, W, B, C) + feats = self.encoder(tensor_list.tensors) + feats = self.projector(feats) + # x: [(B, C, H, W)] + out = [] + for feat in feats: + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=feat.shape[-2:]).to(torch.bool)[ + 0 + ] + out.append(NestedTensor(feat, mask)) + return out + + def forward_export(self, tensors: torch.Tensor): + feats = self.encoder(tensors) + feats = self.projector(feats) + out_feats = [] + out_masks = [] + for feat in feats: + # x: [(B, C, H, W)] + b, _, h, w = feat.shape + out_masks.append( + torch.zeros((b, h, w), dtype=torch.bool, device=feat.device) + ) + out_feats.append(feat) + return out_feats, out_masks + + def get_named_param_lr_pairs(self, args, prefix: str = "backbone.0"): + num_layers = args.out_feature_indexes[-1] + 1 + backbone_key = "backbone.0.encoder" + named_param_lr_pairs = {} + for n, p in self.named_parameters(): + n = prefix + "." + n + if backbone_key in n and p.requires_grad: + lr = ( + args.lr_encoder + * get_dinov2_lr_decay_rate( + n, + lr_decay_rate=args.lr_vit_layer_decay, + num_layers=num_layers, + ) + * args.lr_component_decay**2 + ) + wd = args.weight_decay * get_dinov2_weight_decay_rate(n) + named_param_lr_pairs[n] = { + "params": p, + "lr": lr, + "weight_decay": wd, + } + return named_param_lr_pairs + + +def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if "embeddings" in name: + layer_id = 0 + elif ".layer." in name and ".residual." not in name: + layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1 + return lr_decay_rate ** (num_layers + 1 - layer_id) + +def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0): + if ( + ("gamma" in name) + or ("pos_embed" in name) + or ("rel_pos" in name) + or ("bias" in name) + or ("norm" in name) + or ("embeddings" in name) + ): + weight_decay_rate = 0.0 + return weight_decay_rate diff --git a/rfdetr/models/backbone/base.py b/rfdetr/models/backbone/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd6dd5f656f66a45bbcb4b8149a3a735239241c --- /dev/null +++ b/rfdetr/models/backbone/base.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ + +import torch +import torch.nn.functional as F +from torch import nn + + +class BackboneBase(nn.Module): + def __init__(self): + super().__init__() + + def get_named_param_lr_pairs(self, args, prefix:str): + raise NotImplementedError diff --git a/rfdetr/models/backbone/dinov2.py b/rfdetr/models/backbone/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..bb762b5c299be6e02f04fc986f38f6b9021cad7a --- /dev/null +++ b/rfdetr/models/backbone/dinov2.py @@ -0,0 +1,197 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch +import torch.nn as nn +from transformers import AutoBackbone +import torch.nn.functional as F +import types +import math +import json +import os + +from .dinov2_with_windowed_attn import WindowedDinov2WithRegistersConfig, WindowedDinov2WithRegistersBackbone + + +size_to_width = { + "tiny": 192, + "small": 384, + "base": 768, + "large": 1024, +} + +size_to_config = { + "small": "dinov2_small.json", + "base": "dinov2_base.json", + "large": "dinov2_large.json", +} + +size_to_config_with_registers = { + "small": "dinov2_with_registers_small.json", + "base": "dinov2_with_registers_base.json", + "large": "dinov2_with_registers_large.json", +} + +def get_config(size, use_registers): + config_dict = size_to_config_with_registers if use_registers else size_to_config + current_dir = os.path.dirname(os.path.abspath(__file__)) + configs_dir = os.path.join(current_dir, "dinov2_configs") + config_path = os.path.join(configs_dir, config_dict[size]) + with open(config_path, "r") as f: + dino_config = json.load(f) + return dino_config + + +class DinoV2(nn.Module): + def __init__(self, + shape=(640, 640), + out_feature_indexes=[2, 4, 5, 9], + size="base", + use_registers=True, + use_windowed_attn=True, + gradient_checkpointing=False, + load_dinov2_weights=True, + patch_size=14, + num_windows=4, + positional_encoding_size=37, + ): + super().__init__() + + name = f"facebook/dinov2-with-registers-{size}" if use_registers else f"facebook/dinov2-{size}" + + self.shape = shape + self.patch_size = patch_size + self.num_windows = num_windows + + # Create the encoder + + if not use_windowed_attn: + assert not gradient_checkpointing, "Gradient checkpointing is not supported for non-windowed attention" + assert load_dinov2_weights, "Using non-windowed attention requires loading dinov2 weights from hub" + self.encoder = AutoBackbone.from_pretrained( + name, + out_features=[f"stage{i}" for i in out_feature_indexes], + return_dict=False, + ) + else: + window_block_indexes = set(range(out_feature_indexes[-1] + 1)) + window_block_indexes.difference_update(out_feature_indexes) + window_block_indexes = list(window_block_indexes) + + dino_config = get_config(size, use_registers) + + dino_config["return_dict"] = False + dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes] + + implied_resolution = positional_encoding_size * patch_size + + if implied_resolution != dino_config["image_size"]: + print(f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.") + dino_config["image_size"] = implied_resolution + load_dinov2_weights = False + + if patch_size != 14: + print(f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.") + dino_config["patch_size"] = patch_size + load_dinov2_weights = False + + if use_registers: + windowed_dino_config = WindowedDinov2WithRegistersConfig( + **dino_config, + num_windows=num_windows, + window_block_indexes=window_block_indexes, + gradient_checkpointing=gradient_checkpointing, + ) + else: + windowed_dino_config = WindowedDinov2WithRegistersConfig( + **dino_config, + num_windows=num_windows, + window_block_indexes=window_block_indexes, + num_register_tokens=0, + gradient_checkpointing=gradient_checkpointing, + ) + self.encoder = WindowedDinov2WithRegistersBackbone.from_pretrained( + name, + config=windowed_dino_config, + ) if load_dinov2_weights else WindowedDinov2WithRegistersBackbone(windowed_dino_config) + + + self._out_feature_channels = [size_to_width[size]] * len(out_feature_indexes) + self._export = False + + def export(self): + if self._export: + return + self._export = True + shape = self.shape + def make_new_interpolated_pos_encoding( + position_embeddings, patch_size, height, width + ): + + num_positions = position_embeddings.shape[1] - 1 + dim = position_embeddings.shape[-1] + height = height // patch_size + width = width // patch_size + + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + + # Reshape and permute + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Use bilinear interpolation without antialias + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(height, width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + + # Reshape back + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + # If the shape of self.encoder.embeddings.position_embeddings + # matches the shape of your new tensor, use copy_: + with torch.no_grad(): + new_positions = make_new_interpolated_pos_encoding( + self.encoder.embeddings.position_embeddings, + self.encoder.config.patch_size, + shape[0], + shape[1], + ) + # Create a new Parameter with the new size + old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding + def new_interpolate_pos_encoding(self_mod, embeddings, height, width): + num_patches = embeddings.shape[1] - 1 + num_positions = self_mod.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self_mod.position_embeddings + return old_interpolate_pos_encoding(embeddings, height, width) + + self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions) + self.encoder.embeddings.interpolate_pos_encoding = types.MethodType( + new_interpolate_pos_encoding, + self.encoder.embeddings + ) + + def forward(self, x): + block_size = self.patch_size * self.num_windows + assert x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0, f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}" + x = self.encoder(x) + return list(x[0]) + +if __name__ == "__main__": + model = DinoV2() + model.export() + x = torch.randn(1, 3, 640, 640) + print(model(x)) + for j in model(x): + print(j.shape) diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_base.json b/rfdetr/models/backbone/dinov2_configs/dinov2_base.json new file mode 100644 index 0000000000000000000000000000000000000000..1bcba88ac59b69dfbb2bf48aa3d986918b59e9a8 --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_base.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "Dinov2Model" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 768, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "patch_size": 14, + "qkv_bias": true, + "torch_dtype": "float32", + "transformers_version": "4.31.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_large.json b/rfdetr/models/backbone/dinov2_configs/dinov2_large.json new file mode 100644 index 0000000000000000000000000000000000000000..9a9c7d228e9119ab6898af188815e5236d98a46b --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_large.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "Dinov2Model" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1024, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "qkv_bias": true, + "torch_dtype": "float32", + "transformers_version": "4.31.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_small.json b/rfdetr/models/backbone/dinov2_configs/dinov2_small.json new file mode 100644 index 0000000000000000000000000000000000000000..aa05bd92ba5e00ca4ba1893afd4538c8d4df5bc6 --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_small.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "Dinov2Model" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 384, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "patch_size": 14, + "qkv_bias": true, + "torch_dtype": "float32", + "transformers_version": "4.32.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_base.json b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_base.json new file mode 100644 index 0000000000000000000000000000000000000000..b01f45b32ac1c6b650b5a0a3aef493eec9d4b46b --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_base.json @@ -0,0 +1,50 @@ +{ + "apply_layernorm": true, + "architectures": [ + "Dinov2WithRegistersModel" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 768, + "image_size": 518, + "initializer_range": 0.02, + "interpolate_antialias": true, + "interpolate_offset": 0.0, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2_with_registers", + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "num_register_tokens": 4, + "out_features": [ + "stage12" + ], + "out_indices": [ + 12 + ], + "patch_size": 14, + "qkv_bias": true, + "reshape_hidden_states": true, + "stage_names": [ + "stem", + "stage1", + "stage2", + "stage3", + "stage4", + "stage5", + "stage6", + "stage7", + "stage8", + "stage9", + "stage10", + "stage11", + "stage12" + ], + "torch_dtype": "float32", + "transformers_version": "4.48.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_large.json b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_large.json new file mode 100644 index 0000000000000000000000000000000000000000..ea63daa9320337770fb20727e1958da767fd786f --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_large.json @@ -0,0 +1,50 @@ +{ + "apply_layernorm": true, + "architectures": [ + "Dinov2WithRegistersModel" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1024, + "image_size": 518, + "initializer_range": 0.02, + "interpolate_antialias": true, + "interpolate_offset": 0.0, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2_with_registers", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "num_register_tokens": 4, + "out_features": [ + "stage12" + ], + "out_indices": [ + 12 + ], + "patch_size": 14, + "qkv_bias": true, + "reshape_hidden_states": true, + "stage_names": [ + "stem", + "stage1", + "stage2", + "stage3", + "stage4", + "stage5", + "stage6", + "stage7", + "stage8", + "stage9", + "stage10", + "stage11", + "stage12" + ], + "torch_dtype": "float32", + "transformers_version": "4.48.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_small.json b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_small.json new file mode 100644 index 0000000000000000000000000000000000000000..55c93122b935d75f39c7c5cb616d76799f18fa5c --- /dev/null +++ b/rfdetr/models/backbone/dinov2_configs/dinov2_with_registers_small.json @@ -0,0 +1,50 @@ +{ + "apply_layernorm": true, + "architectures": [ + "Dinov2WithRegistersModel" + ], + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 384, + "image_size": 518, + "initializer_range": 0.02, + "interpolate_antialias": true, + "interpolate_offset": 0.0, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2_with_registers", + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "num_register_tokens": 4, + "out_features": [ + "stage12" + ], + "out_indices": [ + 12 + ], + "patch_size": 14, + "qkv_bias": true, + "reshape_hidden_states": true, + "stage_names": [ + "stem", + "stage1", + "stage2", + "stage3", + "stage4", + "stage5", + "stage6", + "stage7", + "stage8", + "stage9", + "stage10", + "stage11", + "stage12" + ], + "torch_dtype": "float32", + "transformers_version": "4.48.0.dev0", + "use_swiglu_ffn": false +} \ No newline at end of file diff --git a/rfdetr/models/backbone/dinov2_with_windowed_attn.py b/rfdetr/models/backbone/dinov2_with_windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b315c46811d57351135d4eb30646613304149011 --- /dev/null +++ b/rfdetr/models/backbone/dinov2_with_windowed_attn.py @@ -0,0 +1,1130 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from HuggingFace Dinov2 (https://github.com/huggingface/transformers) +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# ------------------------------------------------------------------------ + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from transformers.utils.backbone_utils import BackboneMixin + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" + +# General docstring +_CONFIG_FOR_DOC = "WindowedDinov2WithRegistersConfig" + + +class WindowedDinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an + Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 with Registers + [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of register tokens to use. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + + Example: + + ```python + >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + >>> # Initializing a Dinov2WithRegisters base style configuration + >>> configuration = Dinov2WithRegistersConfig() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = Dinov2WithRegistersModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2_with_registers" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + num_register_tokens=4, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + num_windows=1, + window_block_indexes=None, + gradient_checkpointing=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.num_register_tokens = num_register_tokens + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + self.num_windows = num_windows + self.window_block_indexes = list(range(num_hidden_layers)) if window_block_indexes is None else window_block_indexes + self.gradient_checkpointing = gradient_checkpointing + + +class Dinov2WithRegistersPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class WindowedDinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) if config.num_register_tokens > 0 else None + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility + with the original implementation. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py + - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # Skip interpolation for matching dimensions (unless tracing) + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + # Handle class token and patch embeddings separately + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + + # Calculate new dimensions + height = height // self.config.patch_size + width = width // self.config.patch_size + + # Reshape for interpolation + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Store original dtype for restoration after interpolation + target_dtype = patch_pos_embed.dtype + + # Interpolate at float32 precision + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + # Validate output dimensions if not tracing + if not torch.jit.is_tracing(): + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + # Reshape back to original format + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + # Combine class and patch embeddings + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + if self.config.num_windows > 1: + # reshape for windows + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size + cls_token_with_pos_embed = embeddings[:, :1] + pixel_tokens_with_pos_embed = embeddings[:, 1:] + pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1) + num_w_patches_per_window = num_w_patches // self.config.num_windows + num_h_patches_per_window = num_h_patches // self.config.num_windows + num_windows = self.config.num_windows + windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1) + windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5) + windowed_pixel_tokens = windowed_pixel_tokens.reshape(batch_size * num_windows ** 2, num_h_patches_per_window * num_w_patches_per_window, -1) + windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows ** 2, 1, 1) + embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1) + + # add register tokens + embeddings = torch.cat( + (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 + ) if self.config.num_register_tokens > 0 else embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2WithRegistersSelfAttention(nn.Module): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions + ) + + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + +class Dinov2WithRegistersSelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Dinov2WithRegistersAttention(nn.Module): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + self.attention = Dinov2WithRegistersSelfAttention(config) + self.output = Dinov2WithRegistersSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention = Dinov2WithRegistersSdpaSelfAttention(config) + + +class Dinov2WithRegistersLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Dinov2WithRegistersDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2WithRegistersMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2WithRegistersSwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { + "eager": Dinov2WithRegistersAttention, + "sdpa": Dinov2WithRegistersSdpaAttention, +} + + +class WindowedDinov2WithRegistersLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + + self.num_windows = config.num_windows + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_scale1 = Dinov2WithRegistersLayerScale(config) + self.drop_path = ( + Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2WithRegistersSwiGLUFFN(config) + else: + self.mlp = Dinov2WithRegistersMLP(config) + self.layer_scale2 = Dinov2WithRegistersLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + run_full_attention: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + assert head_mask is None, "head_mask is not supported for windowed attention" + assert not output_attentions, "output_attentions is not supported for windowed attention" + shortcut = hidden_states + if run_full_attention: + # reshape x to remove windows + B, HW, C = hidden_states.shape + num_windows_squared = self.num_windows ** 2 + hidden_states = hidden_states.view(B // num_windows_squared, num_windows_squared * HW, C) + + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + if run_full_attention: + # reshape x to add windows back + B, HW, C = hidden_states.shape + num_windows_squared = self.num_windows ** 2 + # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C) + attention_output = attention_output.view(B * num_windows_squared, HW // num_windows_squared, C) + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + shortcut + + # in Dinov2WithRegisters, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class WindowedDinov2WithRegistersEncoder(nn.Module): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([WindowedDinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = config.gradient_checkpointing + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if i > int(self.config.out_features[-1][5:]): + # early stop if we have reached the last output feature + break + + run_full_attention = i not in self.config.window_block_indexes + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + run_full_attention, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, run_full_attention) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class WindowedDinov2WithRegistersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = WindowedDinov2WithRegistersConfig + base_model_prefix = "dinov2_with_registers" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] + _supports_sdpa = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, WindowedDinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + + +DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class WindowedDinov2WithRegistersModel(WindowedDinov2WithRegistersPreTrainedModel): + def __init__(self, config: WindowedDinov2WithRegistersConfig): + super().__init__(config) + self.config = config + + self.embeddings = WindowedDinov2WithRegistersEmbeddings(config) + self.encoder = WindowedDinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class WindowedDinov2WithRegistersForImageClassification(WindowedDinov2WithRegistersPreTrainedModel): + def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2_with_registers = WindowedDinov2WithRegistersModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2_with_registers( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class WindowedDinov2WithRegistersBackbone(WindowedDinov2WithRegistersPreTrainedModel, BackboneMixin): + def __init__(self, config: WindowedDinov2WithRegistersConfig): + super().__init__(config) + super()._init_backbone(config) + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = WindowedDinov2WithRegistersEmbeddings(config) + self.encoder = WindowedDinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.num_register_tokens = config.num_register_tokens + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + Returns: + + Examples: + + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + + num_h_patches = height // patch_size + num_w_patches = width // patch_size + + if self.config.num_windows > 1: + # undo windowing + num_windows_squared = self.config.num_windows ** 2 + B, HW, C = hidden_state.shape + num_h_patches_per_window = num_h_patches // self.config.num_windows + num_w_patches_per_window = num_w_patches // self.config.num_windows + hidden_state = hidden_state.reshape(B // num_windows_squared, num_windows_squared * HW, C) + hidden_state = hidden_state.view(B // num_windows_squared, self.config.num_windows, self.config.num_windows, num_h_patches_per_window, num_w_patches_per_window, C) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5) + + hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "WindowedDinov2WithRegistersPreTrainedModel", + "WindowedDinov2WithRegistersModel", + "WindowedDinov2WithRegistersForImageClassification", + "WindowedDinov2WithRegistersBackbone", +] \ No newline at end of file diff --git a/rfdetr/models/backbone/projector.py b/rfdetr/models/backbone/projector.py new file mode 100644 index 0000000000000000000000000000000000000000..38175574a13ced71cc44376ad476dc28b5784459 --- /dev/null +++ b/rfdetr/models/backbone/projector.py @@ -0,0 +1,293 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from ViTDet (https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Projector +""" +import math +import random +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + """ + LayerNorm forward + TODO: this is a hack to avoid overflow when using fp16 + """ + #if x.dtype == torch.half: + # x = x / (x.max() + self.eps) + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "LN": lambda channels: LayerNorm(channels), + }[norm] + return norm(out_channels) + + +def get_activation(name, inplace=False): + """ get activation """ + if name == "silu": + module = nn.SiLU(inplace=inplace) + elif name == "relu": + module = nn.ReLU(inplace=inplace) + elif name in ["LeakyReLU", 'leakyrelu', 'lrelu']: + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name is None: + module = nn.Identity() + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + + +class ConvX(nn.Module): + """ Conv-bn module""" + def __init__(self, in_planes, out_planes, kernel=3, stride=1, groups=1, dilation=1, act='relu', layer_norm=False, rms_norm=False): + super(ConvX, self).__init__() + if not isinstance(kernel, tuple): + kernel = (kernel, kernel) + padding = (kernel[0] // 2, kernel[1] // 2) + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, + stride=stride, padding=padding, groups=groups, + dilation=dilation, bias=False) + if rms_norm: + self.bn = nn.RMSNorm(out_planes) + else: + self.bn = get_norm('LN', out_planes) if layer_norm else nn.BatchNorm2d(out_planes) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + """ forward """ + out = self.act(self.bn(self.conv(x))) + return out + + +class Bottleneck(nn.Module): + """Standard bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, act='silu', layer_norm=False, rms_norm=False): + """ ch_in, ch_out, shortcut, groups, kernels, expand """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = ConvX(c1, c_, k[0], 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm) + self.cv2 = ConvX(c_, c2, k[1], 1, groups=g, act=act, layer_norm=layer_norm, rms_norm=rms_norm) + self.add = shortcut and c1 == c2 + + def forward(self, x): + """'forward()' applies the YOLOv5 FPN to input data.""" + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class C2f(nn.Module): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, act='silu', layer_norm=False, rms_norm=False): + """ ch_in, ch_out, number, shortcut, groups, expansion """ + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = ConvX(c1, 2 * self.c, 1, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm) + self.cv2 = ConvX((2 + n) * self.c, c2, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=(3, 3), e=1.0, act=act, layer_norm=layer_norm, rms_norm=rms_norm) for _ in range(n)) + + def forward(self, x): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + +class MultiScaleProjector(nn.Module): + """ + This module implements MultiScaleProjector in :paper:`lwdetr`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + in_channels, + out_channels, + scale_factors, + num_blocks=3, + layer_norm=False, + rms_norm=False, + survival_prob=1.0, + force_drop_last_n_features=0, + ): + """ + Args: + net (Backbone): module representing the subnetwork backbone. + Must be a subclass of :class:`Backbone`. + out_channels (int): number of channels in the output feature maps. + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + """ + super(MultiScaleProjector, self).__init__() + + self.scale_factors = scale_factors + self.survival_prob = survival_prob + self.force_drop_last_n_features = force_drop_last_n_features + + stages_sampling = [] + stages = [] + # use_bias = norm == "" + use_bias = False + self.use_extra_pool = False + for scale in scale_factors: + stages_sampling.append([]) + for in_dim in in_channels: + out_dim = in_dim + layers = [] + + # if in_dim > 512: + # layers.append(ConvX(in_dim, in_dim // 2, kernel=1)) + # in_dim = in_dim // 2 + + if scale == 4.0: + layers.extend([ + nn.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2), + get_norm('LN', in_dim // 2), + nn.GELU(), + nn.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2), + ]) + out_dim = in_dim // 4 + elif scale == 2.0: + # a hack to reduce the FLOPs and Params when the dimention of output feature is too large + # if in_dim > 512: + # layers = [ + # ConvX(in_dim, in_dim // 2, kernel=1), + # nn.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2), + # ] + # out_dim = in_dim // 4 + # else: + layers.extend([ + nn.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2), + ]) + out_dim = in_dim // 2 + elif scale == 1.0: + pass + elif scale == 0.5: + layers.extend([ + ConvX(in_dim, in_dim, 3, 2, layer_norm=layer_norm), + ]) + elif scale == 0.25: + self.use_extra_pool = True + continue + else: + raise NotImplementedError("Unsupported scale_factor:{}".format(scale)) + layers = nn.Sequential(*layers) + stages_sampling[-1].append(layers) + stages_sampling[-1] = nn.ModuleList(stages_sampling[-1]) + + in_dim = int(sum(in_channel // max(1, scale) for in_channel in in_channels)) + layers = [ + C2f(in_dim, out_channels, num_blocks, layer_norm=layer_norm), + get_norm('LN', out_channels), + ] + layers = nn.Sequential(*layers) + stages.append(layers) + + self.stages_sampling = nn.ModuleList(stages_sampling) + self.stages = nn.ModuleList(stages) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: + mapping from feature map name to pyramid feature map tensor + in high to low resolution order. Returned feature names follow the FPN + convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + num_features = len(x) + if self.survival_prob < 1.0 and self.training: + final_drop_prob = 1 - self.survival_prob + drop_p = np.random.uniform() + for i in range(1, num_features): + critical_drop_prob = i * (final_drop_prob / (num_features - 1)) + if drop_p < critical_drop_prob: + x[i][:] = 0 + elif self.force_drop_last_n_features > 0: + for i in range(self.force_drop_last_n_features): + # don't do it inplace to ensure the compiler can optimize out the backbone layers + x[-(i+1)] = torch.zeros_like(x[-(i+1)]) + + results = [] + # x list of len(out_features_indexes) + for i, stage in enumerate(self.stages): + feat_fuse = [] + for j, stage_sampling in enumerate(self.stages_sampling[i]): + feat_fuse.append(stage_sampling(x[j])) + if len(feat_fuse) > 1: + feat_fuse = torch.cat(feat_fuse, dim=1) + else: + feat_fuse = feat_fuse[0] + results.append(stage(feat_fuse)) + if self.use_extra_pool: + results.append( + F.max_pool2d(results[-1], kernel_size=1, stride=2, padding=0) + ) + return results + + +class SimpleProjector(nn.Module): + def __init__(self, in_dim, out_dim, factor_kernel=False): + super(SimpleProjector, self).__init__() + if not factor_kernel: + self.convx1 = ConvX(in_dim, in_dim*2, layer_norm=True, act='silu') + self.convx2 = ConvX(in_dim*2, out_dim, layer_norm=True, act='silu') + else: + self.convx1 = ConvX(in_dim, out_dim, kernel=(3, 1), layer_norm=True, act='silu') + self.convx2 = ConvX(out_dim, out_dim, kernel=(1, 3), layer_norm=True, act='silu') + self.ln = get_norm('LN', out_dim) + + def forward(self, x): + """ forward """ + out = self.ln(self.convx2(self.convx1(x[0]))) + return [out] diff --git a/rfdetr/models/lwdetr.py b/rfdetr/models/lwdetr.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4b2aea40a5e17bb66c161f6bc38f9e7f21a9ce --- /dev/null +++ b/rfdetr/models/lwdetr.py @@ -0,0 +1,684 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +LW-DETR model and criterion classes +""" +import copy +import math +from typing import Callable +import torch +import torch.nn.functional as F +from torch import nn + +from rfdetr.util import box_ops +from rfdetr.util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, + is_dist_avail_and_initialized) + +from rfdetr.models.backbone import build_backbone +from rfdetr.models.matcher import build_matcher +from rfdetr.models.transformer import build_transformer + +class LWDETR(nn.Module): + """ This is the Group DETR v3 module that performs object detection """ + def __init__(self, + backbone, + transformer, + num_classes, + num_queries, + aux_loss=False, + group_detr=1, + two_stage=False, + lite_refpoint_refine=False, + bbox_reparam=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + group_detr: Number of groups to speed detr training. Default is 1. + lite_refpoint_refine: TODO + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + query_dim=4 + self.refpoint_embed = nn.Embedding(num_queries * group_detr, query_dim) + self.query_feat = nn.Embedding(num_queries * group_detr, hidden_dim) + nn.init.constant_(self.refpoint_embed.weight.data, 0) + + self.backbone = backbone + self.aux_loss = aux_loss + self.group_detr = group_detr + + # iter update + self.lite_refpoint_refine = lite_refpoint_refine + if not self.lite_refpoint_refine: + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + self.transformer.decoder.bbox_embed = None + + self.bbox_reparam = bbox_reparam + + # init prior_prob setting for focal loss + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + + # init bbox_mebed + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + # two_stage + self.two_stage = two_stage + if self.two_stage: + self.transformer.enc_out_bbox_embed = nn.ModuleList( + [copy.deepcopy(self.bbox_embed) for _ in range(group_detr)]) + self.transformer.enc_out_class_embed = nn.ModuleList( + [copy.deepcopy(self.class_embed) for _ in range(group_detr)]) + + self._export = False + + def reinitialize_detection_head(self, num_classes): + # Create new classification head + del self.class_embed + self.add_module("class_embed", nn.Linear(self.transformer.d_model, num_classes)) + + # Initialize with focal loss bias adjustment + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + + if self.two_stage: + del self.transformer.enc_out_class_embed + self.transformer.add_module("enc_out_class_embed", nn.ModuleList( + [copy.deepcopy(self.class_embed) for _ in range(self.group_detr)])) + + + def export(self): + self._export = True + self._forward_origin = self.forward + self.forward = self.forward_export + for name, m in self.named_modules(): + if hasattr(m, "export") and isinstance(m.export, Callable) and hasattr(m, "_export") and not m._export: + m.export() + + def forward(self, samples: NestedTensor, targets=None): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x num_classes] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, width, height). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, poss = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(src) + masks.append(mask) + assert mask is not None + + if self.training: + refpoint_embed_weight = self.refpoint_embed.weight + query_feat_weight = self.query_feat.weight + else: + # only use one group in inference + refpoint_embed_weight = self.refpoint_embed.weight[:self.num_queries] + query_feat_weight = self.query_feat.weight[:self.num_queries] + + hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer( + srcs, masks, poss, refpoint_embed_weight, query_feat_weight) + + if hs is not None: + if self.bbox_reparam: + outputs_coord_delta = self.bbox_embed(hs) + outputs_coord_cxcy = outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:] + ref_unsigmoid[..., :2] + outputs_coord_wh = outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:] + outputs_coord = torch.concat( + [outputs_coord_cxcy, outputs_coord_wh], dim=-1 + ) + else: + outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid() + + outputs_class = self.class_embed(hs) + + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + group_detr = self.group_detr if self.training else 1 + hs_enc_list = hs_enc.chunk(group_detr, dim=1) + cls_enc = [] + for g_idx in range(group_detr): + cls_enc_gidx = self.transformer.enc_out_class_embed[g_idx](hs_enc_list[g_idx]) + cls_enc.append(cls_enc_gidx) + cls_enc = torch.cat(cls_enc, dim=1) + if hs is not None: + out['enc_outputs'] = {'pred_logits': cls_enc, 'pred_boxes': ref_enc} + else: + out = {'pred_logits': cls_enc, 'pred_boxes': ref_enc} + + return out + + def forward_export(self, tensors): + srcs, _, poss = self.backbone(tensors) + # only use one group in inference + refpoint_embed_weight = self.refpoint_embed.weight[:self.num_queries] + query_feat_weight = self.query_feat.weight[:self.num_queries] + + hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer( + srcs, None, poss, refpoint_embed_weight, query_feat_weight) + + if hs is not None: + if self.bbox_reparam: + outputs_coord_delta = self.bbox_embed(hs) + outputs_coord_cxcy = outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:] + ref_unsigmoid[..., :2] + outputs_coord_wh = outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:] + outputs_coord = torch.concat( + [outputs_coord_cxcy, outputs_coord_wh], dim=-1 + ) + else: + outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid() + outputs_class = self.class_embed(hs) + else: + assert self.two_stage, "if not using decoder, two_stage must be True" + outputs_class = self.transformer.enc_out_class_embed[0](hs_enc) + outputs_coord = ref_enc + + return outputs_coord, outputs_class + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + def update_drop_path(self, drop_path_rate, vit_encoder_num_layers): + """ """ + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, vit_encoder_num_layers)] + for i in range(vit_encoder_num_layers): + if hasattr(self.backbone[0].encoder, 'blocks'): # Not aimv2 + if hasattr(self.backbone[0].encoder.blocks[i].drop_path, 'drop_prob'): + self.backbone[0].encoder.blocks[i].drop_path.drop_prob = dp_rates[i] + else: # aimv2 + if hasattr(self.backbone[0].encoder.trunk.blocks[i].drop_path, 'drop_prob'): + self.backbone[0].encoder.trunk.blocks[i].drop_path.drop_prob = dp_rates[i] + + def update_dropout(self, drop_rate): + for module in self.transformer.modules(): + if isinstance(module, nn.Dropout): + module.p = drop_rate + + +class SetCriterion(nn.Module): + """ This class computes the loss for Conditional DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, + num_classes, + matcher, + weight_dict, + focal_alpha, + losses, + group_detr=1, + sum_group_losses=False, + use_varifocal_loss=False, + use_position_supervised_loss=False, + ia_bce_loss=False,): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + group_detr: Number of groups to speed detr training. Default is 1. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + self.group_detr = group_detr + self.sum_group_losses = sum_group_losses + self.use_varifocal_loss = use_varifocal_loss + self.use_position_supervised_loss = use_position_supervised_loss + self.ia_bce_loss = ia_bce_loss + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (Binary focal loss) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + + if self.ia_bce_loss: + alpha = self.focal_alpha + gamma = 2 + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + iou_targets=torch.diag(box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes.detach()), + box_ops.box_cxcywh_to_xyxy(target_boxes))[0]) + pos_ious = iou_targets.clone().detach() + prob = src_logits.sigmoid() + #init positive weights and negative weights + pos_weights = torch.zeros_like(src_logits) + neg_weights = prob ** gamma + + pos_ind=[id for id in idx] + pos_ind.append(target_classes_o) + + t = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + t = torch.clamp(t, 0.01).detach() + + pos_weights[pos_ind] = t.to(pos_weights.dtype) + neg_weights[pos_ind] = 1 - t.to(neg_weights.dtype) + # a reformulation of the standard loss_ce = - pos_weights * prob.log() - neg_weights * (1 - prob).log() + # with a focus on statistical stability by using fused logsigmoid + loss_ce = neg_weights * src_logits - F.logsigmoid(src_logits) * (pos_weights + neg_weights) + loss_ce = loss_ce.sum() / num_boxes + + elif self.use_position_supervised_loss: + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + iou_targets=torch.diag(box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes.detach()), + box_ops.box_cxcywh_to_xyxy(target_boxes))[0]) + pos_ious = iou_targets.clone().detach() + # pos_ious_func = pos_ious ** 2 + pos_ious_func = pos_ious + + cls_iou_func_targets = torch.zeros((src_logits.shape[0], src_logits.shape[1],self.num_classes), + dtype=src_logits.dtype, device=src_logits.device) + + pos_ind=[id for id in idx] + pos_ind.append(target_classes_o) + cls_iou_func_targets[pos_ind] = pos_ious_func + norm_cls_iou_func_targets = cls_iou_func_targets \ + / (cls_iou_func_targets.view(cls_iou_func_targets.shape[0], -1, 1).amax(1, True) + 1e-8) + loss_ce = position_supervised_loss(src_logits, norm_cls_iou_func_targets, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + + elif self.use_varifocal_loss: + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + iou_targets=torch.diag(box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes.detach()), + box_ops.box_cxcywh_to_xyxy(target_boxes))[0]) + pos_ious = iou_targets.clone().detach() + + cls_iou_targets = torch.zeros((src_logits.shape[0], src_logits.shape[1],self.num_classes), + dtype=src_logits.dtype, device=src_logits.device) + + pos_ind=[id for id in idx] + pos_ind.append(target_classes_o) + cls_iou_targets[pos_ind] = pos_ious + loss_ce = sigmoid_varifocal_loss(src_logits, cls_iou_targets, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + else: + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + group_detr = self.group_detr if self.training else 1 + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets, group_detr=group_detr) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + if not self.sum_group_losses: + num_boxes = num_boxes * group_detr + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets, group_detr=group_detr) + for loss in self.losses: + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + indices = self.matcher(enc_outputs, targets, group_detr=group_detr) + for loss in self.losses: + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +def sigmoid_varifocal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + prob = inputs.sigmoid() + focal_weight = targets * (targets > 0.0).float() + \ + (1 - alpha) * (prob - targets).abs().pow(gamma) * \ + (targets <= 0.0).float() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + loss = ce_loss * focal_weight + + return loss.mean(1).sum() / num_boxes + + +def position_supervised_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + loss = ce_loss * (torch.abs(targets - prob) ** gamma) + + if alpha >= 0: + alpha_t = alpha * (targets > 0.0).float() + (1 - alpha) * (targets <= 0.0).float() + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + def __init__(self, num_select=300) -> None: + super().__init__() + self.num_select = num_select + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), self.num_select, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build_model(args): + # the `num_classes` naming here is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 + num_classes = args.num_classes + 1 + device = torch.device(args.device) + + + backbone = build_backbone( + encoder=args.encoder, + vit_encoder_num_layers=args.vit_encoder_num_layers, + pretrained_encoder=args.pretrained_encoder, + window_block_indexes=args.window_block_indexes, + drop_path=args.drop_path, + out_channels=args.hidden_dim, + out_feature_indexes=args.out_feature_indexes, + projector_scale=args.projector_scale, + use_cls_token=args.use_cls_token, + hidden_dim=args.hidden_dim, + position_embedding=args.position_embedding, + freeze_encoder=args.freeze_encoder, + layer_norm=args.layer_norm, + target_shape=args.shape if hasattr(args, 'shape') else (args.resolution, args.resolution) if hasattr(args, 'resolution') else (640, 640), + rms_norm=args.rms_norm, + backbone_lora=args.backbone_lora, + force_no_pretrain=args.force_no_pretrain, + gradient_checkpointing=args.gradient_checkpointing, + load_dinov2_weights=args.pretrain_weights is None, + patch_size=args.patch_size, + num_windows=args.num_windows, + positional_encoding_size=args.positional_encoding_size, + ) + if args.encoder_only: + return backbone[0].encoder, None, None + if args.backbone_only: + return backbone, None, None + + args.num_feature_levels = len(args.projector_scale) + transformer = build_transformer(args) + + model = LWDETR( + backbone, + transformer, + num_classes=num_classes, + num_queries=args.num_queries, + aux_loss=args.aux_loss, + group_detr=args.group_detr, + two_stage=args.two_stage, + lite_refpoint_refine=args.lite_refpoint_refine, + bbox_reparam=args.bbox_reparam, + ) + return model + +def build_criterion_and_postprocessors(args): + device = torch.device(args.device) + matcher = build_matcher(args) + weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + if args.two_stage: + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + + try: + sum_group_losses = args.sum_group_losses + except: + sum_group_losses = False + criterion = SetCriterion(args.num_classes + 1, matcher=matcher, weight_dict=weight_dict, + focal_alpha=args.focal_alpha, losses=losses, + group_detr=args.group_detr, sum_group_losses=sum_group_losses, + use_varifocal_loss = args.use_varifocal_loss, + use_position_supervised_loss=args.use_position_supervised_loss, + ia_bce_loss=args.ia_bce_loss) + criterion.to(device) + postprocessors = {'bbox': PostProcess(num_select=args.num_select)} + + return criterion, postprocessors diff --git a/rfdetr/models/matcher.py b/rfdetr/models/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9afb7910308fc58b7d599d8ade6604492cb864 --- /dev/null +++ b/rfdetr/models/matcher.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from rfdetr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, focal_alpha: float = 0.25, use_pos_only: bool = False, + use_position_modulated_cost: bool = False): + """Creates the matcher + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + self.focal_alpha = focal_alpha + + @torch.no_grad() + def forward(self, outputs, targets, group_detr=1): + """ Performs the matching + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + group_detr: Number of groups used for matching. + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the giou cost betwen boxes + giou = generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + cost_giou = -giou + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [] + g_num_queries = num_queries // group_detr + C_list = C.split(g_num_queries, dim=1) + for g_i in range(group_detr): + C_g = C_list[g_i] + indices_g = [linear_sum_assignment(c[i]) for i, c in enumerate(C_g.split(sizes, -1))] + if g_i == 0: + indices = indices_g + else: + indices = [ + (np.concatenate([indice1[0], indice2[0] + g_num_queries * g_i]), np.concatenate([indice1[1], indice2[1]])) + for indice1, indice2 in zip(indices, indices_g) + ] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher( + cost_class=args.set_cost_class, + cost_bbox=args.set_cost_bbox, + cost_giou=args.set_cost_giou, + focal_alpha=args.focal_alpha,) \ No newline at end of file diff --git a/rfdetr/models/ops/__init__.py b/rfdetr/models/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdetr/models/ops/__pycache__/__init__.cpython-313.pyc b/rfdetr/models/ops/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e218f88948156444ae1b8c528ca3da6821c7d006 Binary files /dev/null and b/rfdetr/models/ops/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/models/ops/functions/__init__.py b/rfdetr/models/ops/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af6de029cd305ba5a17206b54ad70caab6d22a37 --- /dev/null +++ b/rfdetr/models/ops/functions/__init__.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ +""" +ms_deform_attn_func +""" +from .ms_deform_attn_func import ms_deform_attn_core_pytorch diff --git a/rfdetr/models/ops/functions/__pycache__/__init__.cpython-313.pyc b/rfdetr/models/ops/functions/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8593466250c9c90e808d219754e73aba4eed230 Binary files /dev/null and b/rfdetr/models/ops/functions/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-313.pyc b/rfdetr/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d14d4f129e3a991183c8c6911e1f08b7cefa97f Binary files /dev/null and b/rfdetr/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-313.pyc differ diff --git a/rfdetr/models/ops/functions/ms_deform_attn_func.py b/rfdetr/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000000000000000000000000000000000000..00518f0f4a982187a393e21b229df181b3538195 --- /dev/null +++ b/rfdetr/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,50 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ +""" +ms_deform_attn_func +""" +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + """"for debug and test only, need to use cuda version instead + """ + # B, n_heads, head_dim, N + B, n_heads, head_dim, _ = value.shape + _, Len_q, n_heads, L, P, _ = sampling_locations.shape + value_list = value.split([H * W for H, W in value_spatial_shapes], dim=3) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H, W) in enumerate(value_spatial_shapes): + # B, n_heads, head_dim, H, W + value_l_ = value_list[lid_].view(B * n_heads, head_dim, H, W) + # B, Len_q, n_heads, P, 2 -> B, n_heads, Len_q, P, 2 -> B*n_heads, Len_q, P, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # B*n_heads, head_dim, Len_q, P + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (B, Len_q, n_heads, L * P) -> (B, n_heads, Len_q, L, P) -> (B*n_heads, 1, Len_q, L*P) + attention_weights = attention_weights.transpose(1, 2).reshape(B * n_heads, 1, Len_q, L * P) + # B*n_heads, head_dim, Len_q, L*P + sampling_value_list = torch.stack(sampling_value_list, dim=-2).flatten(-2) + output = (sampling_value_list * attention_weights).sum(-1).view(B, n_heads * head_dim, Len_q) + return output.transpose(1, 2).contiguous() diff --git a/rfdetr/models/ops/modules/__init__.py b/rfdetr/models/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f82cb1ad9d634a87b54ba6a71b58a230bcade5fe --- /dev/null +++ b/rfdetr/models/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/rfdetr/models/ops/modules/__pycache__/__init__.cpython-313.pyc b/rfdetr/models/ops/modules/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3bb81a1952532c63bfad0f0a1dd6a22e3dc138 Binary files /dev/null and b/rfdetr/models/ops/modules/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/models/ops/modules/__pycache__/ms_deform_attn.cpython-313.pyc b/rfdetr/models/ops/modules/__pycache__/ms_deform_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecd1840490928930d0dfaa7bf98ab060162f3b62 Binary files /dev/null and b/rfdetr/models/ops/modules/__pycache__/ms_deform_attn.cpython-313.pyc differ diff --git a/rfdetr/models/ops/modules/ms_deform_attn.py b/rfdetr/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..79f37273fcd947619ceccfd9d100d5bce5044093 --- /dev/null +++ b/rfdetr/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,140 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ +""" +Multi-Scale Deformable Attention Module +""" + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import ms_deform_attn_core_pytorch + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + """Multi-Scale Deformable Attention Module + """ + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the " + "dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + self._export = False + + def export(self): + """export mode + """ + self._export = True + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True) + [0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, + input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + attention_weights = F.softmax(attention_weights, -1) + + value = value.transpose(1, 2).contiguous().view(N, self.n_heads, self.d_model // self.n_heads, Len_in) + output = ms_deform_attn_core_pytorch( + value, input_spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + return output diff --git a/rfdetr/models/position_encoding.py b/rfdetr/models/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..3847f0f9c70f0f27db681b87d4c88eb53cf8d78d --- /dev/null +++ b/rfdetr/models/position_encoding.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from rfdetr.util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + self._export = False + + def export(self): + self._export = True + self._forward_origin = self.forward + self.forward = self.forward_export + + def forward(self, tensor_list: NestedTensor, align_dim_orders = True): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + if align_dim_orders: + pos = torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3) + # return: (H, W, bs, C) + else: + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + # return: (bs, C, H, W) + return pos + + def forward_export(self, mask:torch.Tensor, align_dim_orders = True): + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + if align_dim_orders: + pos = torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3) + # return: (H, W, bs, C) + else: + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + # return: (bs, C, H, W) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + self._export = False + + def export(self): + raise NotImplementedError + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[:2] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).unsqueeze(2).repeat(1, 1, x.shape[2], 1) + # return: (H, W, bs, C) + return pos + + +def build_position_encoding(hidden_dim, position_embedding): + N_steps = hidden_dim // 2 + if position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + + return position_embedding diff --git a/rfdetr/models/transformer.py b/rfdetr/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2343b772a7d5126608222d7104273110336b34 --- /dev/null +++ b/rfdetr/models/transformer.py @@ -0,0 +1,591 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +""" +Transformer class +""" +import math +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from rfdetr.models.ops.modules import MSDeformAttn + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def gen_sineembed_for_position(pos_tensor, dim=128): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True): + """ + Input: + - memory: bs, \sum{hw}, d_model + - memory_padding_mask: bs, \sum{hw} + - spatial_shapes: nlevel, 2 + Output: + - output_memory: bs, \sum{hw}, d_model + - output_proposals: bs, \sum{hw}, 4 + """ + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + if memory_padding_mask is not None: + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + else: + valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device) + valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + + if unsigmoid: + output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid + if memory_padding_mask is not None: + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + else: + if memory_padding_mask is not None: + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(0)) + + output_memory = memory + if memory_padding_mask is not None: + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + + return output_memory.to(memory.dtype), output_proposals.to(memory.dtype) + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False, + return_intermediate_dec=False, group_detr=1, + two_stage=False, + num_feature_levels=4, dec_n_points=4, + lite_refpoint_refine=False, + decoder_norm_type='LN', + bbox_reparam=False): + super().__init__() + self.encoder = None + + decoder_layer = TransformerDecoderLayer(d_model, sa_nhead, ca_nhead, dim_feedforward, + dropout, activation, normalize_before, + group_detr=group_detr, + num_feature_levels=num_feature_levels, + dec_n_points=dec_n_points, + skip_self_attn=False,) + assert decoder_norm_type in ['LN', 'Identity'] + norm = { + "LN": lambda channels: nn.LayerNorm(channels), + "Identity": lambda channels: nn.Identity(), + } + decoder_norm = norm[decoder_norm_type](d_model) + + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec, + d_model=d_model, + lite_refpoint_refine=lite_refpoint_refine, + bbox_reparam=bbox_reparam) + + + self.two_stage = two_stage + if two_stage: + self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)]) + self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)]) + + self._reset_parameters() + + self.num_queries = num_queries + self.d_model = d_model + self.dec_layers = num_decoder_layers + self.group_detr = group_detr + self.num_feature_levels = num_feature_levels + self.bbox_reparam = bbox_reparam + + self._export = False + + def export(self): + self._export = True + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): + src_flatten = [] + mask_flatten = [] if masks is not None else None + lvl_pos_embed_flatten = [] + spatial_shapes = [] + valid_ratios = [] if masks is not None else None + for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + lvl_pos_embed_flatten.append(pos_embed) + src_flatten.append(src) + if masks is not None: + mask = masks[lvl].flatten(1) # bs, hw + mask_flatten.append(mask) + memory = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + if masks is not None: + mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + + if self.two_stage: + output_memory, output_proposals = gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam) + # group detr for first stage + refpoint_embed_ts, memory_ts, boxes_ts = [], [], [] + group_detr = self.group_detr if self.training else 1 + for g_idx in range(group_detr): + output_memory_gidx = self.enc_output_norm[g_idx](self.enc_output[g_idx](output_memory)) + + enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](output_memory_gidx) + if self.bbox_reparam: + enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](output_memory_gidx) + enc_outputs_coord_cxcy_gidx = enc_outputs_coord_delta_gidx[..., + :2] * output_proposals[..., 2:] + output_proposals[..., :2] + enc_outputs_coord_wh_gidx = enc_outputs_coord_delta_gidx[..., 2:].exp() * output_proposals[..., 2:] + enc_outputs_coord_unselected_gidx = torch.concat( + [enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1) + else: + enc_outputs_coord_unselected_gidx = self.enc_out_bbox_embed[g_idx]( + output_memory_gidx) + output_proposals # (bs, \sum{hw}, 4) unsigmoid + + topk = min(self.num_queries, enc_outputs_class_unselected_gidx.shape[-2]) + topk_proposals_gidx = torch.topk(enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1)[1] # bs, nq + + refpoint_embed_gidx_undetach = torch.gather( + enc_outputs_coord_unselected_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid + # for decoder layer, detached as initial ones, (bs, nq, 4) + refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach() + + # get memory tgt + tgt_undetach_gidx = torch.gather( + output_memory_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model)) + + refpoint_embed_ts.append(refpoint_embed_gidx) + memory_ts.append(tgt_undetach_gidx) + boxes_ts.append(refpoint_embed_gidx_undetach) + # concat on dim=1, the nq dimension, (bs, nq, d) --> (bs, nq, d) + refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1) + # (bs, nq, d) + memory_ts = torch.cat(memory_ts, dim=1)#.transpose(0, 1) + boxes_ts = torch.cat(boxes_ts, dim=1)#.transpose(0, 1) + + if self.dec_layers > 0: + tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1) + refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1) + if self.two_stage: + ts_len = refpoint_embed_ts.shape[-2] + refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :] + refpoint_embed_subset = refpoint_embed[..., ts_len:, :] + + if self.bbox_reparam: + refpoint_embed_cxcy = refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:] + refpoint_embed_cxcy = refpoint_embed_cxcy + refpoint_embed_ts[..., :2] + refpoint_embed_wh = refpoint_embed_ts_subset[..., 2:].exp() * refpoint_embed_ts[..., 2:] + refpoint_embed_ts_subset = torch.concat( + [refpoint_embed_cxcy, refpoint_embed_wh], dim=-1 + ) + else: + refpoint_embed_ts_subset = refpoint_embed_ts_subset + refpoint_embed_ts + + refpoint_embed = torch.concat( + [refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2) + + hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten, + pos=lvl_pos_embed_flatten, refpoints_unsigmoid=refpoint_embed, + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + valid_ratios=valid_ratios.to(memory.dtype) if valid_ratios is not None else valid_ratios) + else: + assert self.two_stage, "if not using decoder, two_stage must be True" + hs = None + references = None + + if self.two_stage: + if self.bbox_reparam: + return hs, references, memory_ts, boxes_ts + else: + return hs, references, memory_ts, boxes_ts.sigmoid() + return hs, references, None, None + + +class TransformerDecoder(nn.Module): + + def __init__(self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False, + d_model=256, + lite_refpoint_refine=False, + bbox_reparam=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.d_model = d_model + self.norm = norm + self.return_intermediate = return_intermediate + self.lite_refpoint_refine = lite_refpoint_refine + self.bbox_reparam = bbox_reparam + + self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) + + self._export = False + + def export(self): + self._export = True + + def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta): + if self.bbox_reparam: + new_refpoints_cxcy = new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:] + refpoints_unsigmoid[..., :2] + new_refpoints_wh = new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:] + new_refpoints_unsigmoid = torch.concat( + [new_refpoints_cxcy, new_refpoints_wh], dim=-1 + ) + else: + new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta + return new_refpoints_unsigmoid + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + refpoints_unsigmoid: Optional[Tensor] = None, + # for memory + level_start_index: Optional[Tensor] = None, # num_levels + spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 + valid_ratios: Optional[Tensor] = None): + output = tgt + + intermediate = [] + hs_refpoints_unsigmoid = [refpoints_unsigmoid] + + def get_reference(refpoints): + # [num_queries, batch_size, 4] + obj_center = refpoints[..., :4] + + if self._export: + query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model / 2) # bs, nq, 256*2 + refpoints_input = obj_center[:, :, None] # bs, nq, 1, 4 + else: + refpoints_input = obj_center[:, :, None] \ + * torch.cat([valid_ratios, valid_ratios], -1)[:, None] # bs, nq, nlevel, 4 + query_sine_embed = gen_sineembed_for_position( + refpoints_input[:, :, 0, :], self.d_model / 2) # bs, nq, 256*2 + query_pos = self.ref_point_head(query_sine_embed) + return obj_center, refpoints_input, query_pos, query_sine_embed + + # always use init refpoints + if self.lite_refpoint_refine: + if self.bbox_reparam: + obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) + else: + obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) + + for layer_id, layer in enumerate(self.layers): + # iter refine each layer + if not self.lite_refpoint_refine: + if self.bbox_reparam: + obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) + else: + obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) + + # For the first decoder layer, we do not apply transformation over p_s + pos_transformation = 1 + + query_pos = query_pos * pos_transformation + + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, + is_first=(layer_id == 0), + reference_points=refpoints_input, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + + if not self.lite_refpoint_refine: + # box iterative update + new_refpoints_delta = self.bbox_embed(output) + new_refpoints_unsigmoid = self.refpoints_refine(refpoints_unsigmoid, new_refpoints_delta) + if layer_id != self.num_layers - 1: + hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid) + refpoints_unsigmoid = new_refpoints_unsigmoid.detach() + + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + if self._export: + # to shape: B, N, C + hs = intermediate[-1] + if self.bbox_embed is not None: + ref = hs_refpoints_unsigmoid[-1] + else: + ref = refpoints_unsigmoid + return hs, ref + # box iterative update + if self.bbox_embed is not None: + return [ + torch.stack(intermediate), + torch.stack(hs_refpoints_unsigmoid), + ] + else: + return [ + torch.stack(intermediate), + refpoints_unsigmoid.unsqueeze(0) + ] + + return output.unsqueeze(0) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, group_detr=1, + num_feature_levels=4, dec_n_points=4, + skip_self_attn=False): + super().__init__() + # Decoder Self-Attention + self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # Decoder Cross-Attention + self.cross_attn = MSDeformAttn( + d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points) + + self.nhead = ca_nhead + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + self.group_detr = group_detr + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + query_sine_embed = None, + is_first = False, + reference_points = None, + spatial_shapes=None, + level_start_index=None, + ): + bs, num_queries, _ = tgt.shape + + # ========== Begin of Self-Attention ============= + # Apply projections here + # shape: batch_size x num_queries x 256 + q = k = tgt + query_pos + v = tgt + if self.training: + q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0) + k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0) + v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0) + + tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + need_weights=False)[0] + + if self.training: + tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1) + # ========== End of Self-Attention ============= + + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ========== Begin of Cross-Attention ============= + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + reference_points, + memory, + spatial_shapes, + level_start_index, + memory_key_padding_mask + ) + # ========== End of Cross-Attention ============= + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + query_sine_embed = None, + is_first = False, + reference_points = None, + spatial_shapes=None, + level_start_index=None): + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, + query_sine_embed, is_first, + reference_points, spatial_shapes, level_start_index) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + + try: + two_stage = args.two_stage + except: + two_stage = False + + return Transformer( + d_model=args.hidden_dim, + sa_nhead=args.sa_nheads, + ca_nhead=args.ca_nheads, + num_queries=args.num_queries, + dropout=args.dropout, + dim_feedforward=args.dim_feedforward, + num_decoder_layers=args.dec_layers, + return_intermediate_dec=True, + group_detr=args.group_detr, + two_stage=two_stage, + num_feature_levels=args.num_feature_levels, + dec_n_points=args.dec_n_points, + lite_refpoint_refine=args.lite_refpoint_refine, + decoder_norm_type=args.decoder_norm, + bbox_reparam=args.bbox_reparam, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/rfdetr/py.typed b/rfdetr/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdetr/util/__init__.py b/rfdetr/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c299cbb02686a10e8c5961571b36aa2dd5547f46 --- /dev/null +++ b/rfdetr/util/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ diff --git a/rfdetr/util/__pycache__/__init__.cpython-313.pyc b/rfdetr/util/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e703cc1df956c44aa1c1271d125401eff9ed3ef Binary files /dev/null and b/rfdetr/util/__pycache__/__init__.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/benchmark.cpython-313.pyc b/rfdetr/util/__pycache__/benchmark.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8315d382bb4e32c43c2307ae8370133d88d7e5fb Binary files /dev/null and b/rfdetr/util/__pycache__/benchmark.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/box_ops.cpython-313.pyc b/rfdetr/util/__pycache__/box_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f73714178597ac5dad7755fe16b5acd4edfe3a88 Binary files /dev/null and b/rfdetr/util/__pycache__/box_ops.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/coco_classes.cpython-313.pyc b/rfdetr/util/__pycache__/coco_classes.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a295be082ff7dd789a29a2ba377ccdf91ef936f Binary files /dev/null and b/rfdetr/util/__pycache__/coco_classes.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/drop_scheduler.cpython-313.pyc b/rfdetr/util/__pycache__/drop_scheduler.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e24a0e1b1e6630a654d2bc0442953487edc57036 Binary files /dev/null and b/rfdetr/util/__pycache__/drop_scheduler.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/early_stopping.cpython-313.pyc b/rfdetr/util/__pycache__/early_stopping.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e83df9d258d2ef27c57f0532bd339e72aec2b1 Binary files /dev/null and b/rfdetr/util/__pycache__/early_stopping.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/files.cpython-313.pyc b/rfdetr/util/__pycache__/files.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e01305007e98ec3f20292185dcaec1feef7753 Binary files /dev/null and b/rfdetr/util/__pycache__/files.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/get_param_dicts.cpython-313.pyc b/rfdetr/util/__pycache__/get_param_dicts.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1568220e3ed15f2179bdf8b773ecfea76faf2afa Binary files /dev/null and b/rfdetr/util/__pycache__/get_param_dicts.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/metrics.cpython-313.pyc b/rfdetr/util/__pycache__/metrics.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..557fa5c8886e728496ce8f976fa988c36a1d4015 Binary files /dev/null and b/rfdetr/util/__pycache__/metrics.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/misc.cpython-313.pyc b/rfdetr/util/__pycache__/misc.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115a0976455ce3a4cb41d5d5c7c37298cb820dcf Binary files /dev/null and b/rfdetr/util/__pycache__/misc.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/obj365_to_coco_model.cpython-313.pyc b/rfdetr/util/__pycache__/obj365_to_coco_model.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..011b84c12dab11e49426cc8efd1a1a25abce4604 Binary files /dev/null and b/rfdetr/util/__pycache__/obj365_to_coco_model.cpython-313.pyc differ diff --git a/rfdetr/util/__pycache__/utils.cpython-313.pyc b/rfdetr/util/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0053be8992ef0e7655dfe91c8cfad1711a752748 Binary files /dev/null and b/rfdetr/util/__pycache__/utils.cpython-313.pyc differ diff --git a/rfdetr/util/benchmark.py b/rfdetr/util/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..3d78ee42bbf563df608103b9562e5b1bd9665ac4 --- /dev/null +++ b/rfdetr/util/benchmark.py @@ -0,0 +1,634 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# taken from https://gist.github.com/fmassa/c0fbb9fe7bf53b533b5cc241f5c8234c with a few modifications +# ------------------------------------------------------------------------ +# taken from detectron2 / fvcore with a few modifications +# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/analysis.py +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +from collections import OrderedDict, Counter, defaultdict +import json +import os +import pdb +from posixpath import join +import sys + + +sys.path.append(os.path.dirname(sys.path[0])) + +import numpy as np +from numpy import prod +from itertools import zip_longest +import tqdm +import logging +import typing +import torch +import torch.nn as nn +from functools import partial +import time + + +from typing import Any, Callable, List, Optional, Union +from numbers import Number + +Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], Number]] + + +def get_shape(val: object) -> typing.List[int]: + """ + Get the shapes from a jit value object. + Args: + val (torch._C.Value): jit value object. + Returns: + list(int): return a list of ints. + """ + if val.isCompleteTensor(): # pyre-ignore + r = val.type().sizes() # pyre-ignore + if not r: + r = [1] + return r + elif val.type().kind() in ("IntType", "FloatType"): + return [1] + elif val.type().kind() in ("StringType",): + return [0] + elif val.type().kind() in ("ListType",): + return [1] + elif val.type().kind() in ("BoolType", "NoneType"): + return [0] + else: + raise ValueError() + + +def addmm_flop_jit( + inputs: typing.List[object], outputs: typing.List[object] +) -> typing.Counter[str]: + """ + This method counts the flops for fully connected layers with torch script. + Args: + inputs (list(torch._C.Value)): The input shape in the form of a list of + jit object. + outputs (list(torch._C.Value)): The output shape in the form of a list + of jit object. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [get_shape(v) for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [batch size, output feature dimension] + assert len(input_shapes[0]) == 2 + assert len(input_shapes[1]) == 2 + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flop = batch_size * input_dim * output_dim + flop_counter = Counter({"addmm": flop}) + return flop_counter + + +def bmm_flop_jit(inputs, outputs): + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [get_shape(v) for v in inputs] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [batch size, output feature dimension] + assert len(input_shapes[0]) == 3 + assert len(input_shapes[1]) == 3 + T, batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][2] + flop = T * batch_size * input_dim * output_dim + flop_counter = Counter({"bmm": flop}) + return flop_counter + + +def basic_binary_op_flop_jit(inputs, outputs, name): + input_shapes = [get_shape(v) for v in inputs] + # for broadcasting + input_shapes = [s[::-1] for s in input_shapes] + max_shape = np.array(list(zip_longest(*input_shapes, fillvalue=1))).max(1) + flop = prod(max_shape) + flop_counter = Counter({name: flop}) + return flop_counter + + +def rsqrt_flop_jit(inputs, outputs): + input_shapes = [get_shape(v) for v in inputs] + flop = prod(input_shapes[0]) * 2 + flop_counter = Counter({"rsqrt": flop}) + return flop_counter + + +def dropout_flop_jit(inputs, outputs): + input_shapes = [get_shape(v) for v in inputs[:1]] + flop = prod(input_shapes[0]) + flop_counter = Counter({"dropout": flop}) + return flop_counter + + +def softmax_flop_jit(inputs, outputs): + # from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/profiler/internal/flops_registry.py + input_shapes = [get_shape(v) for v in inputs[:1]] + flop = prod(input_shapes[0]) * 5 + flop_counter = Counter({"softmax": flop}) + return flop_counter + + +def _reduction_op_flop_jit(inputs, outputs, reduce_flops=1, finalize_flops=0): + input_shapes = [get_shape(v) for v in inputs] + output_shapes = [get_shape(v) for v in outputs] + + in_elements = prod(input_shapes[0]) + out_elements = prod(output_shapes[0]) + + num_flops = in_elements * reduce_flops + out_elements * ( + finalize_flops - reduce_flops + ) + + return num_flops + + +def conv_flop_count( + x_shape: typing.List[int], + w_shape: typing.List[int], + out_shape: typing.List[int], +) -> typing.Counter[str]: + """ + This method counts the flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + batch_size, Cin_dim, Cout_dim = x_shape[0], w_shape[1], out_shape[1] + out_size = prod(out_shape[2:]) + kernel_size = prod(w_shape[2:]) + flop = batch_size * out_size * Cout_dim * Cin_dim * kernel_size + flop_counter = Counter({"conv": flop}) + return flop_counter + + +def conv_flop_jit( + inputs: typing.List[object], outputs: typing.List[object] +) -> typing.Counter[str]: + """ + This method counts the flops for convolution using torch script. + Args: + inputs (list(torch._C.Value)): The input shape in the form of a list of + jit object before convolution. + outputs (list(torch._C.Value)): The output shape in the form of a list + of jit object after convolution. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + # Inputs of Convolution should be a list of length 12. They represent: + # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding, + # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn, + # 10) deterministic_cudnn and 11) user_enabled_cudnn. + # import ipdb; ipdb.set_trace() + # assert len(inputs) == 12 + x, w = inputs[:2] + x_shape, w_shape, out_shape = ( + get_shape(x), + get_shape(w), + get_shape(outputs[0]), + ) + return conv_flop_count(x_shape, w_shape, out_shape) + + +def einsum_flop_jit( + inputs: typing.List[object], outputs: typing.List[object] +) -> typing.Counter[str]: + """ + This method counts the flops for the einsum operation. We currently support + two einsum operations: "nct,ncp->ntp" and "ntg,ncg->nct". + Args: + inputs (list(torch._C.Value)): The input shape in the form of a list of + jit object before einsum. + outputs (list(torch._C.Value)): The output shape in the form of a list + of jit object after einsum. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + # Inputs of einsum should be a list of length 2. + # Inputs[0] stores the equation used for einsum. + # Inputs[1] stores the list of input shapes. + assert len(inputs) == 2 + equation = inputs[0].toIValue() # pyre-ignore + # Get rid of white space in the equation string. + equation = equation.replace(" ", "") + # Re-map equation so that same equation with different alphabet + # representations will look the same. + letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() + mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} + equation = equation.translate(mapping) + input_shapes_jit = inputs[1].node().inputs() # pyre-ignore + input_shapes = [get_shape(v) for v in input_shapes_jit] + + if equation == "abc,abd->acd": + n, c, t = input_shapes[0] + p = input_shapes[-1][-1] + flop = n * c * t * p + flop_counter = Counter({"einsum": flop}) + return flop_counter + + elif equation == "abc,adc->adb": + n, t, g = input_shapes[0] + c = input_shapes[-1][1] + flop = n * t * g * c + flop_counter = Counter({"einsum": flop}) + return flop_counter + + else: + raise NotImplementedError("Unsupported einsum operation.") + + +def matmul_flop_jit( + inputs: typing.List[object], outputs: typing.List[object] +) -> typing.Counter[str]: + """ + This method counts the flops for matmul. + Args: + inputs (list(torch._C.Value)): The input shape in the form of a list of + jit object before matmul. + outputs (list(torch._C.Value)): The output shape in the form of a list + of jit object after matmul. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + + # Inputs contains the shapes of two matrices. + input_shapes = [get_shape(v) for v in inputs] + assert len(input_shapes) == 2 + assert input_shapes[0][-1] == input_shapes[1][-2] + + dim_len = len(input_shapes[1]) + assert dim_len >= 2 + batch = 1 + for i in range(dim_len - 2): + assert input_shapes[0][i] == input_shapes[1][i] + batch *= input_shapes[0][i] + + # (b,m,c) x (b,c,n), flop = bmnc + flop = batch * input_shapes[0][-2] * input_shapes[0][-1] * input_shapes[1][-1] + flop_counter = Counter({"matmul": flop}) + return flop_counter + + +def batchnorm_flop_jit( + inputs: typing.List[object], outputs: typing.List[object] +) -> typing.Counter[str]: + """ + This method counts the flops for batch norm. + Args: + inputs (list(torch._C.Value)): The input shape in the form of a list of + jit object before batch norm. + outputs (list(torch._C.Value)): The output shape in the form of a list + of jit object after batch norm. + Returns: + Counter: A Counter dictionary that records the number of flops for each + operation. + """ + # Inputs[0] contains the shape of the input. + input_shape = get_shape(inputs[0]) + assert 2 <= len(input_shape) <= 5 + flop = prod(input_shape) * 4 + flop_counter = Counter({"batchnorm": flop}) + return flop_counter + + +def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the aten::linear operator. + """ + # Inputs is a list of length 3; unlike aten::addmm, it is the first + # two elements that are relevant. + input_shapes = [get_shape(v) for v in inputs[0:2]] + # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] + # input_shapes[1]: [output_feature_dim, input_feature_dim] + assert input_shapes[0][-1] == input_shapes[1][-1] + flops = prod(input_shapes[0]) * input_shapes[1][0] + flop_counter = Counter({"linear": flops}) + return flop_counter + + +def norm_flop_counter(affine_arg_index: int) -> Handle: + """ + Args: + affine_arg_index: index of the affine argument in inputs + """ + + def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for norm layers. + """ + # Inputs[0] contains the shape of the input. + input_shape = get_shape(inputs[0]) + has_affine = get_shape(inputs[affine_arg_index]) is not None + assert 2 <= len(input_shape) <= 5, input_shape + # 5 is just a rough estimate + flop = prod(input_shape) * (5 if has_affine else 4) + flop_counter = Counter({"norm": flop}) + return flop_counter + + return norm_flop_jit + + +def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Handle: + """ + Count flops by + input_tensor.numel() * input_scale + output_tensor.numel() * output_scale + + Args: + input_scale: scale of the input tensor (first argument) + output_scale: scale of the output tensor (first element in outputs) + """ + + def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: + ret = 0 + if input_scale != 0: + shape = get_shape(inputs[0]) + ret += input_scale * prod(shape) + if output_scale != 0: + shape = get_shape(outputs[0]) + ret += output_scale * prod(shape) + flop_counter = Counter({"elementwise": ret}) + return flop_counter + + return elementwise_flop + + +# A dictionary that maps supported operations to their flop count jit handles. +_SUPPORTED_OPS: typing.Dict[str, typing.Callable] = { + "aten::addmm": addmm_flop_jit, + "aten::_convolution": conv_flop_jit, + "aten::einsum": einsum_flop_jit, + "aten::matmul": matmul_flop_jit, + "aten::batch_norm": batchnorm_flop_jit, + "aten::bmm": bmm_flop_jit, + "aten::add": partial(basic_binary_op_flop_jit, name="aten::add"), + "aten::add_": partial(basic_binary_op_flop_jit, name="aten::add_"), + "aten::mul": partial(basic_binary_op_flop_jit, name="aten::mul"), + "aten::sub": partial(basic_binary_op_flop_jit, name="aten::sub"), + "aten::div": partial(basic_binary_op_flop_jit, name="aten::div"), + "aten::floor_divide": partial(basic_binary_op_flop_jit, name="aten::floor_divide"), + "aten::relu": partial(basic_binary_op_flop_jit, name="aten::relu"), + "aten::relu_": partial(basic_binary_op_flop_jit, name="aten::relu_"), + "aten::sigmoid": partial(basic_binary_op_flop_jit, name="aten::sigmoid"), + "aten::log": partial(basic_binary_op_flop_jit, name="aten::log"), + "aten::sum": partial(basic_binary_op_flop_jit, name="aten::sum"), + "aten::sin": partial(basic_binary_op_flop_jit, name="aten::sin"), + "aten::cos": partial(basic_binary_op_flop_jit, name="aten::cos"), + "aten::pow": partial(basic_binary_op_flop_jit, name="aten::pow"), + "aten::cumsum": partial(basic_binary_op_flop_jit, name="aten::cumsum"), + "aten::rsqrt": rsqrt_flop_jit, + "aten::softmax": softmax_flop_jit, + "aten::dropout": dropout_flop_jit, + "aten::linear": linear_flop_jit, + "aten::group_norm": norm_flop_counter(2), + "aten::layer_norm": norm_flop_counter(2), + "aten::instance_norm": norm_flop_counter(1), + "aten::upsample_nearest2d": elementwise_flop_counter(0, 1), + "aten::upsample_bilinear2d": elementwise_flop_counter(0, 4), + "aten::adaptive_avg_pool2d": elementwise_flop_counter(1, 0), + "aten::max_pool2d": elementwise_flop_counter(1, 0), + "aten::mm": matmul_flop_jit, +} + + +# A list that contains ignored operations. +_IGNORED_OPS: typing.List[str] = [ + "aten::Int", + "aten::__and__", + "aten::arange", + "aten::cat", + "aten::clamp", + "aten::clamp_", + "aten::contiguous", + "aten::copy_", + "aten::detach", + "aten::empty", + "aten::eq", + "aten::expand", + "aten::flatten", + "aten::floor", + "aten::full", + "aten::gt", + "aten::index", + "aten::index_put_", + "aten::max", + "aten::nonzero", + "aten::permute", + "aten::remainder", + "aten::reshape", + "aten::select", + "aten::gather", + "aten::topk", + "aten::meshgrid", + "aten::masked_fill", + "aten::linspace", + "aten::size", + "aten::slice", + "aten::split_with_sizes", + "aten::squeeze", + "aten::t", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "aten::zeros", + "aten::zeros_like", + "aten::ones_like", + "aten::new_zeros", + "aten::all", + "prim::Constant", + "prim::Int", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::NumToTensor", + "prim::TupleConstruct", + "aten::stack", + "aten::chunk", + "aten::repeat", + "aten::grid_sampler", + "aten::constant_pad_nd", +] + +_HAS_ALREADY_SKIPPED = False + + +def flop_count( + model: nn.Module, + inputs: typing.Tuple[object, ...], + whitelist: typing.Union[typing.List[str], None] = None, + customized_ops: typing.Union[typing.Dict[str, typing.Callable], None] = None, +) -> typing.DefaultDict[str, float]: + """ + Given a model and an input to the model, compute the Gflops of the given + model. Note the input should have a batch size of 1. + Args: + model (nn.Module): The model to compute flop counts. + inputs (tuple): Inputs that are passed to `model` to count flops. + Inputs need to be in a tuple. + whitelist (list(str)): Whitelist of operations that will be counted. It + needs to be a subset of _SUPPORTED_OPS. By default, the function + computes flops for all supported operations. + customized_ops (dict(str,Callable)) : A dictionary contains customized + operations and their flop handles. If customized_ops contains an + operation in _SUPPORTED_OPS, then the default handle in + _SUPPORTED_OPS will be overwritten. + Returns: + defaultdict: A dictionary that records the number of gflops for each + operation. + """ + # Copy _SUPPORTED_OPS to flop_count_ops. + # If customized_ops is provided, update _SUPPORTED_OPS. + flop_count_ops = _SUPPORTED_OPS.copy() + if customized_ops: + flop_count_ops.update(customized_ops) + + # If whitelist is None, count flops for all suported operations. + if whitelist is None: + whitelist_set = set(flop_count_ops.keys()) + else: + whitelist_set = set(whitelist) + + # Torch script does not support parallell torch models. + if isinstance( + model, + (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel), + ): + model = model.module # pyre-ignore + + assert set(whitelist_set).issubset( + flop_count_ops + ), "whitelist needs to be a subset of _SUPPORTED_OPS and customized_ops." + assert isinstance(inputs, tuple), "Inputs need to be in a tuple." + + # Compatibility with torch.jit. + if hasattr(torch.jit, "get_trace_graph"): + trace, _ = torch.jit.get_trace_graph(model, inputs) + trace_nodes = trace.graph().nodes() + else: + trace, _ = torch.jit._get_trace_graph(model, inputs) + trace_nodes = trace.nodes() + + skipped_ops = Counter() + total_flop_counter = Counter() + + for node in trace_nodes: + kind = node.kind() + if kind not in whitelist_set: + # If the operation is not in _IGNORED_OPS, count skipped operations. + if kind not in _IGNORED_OPS: + skipped_ops[kind] += 1 + continue + + handle_count = flop_count_ops.get(kind, None) + if handle_count is None: + continue + + inputs, outputs = list(node.inputs()), list(node.outputs()) + flops_counter = handle_count(inputs, outputs) + total_flop_counter += flops_counter + + global _HAS_ALREADY_SKIPPED + if len(skipped_ops) > 0 and not _HAS_ALREADY_SKIPPED: + _HAS_ALREADY_SKIPPED = True + for op, freq in skipped_ops.items(): + logging.warning("Skipped operation {} {} time(s)".format(op, freq)) + + # Convert flop count to gigaflops. + final_count = defaultdict(float) + for op in total_flop_counter: + final_count[op] = total_flop_counter[op] / 1e9 + + return final_count + + +def warmup(model, inputs, N=10): + for i in range(N): + out = model(inputs) + torch.cuda.synchronize() + + +def measure_time(model, inputs, N=10): + warmup(model, inputs) + s = time.time() + for i in range(N): + out = model(inputs) + torch.cuda.synchronize() + t = (time.time() - s) / N + return t + + +def fmt_res(data): + # return data.mean(), data.std(), data.min(), data.max() + return { + "mean": data.mean(), + "std": data.std(), + "min": data.min(), + "max": data.max(), + } + + +def benchmark(model, dataset, output_dir): + print("Get model size, FLOPs, and FPS") + # import pdb; pdb.set_trace() + _outputs = {} + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + _outputs.update({"nparam": n_parameters}) + + model.cuda() + model.eval() + + warmup_step = 5 + total_step = 20 + + images = [] + for idx in range(total_step): + img, t = dataset[idx] + images.append(img) + # import pdb; pdb.set_trace() + with torch.no_grad(): + tmp = [] + tmp2 = [] + for imgid, img in enumerate(tqdm.tqdm(images)): + inputs = [img.to("cuda")] + res = flop_count(model, (inputs,)) + t = measure_time(model, inputs) + tmp.append(sum(res.values())) + if imgid >= warmup_step: + tmp2.append(t) + _outputs.update({"detailed_flops": res}) + _outputs.update({"flops": fmt_res(np.array(tmp)), "time": fmt_res(np.array(tmp2))}) + + mean_infer_time = float(fmt_res(np.array(tmp2))["mean"]) + _outputs.update({"fps": 1 / mean_infer_time}) + + res = {"flops": fmt_res(np.array(tmp)), "time": fmt_res(np.array(tmp2))} + # print(res) + + output_file = os.path.join(output_dir, "flops", "log.txt") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with (output_dir / "log.txt").open("a") as f: + f.write("Test benchmark on Val Dataset" + "\n") + f.write(json.dumps(_outputs, indent=2) + "\n") + + return _outputs + + +# if __name__ == "__main__": +# res = benchmark() +# print(json.dumps(res, indent=2)) diff --git a/rfdetr/util/box_ops.py b/rfdetr/util/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3c267896b6c30a1dcf0e0e100eaae3e55f0bc5 --- /dev/null +++ b/rfdetr/util/box_ops.py @@ -0,0 +1,98 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)), + (x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/rfdetr/util/coco_classes.py b/rfdetr/util/coco_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..b09dd9718df9d3bd9be7631caef55d01c71416ee --- /dev/null +++ b/rfdetr/util/coco_classes.py @@ -0,0 +1,82 @@ +COCO_CLASSES = { + 1: "person", + 2: "bicycle", + 3: "car", + 4: "motorcycle", + 5: "airplane", + 6: "bus", + 7: "train", + 8: "truck", + 9: "boat", + 10: "traffic light", + 11: "fire hydrant", + 13: "stop sign", + 14: "parking meter", + 15: "bench", + 16: "bird", + 17: "cat", + 18: "dog", + 19: "horse", + 20: "sheep", + 21: "cow", + 22: "elephant", + 23: "bear", + 24: "zebra", + 25: "giraffe", + 27: "backpack", + 28: "umbrella", + 31: "handbag", + 32: "tie", + 33: "suitcase", + 34: "frisbee", + 35: "skis", + 36: "snowboard", + 37: "sports ball", + 38: "kite", + 39: "baseball bat", + 40: "baseball glove", + 41: "skateboard", + 42: "surfboard", + 43: "tennis racket", + 44: "bottle", + 46: "wine glass", + 47: "cup", + 48: "fork", + 49: "knife", + 50: "spoon", + 51: "bowl", + 52: "banana", + 53: "apple", + 54: "sandwich", + 55: "orange", + 56: "broccoli", + 57: "carrot", + 58: "hot dog", + 59: "pizza", + 60: "donut", + 61: "cake", + 62: "chair", + 63: "couch", + 64: "potted plant", + 65: "bed", + 67: "dining table", + 70: "toilet", + 72: "tv", + 73: "laptop", + 74: "mouse", + 75: "remote", + 76: "keyboard", + 77: "cell phone", + 78: "microwave", + 79: "oven", + 80: "toaster", + 81: "sink", + 82: "refrigerator", + 84: "book", + 85: "clock", + 86: "vase", + 87: "scissors", + 88: "teddy bear", + 89: "hair drier", + 90: "toothbrush", +} diff --git a/rfdetr/util/drop_scheduler.py b/rfdetr/util/drop_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0f65665c5c0df59a906f8644c50c174e9f2deb6a --- /dev/null +++ b/rfdetr/util/drop_scheduler.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +"""util for drop scheduler.""" +import numpy as np + + +def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'): + """drop scheduler""" + assert mode in ['standard', 'early', 'late'] + if mode == 'standard': + return np.full(epochs * niter_per_ep, drop_rate) + + early_iters = cutoff_epoch * niter_per_ep + late_iters = (epochs - cutoff_epoch) * niter_per_ep + + if mode == 'early': + assert schedule in ['constant', 'linear'] + if schedule == 'constant': + early_schedule = np.full(early_iters, drop_rate) + elif schedule == 'linear': + early_schedule = np.linspace(drop_rate, 0, early_iters) + final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) + elif mode == 'late': + assert schedule in ['constant'] + early_schedule = np.full(early_iters, 0) + final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate))) + + assert len(final_schedule) == epochs * niter_per_ep + return final_schedule diff --git a/rfdetr/util/early_stopping.py b/rfdetr/util/early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..30bf8885fab5516bf0296423a9986edf5b7f6dec --- /dev/null +++ b/rfdetr/util/early_stopping.py @@ -0,0 +1,75 @@ +""" +Early stopping callback for RF-DETR training +""" + +from logging import getLogger + +logger = getLogger(__name__) + +class EarlyStoppingCallback: + """ + Early stopping callback that monitors mAP and stops training if no improvement + over a threshold is observed for a specified number of epochs. + + Args: + patience (int): Number of epochs with no improvement to wait before stopping + min_delta (float): Minimum change in mAP to qualify as improvement + use_ema (bool): Whether to use EMA model metrics for early stopping + verbose (bool): Whether to print early stopping messages + """ + + def __init__(self, model, patience=5, min_delta=0.001, use_ema=False, verbose=True): + self.patience = patience + self.min_delta = min_delta + self.use_ema = use_ema + self.verbose = verbose + self.best_map = 0.0 + self.counter = 0 + self.model = model + + def update(self, log_stats): + """Update early stopping state based on epoch validation metrics""" + regular_map = None + ema_map = None + + if 'test_coco_eval_bbox' in log_stats: + regular_map = log_stats['test_coco_eval_bbox'][0] + + if 'ema_test_coco_eval_bbox' in log_stats: + ema_map = log_stats['ema_test_coco_eval_bbox'][0] + + current_map = None + if regular_map is not None and ema_map is not None: + if self.use_ema: + current_map = ema_map + metric_source = "EMA" + else: + current_map = max(regular_map, ema_map) + metric_source = "max(regular, EMA)" + elif ema_map is not None: + current_map = ema_map + metric_source = "EMA" + elif regular_map is not None: + current_map = regular_map + metric_source = "regular" + else: + if self.verbose: + raise ValueError("No valid mAP metric found!") + return + + if self.verbose: + print(f"Early stopping: Current mAP ({metric_source}): {current_map:.4f}, Best: {self.best_map:.4f}, Diff: {current_map - self.best_map:.4f}, Min delta: {self.min_delta}") + + if current_map > self.best_map + self.min_delta: + self.best_map = current_map + self.counter = 0 + logger.info(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric") + else: + self.counter += 1 + if self.verbose: + print(f"Early stopping: No improvement in mAP for {self.counter} epochs (best: {self.best_map:.4f}, current: {current_map:.4f})") + + if self.counter >= self.patience: + print(f"Early stopping triggered: No improvement above {self.min_delta} threshold for {self.patience} epochs") + if self.model: + self.model.request_early_stop() \ No newline at end of file diff --git a/rfdetr/util/files.py b/rfdetr/util/files.py new file mode 100644 index 0000000000000000000000000000000000000000..554bf3cbee60d6dc687926ba4ec666f4c8aa3184 --- /dev/null +++ b/rfdetr/util/files.py @@ -0,0 +1,17 @@ +import requests +from tqdm import tqdm +from logging import getLogger + +def download_file(url, filename): + response = requests.get(url, stream=True) + total_size = int(response.headers['content-length']) + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for data in response.iter_content(chunk_size=1024): + size = f.write(data) + pbar.update(size) diff --git a/rfdetr/util/get_param_dicts.py b/rfdetr/util/get_param_dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..aac01499b09ccb7c6e928301e5dcf168d8a9ed97 --- /dev/null +++ b/rfdetr/util/get_param_dicts.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Functions to get params dict""" +import torch.nn as nn + +from rfdetr.models.backbone import Joiner + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + print("name: {}, lr_decay: {}".format(name, lr_decay_rate ** (num_layers + 1 - layer_id))) + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_vit_weight_decay_rate(name, weight_decay_rate=1.0): + if ('gamma' in name) or ('pos_embed' in name) or ('rel_pos' in name) or ('bias' in name) or ('norm' in name): + weight_decay_rate = 0. + print("name: {}, weight_decay rate: {}".format(name, weight_decay_rate)) + return weight_decay_rate + + +def get_param_dict(args, model_without_ddp: nn.Module): + assert isinstance(model_without_ddp.backbone, Joiner) + backbone = model_without_ddp.backbone[0] + backbone_named_param_lr_pairs = backbone.get_named_param_lr_pairs(args, prefix="backbone.0") + backbone_param_lr_pairs = [param_dict for _, param_dict in backbone_named_param_lr_pairs.items()] + + decoder_key = 'transformer.decoder' + decoder_params = [ + p + for n, p in model_without_ddp.named_parameters() if decoder_key in n and p.requires_grad + ] + + decoder_param_lr_pairs = [ + {"params": param, "lr": args.lr * args.lr_component_decay} + for param in decoder_params + ] + + other_params = [ + p + for n, p in model_without_ddp.named_parameters() if ( + n not in backbone_named_param_lr_pairs and decoder_key not in n and p.requires_grad) + ] + other_param_dicts = [ + {"params": param, "lr": args.lr} + for param in other_params + ] + + final_param_dicts = ( + other_param_dicts + backbone_param_lr_pairs + decoder_param_lr_pairs + ) + + return final_param_dicts diff --git a/rfdetr/util/metrics.py b/rfdetr/util/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9cf49adf8aba6395533db401787e7e94140c03 --- /dev/null +++ b/rfdetr/util/metrics.py @@ -0,0 +1,243 @@ +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np + +try: + from torch.utils.tensorboard import SummaryWriter +except ModuleNotFoundError: + SummaryWriter = None + +try: + import wandb +except ModuleNotFoundError: + wandb = None + +plt.ioff() + +PLOT_FILE_NAME = "metrics_plot.png" + + +def safe_index(arr, idx): + return arr[idx] if 0 <= idx < len(arr) else None + + +class MetricsPlotSink: + """ + The MetricsPlotSink class records training metrics and saves them to a plot. + + Args: + output_dir (str): Directory where the plot will be saved. + """ + + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.history = [] + + def update(self, values: dict): + self.history.append(values) + + def save(self): + if not self.history: + print("No data to plot.") + return + + def get_array(key): + return np.array([h[key] for h in self.history if key in h]) + + epochs = get_array('epoch') + train_loss = get_array('train_loss') + test_loss = get_array('test_loss') + test_coco_eval = [h['test_coco_eval_bbox'] for h in self.history if 'test_coco_eval_bbox' in h] + ap50_90 = np.array([safe_index(x, 0) for x in test_coco_eval if x is not None], dtype=np.float32) + ap50 = np.array([safe_index(x, 1) for x in test_coco_eval if x is not None], dtype=np.float32) + ar50_90 = np.array([safe_index(x, 8) for x in test_coco_eval if x is not None], dtype=np.float32) + + ema_coco_eval = [h['ema_test_coco_eval_bbox'] for h in self.history if 'ema_test_coco_eval_bbox' in h] + ema_ap50_90 = np.array([safe_index(x, 0) for x in ema_coco_eval if x is not None], dtype=np.float32) + ema_ap50 = np.array([safe_index(x, 1) for x in ema_coco_eval if x is not None], dtype=np.float32) + ema_ar50_90 = np.array([safe_index(x, 8) for x in ema_coco_eval if x is not None], dtype=np.float32) + + fig, axes = plt.subplots(2, 2, figsize=(18, 12)) + + # Subplot (0,0): Training and Validation Loss + if len(epochs) > 0: + if len(train_loss): + axes[0][0].plot(epochs, train_loss, label='Training Loss', marker='o', linestyle='-') + if len(test_loss): + axes[0][0].plot(epochs, test_loss, label='Validation Loss', marker='o', linestyle='--') + axes[0][0].set_title('Training and Validation Loss') + axes[0][0].set_xlabel('Epoch Number') + axes[0][0].set_ylabel('Loss Value') + axes[0][0].legend() + axes[0][0].grid(True) + + # Subplot (0,1): Average Precision @0.50 + if ap50.size > 0 or ema_ap50.size > 0: + if ap50.size > 0: + axes[0][1].plot(epochs[:len(ap50)], ap50, marker='o', linestyle='-', label='Base Model') + if ema_ap50.size > 0: + axes[0][1].plot(epochs[:len(ema_ap50)], ema_ap50, marker='o', linestyle='--', label='EMA Model') + axes[0][1].set_title('Average Precision @0.50') + axes[0][1].set_xlabel('Epoch Number') + axes[0][1].set_ylabel('AP50') + axes[0][1].legend() + axes[0][1].grid(True) + + # Subplot (1,0): Average Precision @0.50:0.95 + if ap50_90.size > 0 or ema_ap50_90.size > 0: + if ap50_90.size > 0: + axes[1][0].plot(epochs[:len(ap50_90)], ap50_90, marker='o', linestyle='-', label='Base Model') + if ema_ap50_90.size > 0: + axes[1][0].plot(epochs[:len(ema_ap50_90)], ema_ap50_90, marker='o', linestyle='--', label='EMA Model') + axes[1][0].set_title('Average Precision @0.50:0.95') + axes[1][0].set_xlabel('Epoch Number') + axes[1][0].set_ylabel('AP') + axes[1][0].legend() + axes[1][0].grid(True) + + # Subplot (1,1): Average Recall @0.50:0.95 + if ar50_90.size > 0 or ema_ar50_90.size > 0: + if ar50_90.size > 0: + axes[1][1].plot(epochs[:len(ar50_90)], ar50_90, marker='o', linestyle='-', label='Base Model') + if ema_ar50_90.size > 0: + axes[1][1].plot(epochs[:len(ema_ar50_90)], ema_ar50_90, marker='o', linestyle='--', label='EMA Model') + axes[1][1].set_title('Average Recall @0.50:0.95') + axes[1][1].set_xlabel('Epoch Number') + axes[1][1].set_ylabel('AR') + axes[1][1].legend() + axes[1][1].grid(True) + + plt.tight_layout() + plt.savefig(f"{self.output_dir}/{PLOT_FILE_NAME}") + plt.close(fig) + print(f"Results saved to {self.output_dir}/{PLOT_FILE_NAME}") + + +class MetricsTensorBoardSink: + """ + Training metrics via TensorBoard. + + Args: + output_dir (str): Directory where TensorBoard logs will be written. + """ + + def __init__(self, output_dir: str): + if SummaryWriter: + self.writer = SummaryWriter(log_dir=output_dir) + print(f"TensorBoard logging initialized. To monitor logs, use 'tensorboard --logdir {output_dir}' and open http://localhost:6006/ in browser.") + else: + self.writer = None + print("Unable to initialize TensorBoard. Logging is turned off for this session. Run 'pip install tensorboard' to enable logging.") + + def update(self, values: dict): + if not self.writer: + return + + epoch = values['epoch'] + + if 'train_loss' in values: + self.writer.add_scalar("Loss/Train", values['train_loss'], epoch) + if 'test_loss' in values: + self.writer.add_scalar("Loss/Test", values['test_loss'], epoch) + + if 'test_coco_eval_bbox' in values: + coco_eval = values['test_coco_eval_bbox'] + ap50_90 = safe_index(coco_eval, 0) + ap50 = safe_index(coco_eval, 1) + ar50_90 = safe_index(coco_eval, 8) + if ap50_90 is not None: + self.writer.add_scalar("Metrics/Base/AP50_90", ap50_90, epoch) + if ap50 is not None: + self.writer.add_scalar("Metrics/Base/AP50", ap50, epoch) + if ar50_90 is not None: + self.writer.add_scalar("Metrics/Base/AR50_90", ar50_90, epoch) + + if 'ema_test_coco_eval_bbox' in values: + ema_coco_eval = values['ema_test_coco_eval_bbox'] + ema_ap50_90 = safe_index(ema_coco_eval, 0) + ema_ap50 = safe_index(ema_coco_eval, 1) + ema_ar50_90 = safe_index(ema_coco_eval, 8) + if ema_ap50_90 is not None: + self.writer.add_scalar("Metrics/EMA/AP50_90", ema_ap50_90, epoch) + if ema_ap50 is not None: + self.writer.add_scalar("Metrics/EMA/AP50", ema_ap50, epoch) + if ema_ar50_90 is not None: + self.writer.add_scalar("Metrics/EMA/AR50_90", ema_ar50_90, epoch) + + self.writer.flush() + + def close(self): + if not self.writer: + return + + self.writer.close() + +class MetricsWandBSink: + """ + Training metrics via W&B. + + Args: + output_dir (str): Directory where W&B logs will be written locally. + project (str, optional): Associate this training run with a W&B project. If None, W&B will generate a name based on the git repo name. + run (str, optional): W&B run name. If None, W&B will generate a random name. + config (dict, optional): Input parameters, like hyperparameters or data preprocessing settings for the run for later comparison. + """ + + def __init__(self, output_dir: str, project: Optional[str] = None, run: Optional[str] = None, config: Optional[dict] = None): + self.output_dir = output_dir + if wandb: + self.run = wandb.init( + project=project, + name=run, + config=config, + dir=output_dir + ) + print(f"W&B logging initialized. To monitor logs, open {wandb.run.url}.") + else: + self.run = None + print("Unable to initialize W&B. Logging is turned off for this session. Run 'pip install wandb' to enable logging.") + + def update(self, values: dict): + if not wandb or not self.run: + return + + epoch = values['epoch'] + log_dict = {"epoch": epoch} + + if 'train_loss' in values: + log_dict["Loss/Train"] = values['train_loss'] + if 'test_loss' in values: + log_dict["Loss/Test"] = values['test_loss'] + + if 'test_coco_eval_bbox' in values: + coco_eval = values['test_coco_eval_bbox'] + ap50_90 = safe_index(coco_eval, 0) + ap50 = safe_index(coco_eval, 1) + ar50_90 = safe_index(coco_eval, 8) + if ap50_90 is not None: + log_dict["Metrics/Base/AP50_90"] = ap50_90 + if ap50 is not None: + log_dict["Metrics/Base/AP50"] = ap50 + if ar50_90 is not None: + log_dict["Metrics/Base/AR50_90"] = ar50_90 + + if 'ema_test_coco_eval_bbox' in values: + ema_coco_eval = values['ema_test_coco_eval_bbox'] + ema_ap50_90 = safe_index(ema_coco_eval, 0) + ema_ap50 = safe_index(ema_coco_eval, 1) + ema_ar50_90 = safe_index(ema_coco_eval, 8) + if ema_ap50_90 is not None: + log_dict["Metrics/EMA/AP50_90"] = ema_ap50_90 + if ema_ap50 is not None: + log_dict["Metrics/EMA/AP50"] = ema_ap50 + if ema_ar50_90 is not None: + log_dict["Metrics/EMA/AR50_90"] = ema_ar50_90 + + wandb.log(log_dict) + + def close(self): + if not wandb or not self.run: + return + + self.run.finish() \ No newline at end of file diff --git a/rfdetr/util/misc.py b/rfdetr/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..237cb5ef08472906129c921843013ae2135d3312 --- /dev/null +++ b/rfdetr/util/misc.py @@ -0,0 +1,506 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import datetime +import os +import pickle +import subprocess +import time +from collections import defaultdict, deque +from typing import Optional, List + +import torch +import torch.distributed as dist +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + +if float(torchvision.__version__.split(".")[1]) < 7.0: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t", wandb_logging=False): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + if wandb_logging: + import wandb + self.wandb = wandb + else: + self.wandb = None + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if self.wandb: + if is_main_process(): + log_dict = {k: v.value for k, v in self.meters.items()} + self.wandb.log(log_dict) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(obj, f, *args, **kwargs): + """ + Safely save objects, removing any callbacks that can't be pickled + """ + if is_main_process(): + torch.save(obj, f, *args, **kwargs) + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split(".")[1]) < 7.0: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + + +def strip_checkpoint(checkpoint): + state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False) + new_state_dict = { + 'model': state_dict['model'], + 'args': state_dict['args'], + } + torch.save(new_state_dict, checkpoint) \ No newline at end of file diff --git a/rfdetr/util/obj365_to_coco_model.py b/rfdetr/util/obj365_to_coco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..26fc78b97b8f90796c5cbeeb469a5085147c9f96 --- /dev/null +++ b/rfdetr/util/obj365_to_coco_model.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------ +# LW-DETR +# Copyright (c) 2024 Baidu. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Utils to load object365 pretrain.""" + +# obj365_classes = [ +# 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', +# 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', +# 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book', 'Gloves', 'Storage box', +# 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', +# 'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', +# 'Belt', 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', +# 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', 'Barrel/bucket', 'Van', +# 'Couch', 'Sandals', 'Bakset', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', +# 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', +# 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', +# 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', +# 'Sink', 'Apple', 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', +# 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot', 'Cow', +# 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', +# 'Other Fish', 'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', +# 'Machinery Vehicle', 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', +# 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', 'Nightstand', +# 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign', 'Dessert', +# 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', +# 'Baseball Bat', 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', +# 'Elephant', 'Skateboard', 'Surfboard', 'Gun', 'Skating and Skiing shoes', 'Gas stove', +# 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', +# 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', 'Microwave', +# 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors', +# 'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', +# 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', +# 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', 'Converter', 'Bathtub', +# 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', +# 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', +# 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', +# 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', +# 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine', 'Chicken', +# 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', +# 'Hotair ballon', 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', +# 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', 'Goose', 'Tape', +# 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball', 'Ambulance', 'Parking meter', +# 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', +# 'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', +# 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', +# 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', 'Router/modem', 'Poker Card', 'Toaster', +# 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', +# 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', 'Recorder', +# 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measur/ Ruler', 'Pig', +# 'Showerhead', 'Globe', 'Chips', 'Steak', 'Crosswalk Sign', 'Stapler', 'Campel', +# 'Formula 1 ', 'Pomegranate', 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', +# 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', +# 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', +# 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', +# 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French', 'Spring Rolls', 'Monkey', +# 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', +# 'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', +# 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', +# 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis ' +# ] + +# coco_classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', +# 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', +# 'stop sign', 'parking meter', 'bench', 'wild bird', 'cat', 'dog', +# 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', +# 'backpack', 'umbrella', 'handbag/satchel', 'tie', 'luggage', 'frisbee', +# 'skating and skiing shoes', 'snowboard', 'baseball', 'kite', 'baseball bat', +# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', +# 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl/basin', +# 'banana', 'apple', 'sandwich', 'orange/tangerine', 'broccoli', 'carrot', +# 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', +# 'potted plant', 'bed', 'dinning table', 'toilet', 'moniter/tv', 'laptop', +# 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', +# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', +# 'vase', 'scissors', 'stuffed toy', 'hair dryer', 'toothbrush'] + + +def get_coco_pretrain_from_obj365(cur_tensor, pretrain_tensor): + """Get coco weights from obj365 pretrained model.""" + if pretrain_tensor.size() == cur_tensor.size(): + return pretrain_tensor + cur_tensor.requires_grad = False + coco_ids = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, + 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90 + ] + obj365_ids = [ + 0, 46, 5, 58, 114, 55, 116, 65, 21, 40, 176, 127, 249, 24, 56, 139, 92, 78, 99, 96, + 144, 295, 178, 180, 38, 39, 13, 43, 120, 219, 148, 173, 165, 154, 137, 113, 145, 146, + 204, 8, 35, 10, 88, 84, 93, 26, 112, 82, 265, 104, 141, 152, 234, 143, 150, 97, 2, + 50, 25, 75, 98, 153, 37, 73, 115, 132, 106, 61, 163, 134, 277, 81, 133, 18, 94, 30, + 169, 70, 328, 226 + ] + + for coco_id, obj_id in zip(coco_ids, obj365_ids): + cur_tensor[coco_id] = pretrain_tensor[obj_id + 1] + return cur_tensor diff --git a/rfdetr/util/utils.py b/rfdetr/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..758544b17d3d4f844a760b3feff0b3ef2a33d8aa --- /dev/null +++ b/rfdetr/util/utils.py @@ -0,0 +1,127 @@ +from copy import deepcopy +import torch +import json +from collections import OrderedDict +import math + + +class ModelEma(torch.nn.Module): + """EMA Model""" + def __init__(self, model, decay=0.9997, tau=0, device=None): + super(ModelEma, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + + self.decay = decay + self.tau = tau + self.updates = 1 + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _get_decay(self): + if self.tau == 0: + decay = self.decay + else: + decay = self.decay * (1 - math.exp(-self.updates / self.tau)) + return decay + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip( + self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + decay = self._get_decay() + self._update(model, update_fn=lambda e, m: decay * e + (1. - decay) * m) + self.updates += 1 + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + +class BestMetricSingle(): + def __init__(self, init_res=0.0, better='large') -> None: + self.init_res = init_res + self.best_res = init_res + self.best_ep = -1 + + self.better = better + assert better in ['large', 'small'] + + def isbetter(self, new_res, old_res): + if self.better == 'large': + return new_res > old_res + if self.better == 'small': + return new_res < old_res + + def update(self, new_res, ep): + if self.isbetter(new_res, self.best_res): + self.best_res = new_res + self.best_ep = ep + return True + return False + + def __str__(self) -> str: + return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep) + + def __repr__(self) -> str: + return self.__str__() + + def summary(self) -> dict: + return { + 'best_res': self.best_res, + 'best_ep': self.best_ep, + } + + +class BestMetricHolder(): + def __init__(self, init_res=0.0, better='large', use_ema=False) -> None: + self.best_all = BestMetricSingle(init_res, better) + self.use_ema = use_ema + if use_ema: + self.best_ema = BestMetricSingle(init_res, better) + self.best_regular = BestMetricSingle(init_res, better) + + def update(self, new_res, epoch, is_ema=False): + """ + return if the results is the best. + """ + if not self.use_ema: + return self.best_all.update(new_res, epoch) + else: + if is_ema: + self.best_ema.update(new_res, epoch) + return self.best_all.update(new_res, epoch) + else: + self.best_regular.update(new_res, epoch) + return self.best_all.update(new_res, epoch) + + def summary(self): + if not self.use_ema: + return self.best_all.summary() + + res = {} + res.update({f'all_{k}':v for k,v in self.best_all.summary().items()}) + res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()}) + res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()}) + return res + + def __repr__(self) -> str: + return json.dumps(self.summary(), indent=2) + + def __str__(self) -> str: + return self.__repr__() + + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == 'module.': + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict