Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import List, Tuple, Union | |
| from mmengine.model import BaseModule | |
| from torch import Tensor | |
| from mmdet.structures import SampleList | |
| from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig | |
| from ..utils import unpack_gt_instances | |
| class BaseMaskHead(BaseModule, metaclass=ABCMeta): | |
| """Base class for mask heads used in One-Stage Instance Segmentation.""" | |
| def __init__(self, init_cfg: OptMultiConfig = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| def loss_by_feat(self, *args, **kwargs): | |
| """Calculate the loss based on the features extracted by the mask | |
| head.""" | |
| pass | |
| def predict_by_feat(self, *args, **kwargs): | |
| """Transform a batch of output features extracted from the head into | |
| mask results.""" | |
| pass | |
| def loss(self, | |
| x: Union[List[Tensor], Tuple[Tensor]], | |
| batch_data_samples: SampleList, | |
| positive_infos: OptInstanceList = None, | |
| **kwargs) -> dict: | |
| """Perform forward propagation and loss calculation of the mask head on | |
| the features of the upstream network. | |
| Args: | |
| x (list[Tensor] | tuple[Tensor]): Features from FPN. | |
| Each has a shape (B, C, H, W). | |
| batch_data_samples (list[:obj:`DetDataSample`]): Each item contains | |
| the meta information of each image and corresponding | |
| annotations. | |
| positive_infos (list[:obj:`InstanceData`], optional): Information | |
| of positive samples. Used when the label assignment is | |
| done outside the MaskHead, e.g., BboxHead in | |
| YOLACT or CondInst, etc. When the label assignment is done in | |
| MaskHead, it would be None, like SOLO or SOLOv2. All values | |
| in it should have shape (num_positive_samples, *). | |
| Returns: | |
| dict: A dictionary of loss components. | |
| """ | |
| if positive_infos is None: | |
| outs = self(x) | |
| else: | |
| outs = self(x, positive_infos) | |
| assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \ | |
| 'even if only one item is returned' | |
| outputs = unpack_gt_instances(batch_data_samples) | |
| batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ | |
| = outputs | |
| for gt_instances, img_metas in zip(batch_gt_instances, | |
| batch_img_metas): | |
| img_shape = img_metas['batch_input_shape'] | |
| gt_masks = gt_instances.masks.pad(img_shape) | |
| gt_instances.masks = gt_masks | |
| losses = self.loss_by_feat( | |
| *outs, | |
| batch_gt_instances=batch_gt_instances, | |
| batch_img_metas=batch_img_metas, | |
| positive_infos=positive_infos, | |
| batch_gt_instances_ignore=batch_gt_instances_ignore, | |
| **kwargs) | |
| return losses | |
| def predict(self, | |
| x: Tuple[Tensor], | |
| batch_data_samples: SampleList, | |
| rescale: bool = False, | |
| results_list: OptInstanceList = None, | |
| **kwargs) -> InstanceList: | |
| """Test function without test-time augmentation. | |
| Args: | |
| x (tuple[Tensor]): Multi-level features from the | |
| upstream network, each is a 4D-tensor. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool, optional): Whether to rescale the results. | |
| Defaults to False. | |
| results_list (list[obj:`InstanceData`], optional): Detection | |
| results of each image after the post process. Only exist | |
| if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc. | |
| Returns: | |
| list[obj:`InstanceData`]: Instance segmentation | |
| results of each image after the post process. | |
| Each item usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance,) | |
| - labels (Tensor): Has a shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has a | |
| shape (num_instances, h, w). | |
| """ | |
| batch_img_metas = [ | |
| data_samples.metainfo for data_samples in batch_data_samples | |
| ] | |
| if results_list is None: | |
| outs = self(x) | |
| else: | |
| outs = self(x, results_list) | |
| results_list = self.predict_by_feat( | |
| *outs, | |
| batch_img_metas=batch_img_metas, | |
| rescale=rescale, | |
| results_list=results_list, | |
| **kwargs) | |
| return results_list | |