forked from ZzZZCHS/Chat-Scene
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_scannet_mask3d_attributes.py
55 lines (46 loc) · 1.92 KB
/
prepare_scannet_mask3d_attributes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import json
import os
import glob
import numpy as np
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--scan_dir', required=True, type=str,
help='the path of the directory to be saved preprocessed scans')
parser.add_argument('--segmentor', required=True, type=str)
parser.add_argument('--max_inst_num', required=True, type=int)
parser.add_argument('--version', type=str, default='')
args = parser.parse_args()
for split in ["train", "val"]:
scan_dir = os.path.join(args.scan_dir, 'pcd_all')
output_dir = "annotations"
split_path = f"annotations/scannet/scannetv2_{split}.txt"
scan_ids = [line.strip() for line in open(split_path).readlines()]
scan_ids = sorted(scan_ids)
# print(scan_ids)
scans = {}
for scan_id in tqdm(scan_ids):
pcd_path = os.path.join(scan_dir, f"{scan_id}.pth")
if not os.path.exists(pcd_path):
print('skip', scan_id)
continue
points, colors, instance_class_labels, instance_segids = torch.load(pcd_path)
inst_locs = []
num_insts = len(instance_class_labels)
for i in range(min(num_insts, args.max_inst_num)):
inst_mask = instance_segids[i]
pc = points[inst_mask]
if len(pc) == 0:
print(scan_id, i, 'empty bbox')
inst_locs.append(np.zeros(6, ).astype(np.float32))
continue
size = pc.max(0) - pc.min(0)
center = (pc.max(0) + pc.min(0)) / 2
inst_locs.append(np.concatenate([center, size], 0))
inst_locs = torch.tensor(np.stack(inst_locs, 0), dtype=torch.float32)
scans[scan_id] = {
'objects': instance_class_labels, # (n_obj, )
'locs': inst_locs, # (n_obj, 6) center xyz, whl
}
torch.save(scans, os.path.join(output_dir, f"scannet_{args.segmentor}_{split}_attributes{args.version}.pt"))