Skip to content

Commit

Permalink
added some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adityak77 committed Dec 8, 2021
1 parent 8b3c271 commit 870debe
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deperceiver/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def forward(self, x):

class BackboneBase(pl.LightningModule):

# layer name and channel size based on downsample factor
DOWNSAMPLE_DICT = {
32: ('layer4', 2048),
16: ('layer3', 1024),
Expand All @@ -78,6 +79,8 @@ def __init__(
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)

# choose layers of backbone to return based on downsampling factor and multiscale version
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
elif multiscale:
Expand All @@ -87,6 +90,7 @@ def __init__(

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

# record number of channels based on downsample factor and multiscale version
if multiscale:
self.num_channels = [self.DOWNSAMPLE_DICT[downsample_factor][1] for downsample_factor in [8, 16, 32]]
else:
Expand Down Expand Up @@ -122,6 +126,7 @@ def __init__(
super().__init__(backbone, num_channels, train_backbone=train_backbone, return_interm_layers=return_interm_layers, downsample_factor=downsample_factor, multiscale=multiscale)


# Merge backbone features and position encodings
class Joiner(pl.LightningModule):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions deperceiver/models/naive_deperceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ def forward(self, samples: DETRInput) -> Dict[str, Any]:
else:
srcs = []
masks = []
# generate inputs for all scales
for feature in features:
src, mask_scale = feature.decompose()
assert mask_scale is not None
srcs.append(src)
masks.append(mask_scale)

# combine projections and postitional encodings inputs for all the scales
multiscale_inputs = []
mask_all_scales = []
for i, model in enumerate(self.input_proj):
Expand Down
2 changes: 2 additions & 0 deletions main_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def main(args):

losses = ['labels', 'boxes', 'cardinality']

# Use DeTR set criterion
criterion = SetCriterion(91, matcher=matcher, weight_dict=weight_dict, eos_coef=args.eos_coef, losses=losses)

model = NaiveDePerceiver(
Expand All @@ -167,6 +168,7 @@ def main(args):

datamodule = CocoDataModule(args)

# logging for wandb visualizations
lr_monitor = LearningRateMonitor()
wandb_logger = WandbLogger(
name=args.run_name,
Expand Down

0 comments on commit 870debe

Please sign in to comment.