Spaces:
Runtime error
Runtime error
| import pdb | |
| from abc import abstractmethod | |
| from functools import partial | |
| import PIL | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| from torch.utils.data import Dataset, IterableDataset | |
| from ..utils.aug_utils import get_lidar_transform, get_camera_transform, get_anno_transform | |
| class DatasetBase(Dataset): | |
| def __init__(self, data_root, split, dataset_config, aug_config, return_pcd=False, condition_key=None, | |
| scale_factors=None, degradation=None, **kwargs): | |
| self.data_root = data_root | |
| self.split = split | |
| self.data = [] | |
| self.aug_config = aug_config | |
| self.img_size = dataset_config.size | |
| self.fov = dataset_config.fov | |
| self.depth_range = dataset_config.depth_range | |
| self.filtered_map_cats = dataset_config.filtered_map_cats | |
| self.depth_scale = dataset_config.depth_scale | |
| self.log_scale = dataset_config.log_scale | |
| if self.log_scale: | |
| self.depth_thresh = (np.log2(1./255. + 1) / self.depth_scale) * 2. - 1 + 1e-6 | |
| else: | |
| self.depth_thresh = (1./255. / self.depth_scale) * 2. - 1 + 1e-6 | |
| self.return_pcd = return_pcd | |
| if degradation is not None and scale_factors is not None: | |
| scaled_img_size = (int(self.img_size[0] / scale_factors[0]), int(self.img_size[1] / scale_factors[1])) | |
| degradation_fn = { | |
| "pil_nearest": PIL.Image.NEAREST, | |
| "pil_bilinear": PIL.Image.BILINEAR, | |
| "pil_bicubic": PIL.Image.BICUBIC, | |
| "pil_box": PIL.Image.BOX, | |
| "pil_hamming": PIL.Image.HAMMING, | |
| "pil_lanczos": PIL.Image.LANCZOS, | |
| }[degradation] | |
| self.degradation_transform = partial(TF.resize, size=scaled_img_size, interpolation=degradation_fn) | |
| else: | |
| self.degradation_transform = None | |
| self.condition_key = condition_key | |
| self.lidar_transform = get_lidar_transform(aug_config, split) | |
| self.anno_transform = get_anno_transform(aug_config, split) if condition_key in ['bbox', 'center'] else None | |
| self.view_transform = get_camera_transform(aug_config, split) if condition_key in ['camera'] else None | |
| self.prepare_data() | |
| def prepare_data(self): | |
| raise NotImplementedError | |
| def process_scan(self, range_img): | |
| range_img = np.where(range_img < 0, 0, range_img) | |
| if self.log_scale: | |
| # log scale | |
| range_img = np.log2(range_img + 0.0001 + 1) | |
| range_img = range_img / self.depth_scale | |
| range_img = range_img * 2. - 1. | |
| range_img = np.clip(range_img, -1, 1) | |
| range_img = np.expand_dims(range_img, axis=0) | |
| # mask | |
| range_mask = np.ones_like(range_img) | |
| range_mask[range_img < self.depth_thresh] = -1 | |
| return range_img, range_mask | |
| def load_lidar_sweep(*args, **kwargs): | |
| raise NotImplementedError | |
| def load_semantic_map(*args, **kwargs): | |
| raise NotImplementedError | |
| def load_camera(*args, **kwargs): | |
| raise NotImplementedError | |
| def load_annotation(*args, **kwargs): | |
| raise NotImplementedError | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| example = dict() | |
| return example | |
| class Txt2ImgIterableBaseDataset(IterableDataset): | |
| """ | |
| Define an interface to make the IterableDatasets for text2img data chainable | |
| """ | |
| def __init__(self, num_records=0, valid_ids=None, size=256): | |
| super().__init__() | |
| self.num_records = num_records | |
| self.valid_ids = valid_ids | |
| self.sample_ids = valid_ids | |
| self.size = size | |
| print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') | |
| def __len__(self): | |
| return self.num_records | |
| def __iter__(self): | |
| pass |