-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstructure_index.py
651 lines (546 loc) · 25 KB
/
structure_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
import warnings, copy #,sys
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
#from scipy import sparse, linalg
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import distance_matrix
try:
import faiss
use_fast = True
except:
use_fast = False
from sklearn.metrics import pairwise_distances
from sklearn.manifold import Isomap
from decorator import decorator
from tqdm.auto import tqdm
import networkx as nx
# overlap_options = ['one_third','continuity']
# graph_options = ['binary', 'weighted']
distance_options = ['euclidean','geodesic']
def validate_args_types(**decls):
"""Decorator to check argument types.
Usage:
@check_args(name=str, text=(int,str))
def parse_rule(name, text): ...
"""
@decorator
def wrapper(func, *args, **kwargs):
code = func.__code__
fname = func.__name__
names = code.co_varnames[:code.co_argcount]
for argname, argtype in decls.items():
arg_provided = True
if argname in names:
argval = args[names.index(argname)]
elif argname in kwargs:
argval = kwargs.get(argname)
else:
arg_provided = False
if arg_provided:
if not isinstance(argval, argtype):
raise TypeError(f"{fname}(...): arg '{argname}': type is"+\
f" {type(argval)}, must be {argtype}")
return func(*args, **kwargs)
return wrapper
def filter_noisy_outliers(data):
D = pairwise_distances(data)
np.fill_diagonal(D, np.nan)
nn_dist = np.sum(D < np.nanpercentile(D,1), axis=1) - 1
noiseIdx = np.where(nn_dist < np.percentile(nn_dist, 20))[0]
return noiseIdx
def meshgrid2(arrs):
#arrs: tuple with np.arange of shape of all dimensions
lens = list(map(len, arrs))
dim = len(arrs)
sz = 1
for s in lens:
sz*=s
ans = []
for i, arr in enumerate(arrs):
slc = [1]*dim
slc[i] = lens[i]
arr2 = np.asarray(arr).reshape(slc)
for j, sz in enumerate(lens):
if j!=i:
arr2 = arr2.repeat(sz, axis=j)
ans.append(arr2)
return tuple(ans)
def create_ndim_grid(label, n_bins, min_label, max_label, discrete_label):
ndims = label.shape[1]
grid_edges = list()
for nd in range(ndims):
if discrete_label[nd]:
grid_edges.append(np.tile(np.unique(label[:,nd]).reshape(-1,1),(1,2)))
else:
steps = (max_label[nd] - min_label[nd])/n_bins[nd]
edges = np.linspace(min_label[nd], max_label[nd],
n_bins[nd]+1).reshape(-1,1)
grid_edges.append(np.concatenate((edges[:-1], edges[1:]), axis = 1))
#Generate the grid containing the indices of the points of the label and
#the coordinates as the mid point of edges
grid = np.empty([e.shape[0] for e in grid_edges], object)
mesh = meshgrid2(tuple([np.arange(s) for s in grid.shape]))
meshIdx = np.vstack([col.ravel() for col in mesh]).T
coords = np.zeros(meshIdx.shape+(3,))
grid = grid.ravel()
for elem, idx in enumerate(meshIdx):
logic = np.zeros(label.shape[0])
for dim in range(len(idx)):
min_edge = grid_edges[dim][idx[dim],0]
max_edge = grid_edges[dim][idx[dim],1]
logic = logic + 1*np.logical_and(label[:,dim]>=min_edge,label[:,dim]<=max_edge)
coords[elem,dim,0] = min_edge
coords[elem,dim,1] = 0.5*(min_edge + max_edge)
coords[elem,dim,2] = max_edge
grid[elem] = list(np.where(logic == meshIdx.shape[1])[0])
return grid, coords
def cloud_overlap_radius(cloud1, cloud2, r, distance_metric):
"""Compute overlapping between two clouds of points.
Parameters:
----------
cloud1: numpy 2d array of shape [n_samples_1,n_features]
Array containing the cloud of points 1
cloud2: numpy 2d array of shape [n_samples_2,n_features]
Array containing the cloud of points 2
k: int
Number of neighbors used to compute the overlapping between
bin-groups. This parameter controls the tradeoff between local
and global structure.
distance_metric: str
Type of distance used to compute the closest n_neighbors. See
'distance_options' for currently supported distances.
overlap_method: str (default: 'one_third')
Type of method use to compute the overlapping between bin-groups.
See 'overlap_options' for currently supported methods.
Returns:
-------
overlap_1_2: float
Degree of overlapping of cloud1 over cloud2
overlap_1_2: float
Degree of overlapping of cloud2 over cloud1
"""
#Stack both clouds
cloud_all = np.vstack((cloud1, cloud2)).astype('float32')
idx_sep = cloud1.shape[0]
#Create cloud label
cloud_label = np.hstack((np.ones(cloud1.shape[0]),np.ones(cloud2.shape[0])*2))
#Compute k neighbours graph
if distance_metric == 'euclidean':
D = distance_matrix(cloud_all, cloud_all,p=2)
elif distance_metric == 'geodesic':
model_iso = Isomap(n_components = 1)
emb = model_iso.fit_transform(cloud_all)
D = model_iso.dist_matrix_
I = np.argsort(D, axis = 1)
for row in range(I.shape[0]):
D[row,:] = D[row,I[row,:]]
I = I[:, 1:].astype('float32')
D = D[:, 1:]
I[D>r]= np.nan
num_neigh = I.shape[0] - np.sum(np.isnan(I), axis = 1).astype('float32') - 1
#Compute overlapping
#total fraction of neighbors that belong to the other cloud
overlap_1_2 = np.sum(I[:idx_sep,:]>=idx_sep)/np.sum(num_neigh[:idx_sep])
overlap_2_1 = np.sum(I[idx_sep:,:]<idx_sep)/np.sum(num_neigh[idx_sep:])
return overlap_1_2, overlap_2_1
def cloud_overlap_neighbors(cloud1, cloud2, k, distance_metric):
"""Compute overlapping between two clouds of points.
Parameters:
----------
cloud1: numpy 2d array of shape [n_samples_1,n_features]
Array containing the cloud of points 1
cloud2: numpy 2d array of shape [n_samples_2,n_features]
Array containing the cloud of points 2
k: int
Number of neighbors used to compute the overlapping between
bin-groups. This parameter controls the tradeoff between local
and global structure.
distance_metric: str
Type of distance used to compute the closest n_neighbors. See
'distance_options' for currently supported distances.
overlap_method: str (default: 'one_third')
Type of method use to compute the overlapping between bin-groups.
See 'overlap_options' for currently supported methods.
Returns:
-------
overlap_1_2: float
Degree of overlapping of cloud1 over cloud2
overlap_1_2: float
Degree of overlapping of cloud2 over cloud1
"""
#Stack both clouds
cloud_all = np.vstack((cloud1, cloud2)).astype('float32')
idx_sep = cloud1.shape[0]
#Create cloud label
cloud_label = np.hstack((np.ones(cloud1.shape[0]), np.ones(cloud2.shape[0])*2))
#Compute k neighbours graph
if distance_metric == 'euclidean':
if use_fast:
index = faiss.IndexFlatL2(cloud_all.shape[1]) # build the index
index.add(cloud_all) # add vectors to the index
_, I = index.search(cloud_all, k+1)
I = I[:,1:]
else:
knn = NearestNeighbors(n_neighbors=k, metric="minkowski", p=2).fit(cloud_all)
I = knn.kneighbors(return_distance=False)
elif distance_metric == 'geodesic':
model_iso = Isomap(n_components = 1)
emb = model_iso.fit_transform(cloud_all)
dist_mat = model_iso.dist_matrix_
knn = NearestNeighbors(n_neighbors=k, metric="precomputed").fit(dist_mat)
I = knn.kneighbors(return_distance=False)
#Compute overlapping
#total fraction of neighbors that belong to the other cloud
overlap_1_2 = np.sum(I[:idx_sep,:]>=idx_sep)/(cloud1.shape[0]*k)
overlap_2_1 = np.sum(I[idx_sep:,:]<idx_sep)/(cloud2.shape[0]*k)
return overlap_1_2, overlap_2_1
@validate_args_types(data=np.ndarray, label=np.ndarray, n_bins=(int,np.integer,list),
dims=(type(None),list), distance_metric=str, n_neighbors=(int,np.integer),
num_shuffles=(int,np.integer), discrete_label=(list,bool), verbose=bool)
def compute_structure_index(data, label, n_bins=10, dims=None, **kwargs):
'''compute structure index main function
Parameters:
----------
data: numpy 2d array of shape [n_samples,n_dimensions]
Array containing the signal
label: numpy 2d array of shape [n_samples,n_features]
Array containing the labels of the data. It can either be a
column vector (scalar feature) or a 2D array (vectorial feature)
Optional parameters:
--------------------
n_bins: integer (default: 10)
number of bin-groups the label will be divided into (they will
become nodes on the graph). For vectorial features, if one wants
different number of bins for each entry then specify n_bins as a
list (i.e. [10,20,5]). Note that it will be ignored if
'discrete_label' is set to True.
dims: list of integers or None (default: None)
list of integers containing the dimensions of data along which the
structure index will be computed. Provide None to compute it along
all dimensions of data.
distance_metric: str (default: 'euclidean')
Type of distance used to compute the closest n_neighbors. See
'distance_options' for currently supported distances.
n_neighbors: int (default: 15)
Number of neighbors used to compute the overlapping between
bin-groups. This parameter controls the tradeoff between local and
global structure.
discrete_label: boolean (default: False)
If the label is discrete, then one bin-group will be created for
each discrete value it takes. Note that if set to True, 'n_bins'
parameter will be ignored.
num_shuffles: int (default: 100)
Number of shuffles to be computed. Note it must fall within the
interval [0, np.inf).
verbose: boolean (default: False)
Boolean controling whether or not to print internal process.
Returns:
-------
SI: float
structure index
bin_label: tuple
Tuple containing:
[0] Array indicating the bin-group to which each data point has
been assigned.
[1] Array indicating feature limits of each bin-group. Size is
[number_bin_groups, n_features, 3] where the last dimension
contains [bin_st, bin_center, bin_en]
overlap_mat: numpy 2d array of shape [n_bins, n_bins]
Array containing the overlapping between each pair of bin-groups.
shuf_SI: numpy 1d array of shape [num_shuffles,]
Array containing the structure index computed for each shuffling
iteration.
'''
#TODO:
#distance_metric: cosyne
#re-evaluate which arguments to put outside which ones in kwargs
#include plot-function
#maybe plot neighbors inside radius distribution when radius provided
#if radius selected, and no point has neighbors prompt error
#__________________________________________________________________________
#| |#
#| 0. CHECK INPUT VALIDITY |#
#|________________________________________________________________________|#
#Note input type validity is handled by the decorator. Here the values
#themselves are being checked.
#i) data input
assert data.ndim==2, "Input 'data' must be a 2D numpy ndarray with shape"+\
" of samples and m the number of dimensions."
#ii) label input
if label.ndim==1:
label = label.reshape(-1,1)
assert label.ndim==2,\
"label must be a 1D or 2D array."
#iii) n_bins input
if isinstance(n_bins, int) or isinstance(n_bins, np.integer):
assert n_bins>1,\
"Input 'n_bins' must be an int or list of int larger than 1."
n_bins = [n_bins for nb in range(label.shape[1])]
elif isinstance(n_bins, list):
assert np.all([nb>1 for nb in n_bins]),\
"Input 'n_bins' must be an int or list of int larger than 1."
#iv) dims input
if isinstance(dims, type(None)): #if dims is None, then take all dimensions
dims = list(range(data.shape[1]))
#v) distance_metric
if 'distance_metric' in kwargs:
distance_metric = kwargs['distance_metric']
assert distance_metric in distance_options, f"Invalid input "+\
"'distance_metric'. Valid options are {distance_options}."
else:
distance_metric = 'euclidean'
#ix) n_neighbors input
if ('n_neighbors' in kwargs) and ('radius' in kwargs):
raise ValueError("Both n_neighbors and radius provided. Please only"+\
" specify one")
if 'radius' in kwargs:
neighborhood_size = kwargs['radius']
assert neighborhood_size>0, "Input 'radius' must be larger than 0"
cloud_overlap = cloud_overlap_radius
else:
if 'n_neighbors' in kwargs:
neighborhood_size = kwargs['n_neighbors']
assert neighborhood_size>2, "Input 'n_neighbors' must be larger"+\
"than 2."
else:
neighborhood_size = 15
cloud_overlap = cloud_overlap_neighbors
#x) discrete_label input
if 'discrete_label' in kwargs:
discrete_label = kwargs['discrete_label']
if isinstance(discrete_label,bool):
discrete_label = [discrete_label for idx in range(label.shape[1])]
else:
assert np.all([isinstance(idx, bool) for idx in discrete_label]),\
"Input 'discrete_label' must be boolean or list of booleans."
else:
discrete_label = [False for idx in range(label.shape[1])]
#xi) num_shuffles input
if 'num_shuffles' in kwargs:
num_shuffles = kwargs['num_shuffles']
assert num_shuffles>=0, "Input 'num_shuffles must fall within the "+\
"interval [0, np.inf)"
else:
num_shuffles = 100
#xii) verbose input
if 'verbose' in kwargs:
verbose = kwargs['verbose']
else:
verbose = False
#__________________________________________________________________________
#| |#
#| 1. PREPROCESS DATA |#
#|________________________________________________________________________|#
#i).Keep only desired dims
data = data[:,dims]
if data.ndim == 1: #if left with 1 dim, keep the 2D shape
data = data.reshape(-1,1)
#ii).Delete nan values from label and data
data_nans = np.any(np.isnan(data), axis = 1)
label_nans = np.any(np.isnan(label), axis = 1)
delete_nans = np.where(data_nans+label_nans)[0]
data = np.delete(data,delete_nans, axis=0)
label = np.delete(label,delete_nans, axis=0)
#iii).Binarize label
if verbose: print('Computing bin-groups...', sep='', end = '')
#a) Check bin-num vs num unique label
for dim in range(label.shape[1]):
num_unique_label =len(np.unique(label[:,dim]))
if discrete_label[dim]:
n_bins[dim] = num_unique_label
elif n_bins[dim]>=num_unique_label:
warnings.warn(f"Along column {dim}, input 'label' has less unique "
f"values ({num_unique_label}) than specified in "
f"'n_bins' ({n_bins[dim]}). Changing 'n_bins' to "
f"{num_unique_label} and setting it to discrete.")
n_bins[dim] = num_unique_label
discrete_label[dim] = True
#b) Create bin edges of bin-groups
if 'min_label' in kwargs:
min_label = kwargs['min_label']
if not isinstance(min_label, list):
min_label = [min_label for nb in range(label.shape[1])]
else:
min_label = np.percentile(label,5, axis = 0)
if any(discrete_label):
min_label[discrete_label] = np.min(label[:,discrete_label])
if 'max_label' in kwargs:
max_label = kwargs['max_label']
if not isinstance(max_label, list):
max_label = [max_label for nb in range(label.shape[1])]
else:
max_label = np.percentile(label,95, axis = 0)
if any(discrete_label):
max_label[discrete_label] = np.max(label[:,discrete_label])
for ld in range(label.shape[1]): #prevent rounding problems
label[np.where(label[:,ld]<min_label[ld])[0],ld] = min_label[ld]+0.00001
label[np.where(label[:,ld]>max_label[ld])[0],ld] = max_label[ld]-0.00001
grid, coords = create_ndim_grid(label, n_bins, min_label, max_label, discrete_label)
bin_label = np.zeros(label.shape[0],).astype(int)*np.nan
for b in range(len(grid)):
bin_label[grid[b]] = b
#iv). Clean outliers from each bin-groups if specified in kwargs
if 'filter_noise' in kwargs and kwargs['filter_noise']:
for l in range(len(grid)):
noise_idx = filter_noisy_outliers(data[bin_label==l,:])
noise_idx = np.where(bin_label==l)[0][noise_idx]
bin_label[noise_idx] = 0
#v). Discard outlier bin-groups (n_points < n_neighbors)
#a) Compute number of points in each bin-group
unique_bin_label = np.unique(bin_label[~np.isnan(bin_label)])
n_points = np.array([np.sum(bin_label==val) for val in unique_bin_label])
#b) Get the bin-groups that do not meet criteria and delete them
min_points_per_bin = 0.1*data.shape[0]/np.prod(n_bins)
del_labels = np.where(n_points<min_points_per_bin)[0]
#c) delete outlier bin-groups
for del_idx in del_labels:
bin_label[bin_label==unique_bin_label[del_idx]] = np.nan
#d) re-computed valid bins
unique_bin_label = np.unique(bin_label[~np.isnan(bin_label)])
if len(unique_bin_label)==1: return np.nan, (np.nan,np.nan), np.nan, np.nan
if verbose:
print('\b\b\b: Done')
#__________________________________________________________________________
#| |#
#| 2. COMPUTE STRUCTURE INDEX |#
#|________________________________________________________________________|#
#i). compute overlap between bin-groups pairwise
num_bins = len(unique_bin_label)
overlap_mat = np.zeros((num_bins, num_bins))*np.nan
if verbose:
bar=tqdm(total=int((num_bins**2-num_bins)/2), desc='Computing overlap')
for a in range(num_bins):
A = data[bin_label==unique_bin_label[a]]
for b in range(a+1, num_bins):
B = data[bin_label==unique_bin_label[b]]
overlap_a_b, overlap_b_a = cloud_overlap(A,B,neighborhood_size,
distance_metric)
overlap_mat[a,b] = overlap_a_b
overlap_mat[b,a] = overlap_b_a
if verbose: bar.update(1)
if verbose: bar.close()
#ii). compute structure_index (SI)
if verbose: print('Computing structure index...', sep='', end = '')
degree_nodes = np.nansum(overlap_mat, axis=1)
SI = 1 - np.mean(degree_nodes)/(num_bins-1)
SI = 2*(SI-0.5)
SI = np.max([SI, 0])
if verbose: print(f"\b\b\b: {SI:.2f}")
#iii). Shuffling
shuf_SI = np.zeros((num_shuffles,))*np.nan
shuf_overlap_mat = np.zeros((overlap_mat.shape))
if verbose: bar=tqdm(total=num_shuffles,desc='Computing shuffling')
for s_idx in range(num_shuffles):
shuf_bin_label = copy.deepcopy(bin_label)
np.random.shuffle(shuf_bin_label)
shuf_overlap_mat *= np.nan
for a in range(shuf_overlap_mat.shape[0]):
A = data[shuf_bin_label==unique_bin_label[a]]
for b in range(a+1, shuf_overlap_mat.shape[1]):
B = data[shuf_bin_label==unique_bin_label[b]]
overlap_a_b, overlap_b_a = cloud_overlap(A,B,
neighborhood_size, distance_metric)
shuf_overlap_mat[a,b] = overlap_a_b
shuf_overlap_mat[b,a] = overlap_b_a
#iii) compute structure_index (SI)
degree_nodes = np.nansum(shuf_overlap_mat, axis=1)
shuf_SI[s_idx] = 1 - np.mean(degree_nodes)/(num_bins-1)
shuf_SI[s_idx] = 2*(shuf_SI[s_idx]-0.5)
shuf_SI[s_idx] = np.max([shuf_SI[s_idx], 0])
if verbose: bar.update(1)
if verbose: bar.close()
if verbose and num_shuffles>0:
print(f"Shuffling 99th percentile: {np.percentile(shuf_SI,99):.2f}")
return SI, (bin_label,coords), overlap_mat, shuf_SI
def draw_graph(overlap_mat, ax, node_cmap = plt.cm.tab10, edge_cmap = plt.cm.Greys, **kwargs):
"""Draw weighted directed graph from overlap matrix.
Parameters:
----------
overlap_mat: numpy 2d array of shape [n_bins, n_bins]
Array containing the overlapping between each pair of bin-groups.
ax: matplotlib pyplot axis object.
Optional parameters:
--------------------
node_cmap: pyplot colormap (default: plt.cm.tab10)
colormap for mapping nodes.
edge_cmap: pyplot colormap (default: plt.cm.Greys)
colormap for mapping intensities of edges.
node_cmap: pyplot colormap (default: plt.cm.tab10)
pyplot colormap used to color the nodes of the graph.
node_size: scalar or array (default: 1000)
size of nodes. If an array is specified it must be the same length
as nodelist.
scale_edges: scalar (default: 5)
number used to scale the width of the edges.
edge_vmin: scalar (default: 0)
minimum for edge colormap scaling
edge_vmax: scalar (default: 0.5)
maximum for edge colormap scaling
node_names: scalar (default: 0)
list containing name of nodes. If numerical, then nodes colormap
will be scale according to it.
node_color: list of colors (default: False)
A list of node colors to be used instead of a colormap.
It must be the same length as nodelist.
If not specified it defaults to False (bool) and uses `node_cmap` instead
"""
if int(nx.__version__[0])<3:
g = nx.from_numpy_matrix(overlap_mat,create_using=nx.DiGraph)
else:
g = nx.from_numpy_array(overlap_mat,create_using=nx.DiGraph) #version update function
number_nodes = g.number_of_nodes()
if 'node_size' in kwargs:
node_size = kwargs['node_size']
else:
node_size = 800
if 'scale_edges' in kwargs:
scale_edges = kwargs['scale_edges']
else:
scale_edges = 5
if 'edge_vmin' in kwargs:
edge_vmin = kwargs['edge_vmin']
else:
edge_vmin = 0
if 'edge_vmax' in kwargs:
edge_vmax = kwargs['edge_vmax']
else:
edge_vmax = 0.5
if 'node_color' in kwargs:
node_color = kwargs['node_color']
else:
node_color = False
if 'arrow_size' in kwargs:
arrow_size = kwargs['arrow_size']
else:
arrow_size = 20
if 'node_names' in kwargs:
node_names = kwargs['node_names']
nodes_info = list(g.nodes(data=True))
names_dict = {val[0]: node_names[i] for i, val in enumerate(nodes_info)}
with_labels = True
if not isinstance(node_names[0], str):
node_val = node_names
else:
node_val = range(number_nodes)
else:
names_dict = dict()
node_val = range(number_nodes)
with_labels = False
if 'layout_type' in kwargs:
layout_type = kwargs['layout_type']
else:
layout_type = nx.circular_layout
if not node_color: # obtain list of colors from cmap
norm_cmap = matplotlib.colors.Normalize(vmin=np.min(node_val), vmax=np.max(node_val))
node_color = list()
for ii in range(number_nodes):
#colormap possible values = viridis, jet, spectral
node_color.append(np.array(node_cmap(norm_cmap(node_val[ii]),bytes=True))/255)
widths = nx.get_edge_attributes(g, 'weight')
wdg = nx.draw_networkx(g, pos=layout_type(g), node_size=node_size,
node_color=node_color, width=np.array(list(widths.values()))*scale_edges,
edge_color= np.array(list(widths.values())), edge_cmap =edge_cmap,
arrowsize = arrow_size, edge_vmin = edge_vmin, edge_vmax = edge_vmax, labels = names_dict,
arrows=True ,connectionstyle="arc3,rad=0.15", with_labels = with_labels, ax=ax)
return wdg