-
Notifications
You must be signed in to change notification settings - Fork 17
/
runner.py
1325 lines (1172 loc) · 56.2 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Global runner for all NeRF methods.
# For convenience, we want all methods using NeRF to use this one file.
import argparse
import random
import json
import math
import time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TVF
import torch.nn as nn
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import trange, tqdm
from itertools import chain
import src.loaders as loaders
import src.nerf as nerf
import src.utils as utils
import src.sdf as sdf
import src.refl as refl
import src.lights as lights
import src.cameras as cameras
import src.hyper_config as hyper_config
import src.renderers as renderers
from src.lights import light_kinds
from src.opt import UniformAdam
from src.utils import ( save_image, save_plot, load_image, dir_to_elev_azim, git_hash )
from src.neural_blocks import ( Upsampler, SpatialEncoder, StyleTransfer, FourierEncoder )
import os
def arguments():
a = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
ST="store_true"
a.add_argument("-d", "--data", help="path to data", required=True)
a.add_argument(
"--data-kind", help="Kind of data to load", default="original", choices=list(loaders.kinds)
)
a.add_argument(
"--derive-kind", help="Attempt to derive the kind if a single file is given", action="store_false",
)
a.add_argument("--outdir", help="path to output directory", type=str, default="outputs/")
a.add_argument("--timed-outdir", help="Create new output dir with date+time of run", action=ST)
# various size arguments
a.add_argument("--size", help="post-upsampling size of output", type=int, default=32)
a.add_argument("--render-size", help="pre-upsampling size of output image", type=int, default=16)
a.add_argument("--epochs", help="number of epochs to train for", type=int, default=30000)
a.add_argument("--batch-size", help="# views pet training batch", type=int, default=8)
a.add_argument("--neural-upsample", help="Add neural upsampling", action=ST)
a.add_argument("--crop-size",help="what size to use while cropping",type=int, default=16)
a.add_argument("--test-crop-size",help="what size to use while cropping at test time",type=int, default=0)
a.add_argument("--steps", help="Number of depth steps", type=int, default=64)
a.add_argument(
"--mip", help="Use MipNeRF with different sampling", type=str, choices=["cone", "cylinder"],
)
a.add_argument(
"--sigmoid-kind", help="What activation to use with the reflectance model.",
default="upshifted", choices=list(utils.sigmoid_kinds.keys()),
)
a. add_argument(
"--feature-space", help="The feature space size when neural upsampling.",
type=int, default=32,
)
a.add_argument(
"--model", help="which shape model to use", type=str,
choices=list(nerf.model_kinds.keys()) + ["sdf"], default="plain",
)
a.add_argument(
"--dyn-model", help="Which dynamic model to use", type=str,
choices=list(nerf.dyn_model_kinds.keys()),
)
a.add_argument(
"--bg", help="What background to use for NeRF.", type=str,
choices=list(nerf.sky_kinds.keys()), default="black",
)
# this default for LR seems to work pretty well?
a.add_argument("-lr", "--learning-rate", help="learning rate", type=float, default=5e-4)
a.add_argument("--seed", help="Random seed to use, -1 is no seed", type=int, default=1337)
a.add_argument("--decay", help="Weight decay value", type=float, default=0)
a.add_argument("--notest", help="Do not run test set", action=ST)
a.add_argument("--data-parallel", help="Use data parallel for the model", action=ST)
a.add_argument(
"--omit-bg", action=ST, help="Omit black bg with some probability. Only used for faster training",
)
a.add_argument(
"--train-parts", help="Which parts of the model should be trained",
choices=["all", "refl", "occ", "path-tf", "camera"], default=["all"], nargs="+",
)
a.add_argument(
"--loss-fns", help="Loss functions to use", nargs="+", type=str, choices=list(loss_map.keys()), default=["l2"],
)
a.add_argument(
"--color-spaces", help="Color spaces to compare on", nargs="+", type=str,
choices=["rgb", "hsv", "luminance", "xyz"], default=["rgb"],
)
a.add_argument(
"--tone-map", help="Add tone mapping (1/(1+x)) before loss function", action=ST,
)
a.add_argument("--bendy", help="[WIP] Allow bendy rays!", action=ST)
a.add_argument(
"--gamma-correct-loss", type=float, default=1., help="Gamma correct by x in training",
)
a.add_argument(
"--autogamma-correct-loss", action=ST, help="Automatically infer a weight for gamma correction",
)
a.add_argument("--has-multi-light", help="For NeRV, use multi point light dataset", action=ST)
a.add_argument("--style-img", help="Image to use for style transfer", default=None)
a.add_argument("--no-sched", help="Do not use a scheduler", action=ST)
a.add_argument(
"--sched-min", default=5e-5, type=float, help="Minimum value for the scheduled learning rate.",
)
a.add_argument("--serial-idxs", help="Train on images in serial", action=ST)
# TODO really fix MPIs
a.add_argument(
"--replace", nargs="*", choices=["refl", "occ", "bg", "sigmoid", "light", "dyn", "al_occ"],
default=[], type=str, help="Modules to replace on this run, if any. Take caution for overwriting existing parts.",
)
a.add_argument(
"--all-learned-occ-kind", help="What parameters the Learned Ambient Occlusion should take",
default="pos", type=str, choices=list(renderers.all_learned_occ_kinds.keys()),
)
a.add_argument(
"--volsdf-direct-to-path", action=ST,
help="Convert an existing direct volsdf model to a path tracing model",
)
a.add_argument(
"--volsdf-alternate", help="Use alternating volume rendering/SDF training volsdf", action=ST,
)
a.add_argument(
"--shape-to-refl-size", type=int, default=64, help="Size of vector passed from density to reflectance model",
)
a.add_argument(
"--refl-order", default=2, type=int, help="Order for classical Spherical Harmonics & Fourier Basis BSDFs/Reflectance models",
)
a.add_argument(
"--inc-fourier-freqs", action=ST, help="Multiplicatively increase the fourier frequency standard deviation on each run",
)
a.add_argument("--rig-points", type=int, default=128, help="Number of rigs points to use in RigNeRF")
refla = a.add_argument_group("reflectance")
refla.add_argument(
"--refl-kind", help="What kind of reflectance model to use", choices=list(refl.refl_kinds.keys()), default="view",
)
refla.add_argument(
"--weighted-subrefl-kinds",
help="What subreflectances should be used with --refl-kind weighted. \
They will not take a spacial component, and only rely on view direction, normal, \
and light direction.",
choices=[r for r in refl.refl_kinds if r != "weighted"], nargs="+", default=[],
)
refla.add_argument(
"--normal-kind", choices=[None, "elaz", "raw"], default=None,
help="How to include normals in reflectance model. Not all surface models support normals",
)
refla.add_argument(
"--space-kind", choices=["identity", "surface", "none"], default="identity",
help="Space to encode texture: surface builds a map from 3D (identity) to 2D",
)
refla.add_argument(
"--alt-train", choices=["analytic", "learned"], default="learned",
help="Whether to train the analytic or the learned model, set per run.",
)
refla.add_argument(
"--refl-bidirectional", action=ST,
help="Allow normals to be flipped for the reflectance (just Diffuse for now)",
)
refla.add_argument(
"--view-variance-decay", type=float, default=0, help="Regularize reflectance across view directions",
)
rdra = a.add_argument_group("integrator")
rdra.add_argument(
"--integrator-kind", choices=[None, "direct", "path"], default=None,
help="Integrator to use for surface rendering",
)
rdra.add_argument(
"--occ-kind", choices=list(renderers.occ_kinds.keys()), default=None,
help="Occlusion method for shadows to use in integration.",
)
rdra.add_argument("--smooth-occ", default=0, type=float, help="Weight to smooth occlusion by.")
rdra.add_argument(
"--decay-all-learned-occ", type=float, default=0,
help="Weight to decay all learned occ by, attempting to minimize it",
)
rdra.add_argument(
"--all-learned-to-joint", action=ST,
help="Convert a fully learned occlusion model into one with an additional raycasting check"
)
lighta = a.add_argument_group("light")
lighta.add_argument(
"--light-kind", choices=list(light_kinds.keys()), default=None,
help="Kind of light to use while rendering. Dataset indicates light is in dataset",
)
lighta.add_argument(
"--light-intensity", type=int, default=100, help="Intensity of light to use with loaded dataset",
)
lighta.add_argument(
"--point-light-position", type=float, nargs="+", default=[0, 0, -3], help="Position of point light",
)
sdfa = a.add_argument_group("sdf")
sdfa.add_argument("--sdf-eikonal", help="Weight of SDF eikonal loss", type=float, default=0)
sdfa.add_argument("--surface-eikonal", help="Weight of SDF eikonal loss on surface", type=float, default=0)
# normal smoothing arguments
sdfa.add_argument("--smooth-normals", help="Amount to attempt to smooth normals", type=float, default=0)
sdfa.add_argument("--smooth-surface", help="Amount to attempt to smooth surface normals", type=float, default=0)
sdfa.add_argument(
"--smooth-eps", help="size of random uniform perturbation for smooth normals regularization",
type=float, default=1e-3,
)
sdfa.add_argument(
"--smooth-eps-rng", action=ST, help="Smooth by random amount instead of smoothing by a fixed distance",
)
sdfa.add_argument(
"--smooth-n-ord", nargs="+", default=[2], choices=[1,2], type=int,
help="Order of vector to use when smoothing normals",
)
sdfa.add_argument(
"--sdf-kind", help="Which SDF model to use", type=str,
choices=list(sdf.sdf_kinds.keys()), default="mlp",
)
sdfa.add_argument("--sphere-init", help="Initialize SDF to a sphere", action=ST)
sdfa.add_argument(
"--bound-sphere-rad", type=float, default=-1,
help="Intersect the learned SDF with a bounding sphere at the origin, < 0 is no sphere",
)
sdfa.add_argument(
"--sdf-isect-kind", choices=["sphere", "secant", "bisect"], default="bisect",
help="Marching kind to use when computing SDF intersection.",
)
sdfa.add_argument("--volsdf-scale-decay", type=float, default=0, help="Decay weight for volsdf scale")
dnerfa = a.add_argument_group("dnerf")
dnerfa.add_argument(
"--spline", type=int, default=0, help="Use spline estimator w/ given number of poitns for dynamic nerf delta prediction",
)
dnerfa.add_argument("--time-gamma", help="Apply a gamma based on time", action=ST)
dnerfa.add_argument("--with-canon", help="Preload a canonical NeRF", type=str, default=None)
dnerfa.add_argument("--fix-canon", help="Do not train canonical NeRF", action=ST)
dnerfa.add_argument(
"--render-over-time", default=-1, type=int,
help="Fix camera to i, and render over a time frame. < 0 is no camera",
)
dnerfa.add_argument(
"--render-over-time-steps", default=100, type=int, help="How many steps to render over time",
)
dnerfa.add_argument(
"--render-over-time-end-sec", default=1, type=float, help="Second to stop rendering"
)
cama = a.add_argument_group("camera parameters")
cama.add_argument("--near", help="near plane for camera", type=float, default=2)
cama.add_argument("--far", help="far plane for camera", type=float, default=6)
cama.add_argument("--cam-save-load", help="Location to save/load camera to", default=None)
vida = a.add_argument_group("Video parameters")
vida.add_argument("--start-sec", type=float, default=0, help="Start load time of video")
vida.add_argument("--end-sec", type=float, default=None, help="End load time of video")
vida.add_argument("--dyn-diverge-decay", type=float, default=0, help="Decay divergence of movement field")
vida.add_argument("--ffjord-div-decay", type=float, default=0, help="FFJORD divergence of movement field")
vida.add_argument(
"--delta-x-decay", type=float, default=0, help="How much decay for change in position for dyn",
)
vida.add_argument(
"--spline-len-decay", type=float, default=0, help="Weight for length of spline regularization"
)
vida.add_argument("--spline-pt0-decay", type=float, default=0, help="Add regularization to first point of spline")
vida.add_argument(
"--voxel-random-spline-len-decay", type=float, default=0,
help="Decay for length, randomly sampling a chunk of the grid instead of visible portions",
)
vida.add_argument(
"--random-spline-len-decay", type=float, default=0,
help="Decay for length, randomly sampling a bezier spline",
)
vida.add_argument(
"--voxel-tv-sigma",
type=float, default=0, help="Weight of total variation regularization for densitiy",
)
vida.add_argument(
"--voxel-tv-rgb",
type=float, default=0, help="Weight of total variation regularization for rgb",
)
vida.add_argument(
"--voxel-tv-bezier", type=float, default=0,
help="Weight of total variation regularization for bezier control points",
)
vida.add_argument(
"--voxel-tv-rigidity", type=float, default=0,
help="Weight of total variation regularization for rigidity",
)
vida.add_argument(
"--offset-decay", type=float, default=0,
help="Weight of total variation regularization for rigidity",
)
vida.add_argument(
"--dyn-refl-latent", type=int, default=0,
help="Size of latent vector to pass from the delta for reflectance",
)
vida.add_argument(
"--render-bezier-keyframes", action=ST, help="Render bezier control points for reference",
)
vida.add_argument(
"--cluster-movement", type=int, default=0,
help="attempts to visualize clusters of movement into k groups, 0 is off",
)
# Long videos
vida.add_argument(
"--long-vid-progressive-train", type=int, default=0,
help="Divide dataset into <N> chunks based on time, train each segment separately",
)
vida.add_argument(
"--long-vid-chunk-len-sec", type=float, default=3,
help="For a long video, how long should each chunk be in seconds",
)
vida.add_argument(
"--static-vid-cam-angle-deg", type=float, default=40,
help="Camera angle FOV in degrees, needed for static camera",
)
rprt = a.add_argument_group("reporting parameters")
rprt.add_argument("--name", help="Display name for convenience in log file", type=str, default="")
rprt.add_argument("-q", "--quiet", help="Silence tqdm", action=ST)
rprt.add_argument("--save", help="Where to save the model", type=str, default="models/model.pt")
rprt.add_argument("--save-load-opt", help="Save opt as well as model", action=ST)
rprt.add_argument("--log", help="Where to save log of arguments", type=str, default="log.json")
rprt.add_argument("--save-freq", help="# of epochs between saves", type=int, default=5000)
rprt.add_argument(
"--valid-freq", help="how often validation images are generated", type=int, default=500,
)
rprt.add_argument("--display-smoothness", action=ST, help="Display smoothness regularization")
rprt.add_argument("--nosave", help="do not save", action=ST)
rprt.add_argument("--load", help="model to load from", type=str)
rprt.add_argument("--loss-window", help="# epochs to smooth loss over", type=int, default=250)
rprt.add_argument("--notraintest", help="Do not test on training set", action=ST)
rprt.add_argument(
"--duration-sec", help="Max number of seconds to run this for, s <= 0 implies None",
type=float, default=0,
)
rprt.add_argument(
"--param-file", type=str, default=None, help="Path to JSON file to use for hyper-parameters",
)
rprt.add_argument("--skip-loss", type=int, default=0, help="Number of epochs to skip reporting loss for")
rprt.add_argument("--msssim-loss", action=ST, help="Report ms-ssim loss during testing")
rprt.add_argument("--depth-images", action=ST, help="Whether to render depth images")
rprt.add_argument("--normals-from-depth", action=ST, help="Render extra normal images from depth")
rprt.add_argument("--depth-query-normal", action=ST, help="Render extra normal images from depth")
rprt.add_argument("--plt-cmap-kind", type=str, choices=plt.colormaps(), default="magma")
rprt.add_argument("--gamma-correct", action=ST, help="Gamma correct final images")
rprt.add_argument("--render-frame", type=int, default=-1, help="Render 1 frame only, < 0 means none")
rprt.add_argument("--exp-bg", action=ST, help="Use mask of labels while rendering. For vis only")
rprt.add_argument("--test-white-bg", action=ST, help="Use white background while testing")
rprt.add_argument("--flow-map", action=ST, help="Render a flow map for a dynamic nerf scene")
rprt.add_argument("--rigidity-map", action=ST, help="Render a flow map for a dynamic nerf scene")
# TODO actually implement below
rprt.add_argument(
"--visualize", type=str, nargs="*", default=[],
choices=list(visualizations.keys()),
#["flow", "rigidity", "depth", "normals", "normals-at-depth"],
help="""\
Extra visualizations that can be rendered:
flow: 3d movement, color visualizes direction, and intensity is how much movement
rigidity: how rigid a given region is, higher-intensity is less rigid
depth: how far something is from the camera, darker is closer
normals: surface normals found by raymarching
normals-at-depth: surface normals queried one time at the termination depth of a ray
"""
)
rprt.add_argument(
"--display-regularization", action=ST,
help="Display regularization in addition to reconstruction loss",
)
rprt.add_argument(
"--y-scale", choices=["linear", "log", "symlog", "logit"], type=str,
default="linear", help="Scale kind for y-axis",
)
rprt.add_argument("--with-alpha", action=ST, help="Render images with an alpha channel")
# TODO add ability to show all regularization terms and how they change over time.
meta = a.add_argument_group("meta runner parameters")
# TODO when using torch jit has problems saving?
meta.add_argument("--torchjit", help="Use torch jit for model", action=ST)
meta.add_argument("--train-imgs", help="# training examples", type=int, default=-1)
meta.add_argument("--draw-colormap", help="Draw a colormap for each view", action=ST)
meta.add_argument(
"--convert-analytic-to-alt", action=ST,
help="Combine a model with an analytic BRDF with a learned BRDF for alternating optimization",
)
meta.add_argument("--clip-gradients", type=float, default=0, help="If > 0, clip gradients")
meta.add_argument("--versioned-save", action=ST, help="Save with versions")
meta.add_argument(
"--higher-end-chance", type=int, default=0,
help="Increase chance of training on either the start or the end",
)
meta.add_argument("--opt-step", type=int, default=1, help="Number of steps take before optimizing")
ae = a.add_argument_group("auto encoder parameters")
# TODO are these two items still used?
ae.add_argument("--latent-l2-weight", help="L2 regularize latent codes", type=float, default=0)
ae.add_argument("--normalize-latent", help="L2 normalize latent space", action=ST)
ae.add_argument("--encoding-size",help="Intermediate encoding size for AE",type=int,default=32)
opt = a.add_argument_group("optimization parameters")
opt.add_argument("--opt-kind", default="adam", choices=list(opt_kinds.keys()), help="What optimizer to use for training")
args = a.parse_args()
# runtime checks
hyper_config.load(args)
if args.timed_outdir:
now = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
args.outdir = os.path.join(args.outdir, f"{args.name}{'@' if args.name != '' else ''}{now}")
if not os.path.exists(args.outdir): os.mkdir(args.outdir)
if not args.neural_upsample:
args.render_size = args.size
args.feature_space = 3
plt.set_cmap(args.plt_cmap_kind)
assert(args.valid_freq > 0), "Must pass a valid frequency > 0"
if (args.test_crop_size <= 0): args.test_crop_size = args.crop_size
return args
opt_kinds = {
"adam": optim.Adam,
"sgd": optim.SGD,
"adamw": optim.AdamW,
"rmsprop": optim.RMSprop,
"uniform_adam": UniformAdam,
}
def load_optim(args, params):
cons = opt_kinds.get(args.opt_kind, None)
if cons is None: raise NotImplementedError(f"unknown opt kind {args.opt_kind}")
hyperparams = {
"lr": args.learning_rate,
# eps = 1e-7 was in the original paper.
"eps": 1e-7,
}
if args.opt_kind == "adam": hyperparams["weight_decay"] = args.decay
if args.opt_kind == "sgd": del hyperparams["eps"]
return cons(params, **hyperparams)
# Computes the difference of the fft of two images
def fft_loss(x, ref):
got = torch.fft.rfft2(x, dim=(-3, -2), norm="ortho")
exp = torch.fft.rfft2(ref, dim=(-3, -2), norm="ortho")
return (got - exp).abs().mean()
# TODO add LPIPS?
loss_map = {
"l2": F.mse_loss,
"l1": F.l1_loss,
"rmse": lambda x, ref: F.mse_loss(x, ref).clamp(min=1e-10).sqrt(),
"fft": fft_loss,
"ssim": utils.ssim_loss,
}
color_fns = {
"hsv": utils.rgb2hsv,
"luminance": utils.rgb2luminance,
"xyz": utils.rgb2xyz,
}
# TODO better automatic device discovery here
device = "cpu"
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# DEBUG
#torch.autograd.set_detect_anomaly(True); print("HAS DEBUG")
def render(
model, cam, crop,
# how big should the image be
size, args, times=None, with_noise=0.1,
):
ii, jj = torch.meshgrid(
torch.arange(size, device=device, dtype=torch.float),
torch.arange(size, device=device, dtype=torch.float),
indexing="ij",
)
positions = torch.stack([ii.transpose(-1, -2), jj.transpose(-1, -2)], dim=-1)
t,l,h,w = crop
positions = positions[t:t+h,l:l+w,:]
rays = cam.sample_positions(positions, size=size, with_noise=with_noise)
if times is not None: return model((rays, times)), rays
elif args.data_kind == "pixel-single": return model((rays, positions)), rays
return model(rays), rays
def depth_vis(model, args):
if not hasattr(model, "nerf"): return []
raw_depth = nerf.volumetric_integrate(model.nerf.weights, model.nerf.ts[:, None, None, None, None])
depth = (raw_depth[0]-args.near)/(args.far - args.near).clamp(min=0, max=1)
items = [depth]
if args.normals_from_depth:
depth_normal = (50*utils.depth_to_normals(depth)+1)/2
items.append(depth_normal.clamp(min=0, max=1))
return items
def flow_vis(model, args):
if not hasattr(model, "rigid_dp"): return []
flow_map = nerf.volumetric_integrate(model.nerf.weights, model.rigid_dp)[0]
flow_map /= flow_map.norm(dim=-1).max()
flow_map = flow_map.abs().sqrt().copysign(flow_map)
return [(flow_map +1)/2]
def rigidity_vis(model, args):
if not hasattr(model, "rigidity"): return []
rigidity_map = nerf.volumetric_integrate(model.nerf.weights, model.rigidity)[0]
return [rigidity_map]
# list of possible visualization
visualizations = {
"depth": depth_vis,
"flow": flow_vis,
"rigidity": rigidity_vis,
}
def save_losses(args, losses):
outdir = args.outdir
window = args.loss_window
window = min(window, len(losses))
losses = np.convolve(losses, np.ones(window)/window, mode='valid')
losses = losses[args.skip_loss:]
plt.plot(range(len(losses)), losses)
plt.yscale(args.y_scale)
plt.savefig(os.path.join(outdir, "training_loss.png"), bbox_inches='tight')
plt.close()
def load_loss_fn(args, model):
if args.style_img != None:
return StyleTransfer(load_image(args.style_img, resize=(args.size, args.size)))
# different losses like l1 or l2
loss_fns = [loss_map[lfn] for lfn in args.loss_fns]
assert(len(loss_fns) > 0), "must provide at least 1 loss function"
if len(loss_fns) == 1: loss_fn = loss_fns[0]
else:
def loss_fn(x, ref):
loss = 0
for fn in loss_fns: loss = loss + fn(x, ref)
return loss/len(loss_fns)
assert(len(args.color_spaces) > 0), "must provide at least 1 color space"
num_color_spaces = len(args.color_spaces)
# different colors like rgb, hsv
if num_color_spaces == 1 and args.color_spaces[0] == "rgb":
# do nothing since this is the default return value
...
elif num_color_spaces == 1:
cfn = color_fns[args.color_spaces[0]]
prev_loss_fn = loss_fn
loss_fn = lambda x, ref: prev_loss_fn(cfn(x), cfn(ref))
elif "rgb" in args.color_spaces:
prev_loss_fn = loss_fn
cfns = [color_fns[cs] for cs in args.color_spaces if cs != "rgb"]
def loss_fn(x, ref):
loss = prev_loss_fn(x, ref)
for cfn in cfns: loss = loss + prev_loss_fn(cfn(x), cfn(ref))
return loss/num_color_spaces
else:
prev_loss_fn = loss_fn
cfns = [color_fns[cs] for cs in args.color_spaces]
def loss_fn(x, ref):
loss = 0
for cfn in cfns: loss = loss + prev_loss_fn(cfn(x), cfn(ref))
return loss/num_color_spaces
if args.tone_map: loss_fn = utils.tone_map(loss_fn)
if args.gamma_correct_loss != 1.:
loss_fn = utils.gamma_correct_loss(loss_fn, args.gamma_correct_loss)
if args.volsdf_alternate:
return nerf.alternating_volsdf_loss(model, loss_fn, sdf.masked_loss(loss_fn))
if args.model == "sdf": loss_fn = sdf.masked_loss(loss_fn)
# if using a coarse fine model, necessary to perform loss on both coarse and fine components.
if args.model == "coarse_fine":
prev_loss_fn = loss_fn
loss_fn = lambda x, ref: prev_loss_fn(model.coarse, ref) + prev_loss_fn(x, ref)
return loss_fn
def sqr(x): return x * x
# train the model with a given camera and some labels (imgs or imgs+times)
# light is a per instance light.
def train(model, cam, labels, opt, args, sched=None):
if args.epochs == 0: return
loss_fn = load_loss_fn(args, model)
iters = range(args.epochs) if args.quiet else trange(args.epochs)
update = lambda kwargs: iters.set_postfix(**kwargs)
if args.quiet: update = lambda _: None
times = None
if type(labels) is tuple:
times = labels[-1].to(device) # oops maybe pass this down from somewhere?
labels = labels[0]
batch_size = min(args.batch_size, labels.shape[0])
get_crop = lambda: (0,0, args.size, args.size)
cs = args.crop_size
if cs != 0:
get_crop = lambda: (
random.randint(0, args.render_size-cs), random.randint(0, args.render_size-cs), cs, cs,
)
train_choices = range(labels.shape[0])
if args.higher_end_chance > 0:
train_choices = list(train_choices)
train_choices += [0] * args.higher_end_chance
train_choices += [labels.shape[0]-1] * args.higher_end_chance
next_idxs = lambda _: random.sample(train_choices, batch_size)
if args.serial_idxs: next_idxs = lambda i: [i%len(cam)] * batch_size
#next_idxs = lambda i: [i%10] * batch_size # DEBUG
losses = []
start = time.time()
should_end = lambda: False
if args.duration_sec > 0: should_end = lambda: time.time() - start > args.duration_sec
train_percent = 1/args.epochs
opt.zero_grad()
for i in iters:
if should_end():
print("Training timed out")
break
curr_percent = train_percent * i
# goes from 1/100 -> 1 gradually over epochs
exp_ratio = (1/100) ** (1-curr_percent)
idxs = next_idxs(i)
ts = None if times is None else times[idxs]
c0,c1,c2,c3 = crop = get_crop()
ref = labels[idxs][:, c0:c0+c2,c1:c1+c3, :3].to(device)
if getattr(model.refl, "light", None) is not None:
model.refl.light.set_idx(torch.tensor(idxs, device=device))
# omit items which are all darker with some likelihood. This is mainly used when
# attempting to focus on learning the refl and not the shape.
if args.omit_bg and (i % args.save_freq) != 0 and (i % args.valid_freq) != 0 and \
ref.mean() + 0.3 < sqr(random.random()): continue
out, rays = render(model, cam[idxs], crop, size=args.render_size, times=ts, args=args)
loss = loss_fn(out, ref)
assert(loss.isfinite()), f"Got {loss.item()} loss"
l2_loss = loss.item()
display = {
"l2": f"{l2_loss:.04f}",
"refresh": False,
}
if sched is not None: display["lr"] = f"{sched.get_last_lr()[0]:.1e}"
if args.latent_l2_weight > 0: loss = loss + model.nerf.latent_l2_loss * latent_l2_weight
pts = None
get_pts = lambda: 5*(torch.randn(((1<<13) * 5)//4 , 3, device=device))
# prepare one set of points for either smoothing normals or eikonal.
if args.sdf_eikonal > 0 or args.smooth_normals > 0:
# NOTE the number of points just fits in memory, can modify it at will
pts = get_pts()
n = model.sdf.normals(pts)
# E[d sdf(x)/dx] = 1, enforces that the SDF is valid.
if args.sdf_eikonal > 0: loss = loss + args.sdf_eikonal * utils.eikonal_loss(n)
# E[div(change in x)] = 0, enforcing the change in motion does not compress space.
if args.dyn_diverge_decay > 0:
# TODO maybe this is wrong? Unsure
loss=loss+args.dyn_diverge_decay*utils.divergence(model.pts, model.dp).mean()
# approximation of divergence using ffjord algorithm as in NR-NeRF
if args.ffjord_div_decay:
div_approx = utils.div_approx(model.pts, model.rigid_dp).abs().square()
loss = loss + exp_ratio * args.ffjord_div_decay * (model.canonical.alpha.detach() * div_approx).mean()
if args.view_variance_decay > 0:
pts = pts if pts is not None else get_pts()
views = torch.randn(2, *pts.shape, device=device)
refl = model.refl(pts[None].repeat_interleave(2,dim=0), views)
loss = loss + args.view_variance_decay * F.mse_loss(refl[0], refl[1])
if args.volsdf_scale_decay > 0: loss = loss + args.volsdf_scale_decay * model.scale_post_act
# dn/dx -> 0, hopefully smoothes out the local normals of the surface.
if args.smooth_normals > 0:
s_eps = args.smooth_eps
if s_eps > 0:
if args.smooth_eps_rng: s_eps = random.random() * s_eps
# epsilon-perturbation implementation from unisurf
perturb = F.normalize(torch.randn_like(pts), dim=-1) * s_eps
delta_n = n - model.sdf.normals(pts + perturb)
else:
delta_n = torch.autograd.grad(
inputs=pts, outputs=F.normalize(n, dim=-1), create_graph=True,
grad_outputs=torch.ones_like(n),
)[0]
smoothness = 0
for o in args.smooth_n_ord:
smoothness = smoothness + torch.linalg.norm(delta_n, ord=o, dim=-1).sum()
if args.display_smoothness: display["n-*"] = smoothness.item()
loss = loss + args.smooth_normals * smoothness
# smooth_both occlusion and the normals on the surface
if args.smooth_surface > 0:
model_ts = model.nerf.ts[:, None, None, None, None]
depth_region = nerf.volumetric_integrate(model.nerf.weights, model_ts)[0,...]
r_o, r_d = rays.split([3,3], dim=-1)
isect = r_o + r_d * depth_region
perturb = F.normalize(torch.randn_like(isect), dim=-1) * 1e-3
surface_normals = model.sdf.normals(isect)
delta_n = surface_normals - model.sdf.normals(isect + perturb)
smoothness = 0
for o in args.smooth_n_ord:
smoothness = smoothness + torch.linalg.norm(delta_n, ord=o, dim=-1).sum()
if args.display_smoothness: display["n-s"] = smoothness.item()
loss = loss + args.smooth_surface * smoothness
if args.surface_eikonal > 0: loss = loss + args.surface_eikonal * utils.eikonal_loss(surface_normals)
# smooth occ on the surface
if args.smooth_occ > 0 and args.smooth_surface > 0:
noise = torch.randn([*isect.shape[:-1], model.total_latent_size()], device=device)
elaz = dir_to_elev_azim(torch.randn_like(isect, requires_grad=False))
isect_elaz = torch.cat([isect, elaz], dim=-1)
att = model.occ.attenuation(isect_elaz, noise).sigmoid()
perturb = F.normalize(torch.randn_like(isect_elaz), dim=-1) * 5e-2
att_shifted = model.occ.attenuation(isect_elaz + perturb, noise)
loss = loss + args.smooth_surface * (att - att_shifted).abs().mean()
# smoothing the shadow, randomly over points and directions.
if args.smooth_occ > 0:
if pts is None:
pts = 5*(torch.randn(((1<<13) * 5)//4 , 3, device=device, requires_grad=True))
elaz = dir_to_elev_azim(torch.randn_like(pts, requires_grad=True))
pts_elaz = torch.cat([pts, elaz], dim=-1)
noise = torch.randn(pts.shape[0], model.total_latent_size(),device=device)
att = model.occ.attenuation(pts_elaz, noise).sigmoid()
perturb = F.normalize(torch.randn_like(pts_elaz), dim=-1) * 1e-2
att_shifted = model.occ.attenuation(pts_elaz + perturb, noise)
loss = loss + args.smooth_occ * (att - att_shifted).abs().mean()
if args.decay_all_learned_occ > 0:
loss = loss + args.decay_all_learned_occ * model.occ.all_learned_occ.raw_att.neg().mean()
if args.delta_x_decay > 0: loss = loss + args.delta_x_decay * model.dp.norm(dim=-1).mean()
# Apply voxel total variation losses
if args.voxel_tv_sigma > 0: loss = loss + args.voxel_tv_sigma * nerf.total_variation(model.densities)
if args.voxel_tv_rgb > 0: loss = loss + args.voxel_tv_rgb * nerf.total_variation(model.rgb)
if args.voxel_tv_bezier > 0: loss = loss + args.voxel_tv_bezier * nerf.total_variation(model.ctrl_pts_grid)
if args.voxel_tv_rigidity > 0: loss = loss + args.voxel_tv_rigidity * nerf.total_variation(model.rigidity_grid)
# apply offset loss as described in NR-NeRF
if args.offset_decay > 0:
norm_dp = torch.linalg.vector_norm(model.dp, dim=-1, keepdim=True)\
.pow(2 - model.rigidity)
reg = model.canonical.weights.detach()[None,...,None] * (norm_dp + 3e-3 * model.rigidity)
loss = loss + exp_ratio * args.offset_decay * reg.mean()
# apply regularization on spline length, to get smallest spline that fits.
# only apply this to visible points though
if args.spline_len_decay > 0:
arc_lens = nerf.arc_len(model.ctrl_pts)
w = model.canonical.weights.detach().squeeze(1)
loss = loss + args.spline_len_decay * (w * arc_lens).mean()
if args.spline_pt0_decay > 0 and hasattr(model, "init_p"):
loss = loss + args.spline_pt0_decay * torch.linalg.norm(model.init_p, dim=-1).mean()
# TODO is there any way to unify these two? Maybe provide a method to get a random sample on
# the class?
if args.voxel_random_spline_len_decay > 0:
x,y,z = nerf.random_sample_grid(model.ctrl_pts_grid, samples=16**3)
data = torch.stack(model.ctrl_pts_grid[x,y,z].split(3, dim=-1), dim=0)\
[:, :, None, None, None, :]
loss = loss + args.random_spline_len_decay * nerf.arc_len(data).mean()
if args.random_spline_len_decay > 0:
if pts is None: pts = 5*(torch.randn(((1<<13) * 5)//4 , 3, device=device, requires_grad=True))
pts = torch.stack(model.delta_estim(pts)[..., 1:].split(3,dim=-1), dim=0)\
[:, :, None, None, None, :]
loss = loss + args.random_spline_len_decay * nerf.arc_len(pts).mean()
# apply eikonal loss for rendered DynamicNeRF
if args.sdf_eikonal > 0 and isinstance(model, nerf.DynamicNeRF):
t = torch.rand(*pts.shape[:-1], 1, device=device)
dp = model.time_estim(pts, t)
n_dyn = model.sdf.normals(pts + dp)
loss = loss + args.sdf_eikonal * utils.eikonal_loss(n_dyn)
# --- Finished with applying any sort of regularization
if args.display_regularization: display["reg"] = f"{loss.item():.03f}"
update(display)
losses.append(l2_loss)
assert(loss.isfinite().item()), "Got NaN loss"
if args.opt_step != 1: loss = loss / args.opt_step
loss.backward()
if args.clip_gradients > 0: nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
if i % args.opt_step == 0:
opt.step()
opt.zero_grad()
if sched is not None: sched.step()
if args.inc_fourier_freqs:
for module in model.modules():
if not isinstance(module, FourierEncoder): continue
module.scale_freqs()
# Save outputs within the cropped region.
if i % args.valid_freq == 0:
with torch.no_grad():
ref0 = ref[0,...,:3]
items = [ref0, out[0,...,:3].clamp(min=0, max=1)]
if out.shape[-1] == 4:
items.append(ref[0,...,-1,None].expand_as(ref0))
items.append(out[0,...,-1,None].expand_as(ref0).sigmoid())
items = items + [img for vis in args.visualize for img in visualizations[vis](model, args)]
save_plot(os.path.join(args.outdir, f"valid_{i:05}.png"), *items)
if i % args.save_freq == 0 and i != 0:
version = (i // args.save_freq) if args.versioned_save else None
save(model, cam, args, opt, version)
save_losses(args, losses)
# final save does not have a version and will write to original file
save(model, cam, args, opt)
save_losses(args, losses)
def test(model, cam, labels, args, training: bool = True):
times = None
if type(labels) == tuple:
times = labels[-1].to(device)
labels = labels[0]
ls = []
gots = []
loss_strings = []
def render_test_set(model, cam, labels, offset=0):
with torch.no_grad():
for i in range(labels.shape[0]):
ts = None if times is None else times[i:i+1, ...]
exp = labels[i,...,:3]
got = torch.zeros_like(exp)
normals = torch.zeros_like(got)
depth = torch.zeros(*got.shape[:-1], 1, device=device, dtype=torch.float)
# RigNeRF visualization
if isinstance(model, nerf.RigNeRF): proximity_map = torch.zeros_like(exp)
# dynamic nerf visualization tools
flow_map = torch.zeros_like(normals)
rigidity_map = torch.zeros_like(depth)
if getattr(model.refl, "light", None) is not None:
model.refl.light.set_idx(torch.tensor([i], device=device))
if args.test_crop_size <= 0: args.test_crop_size = args.render_size
cs = args.test_crop_size
N = math.ceil(args.render_size/cs)
for x in range(N):
c0 = x * cs
for y in range(N):
c1 = y * cs
out, rays = render(
model, cam[i:i+1, ...], (c0,c1,cs,cs), size=args.render_size,
with_noise=False, times=ts, args=args,
)
out = out.squeeze(0)
got[c0:c0+cs, c1:c1+cs, :] = out
if hasattr(model, "nerf") and args.depth_images:
model_ts = model.nerf.ts[:, None, None, None, None]
depth[c0:c0+cs, c1:c1+cs, :] = \
nerf.volumetric_integrate(model.nerf.weights, model_ts)[0,...]
if hasattr(model, "n") and hasattr(model, "nerf") :
if args.depth_query_normal and args.depth_images:
r_o, r_d = rays.squeeze(0).split([3,3], dim=-1)
depth_region = depth[c0:c0+cs, c1:c1+cs]
isect = r_o + r_d * depth_region
normals[c0:c0+cs, c1:c1+cs] = (F.normalize(model.sdf.normals(isect), dim=-1)+1)/2
too_far_mask = depth_region > (args.far - 1e-1)
normals[c0:c0+cs, c1:c1+cs][too_far_mask[...,0]] = 0
else:
render_n = nerf.volumetric_integrate(model.nerf.weights, model.n)
normals[c0:c0+cs, c1:c1+cs, :] = (render_n[0]+1)/2
elif hasattr(model, "n") and hasattr(model, "sdf"):
...
if args.flow_map and hasattr(model, "rigid_dp"):
flow_map[c0:c0+cs,c1:c1+cs] = \
nerf.volumetric_integrate(model.nerf.weights, model.rigid_dp)
if args.rigidity_map and hasattr(model, "rigidity"):
rigidity_map[c0:c0+cs,c1:c1+cs] = \
nerf.volumetric_integrate(model.nerf.weights, model.rigidity)
if hasattr(model, "displace"):
proximity_map[c0:c0+cs,c1:c1+cs] = nerf.volumetric_integrate(
model.nerf.weights, model.displace.max(dim=-2)[0],
).clamp(min=0, max=1)
gots.append(got)
loss = F.mse_loss(got, exp)
psnr = utils.mse2psnr(loss).item()
ts = "" if ts is None else f",t={ts.item():.02f}"
o = i + offset
loss_string = f"[{o:03}{ts}]: L2 {loss.item():.03f} PSNR {psnr:.03f}"
print(loss_string)
loss_strings.append(loss_string)
name = f"train_{o:03}.png" if training else f"test_{o:03}.png"
if args.gamma_correct:
exp = exp.clamp(min=1e-10)**(1/2.2)
got = got.clamp(min=1e-10)**(1/2.2)
items = [exp, got.clamp(min=0, max=1)]
if hasattr(model, "n") and hasattr(model, "nerf"): items.append(normals.clamp(min=0, max=1))
if (depth != 0).any() and args.normals_from_depth:
depth_normals = (utils.depth_to_normals(depth * 100)+1)/2
items.append(depth_normals)
if hasattr(model, "nerf") and args.depth_images:
depth = (depth-args.near)/(args.far - args.near)
items.append(depth.clamp(min=0, max=1))
if args.flow_map and hasattr(model, "rigid_dp"):
flow_map /= flow_map.norm(keepdim=True, dim=-1).max()
flow_map = flow_map.abs().sqrt().copysign(flow_map)
items.append(flow_map.add(1).div(2))
if args.rigidity_map and hasattr(model, "rigidity"): items.append(rigidity_map)
if hasattr(model, "displace"): items.append(proximity_map)
if args.draw_colormap:
colormap = utils.color_map(cam[i:i+1])
items.append(colormap)
if args.exp_bg:
new_items = []
for item in items:
if item.shape[:-1] != labels.shape[1:-1]: new_items.append(item)
elif item.shape[-1] == 1: new_items.append(item * labels[i,...,3:])
else: new_items.append(torch.cat([item, labels[i,...,3:]], dim=-1))
items = new_items
save_plot(os.path.join(args.outdir, name), *items)
ls.append(psnr)
rf = args.render_frame
if args.render_frame >= 0:
if hasattr(model.refl, "light"): model.refl.light.set_idx(rf)
return render_test_set(model, cam[rf:rf+1], labels[rf:rf+1], offset=rf)
render_test_set(model, cam, labels)
# also render the multi point light dataset, have to load it separately because it's a
# slightly different light formulation.
if args.data_kind == "nerv_point" and args.has_multi_light:
multi_labels, multi_cams, multi_lights = loaders.nerv_point(
args.data, training=False, size=args.size,
light_intensity=args.light_intensity,
with_mask=False, multi_point=True, device=device,
)
model.refl.lights = multi_lights
render_test_set(model, multi_cams, multi_labels, offset=100)
labels = torch.cat([labels, multi_labels], dim=0)
summary_string = f"""[Summary {args.name} ({"training" if training else "test"}) @ {git_hash()}]:
\tmean {np.mean(ls):.03f}
\tmedian {np.median(ls):.03f}
\tmin {min(ls):.03f}
\tmax {max(ls):.03f}
\tvar {np.var(ls):.03f}"""
if args.msssim_loss:
try:
with torch.no_grad():
msssim = utils.msssim_loss(gots, labels)
summary_string += f"\nms-ssim {msssim:.03f}"
except Exception as e: print(f"msssim failed: {e}")
print(summary_string)
with open(os.path.join(args.outdir, "results.txt"), 'w') as f:
f.write(summary_string)
for ls in loss_strings:
f.write("\n")
f.write(ls)
def render_over_time(args, model, cam):
cam = cam[args.render_over_time:args.render_over_time+1]
ts = torch.linspace(0, args.render_over_time_end_sec, steps=args.render_over_time_steps, device=device)