Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

includes warping that works in 2d setting #134

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.*
.DS_Store
*.sw*
.ipynb_checkpoints
Expand All @@ -13,4 +14,5 @@ _*/
*.egg-info
build
imgui.ini
wandb/
wandb/
**/data/**
652 changes: 652 additions & 0 deletions Preprocessing.ipynb

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions app/nerf/configs/nerf_base2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

logging:
exp_name: 'test-nerf'

dataset:
multiview_dataset_format: 'standard'
num_rays_sampled_per_img: 4096
# mip: 2
bg_color: 'white'

nef:
hidden_dim: 64
num_layers: 1
warp: 'mlp'
pos_embedder: 'none' #'positional' #
warp_pos_embedder: 'none' #'positional'
position_input: False
view_embedder: 'positional'
view_multires: 4
pos_multires: 8

tracer:
raymarch_type: '2d'
num_steps: 1

optimizer:
optimizer_type: 'rmsprop'
lr: 0.00001

trainer:
epochs: 750
batch_size: 1
model_format: 'full' # If loading a pretrained pipeline, 'full' = torch.load instead of torch.load_state_dict
valid_every: 50
save_every: 10
render_tb_every: 2 #10

# NOTE: These are OfflineRenderer definitions, used for validation. See WispState for interactive app definitions.
renderer:
render_batch: -1 #4000
camera_origin:
- -3.0
- 0.65
- -3.0
render_res:
- 800 #64 #600 #1024 #
- 800 #64 #800 #1024 #
26 changes: 26 additions & 0 deletions app/nerf/configs/nerf_hash2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

parent: 'nerf_base2d.yaml'

grid:
grid_type: 'HashEmbedder' #'HashGrid' #
interpolation_type: 'linear'
multiscale_type: 'cat'
feature_dim: 2
feature_std: 0.01
feature_bias: 0.0
num_lods: 16
codebook_bitwidth: 19
tree_type: 'geometric'
min_grid_res: 16
max_grid_res: 2048
blas_level: 7

trainer:
prune_every: -1 #100 # iterations
98 changes: 82 additions & 16 deletions app/nerf/main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import sys
sys.path.append('/scratch/soft/anaconda3/envs/wisp/lib/python3.9/site-packages/torch/lib')

import time
start_time = time.time()

import os
import argparse
Expand All @@ -17,9 +22,9 @@
import wisp.config_parser as config_parser
from wisp.framework import WispState
from wisp.datasets import MultiviewDataset, SampleRays
from wisp.models.grids import BLASGrid, OctreeGrid, CodebookOctreeGrid, TriplanarGrid, HashGrid
from wisp.models.grids import BLASGrid, OctreeGrid, CodebookOctreeGrid, TriplanarGrid, HashGrid, HashEmbedder
from wisp.tracers import BaseTracer, PackedRFTracer
from wisp.models.nefs import BaseNeuralField, NeuralRadianceField
from wisp.models.nefs import BaseNeuralField, NeuralRadianceField, NeuralRadianceField2d, DeformationField
from wisp.models.pipeline import Pipeline
from wisp.trainers import BaseTrainer, MultiviewTrainer

Expand Down Expand Up @@ -105,10 +110,18 @@ def parse_args():
'used to track the occupancy status (bottom level acceleration structure).')

nef_group = parser.add_argument_group('nef')
nef_group.add_argument('--warp', type=str, help='deformation field architecture.')
# nef_group.add_argument('--warp-arch', type=str, choices=['none', 'grid', 'mlp'],
# default='none', help='deformation field architecture.')
# nef_group.add_argument('--warp-type', type=str, choices=['transformation', 'se2', 'se3'],
# default='transformation', help='deformation field type.')
nef_group.add_argument('--pos-embedder', type=str, choices=['none', 'identity', 'positional'],
default='positional',
help='MLP Decoder of neural field: Positional embedder used to encode input coordinates'
'or view directions.')
nef_group.add_argument('--warp-pos-embedder', type=str, choices=['none', 'identity', 'positional'],
default='positional',
help='MLP Decoder of deformation field: Positional embedder used to encode input coordinates')
nef_group.add_argument('--view-embedder', type=str, choices=['none', 'identity', 'positional'],
default='positional',
help='MLP Decoder of neural field: Positional embedder used to encode view direction')
Expand Down Expand Up @@ -254,6 +267,7 @@ def load_dataset(args) -> MultiviewDataset:
split='train',
mip=args.mip,
bg_color=args.bg_color,
warp=args.warp,
dataset_num_workers=args.dataset_num_workers,
transform=transform)
validation_dataset = None
Expand Down Expand Up @@ -356,6 +370,9 @@ def load_grid(args, dataset: MultiviewDataset) -> BLASGrid:
codebook_bitwidth=args.codebook_bitwidth,
blas_level=args.blas_level
)
elif args.grid_type == "HashEmbedder":
bounding_box = torch.tensor([[-1,-1,-1],[1, 1, 1]]).to(device)
grid = HashEmbedder(bounding_box) #TODO: add args later !!
else:
raise ValueError(f"Unknown grid_type argument: {args.grid_type}")
return grid
Expand All @@ -367,23 +384,69 @@ def load_neural_field(args, dataset: MultiviewDataset) -> BaseNeuralField:
The NeuralRadianceField uses spatial feature grids internally for faster feature interpolation and raymarching.
"""
grid = load_grid(args=args, dataset=dataset)
nef = NeuralRadianceField(
grid=grid,
pos_embedder=args.pos_embedder,
view_embedder=args.view_embedder,
position_input=args.position_input,
pos_multires=args.pos_multires,
view_multires=args.view_multires,
activation_type=args.activation_type,
layer_type=args.layer_type,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
prune_density_decay=args.prune_density_decay, # Used only for grid types which support pruning
prune_min_density=args.prune_min_density # Used only for grid types which support pruning
)
warpgrid = load_grid(args=args, dataset=dataset)
if args.raymarch_type == '2d':
nef = NeuralRadianceField2d(
grid=grid,
warpgrid=warpgrid,
warp=args.warp,
pos_embedder=args.pos_embedder,
pos_multires=args.pos_multires,
activation_type=args.activation_type,
layer_type=args.layer_type,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
prune_density_decay=args.prune_density_decay, # Used only for grid types which support pruning
prune_min_density=args.prune_min_density # Used only for grid types which support pruning
)
else:

nef = NeuralRadianceField(
grid=grid,
pos_embedder=args.pos_embedder,
view_embedder=args.view_embedder,
position_input=args.position_input,
pos_multires=args.pos_multires,
view_multires=args.view_multires,
activation_type=args.activation_type,
layer_type=args.layer_type,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
prune_density_decay=args.prune_density_decay, # Used only for grid types which support pruning
prune_min_density=args.prune_min_density # Used only for grid types which support pruning
)
return nef


# def load_warp_field(args, dataset: MultiviewDataset) -> BaseNeuralField:
# """ Creates a "Neural Field" instance which warps input coordinates.
# Here a DeformationField is created, which maps coordinates to coordinates in a canonical space
# for scenes with deforamation.
# The DeformationField can use spatial feature grids internally for faster feature interpolation and raymarching.
# """
# dim = 3
# grid = None
# if args.raymarch_type == '2d':
# dim = 2
# if args.warp_arch == 'grid':
# grid = load_grid(args=args, dataset=dataset)
# dnef = DeformationField(
# input_dim = dim,
# warp_arch = args.warp_arch,
# warp_type = args.warp_type,
# grid=grid,
# pos_embedder=args.warp_pos_embedder,
# pos_multires=args.pos_multires,
# activation_type=args.activation_type,
# layer_type=args.layer_type,
# hidden_dim=args.hidden_dim,
# num_layers=args.num_layers,
# prune_density_decay=args.prune_density_decay, # Used only for grid types which support pruning
# prune_min_density=args.prune_min_density # Used only for grid types which support pruning
# )
# return dnef


def load_tracer(args) -> BaseTracer:
""" Wisp "Tracers" are responsible for taking input rays, marching them through the neural field to render
an output RenderBuffer.
Expand All @@ -406,8 +469,10 @@ def load_neural_pipeline(args, dataset, device) -> Pipeline:
""" In Wisp, a Pipeline comprises of a neural field + a tracer (the latter is optional in some cases).
Together, they form the complete pipeline required to render a neural primitive from input rays / coordinates.
"""
# dnef = load_warp_field(args=args, dataset=dataset)
nef = load_neural_field(args=args, dataset=dataset)
tracer = load_tracer(args=args)
# pipeline = Pipeline(dnef=dnef, nef=nef, tracer=tracer)
pipeline = Pipeline(nef=nef, tracer=tracer)
if args.pretrained:
if args.model_format == "full":
Expand Down Expand Up @@ -497,3 +562,4 @@ def is_interactive() -> bool:
trainer.validate()
else:
trainer.train()
print("--- %s seconds ---" % (time.time() - start_time))
Loading