Skip to content

Commit 20e805d

Browse files
committed
Update gt_graph_to_wkt.py
1 parent b357172 commit 20e805d

File tree

1 file changed

+77
-16
lines changed

1 file changed

+77
-16
lines changed

apls/gt_graph_to_wkt.py

+77-16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import apls
1111
import argparse
12+
import osmnx_funcs
1213
import pandas as pd
1314
# import shapely.wkt
1415

@@ -18,9 +19,11 @@ def gt_geojson_to_wkt(geojson_path, im_path,
1819
weight_keys=['length', 'travel_time_s'],
1920
subgraph_filter_weight='length', min_subgraph_length=5,
2021
travel_time_key='travel_time_s',
21-
speed_key='speed_m/s',
22+
speed_key='inferred_speed_mps',
2223
use_pix_coords=False,
2324
verbose=False,
25+
simplify=False,
26+
refine=True,
2427
super_verbose=False):
2528
'''
2629
Create wkt list of pixel coords in ground truth graph, for use in SpaceNet
@@ -32,26 +35,51 @@ def gt_geojson_to_wkt(geojson_path, im_path,
3235

3336
# linestring = "LINESTRING {}"
3437
im_name = os.path.basename(im_path)
35-
AOI_root = 'AOI' + im_name.split('AOI')[-1]
36-
name_root = AOI_root.split('.')[0].replace('PS-RGB_', '')
37-
print("name_root:", name_root)
3838

39+
# get name_root of image file
40+
name_root = im_name.split('.')[0].replace('PS-RGB_', '').replace('PS-MS_', '')
41+
# # v0
42+
# AOI_root = 'AOI' + im_name.split('AOI')[-1]
43+
# name_root = AOI_root.split('.')[0].replace('PS-RGB_', '').replace('PS-MS_', '')
44+
print("name_root:", name_root)
45+
print("im_path:", im_path)
46+
3947
G_gt, _ = apls._create_gt_graph(geojson_path, im_path,
4048
subgraph_filter_weight=subgraph_filter_weight,
4149
min_subgraph_length=min_subgraph_length,
4250
travel_time_key=travel_time_key,
4351
speed_key=speed_key,
4452
use_pix_coords=use_pix_coords,
53+
refine_graph=refine,
54+
simplify_graph=simplify,
4555
verbose=verbose,
4656
super_verbose=super_verbose)
4757

58+
# simplify and turn to undirected
59+
if simplify:
60+
try:
61+
G_gt = osmnx_funcs.simplify_graph(G_gt).to_undirected()
62+
except:
63+
G_gt = G_gt.to_undirected()
64+
else:
65+
G_gt = G_gt.to_undirected()
66+
67+
# return [name_root, "LINESTRING EMPTY"] if no edges
68+
if (len(G_gt.nodes()) == 0) or (len(G_gt.edges()) == 0):
69+
print(" Empty graph")
70+
if len(weight_keys) > 0:
71+
return [[name_root, "LINESTRING EMPTY"] + [0] * len(weight_keys)]
72+
else:
73+
return [[name_root, "LINESTRING EMPTY"]]
74+
4875
# extract geometry pix wkt, save to list
4976
wkt_list = []
5077
for i, (u, v, attr_dict) in enumerate(G_gt.edges(data=True)):
78+
print("attr_dict:", attr_dict)
5179
geom_pix_wkt = attr_dict['geometry_pix'].wkt
5280
if verbose:
5381
print(i, "/", len(G_gt.edges()), "u, v:", u, v)
54-
print(" attr_dict:", attr_dict)
82+
# print(" attr_dict:", attr_dict)
5583
print(" geom_pix_wkt:", geom_pix_wkt)
5684

5785
wkt_item_root = [name_root, geom_pix_wkt]
@@ -74,32 +102,41 @@ def gt_geojson_dir_to_wkt(geojson_dir, im_dir, output_csv_path,
74102
weight_keys=['length', 'travel_time_s'],
75103
subgraph_filter_weight='length', min_subgraph_length=5,
76104
travel_time_key='travel_time_s',
77-
speed_key='speed_m/s',
105+
speed_key='inferred_speed_mps',
78106
use_pix_coords=False,
107+
simplify=False,
79108
verbose=False,
80109
super_verbose=False):
81110

82111
# make dict of image chip id to file name
112+
chipper = ''
83113
im_chip_dict = {}
84114
for im_name in [z for z in os.listdir(im_dir) if z.endswith('.tif')]:
85115
chip_id = im_name.split('chip')[-1].split('.')[0]
116+
if 'chip' in im_name:
117+
chipper = 'chip'
118+
chip_id = im_name.split(chipper)[-1].split('.')[0]
119+
elif 'img' in im_name:
120+
chipper = 'img'
121+
chip_id = im_name.split(chipper)[-1].split('.')[0]
86122
im_chip_dict[chip_id] = im_name
87123
if verbose:
88124
print("im_chip_dict:", im_chip_dict)
89125

90126
# iterate through geojsons
91127
wkt_list_tot = []
92-
geojson_paths = [z for z in os.listdir(geojson_dir)
93-
if z.endswith('.geojson')]
128+
geojson_paths = sorted([z for z in os.listdir(geojson_dir)
129+
if z.endswith('.geojson')])
94130
for i, geojson_name in enumerate(geojson_paths):
131+
95132
# get image name
96-
chip_id = geojson_name.split('chip')[-1].split('.')[0]
133+
chip_id = geojson_name.split(chipper)[-1].split('.')[0]
97134
try:
98135
im_name = im_chip_dict[chip_id]
99136
except:
100137
print("im_name not in im_chip_dict:", im_name)
101138
return
102-
continue
139+
# continue
103140

104141
geojson_path = os.path.join(geojson_dir, geojson_name)
105142
im_path = os.path.join(im_dir, im_name)
@@ -115,6 +152,7 @@ def gt_geojson_dir_to_wkt(geojson_dir, im_dir, output_csv_path,
115152
travel_time_key=travel_time_key,
116153
speed_key=speed_key,
117154
use_pix_coords=use_pix_coords,
155+
simplify=simplify,
118156
verbose=verbose,
119157
super_verbose=super_verbose)
120158

@@ -126,10 +164,15 @@ def gt_geojson_dir_to_wkt(geojson_dir, im_dir, output_csv_path,
126164
else:
127165
cols = ['ImageId', 'WKT_Pix']
128166

129-
print("cols:", cols)
130167
# use 'length_m' instead?
131168
cols = [z.replace('length', 'length_m') for z in cols]
132169

170+
# print("wkt_list_tot:", wkt_list_tot)
171+
# print("\n")
172+
# for i in wkt_list_tot:
173+
# print(len(i))
174+
print("cols:", cols)
175+
133176
df = pd.DataFrame(wkt_list_tot, columns=cols)
134177
print("df:", df)
135178
# save
@@ -165,6 +208,16 @@ def gt_geojson_dir_to_wkt(geojson_dir, im_dir, output_csv_path,
165208
help='Root directory of geojson data')
166209
parser.add_argument('--PSRGB_dir', default='', type=str,
167210
help='PS-RGB dir, if '', assume in root_dir')
211+
parser.add_argument('--travel_time_key', default='travel_time_s', type=str,
212+
help='key for travel time')
213+
parser.add_argument('--speed_key', default='inferred_speed_mps', type=str,
214+
help='key for road speed')
215+
parser.add_argument('--out_file_name',
216+
default='geojson_roads_speed_wkt_weighted_v0.csv',
217+
type=str,
218+
help='name for output file')
219+
parser.add_argument('--simplify_graph', default=True, type=bool,
220+
help='switch to simplify graph prior to saving')
168221
args = parser.parse_args()
169222

170223
weight_keys = ['length', 'travel_time_s']
@@ -178,14 +231,22 @@ def gt_geojson_dir_to_wkt(geojson_dir, im_dir, output_csv_path,
178231
geojson_dir = os.path.join(root_dir, 'geojson_roads_speed')
179232
# get name
180233
out_prefix = '_'.join(root_dir.split('/')[-3:])
181-
output_csv_path = os.path.join(
182-
root_dir, out_prefix + 'geojson_roads_speed_wkt_weighted.csv')
183-
234+
# if args.simplify_graph:
235+
# out_name = out_prefix + 'geojson_roads_speed_wkt_weighted_simp.csv'
236+
# else:
237+
# out_name = out_prefix + 'geojson_roads_speed_wkt_weighted.csv'
238+
# output_csv_path = os.path.join(root_dir, out_name)
239+
output_csv_path = os.path.join(root_dir, out_prefix + args.out_file_name)
240+
241+
print("output_csv_path:", output_csv_path)
184242
df = gt_geojson_dir_to_wkt(geojson_dir, im_dir,
185243
output_csv_path=output_csv_path,
186-
weight_keys=weight_keys, verbose=verbose)
244+
travel_time_key=args.travel_time_key,
245+
speed_key=args.speed_key,
246+
simplify=args.simplify_graph,
247+
weight_keys=weight_keys,
248+
verbose=verbose)
187249

188-
189250
'''
190251
Execute
191252

0 commit comments

Comments
 (0)