Skip to content

Commit

Permalink
Merge branch 'fabiofelix:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiofelix authored Sep 19, 2024
2 parents 47e8095 + ae6656c commit 84e45c2
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 20 deletions.
30 changes: 28 additions & 2 deletions act_recog/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Configs."""
import os
import pathlib
from fvcore.common.config import CfgNode

_C = CfgNode()
Expand Down Expand Up @@ -46,7 +48,7 @@ def get_cfg():
"""
Get a copy of the default config.
"""
return _C
return _C.clone()

def load_config(args):
"""
Expand All @@ -58,10 +60,34 @@ def load_config(args):
# Setup cfg.
cfg = get_cfg()
# Load config from cfg.
if isinstance(args, (str, pathlib.Path)):
args = args_hook(args)
if args.cfg_file is not None:
cfg.merge_from_file(args.cfg_file)
cfg.merge_from_file(find_config_file(args.cfg_file))
# Load config from command line, overwrite config from opts.
if args.opts is not None:
cfg.merge_from_list(args.opts)

return cfg

# get built-in configs from the step_recog/config directory
CONFIG_DIR = pathlib.Path(__file__).parent.parent.parent / 'config'

def find_config_file(cfg_file):
cfg_files = [
cfg_file, # you passed a valid config file path
CONFIG_DIR / cfg_file, # a path relative to the config directory
CONFIG_DIR / f'{cfg_file}.yaml', # the name without the extension
CONFIG_DIR / f'{cfg_file}.yml',
]
for f in cfg_files:
if os.path.isfile(f):
return f
raise FileNotFoundError(cfg_file)


def args_hook(cfg_file):
args = lambda: None
args.cfg_file = cfg_file
args.opts = None
return args
39 changes: 36 additions & 3 deletions step_recog/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""Configs."""
import os
import pathlib
from fvcore.common.config import CfgNode


_C = CfgNode()

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -51,9 +54,12 @@

_C.MODEL.YOLO_CHECKPOINT_URL = ''
_C.MODEL.OMNIGRU_CHECKPOINT_URL = ''
_C.MODEL.PRETRAINED_CHECKPOINT_URL = ''
_C.MODEL.OMNIVORE_CONFIG = 'OMNIVORE'
_C.MODEL.SLOWFAST_CONFIG = 'SLOWFAST'

_C.MODEL.VARIANTS = CfgNode(new_allowed=True)

# -----------------------------------------------------------------------------
# Dataset options
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -94,7 +100,7 @@ def get_cfg():
"""
Get a copy of the default config.
"""
return _C
return _C.clone()

def load_config(args):
"""
Expand All @@ -106,10 +112,37 @@ def load_config(args):
# Setup cfg.
cfg = get_cfg()
# Load config from cfg.
if isinstance(args, (str, pathlib.Path)):
args = args_hook(args)
if args.cfg_file is not None:
cfg.merge_from_file(args.cfg_file)
#cfg.merge_from_file(args.cfg_file)
cfg.merge_from_file(find_config_file(args.cfg_file))
# Load config from command line, overwrite config from opts.
if args.opts is not None:
cfg.merge_from_list(args.opts)

return cfg
return cfg

# get built-in configs from the step_recog/config directory
CONFIG_DIR = pathlib.Path(__file__).parent.parent.parent / 'config'


def find_config_file(cfg_file):
cfg_files = [
cfg_file, # you passed a valid config file path
CONFIG_DIR / cfg_file, # a path relative to the config directory
CONFIG_DIR / f'{cfg_file}.yaml', # the name without the extension
CONFIG_DIR / f'{cfg_file}.yml',
]
for f in cfg_files:
if os.path.isfile(f):
return f
raise FileNotFoundError(cfg_file)


def args_hook(cfg_file):
args = lambda: None
args.cfg_file = cfg_file
args.opts = None
return args

74 changes: 60 additions & 14 deletions step_recog/full/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import functools
from torch import nn
from collections import deque
from ultralytics import YOLO
Expand Down Expand Up @@ -31,14 +32,22 @@ def build_model(cfg_file, fps):

return MODEL_CLASS(cfg_file, fps).to("cuda")


@functools.lru_cache(1)
def get_omnivore(cfg_fname):
omni_cfg = act_load_config(args_hook(cfg_fname))
omnivore = Omnivore(omni_cfg, resize = False)
return omnivore, omni_cfg


class StepPredictor(nn.Module):
"""Step prediction model that takes in frames and outputs step probabilities.
"""
def __init__(self, cfg_file, video_fps = 30):
super().__init__()
# load config
self._device = nn.Parameter(torch.empty(0))
self.cfg = load_config(args_hook(cfg_file))
# load config
self.cfg = load_config(args_hook(cfg_file)).clone() # clone prob not necessary but tinfoil

# assign vocabulary
self.STEPS = np.array([
Expand Down Expand Up @@ -74,34 +83,60 @@ def forward(self, image, queue_frame = True):
class StepPredictor_GRU(StepPredictor):
def __init__(self, cfg_file, video_fps = 30):
super().__init__(cfg_file, video_fps)
self.omni_cfg = act_load_config(args_hook(self.cfg.MODEL.OMNIVORE_CONFIG))
# self.omni_cfg = act_load_config(args_hook(self.cfg.MODEL.OMNIVORE_CONFIG))

self.MAX_OBJECTS = 25
self.transform = transforms.Compose([
transforms.Resize(self.omni_cfg.MODEL.IN_SIZE),
transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE)
])
# self.transform = transforms.Compose([
# transforms.Resize(self.omni_cfg.MODEL.IN_SIZE),
# transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE)
# ])

# build model
self.head = OmniGRU(self.cfg, load=True)
self.head.eval()
frame_queue_len = 1
if self.cfg.MODEL.USE_ACTION:
self.omnivore = Omnivore(self.omni_cfg, resize = False)
omnivore, omni_cfg = get_omnivore(self.cfg.MODEL.OMNIVORE_CONFIG)
self.omnivore = omnivore
self.omni_cfg = omni_cfg
frame_queue_len = self.omni_cfg.DATASET.FPS * self.omni_cfg.MODEL.WIN_LENGTH
frame_queue_len = video_fps * self.omni_cfg.MODEL.WIN_LENGTH #default: 2seconds
self.transform = transforms.Compose([
transforms.Resize(self.omni_cfg.MODEL.IN_SIZE),
transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE)
])
#self.omnivore = Omnivore(self.omni_cfg, resize = False)
if self.cfg.MODEL.USE_OBJECTS:
yolo_checkpoint = cached_download_file(self.cfg.MODEL.YOLO_CHECKPOINT_URL)
self.yolo = YOLO(yolo_checkpoint)
self.yolo.eval = lambda *a: None
self.clip_patches = ClipPatches(utils.clip_download_root)
self.clip_patches.eval()
names = self.yolo.names
self.OBJECT_LABELS = np.array([str(names.get(i, i)) for i in range(len(names))])
else:
self.OBJECT_LABELS = np.array([], dtype=str)
if self.cfg.MODEL.USE_AUDIO:
raise NotImplementedError()

# frame buffers and model state
self.create_queue(video_fps * self.omni_cfg.MODEL.WIN_LENGTH) #default: 2seconds
self.frame_queue_len = frame_queue_len
self.create_queue(frame_queue_len) #default: 2seconds
self.h = None


def eval(self):
y=self.yolo
self.yolo = None
super().eval()
self.head.eval()
self.omnivore.eval()
self.yolo=y
return self

def reset(self):
super().__init__()
#super().__init__()
super().reset()
self.h = None

def queue_frame(self, image):
Expand All @@ -115,7 +150,7 @@ def queue_frame(self, image):
def prepare(self, im):
return self.transform(Image.fromarray(im))

def forward(self, image, queue_frame = True):
def forward(self, image, queue_frame = True, return_objects=False):
# compute yolo
Z_objects, Z_frame = torch.zeros((1, 1, 25, 0)).float(), torch.zeros((1, 1, 1, 0)).float()
if self.cfg.MODEL.USE_OBJECTS:
Expand Down Expand Up @@ -145,6 +180,7 @@ def forward(self, image, queue_frame = True):
self.queue_frame(image)

# compute omnivore embeddings
# [1, 32, 3, H, W]
X_omnivore = torch.stack(list(self.input_queue), dim=1)[None]
frame_idx = np.linspace(0, self.input_queue.maxlen - 1, self.omni_cfg.MODEL.NFRAMES).astype('long') #same as act_recog.dataset.milly.py:pack_frames_to_video_clip
X_omnivore = X_omnivore[:, :, frame_idx, :, :]
Expand All @@ -154,9 +190,19 @@ def forward(self, image, queue_frame = True):
# mix it all together
if self.h is None:
self.h = self.head.init_hidden(Z_action.shape[0])

prob_step, self.h = self.head(Z_action.to(self._device.device), self.h.float(), Z_audio.to(self._device.device), Z_objects.to(self._device.device), Z_frame.to(self._device.device))

device = self._device.device
prob_step, self.h = self.head(
Z_action.to(device),
self.h.float(),
Z_audio.to(device),
Z_objects.to(device),
Z_frame.to(device))

prob_step = torch.softmax(prob_step[..., :-2].detach(), dim=-1) #prob_step has <n classe positions> <1 no step position> <2 begin-end frame identifiers>

if return_objects:
return prob_step, results
return prob_step

class StepPredictor_Transformer(StepPredictor):
Expand All @@ -179,4 +225,4 @@ def forward(self, image, queue_frame = True):
prob_step = self.head(image.to(self._device.device), self.steps_feat.to(self._device.device))
prob_step = torch.softmax(prob_step.detach(), dim = -1)

return prob_step
return prob_step
4 changes: 3 additions & 1 deletion step_recog/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from collections import OrderedDict
from step_recog.full.download import cached_download_file

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -59,7 +60,8 @@ def __init__(self, cfg, load = False):
self.relu = torch.nn.ReLU()

if load:
self.load_state_dict( self.update_version(torch.load( cfg.MODEL.OMNIGRU_CHECKPOINT_URL )))
f = cfg.MODEL.OMNIGRU_CHECKPOINT_URL or cached_download_file(cfg.MODEL.PRETRAINED_CHECKPOINT_URL)
self.load_state_dict(self.update_version(torch.load(f)))
else:
self.apply(custom_weights)

Expand Down

0 comments on commit 84e45c2

Please sign in to comment.