Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Optional, List, Dict, Union, Any | |
| import warnings | |
| from torch.utils.data import Dataset | |
| from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder | |
| from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder | |
| class Annotated3DObjectsDataset(Dataset): | |
| def __init__(self, min_objects_per_image: int, | |
| max_objects_per_image: int, no_tokens: int, num_beams: int, cats: List[str], | |
| cat_blacklist: Optional[List[str]] = None, **kwargs): | |
| self.min_objects_per_image = min_objects_per_image | |
| self.max_objects_per_image = max_objects_per_image | |
| self.no_tokens = no_tokens | |
| self.num_beams = num_beams | |
| self.categories = [c for c in cats if c not in cat_blacklist] if cat_blacklist is not None else cats | |
| self._conditional_builders = None | |
| def no_classes(self) -> int: | |
| return len(self.categories) | |
| def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: | |
| # cannot set this up in init because no_classes is only known after loading data in init of superclass | |
| if self._conditional_builders is None: | |
| self._conditional_builders = { | |
| 'center': ObjectsCenterPointsConditionalBuilder( | |
| self.no_classes, | |
| self.max_objects_per_image, | |
| self.no_tokens, | |
| self.num_beams | |
| ), | |
| 'bbox': ObjectsBoundingBoxConditionalBuilder( | |
| self.no_classes, | |
| self.max_objects_per_image, | |
| self.no_tokens, | |
| self.num_beams | |
| ) | |
| } | |
| return self._conditional_builders | |
| def get_textual_label_for_category_id(self, category_id: int) -> str: | |
| return self.categories[category_id] | |