|
| 1 | +# *_*coding:utf-8 *_* |
| 2 | +""" |
| 3 | +Author: Xu Yan |
| 4 | +File: kitti_dataset.py |
| 5 | +@time: 2020/8/12 22:03 |
| 6 | +""" |
| 7 | +import os |
| 8 | +import numpy as np |
| 9 | +from utils import laserscan |
| 10 | +import yaml |
| 11 | +from torch.utils.data import Dataset |
| 12 | +import torch |
| 13 | +import spconv |
| 14 | +import math |
| 15 | + |
| 16 | +config_file = os.path.join('opt/semantic-kitti.yaml') |
| 17 | +kitti_config = yaml.safe_load(open(config_file, 'r')) |
| 18 | +remapdict = kitti_config["learning_map"] |
| 19 | + |
| 20 | + |
| 21 | +SPLIT_SEQUENCES = { |
| 22 | + "train": ["00", "01", "02", "03", "04", "05", "06", "07", "09", "10"], |
| 23 | + "valid": ["08"], |
| 24 | + "test": ["11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21"] |
| 25 | +} |
| 26 | + |
| 27 | +SPLIT_FILES = { |
| 28 | + "train": [".bin", ".label", ".invalid", ".occluded"], |
| 29 | + "valid": [".bin", ".label", ".invalid", ".occluded"], |
| 30 | + "test": [".bin"] |
| 31 | +} |
| 32 | + |
| 33 | +EXT_TO_NAME = {".bin": "input", ".label": "label", ".invalid": "invalid", ".occluded": "occluded"} |
| 34 | +scan = laserscan.SemLaserScan(nclasses=20, sem_color_dict=kitti_config['color_map']) |
| 35 | + |
| 36 | + |
| 37 | +def unpack(compressed): |
| 38 | + ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' |
| 39 | + uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) |
| 40 | + uncompressed[::8] = compressed[:] >> 7 & 1 |
| 41 | + uncompressed[1::8] = compressed[:] >> 6 & 1 |
| 42 | + uncompressed[2::8] = compressed[:] >> 5 & 1 |
| 43 | + uncompressed[3::8] = compressed[:] >> 4 & 1 |
| 44 | + uncompressed[4::8] = compressed[:] >> 3 & 1 |
| 45 | + uncompressed[5::8] = compressed[:] >> 2 & 1 |
| 46 | + uncompressed[6::8] = compressed[:] >> 1 & 1 |
| 47 | + uncompressed[7::8] = compressed[:] & 1 |
| 48 | + |
| 49 | + return uncompressed |
| 50 | + |
| 51 | +class get_dataset(Dataset): |
| 52 | + def __init__(self, config, split="train", augment=False): |
| 53 | + """ Load data from given dataset directory. """ |
| 54 | + |
| 55 | + self.config = config |
| 56 | + self.augment = augment |
| 57 | + self.files = {} |
| 58 | + self.filenames = [] |
| 59 | + self.seg_path = config['GENERAL']['dataset_dir'] |
| 60 | + for ext in SPLIT_FILES[split]: |
| 61 | + self.files[EXT_TO_NAME[ext]] = [] |
| 62 | + self.label_to_names = {0: 'car', 1: 'bicycle', 2: 'motorcycle', 3: 'truck', |
| 63 | + 4: 'other-vehicle', 5: 'person', 6: 'bicyclist', 7: 'motorcyclist', |
| 64 | + 8: 'road', 9: 'parking', 10: 'sidewalk', 11: 'other-ground', 12: 'building', |
| 65 | + 13: 'fence', 14: 'vegetation', 15: 'trunk', 16: 'terrain', 17: 'pole', |
| 66 | + 18: 'traffic-sign'} |
| 67 | + |
| 68 | + for sequence in SPLIT_SEQUENCES[split]: |
| 69 | + complete_path = os.path.join(config['GENERAL']['dataset_dir'], "sequences", sequence, "voxels") |
| 70 | + if not os.path.exists(complete_path): raise RuntimeError("Voxel directory missing: " + complete_path) |
| 71 | + |
| 72 | + files = os.listdir(complete_path) |
| 73 | + for ext in SPLIT_FILES[split]: |
| 74 | + comletion_data = sorted([os.path.join(complete_path, f) for f in files if f.endswith(ext)]) |
| 75 | + if len(comletion_data) == 0: raise RuntimeError("Missing data for " + EXT_TO_NAME[ext]) |
| 76 | + self.files[EXT_TO_NAME[ext]].extend(comletion_data) |
| 77 | + |
| 78 | + self.filenames.extend( |
| 79 | + sorted([(sequence, os.path.splitext(f)[0]) for f in files if f.endswith(SPLIT_FILES[split][0])])) |
| 80 | + |
| 81 | + self.num_files = len(self.filenames) |
| 82 | + remapdict = kitti_config["learning_map"] |
| 83 | + # make lookup table for mapping |
| 84 | + maxkey = max(remapdict.keys()) |
| 85 | + |
| 86 | + # +100 hack making lut bigger just in case there are unknown labels |
| 87 | + remap_lut = np.zeros((maxkey + 100), dtype=np.int32) |
| 88 | + remap_lut[list(remapdict.keys())] = list(remapdict.values()) |
| 89 | + seg_remap_lut = remap_lut - 1 |
| 90 | + seg_remap_lut[seg_remap_lut == -1] = -100 |
| 91 | + |
| 92 | + # in completion we have to distinguish empty and invalid voxels. |
| 93 | + # Important: For voxels 0 corresponds to "empty" and not "unlabeled". |
| 94 | + remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' |
| 95 | + remap_lut[0] = 0 # only 'empty' stays 'empty'. |
| 96 | + self.comletion_remap_lut = remap_lut |
| 97 | + self.seg_remap_lut = seg_remap_lut |
| 98 | + |
| 99 | + # sanity check: |
| 100 | + for k, v in self.files.items(): |
| 101 | + # print(k, len(v)) |
| 102 | + assert (len(v) == self.num_files) |
| 103 | + |
| 104 | + if split == 'train': |
| 105 | + seg_num_per_class = np.array(config['TRAIN']['seg_num_per_class']) |
| 106 | + complt_num_per_class = np.array(config['TRAIN']['complt_num_per_class']) |
| 107 | + |
| 108 | + seg_labelweights = seg_num_per_class / np.sum(seg_num_per_class) |
| 109 | + self.seg_labelweights = np.power(np.amax(seg_labelweights) / seg_labelweights, 1 / 3.0) |
| 110 | + compl_labelweights = complt_num_per_class / np.sum(complt_num_per_class) |
| 111 | + self.compl_labelweights = np.power(np.amax(compl_labelweights) / compl_labelweights, 1 / 3.0) |
| 112 | + else: |
| 113 | + self.compl_labelweights = torch.Tensor(np.ones(20) * 3) |
| 114 | + self.seg_labelweights = torch.Tensor(np.ones(19)) |
| 115 | + self.compl_labelweights[0] = 1 |
| 116 | + |
| 117 | + self.voxel_generator = spconv.utils.VoxelGenerator( |
| 118 | + voxel_size=[config['Completion']['voxel_size']]*3, |
| 119 | + point_cloud_range=config['Completion']['point_cloud_range'], |
| 120 | + max_num_points=20, |
| 121 | + max_voxels=256 * 256 * 32 |
| 122 | + ) |
| 123 | + |
| 124 | + def __len__(self): |
| 125 | + return self.num_files |
| 126 | + |
| 127 | + def __getitem__(self, t): |
| 128 | + """ fill dictionary with available data for given index. """ |
| 129 | + '''Load Completion Data''' |
| 130 | + completion_collection = {} |
| 131 | + if self.augment: |
| 132 | + # stat = np.random.randint(0,6) |
| 133 | + stat = np.random.randint(0,4) |
| 134 | + else: |
| 135 | + stat = 0 # set 0 with no augment |
| 136 | + completion_collection['stat'] = stat |
| 137 | + |
| 138 | + # read raw data and unpack (if necessary) |
| 139 | + for typ in self.files.keys(): |
| 140 | + if typ == "label": |
| 141 | + scan_data = np.fromfile(self.files[typ][t], dtype=np.uint16) |
| 142 | + scan_data = self.comletion_remap_lut[scan_data] |
| 143 | + else: |
| 144 | + scan_data = unpack(np.fromfile(self.files[typ][t], dtype=np.uint8)) |
| 145 | + scan_data = scan_data.reshape(self.config['Completion']['full_scale']) |
| 146 | + scan_data = data_augmentation(torch.Tensor(scan_data).unsqueeze(0), stat) |
| 147 | + # turn in actual voxel grid representation. |
| 148 | + completion_collection[typ] = scan_data |
| 149 | + |
| 150 | + '''Load Segmentation Data''' |
| 151 | + seg_point_name = self.seg_path + self.files['input'][t][self.files['input'][t].find('sequences'):].replace('voxels','velodyne') |
| 152 | + seg_label_name = self.seg_path + self.files['label'][t][self.files['label'][t].find('sequences'):].replace('voxels','labels') |
| 153 | + |
| 154 | + scan.open_scan(seg_point_name) |
| 155 | + scan.open_label(seg_label_name) |
| 156 | + remissions = scan.remissions |
| 157 | + xyz = scan.points |
| 158 | + label = scan.sem_label |
| 159 | + label = self.seg_remap_lut[label] |
| 160 | + |
| 161 | + if self.config['Segmentation']['use_coords']: |
| 162 | + feature = np.concatenate([xyz, remissions.reshape(-1, 1)], 1) |
| 163 | + else: |
| 164 | + feature = remissions.reshape(-1, 1) |
| 165 | + |
| 166 | + '''Process Segmentation Data''' |
| 167 | + segmentation_collection = {} |
| 168 | + coords, label, feature, idxs = self.process_seg_data(xyz, label, feature) |
| 169 | + segmentation_collection.update({ |
| 170 | + 'coords': coords, |
| 171 | + 'label': label, |
| 172 | + 'feature': feature, |
| 173 | + }) |
| 174 | + |
| 175 | + '''Generate Alignment Data''' |
| 176 | + aliment_collection = {} |
| 177 | + xyz = xyz[idxs] |
| 178 | + voxels, coords, num_points_per_voxel = self.voxel_generator.generate(np.concatenate([xyz, np.arange(len(xyz)).reshape(-1,1)],-1)) |
| 179 | + voxel_centers = (coords[:, ::-1] + 0.5) * self.voxel_generator.voxel_size + self.voxel_generator.point_cloud_range[0:3] |
| 180 | + aliment_collection.update({ |
| 181 | + 'voxels': voxels, |
| 182 | + 'coords': coords, |
| 183 | + 'voxel_centers': voxel_centers, |
| 184 | + 'num_points_per_voxel': num_points_per_voxel, |
| 185 | + }) |
| 186 | + |
| 187 | + return self.filenames[t], completion_collection, aliment_collection, segmentation_collection |
| 188 | + |
| 189 | + def process_seg_data(self, xyz, label, feature): |
| 190 | + coords = np.ascontiguousarray(xyz - xyz.mean(0)) |
| 191 | + m = np.eye(3) + np.random.randn(3, 3) * 0.1 |
| 192 | + m[0][0] *= np.random.randint(0, 2) * 2 - 1 |
| 193 | + m *= self.config['Segmentation']['scale'] |
| 194 | + theta = np.random.rand() * 2 * math.pi |
| 195 | + m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0], [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) |
| 196 | + coords = np.matmul(coords, m) |
| 197 | + |
| 198 | + m = coords.min(0) |
| 199 | + M = coords.max(0) |
| 200 | + offset = - m + np.clip(self.config['Segmentation']['full_scale'][1] - M + m - 0.001, 0, None) * np.random.rand(3) + np.clip( |
| 201 | + self.config['Segmentation']['full_scale'][1] - M + m + 0.001, None, 0) * np.random.rand(3) |
| 202 | + coords += offset |
| 203 | + idxs = (coords.min(1) >= 0) * (coords.max(1) < self.config['Segmentation']['full_scale'][1]) |
| 204 | + coords = coords[idxs] |
| 205 | + feature = feature[idxs] |
| 206 | + label = label[idxs] |
| 207 | + |
| 208 | + coords = torch.Tensor(coords).long() |
| 209 | + feature = torch.Tensor(feature) |
| 210 | + label = torch.Tensor(label) |
| 211 | + |
| 212 | + return coords, label, feature, idxs |
| 213 | + |
| 214 | + |
| 215 | +def data_augmentation(t, state, inverse=False): |
| 216 | + assert t.dim() == 4, 'input dimension should be 4!' |
| 217 | + if state == 1: |
| 218 | + aug_t = t.flip([1]) |
| 219 | + elif state == 2: |
| 220 | + aug_t = t.flip([2]) |
| 221 | + # elif state == 3: |
| 222 | + # k = 1 if not inverse else 3 |
| 223 | + # aug_t = t.rot90(k, [1, 2]) |
| 224 | + # elif state == 4: |
| 225 | + # aug_t = t.rot90(2, [1, 2]) |
| 226 | + # elif state == 5: |
| 227 | + # k = 3 if not inverse else 1 |
| 228 | + # aug_t = t.rot90(k, [1, 2]) |
| 229 | + else: |
| 230 | + aug_t = t |
| 231 | + |
| 232 | + return aug_t |
| 233 | + |
| 234 | +def sparse_tensor_augmentation(st, states): |
| 235 | + spatial_shape = st.spatial_shape |
| 236 | + batch_size = st.batch_size |
| 237 | + t = st.dense() |
| 238 | + channels = t.shape[1] |
| 239 | + for b in range(batch_size): |
| 240 | + t[b] = data_augmentation(t[b], states[b]) |
| 241 | + coords = torch.sum(torch.abs(t), dim=1).nonzero().type(torch.int32) |
| 242 | + features = t.permute(0, 2, 3, 4, 1).reshape(-1, channels) |
| 243 | + features = features[torch.sum(torch.abs(features), dim=1).nonzero(), :] |
| 244 | + features = features.squeeze(1) |
| 245 | + nst = spconv.SparseConvTensor(features.float(), coords.int(), spatial_shape, batch_size) |
| 246 | + |
| 247 | + return nst |
| 248 | + |
| 249 | +def tensor_augmentation(st, states): |
| 250 | + batch_size = st.shape[0] |
| 251 | + for b in range(batch_size): |
| 252 | + st[b] = data_augmentation(st[b], states[b]) |
| 253 | + |
| 254 | + return st |
| 255 | + |
0 commit comments