diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2defece --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +.vscode +.hydra +inputs +outputs + +# All file or folders start with tmp will be ignored +tmp* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# +.DS_Store/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# torchsparse +torchsparse + +# tensorboard +tensorboard + +# glove +glove \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..901e92c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third-party/DPVO"] + path = third-party/DPVO + url = https://github.com/princeton-vl/DPVO.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f12f3de --- /dev/null +++ b/LICENSE @@ -0,0 +1,16 @@ +Copyright 2022-2023 3D Vision Group at the State Key Lab of CAD&CG, +Zhejiang University. All Rights Reserved. + +For more information see +If you use this software, please cite the corresponding publications +listed on the above website. + +Permission to use, copy, modify and distribute this software and its +documentation for educational, research and non-profit purposes only. +Any modification based on this work must be open-source and prohibited +for commercial use. +You must retain, in the source form of any derivative works that you +distribute, all copyright, patent, trademark, and attribution notices +from the source form of this work. + +For commercial uses of this software, please send email to xwzhou@zju.edu.cn \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..d185783 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# World-Grounded Human Motion Recovery via Gravity-View Coordinates +### [Project Page](https://zju3dv.github.io/gvhmr) | [Paper](https://arxiv.org/pdf/xxxx.xxxxx.pdf) + +> World-Grounded Human Motion Recovery via Gravity-View Coordinates +> [Zehong Shen](https://zehongs.github.io/)\*, +[Huaijin Pi](https://phj128.github.io/)\*, +[Yan Xia](https://isshikihugh.github.io/scholar), +[Zhi Cen](https://scholar.google.com/citations?user=Xyy-uFMAAAAJ), +[Sida Peng](https://pengsida.net/), +[Zechen Hu](https://zju3dv.github.io/gvhmr), +[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/), +[Ruizhen Hu](https://csse.szu.edu.cn/staff/ruizhenhu/), +[Xiaowei Zhou](https://xzhou.me/) +> Siggraph Asia 2024 + +

+ animated +

+ +## TODO List and ETA +- [x] Code for reproducing the train and test results (2024-8-5) +- [x] Demo code (2024-8-5) +- [x] Project page, Repository README (2024-9-4) +- [ ] Arxiv paper link (~2024-9) + + +## Setup + +Please see [installation](docs/INSTALL.md) for details. + +## Quick Start + +### Demo +Demo entries are provided in `tools/demo`. Use `-s` to skip visual odometry if you know the camera is static, otherwise the camera will be estimated by DPVO. +We also provide a script `demo_folder.py` to inference a entire folder. +```shell +python tools/demo/demo.py --video=docs/example_video/tennis.mp4 -s +python tools/demo/demo_folder.py -f inputs/demo/folder_in -d outputs/demo/folder_out -s +``` + +### Reproduce +1. **Test**: +To reproduce the 3DPW, RICH, and EMDB results in a single run, use the following command: + ```shell + python tools/train.py global/task=gvhmr/test_3dpw_emdb_rich exp=gvhmr/mixed/mixed ckpt_path=inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt + ``` + To test individual datasets, change `global/task` to `gvhmr/test_3dpw`, `gvhmr/test_rich`, or `gvhmr/test_emdb`. + +2. **Train**: +To train the model, use the following command: + ```shell + # The gvhmr_siga24_release.ckpt is trained with 2x4090 for 420 epochs, note that different GPU settings may lead to different results. + python tools/train.py exp=gvhmr/mixed/mixed + ``` + During training, note that we do not employ post-processing as in the test script, so the global metrics results will differ (but should still be good for comparison with baseline methods). + + +# Acknowledgement + +We thank the authors of +[WHAM](https://github.com/yohanshin/WHAM), +[4D-Humans](https://github.com/shubham-goel/4D-Humans), +and [ViTPose-Pytorch](https://github.com/gpastal24/ViTPose-Pytorch) for their great works, without which our project/code would not be possible. diff --git a/docs/INSTALL.md b/docs/INSTALL.md new file mode 100644 index 0000000..02e8bd0 --- /dev/null +++ b/docs/INSTALL.md @@ -0,0 +1,88 @@ +# Install + +## Environment + +```bash +git clone https://github.com/zju3dv/GVHMR --recursive +cd GVHMR + +conda create -y -n gvhmr python=3.10 +conda activate gvhmr +pip install -r requirements.txt +pip install -e . +# to install gvhmr in other repo as editable, try adding "python.analysis.extraPaths": ["path/to/your/package"] to settings.json + +# DPVO +cd third-party/DPVO +wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.zip +unzip eigen-3.4.0.zip -d thirdparty && rm -rf eigen-3.4.0.zip +pip install torch-scatter -f "https://data.pyg.org/whl/torch-2.3.0+cu121.html" +export CUDA_HOME=/usr/local/cuda-12.1/ +export PATH=$PATH:/usr/local/cuda-12.1/bin/ +pip install . +``` + +## Inputs & Outputs + +```bash +mkdir inputs +mkdir outputs +``` + +**Weights** + +```bash +mkdir -p inputs/checkpoints + +# 1. You need to sign up for downloading [SMPL](https://smpl.is.tue.mpg.de/) and [SMPLX](https://smpl-x.is.tue.mpg.de/). And the checkpoints should be placed in the following structure: + +inputs/checkpoints/ +├── body_models/smplx/ +│ └── SMPLX_{GENDER}.npz # SMPLX (We predict SMPLX params + evaluation) +└── body_models/smpl/ + └── SMPL_{GENDER}.pkl # SMPL (rendering and evaluation) + +# 2. Download other pretrained models from Google-Drive (By downloading, you agree to the corresponding licences): https://drive.google.com/drive/folders/1eebJ13FUEXrKBawHpJroW0sNSxLjh9xD?usp=drive_link + +inputs/checkpoints/ +├── dpvo/ +│ └── dpvo.pth +├── gvhmr/ +│ └── gvhmr_siga24_release.ckpt +├── hmr2/ +│ └── epoch=10-step=25000.ckpt +├── vitpose/ +│ └── vitpose-h-multi-coco.pth +└── yolo/ + └── yolov8x.pt +``` + +**Data** + +We provide preprocessed data for training and evaluation. +Note that we do not intend to distribute the original datasets, and you need to download them (annotation, videos, etc.) from the original websites. +*We're unable to provide the original data due to the license restrictions.* +By downloading the preprocessed data, you agree to the original dataset's terms of use and use the data for research purposes only. + +You can download them from [Google-Drive](https://drive.google.com/drive/folders/10sEef1V_tULzddFxzCmDUpsIqfv7eP-P?usp=drive_link). Please place them in the "inputs" folder and execute the following commands: + +```bash +cd inputs +# Train +tar -xzvf AMASS_hmr4d_support.tar.gz +tar -xzvf BEDLAM_hmr4d_support.tar.gz +tar -xzvf H36M_hmr4d_support.tar.gz +# Test +tar -xzvf 3DPW_hmr4d_support.tar.gz +tar -xzvf EMDB_hmr4d_support.tar.gz +tar -xzvf RICH_hmr4d_support.tar.gz + +# The folder structure should be like this: +inputs/ +├── AMASS/hmr4d_support/ +├── BEDLAM/hmr4d_support/ +├── H36M/hmr4d_support/ +├── 3DPW/hmr4d_support/ +├── EMDB/hmr4d_support/ +└── RICH/hmr4d_support/ +``` diff --git a/docs/example_video/project_teaser.gif b/docs/example_video/project_teaser.gif new file mode 100644 index 0000000..b4013fb Binary files /dev/null and b/docs/example_video/project_teaser.gif differ diff --git a/docs/example_video/tennis.mp4 b/docs/example_video/tennis.mp4 new file mode 100644 index 0000000..fc89ea2 Binary files /dev/null and b/docs/example_video/tennis.mp4 differ diff --git a/hmr4d/__init__.py b/hmr4d/__init__.py new file mode 100644 index 0000000..ca7a686 --- /dev/null +++ b/hmr4d/__init__.py @@ -0,0 +1,9 @@ +import os +from pathlib import Path + +PROJ_ROOT = Path(__file__).resolve().parents[1] + + +def os_chdir_to_proj_root(): + """useful for running notebooks in different directories.""" + os.chdir(PROJ_ROOT) diff --git a/hmr4d/build_gvhmr.py b/hmr4d/build_gvhmr.py new file mode 100644 index 0000000..669e3af --- /dev/null +++ b/hmr4d/build_gvhmr.py @@ -0,0 +1,11 @@ +from omegaconf import OmegaConf +from hmr4d import PROJ_ROOT +from hydra.utils import instantiate +from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL + + +def build_gvhmr_demo(): + cfg = OmegaConf.load(PROJ_ROOT / "hmr4d/configs/demo_gvhmr_model/siga24_release.yaml") + gvhmr_demo_pl: DemoPL = instantiate(cfg.model, _recursive_=False) + gvhmr_demo_pl.load_pretrained_model(PROJ_ROOT / "inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt") + return gvhmr_demo_pl.eval() diff --git a/hmr4d/configs/__init__.py b/hmr4d/configs/__init__.py new file mode 100644 index 0000000..3079c6e --- /dev/null +++ b/hmr4d/configs/__init__.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from hydra.core.config_store import ConfigStore +from hydra_zen import builds + +import argparse +from hydra import compose, initialize_config_module +import os + +os.environ["HYDRA_FULL_ERROR"] = "1" + +MainStore = ConfigStore.instance() + + +def register_store_gvhmr(): + """Register group options to MainStore""" + from . import store_gvhmr + + +def parse_args_to_cfg(): + """ + Use minimal Hydra API to parse args and return cfg. + This function don't do _run_hydra which create log file hierarchy. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", "-cn", default="train") + parser.add_argument( + "overrides", + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + args = parser.parse_args() + + # Cfg + with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"): + cfg = compose(config_name=args.config_name, overrides=args.overrides) + + return cfg diff --git a/hmr4d/configs/data/mocap/testY.yaml b/hmr4d/configs/data/mocap/testY.yaml new file mode 100644 index 0000000..8dbf3c3 --- /dev/null +++ b/hmr4d/configs/data/mocap/testY.yaml @@ -0,0 +1,10 @@ +# definition of lightning datamodule (dataset + dataloader) +_target_: hmr4d.datamodule.mocap_trainX_testY.DataModule + +dataset_opts: + test: ${test_datasets} + +loader_opts: + test: + batch_size: 1 + num_workers: 0 diff --git a/hmr4d/configs/data/mocap/trainX_testY.yaml b/hmr4d/configs/data/mocap/trainX_testY.yaml new file mode 100644 index 0000000..d512a78 --- /dev/null +++ b/hmr4d/configs/data/mocap/trainX_testY.yaml @@ -0,0 +1,16 @@ +# definition of lightning datamodule (dataset + dataloader) +_target_: hmr4d.datamodule.mocap_trainX_testY.DataModule + +dataset_opts: + train: ${train_datasets} + val: ${test_datasets} + +loader_opts: + train: + batch_size: 32 + num_workers: 8 + val: + batch_size: 1 + num_workers: 1 + +limit_each_trainset: null \ No newline at end of file diff --git a/hmr4d/configs/demo.yaml b/hmr4d/configs/demo.yaml new file mode 100644 index 0000000..0f4a780 --- /dev/null +++ b/hmr4d/configs/demo.yaml @@ -0,0 +1,42 @@ +defaults: + - _self_ + - model: gvhmr/gvhmr_pl_demo + - network: gvhmr/relative_transformer + - endecoder: gvhmr/v1_amass_local_bedlam_cam + +pipeline: + _target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline + args_denoiser3d: ${network} + args: + endecoder_opt: ${endecoder} + normalize_cam_angvel: True + weights: null + static_conf: null + +ckpt_path: inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt + +# ================================ # +# global setting # +# ================================ # + +video_name: ??? +output_root: outputs/demo +output_dir: "${output_root}/${video_name}" +preprocess_dir: ${output_dir}/preprocess +video_path: "${output_dir}/0_input_video.mp4" + +# Options +static_cam: False +verbose: False + +paths: + bbx: ${preprocess_dir}/bbx.pt + bbx_xyxy_video_overlay: ${preprocess_dir}/bbx_xyxy_video_overlay.mp4 + vit_features: ${preprocess_dir}/vit_features.pt + vitpose: ${preprocess_dir}/vitpose.pt + vitpose_video_overlay: ${preprocess_dir}/vitpose_video_overlay.mp4 + hmr4d_results: ${output_dir}/hmr4d_results.pt + incam_video: ${output_dir}/1_incam.mp4 + global_video: ${output_dir}/2_global.mp4 + incam_global_horiz_video: ${output_dir}/${video_name}_3_incam_global_horiz.mp4 + slam: ${preprocess_dir}/slam_results.pt diff --git a/hmr4d/configs/exp/gvhmr/mixed/mixed.yaml b/hmr4d/configs/exp/gvhmr/mixed/mixed.yaml new file mode 100644 index 0000000..9a6d36a --- /dev/null +++ b/hmr4d/configs/exp/gvhmr/mixed/mixed.yaml @@ -0,0 +1,71 @@ +# @package _global_ +defaults: + - override /data: mocap/trainX_testY + - override /model: gvhmr/gvhmr_pl + - override /endecoder: gvhmr/v1_amass_local_bedlam_cam + - override /optimizer: adamw_2e-4 + - override /scheduler_cfg: epoch_half_200_350 + - override /train_datasets: + - pure_motion_amass/v11 + - imgfeat_bedlam/v2 + - imgfeat_h36m/v1 + - imgfeat_3dpw/v1 + - override /test_datasets: + - emdb1/v1_fliptest + - emdb2/v1_fliptest + - rich/all + - 3dpw/fliptest + - override /callbacks: + - simple_ckpt_saver/every10e_top100 + - prog_bar/prog_reporter_every0.1 + - train_speed_timer/base + - lr_monitor/pl + - metric_emdb1 + - metric_emdb2 + - metric_rich + - metric_3dpw + - override /network: gvhmr/relative_transformer + +exp_name_base: mixed +exp_name_var: "" +exp_name: ${exp_name_base}${exp_name_var} +data_name: mocap_mixed_v1 + +pipeline: + _target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline + args_denoiser3d: ${network} + args: + endecoder_opt: ${endecoder} + normalize_cam_angvel: True + weights: + cr_j3d: 500. + transl_c: 1. + cr_verts: 500. + j2d: 1000. + verts2d: 1000. + + transl_w: 1. + static_conf_bce: 1. + + static_conf: + vel_thr: 0.15 + +data: + loader_opts: + train: + batch_size: 128 + num_workers: 12 + +pl_trainer: + precision: 16-mixed + log_every_n_steps: 50 + gradient_clip_val: 0.5 + max_epochs: 500 + check_val_every_n_epoch: 10 + devices: 2 + +logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: ${output_dir} # /save_dir/name/version/sub_dir + name: "" + version: "tb" # merge name and version diff --git a/hmr4d/configs/global/debug/debug_train.yaml b/hmr4d/configs/global/debug/debug_train.yaml new file mode 100644 index 0000000..1545817 --- /dev/null +++ b/hmr4d/configs/global/debug/debug_train.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +data_name: debug +exp_name: debug + +# data: +# limit_each_trainset: 40 +# loader_opts: +# train: +# batch_size: 4 +# num_workers: 0 +# val: +# batch_size: 1 +# num_workers: 0 + +pl_trainer: + limit_train_batches: 32 + limit_val_batches: 2 + check_val_every_n_epoch: 3 + enable_checkpointing: False + devices: 1 + +callbacks: + model_checkpoint: null diff --git a/hmr4d/configs/global/debug/debug_train_limit_data.yaml b/hmr4d/configs/global/debug/debug_train_limit_data.yaml new file mode 100644 index 0000000..727cf23 --- /dev/null +++ b/hmr4d/configs/global/debug/debug_train_limit_data.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +data_name: debug +exp_name: debug + +data: + limit_each_trainset: 40 + loader_opts: + train: + batch_size: 4 + num_workers: 0 + val: + batch_size: 1 + num_workers: 0 + +pl_trainer: + limit_val_batches: 2 + check_val_every_n_epoch: 3 + enable_checkpointing: False + devices: 1 + +callbacks: + model_checkpoint: null diff --git a/hmr4d/configs/global/task/gvhmr/test_3dpw.yaml b/hmr4d/configs/global/task/gvhmr/test_3dpw.yaml new file mode 100644 index 0000000..f820f9d --- /dev/null +++ b/hmr4d/configs/global/task/gvhmr/test_3dpw.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - override /data: mocap/testY + - override /test_datasets: + - 3dpw/fliptest + - override /callbacks: + - metric_3dpw + - _self_ + +task: test +data_name: test_mocap +ckpt_path: ??? # will not override previous setting if already set + +# lightning utilities +pl_trainer: + devices: 1 +logger: null diff --git a/hmr4d/configs/global/task/gvhmr/test_3dpw_emdb_rich.yaml b/hmr4d/configs/global/task/gvhmr/test_3dpw_emdb_rich.yaml new file mode 100644 index 0000000..5002168 --- /dev/null +++ b/hmr4d/configs/global/task/gvhmr/test_3dpw_emdb_rich.yaml @@ -0,0 +1,23 @@ +# @package _global_ +defaults: + - override /data: mocap/testY + - override /test_datasets: + - rich/all + - emdb1/v1_fliptest + - emdb2/v1_fliptest + - 3dpw/fliptest + - override /callbacks: + - metric_rich + - metric_emdb1 + - metric_emdb2 + - metric_3dpw + - _self_ + +task: test +data_name: test_mocap +ckpt_path: ??? # will not override previous setting if already set + +# lightning utilities +pl_trainer: + devices: 1 +logger: null diff --git a/hmr4d/configs/global/task/gvhmr/test_emdb.yaml b/hmr4d/configs/global/task/gvhmr/test_emdb.yaml new file mode 100644 index 0000000..ff1e1c9 --- /dev/null +++ b/hmr4d/configs/global/task/gvhmr/test_emdb.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - override /data: mocap/testY + - override /test_datasets: + - emdb1/v1_fliptest + - emdb2/v1_fliptest + - override /callbacks: + - metric_emdb1 + - metric_emdb2 + - _self_ + +task: test +data_name: test_mocap +ckpt_path: ??? # will not override previous setting if already set + +# lightning utilities +pl_trainer: + devices: 1 +logger: null diff --git a/hmr4d/configs/global/task/gvhmr/test_rich.yaml b/hmr4d/configs/global/task/gvhmr/test_rich.yaml new file mode 100644 index 0000000..923511b --- /dev/null +++ b/hmr4d/configs/global/task/gvhmr/test_rich.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - override /data: mocap/testY + - override /test_datasets: + - rich/all + - override /callbacks: + - metric_rich + - _self_ + +task: test +data_name: test_mocap +ckpt_path: ??? # will not override previous setting if already set + +# lightning utilities +pl_trainer: + devices: 1 +logger: null diff --git a/hmr4d/configs/hydra/default.yaml b/hmr4d/configs/hydra/default.yaml new file mode 100644 index 0000000..619a4b4 --- /dev/null +++ b/hmr4d/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +job_logging: + formatters: + simple: + datefmt: '%m/%d %H:%M:%S' + format: '[%(asctime)s][%(levelname)s] %(message)s' + colorlog: + datefmt: '%m/%d %H:%M:%S' + format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] %(message)s' + handlers: + file: + filename: ${output_dir}/${hydra.job.name}.log + +run: + dir: ${output_dir} \ No newline at end of file diff --git a/hmr4d/configs/siga24_release.yaml b/hmr4d/configs/siga24_release.yaml new file mode 100644 index 0000000..d664542 --- /dev/null +++ b/hmr4d/configs/siga24_release.yaml @@ -0,0 +1,35 @@ +pipeline: + _target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline + args_denoiser3d: ${network} + args: + endecoder_opt: ${endecoder} + normalize_cam_angvel: true + weights: null + static_conf: null +model: + _target_: hmr4d.model.gvhmr.gvhmr_pl_demo.DemoPL + pipeline: ${pipeline} +network: + _target_: hmr4d.network.gvhmr.relative_transformer.NetworkEncoderRoPEV2 + output_dim: 151 + max_len: 120 + kp2d_mapping: linear_v2 + cliffcam_dim: 3 + cam_angvel_dim: 6 + imgseq_dim: 1024 + f_imgseq_filter: null + cond_ver: v1 + latent_dim: 512 + num_layers: 12 + num_heads: 8 + mlp_ratio: 4.0 + pred_cam_ver: v2 + pred_cam_dim: 3 + static_conf_dim: 6 + pred_coco17_dim: 0 + dropout: 0.1 + avgbeta: true +endecoder: + _target_: hmr4d.model.gvhmr.utils.endecoder.EnDecoder + stats_name: MM_V1_AMASS_LOCAL_BEDLAM_CAM + noise_pose_k: 10 diff --git a/hmr4d/configs/store_gvhmr.py b/hmr4d/configs/store_gvhmr.py new file mode 100644 index 0000000..b038911 --- /dev/null +++ b/hmr4d/configs/store_gvhmr.py @@ -0,0 +1,29 @@ +# Dataset +import hmr4d.dataset.pure_motion.amass +import hmr4d.dataset.emdb.emdb_motion_test +import hmr4d.dataset.rich.rich_motion_test +import hmr4d.dataset.threedpw.threedpw_motion_test +import hmr4d.dataset.threedpw.threedpw_motion_train +import hmr4d.dataset.bedlam.bedlam +import hmr4d.dataset.h36m.h36m + +# Trainer: Model Optimizer Loss +import hmr4d.model.gvhmr.gvhmr_pl +import hmr4d.model.gvhmr.utils.endecoder +import hmr4d.model.common_utils.optimizer +import hmr4d.model.common_utils.scheduler_cfg + +# Metric +import hmr4d.model.gvhmr.callbacks.metric_emdb +import hmr4d.model.gvhmr.callbacks.metric_rich +import hmr4d.model.gvhmr.callbacks.metric_3dpw + + +# PL Callbacks +import hmr4d.utils.callbacks.simple_ckpt_saver +import hmr4d.utils.callbacks.train_speed_timer +import hmr4d.utils.callbacks.prog_bar +import hmr4d.utils.callbacks.lr_monitor + +# Networks +import hmr4d.network.gvhmr.relative_transformer diff --git a/hmr4d/configs/train.yaml b/hmr4d/configs/train.yaml new file mode 100644 index 0000000..ee1ff15 --- /dev/null +++ b/hmr4d/configs/train.yaml @@ -0,0 +1,52 @@ +# ================================ # +# override # +# ================================ # +# specify default configuration; the order determines the override order +defaults: + - _self_ + # pytorch-lightning + - data: ??? + - model: ??? + - callbacks: null + + # system + - hydra: default + + # utility groups that changes a lot + - pipeline: null + - network: null + - optimizer: null + - scheduler_cfg: default + - train_datasets: null + - test_datasets: null + - endecoder: null # normalize/unnormalize data + - refiner: null + + # global-override + - exp: ??? # set "data, model and callbacks" in yaml + - global/task: null # dump/test + - global/hsearch: null # hyper-param search + - global/debug: null # debug mode + +# ================================ # +# global setting # +# ================================ # +# expirement information +task: fit # [fit, predict] +exp_name: ??? +data_name: ??? + +# utilities in the entry file +output_dir: "outputs/${data_name}/${exp_name}" +ckpt_path: null +resume_mode: null +seed: 42 + +# lightning default settings +pl_trainer: + devices: 1 + num_sanity_val_steps: 0 # disable sanity check + precision: 32 + inference_mode: False + +logger: null diff --git a/hmr4d/datamodule/mocap_trainX_testY.py b/hmr4d/datamodule/mocap_trainX_testY.py new file mode 100644 index 0000000..af04744 --- /dev/null +++ b/hmr4d/datamodule/mocap_trainX_testY.py @@ -0,0 +1,130 @@ +import pytorch_lightning as pl +from pytorch_lightning.utilities.combined_loader import CombinedLoader +from hydra.utils import instantiate +from torch.utils.data import DataLoader, ConcatDataset, Subset +from omegaconf import ListConfig, DictConfig +from hmr4d.utils.pylogger import Log +from numpy.random import choice +from torch.utils.data import default_collate + + +import resource + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + + +def collate_fn(batch): + """Handle meta and Add batch size to the return dict + Args: + batch: list of dict, each dict is a data point + """ + # Assume all keys in the batch are the same + return_dict = {} + for k in batch[0].keys(): + if k.startswith("meta"): # data information, do not batch + return_dict[k] = [d[k] for d in batch] + else: + return_dict[k] = default_collate([d[k] for d in batch]) + return_dict["B"] = len(batch) + return return_dict + + +class DataModule(pl.LightningDataModule): + def __init__(self, dataset_opts: DictConfig, loader_opts: DictConfig, limit_each_trainset=None): + """This is a general datamodule that can be used for any dataset. + Train uses ConcatDataset + Val and Test use CombinedLoader, sequential, completely consumes ecah iterable sequentially, and returns a triplet (data, idx, iterable_idx) + + Args: + dataset_opts: the target of the dataset. e.g. dataset_opts.train = {_target_: ..., limit_size: None} + loader_opts: the options for the dataset + limit_each_trainset: limit the size of each dataset, None means no limit, useful for debugging + """ + super().__init__() + self.loader_opts = loader_opts + self.limit_each_trainset = limit_each_trainset + + # Train uses concat dataset + if "train" in dataset_opts: + assert "train" in self.loader_opts, "train not in loader_opts" + split_opts = dataset_opts.get("train") + assert isinstance(split_opts, DictConfig), "split_opts should be a dict for each dataset" + dataset = [] + dataset_num = len(split_opts) + for idx, (k, v) in enumerate(split_opts.items()): + dataset_i = instantiate(v) + if self.limit_each_trainset: + dataset_i = Subset(dataset_i, choice(len(dataset_i), self.limit_each_trainset)) + dataset.append(dataset_i) + Log.info(f"[Train Dataset][{idx+1}/{dataset_num}]: name={k}, size={len(dataset[-1])}, {v._target_}") + dataset = ConcatDataset(dataset) + self.trainset = dataset + Log.info(f"[Train Dataset][All]: ConcatDataset size={len(dataset)}") + Log.info(f"") + + # Val and Test use sequential dataset + for split in ("val", "test"): + if split not in dataset_opts: + continue + assert split in self.loader_opts, f"split={split} not in loader_opts" + split_opts = dataset_opts.get(split) + assert isinstance(split_opts, DictConfig), "split_opts should be a dict for each dataset" + dataset = [] + dataset_num = len(split_opts) + for idx, (k, v) in enumerate(split_opts.items()): + dataset.append(instantiate(v)) + dataset_type = "Val Dataset" if split == "val" else "Test Dataset" + Log.info(f"[{dataset_type}][{idx+1}/{dataset_num}]: name={k}, size={len(dataset[-1])}, {v._target_}") + setattr(self, f"{split}sets", dataset) + Log.info(f"") + + def train_dataloader(self): + if hasattr(self, "trainset"): + return DataLoader( + self.trainset, + shuffle=True, + num_workers=self.loader_opts.train.num_workers, + persistent_workers=True and self.loader_opts.train.num_workers > 0, + batch_size=self.loader_opts.train.batch_size, + drop_last=True, + collate_fn=collate_fn, + ) + else: + return super().train_dataloader() + + def val_dataloader(self): + if hasattr(self, "valsets"): + loaders = [] + for valset in self.valsets: + loaders.append( + DataLoader( + valset, + shuffle=False, + num_workers=self.loader_opts.val.num_workers, + persistent_workers=True and self.loader_opts.val.num_workers > 0, + batch_size=self.loader_opts.val.batch_size, + collate_fn=collate_fn, + ) + ) + return CombinedLoader(loaders, mode="sequential") + else: + return None + + def test_dataloader(self): + if hasattr(self, "testsets"): + loaders = [] + for testset in self.testsets: + loaders.append( + DataLoader( + testset, + shuffle=False, + num_workers=self.loader_opts.test.num_workers, + persistent_workers=False, + batch_size=self.loader_opts.test.batch_size, + collate_fn=collate_fn, + ) + ) + return CombinedLoader(loaders, mode="sequential") + else: + return super().test_dataloader() diff --git a/hmr4d/dataset/bedlam/bedlam.py b/hmr4d/dataset/bedlam/bedlam.py new file mode 100644 index 0000000..52b75ef --- /dev/null +++ b/hmr4d/dataset/bedlam/bedlam.py @@ -0,0 +1,251 @@ +from pathlib import Path +import numpy as np +import torch +from hmr4d.utils.pylogger import Log +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle +from time import time + +from hmr4d.configs import MainStore, builds +from hmr4d.utils.smplx_utils import make_smplx +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background +from hmr4d.utils.video_io_utils import read_video_np, save_video + +import hmr4d.utils.matrix as matrix +from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict +from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase +from hmr4d.dataset.bedlam.utils import mid2featname, mid2vname +from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points +from hmr4d.utils.geo.hmr_global import get_T_w2c_from_wcparams, get_c_rootparam, get_R_c2gv + + +class BedlamDatasetV2(ImgfeatMotionDatasetBase): + """mid_to_valid_range and features are newly generated.""" + + MIDINDEX_TO_LOAD = { + "all60": ("mid_to_valid_range_all60.pt", "imgfeats/bedlam_all60"), + "maxspan60": ("mid_to_valid_range_maxspan60.pt", "imgfeats/bedlam_maxspan60"), + } + + def __init__( + self, + mid_indices=["all60", "maxspan60"], + lazy_load=True, # Load from disk when needed + random1024=False, # Faster loading for debugging + ): + self.root = Path("inputs/BEDLAM/hmr4d_support") + self.min_motion_frames = 60 + self.max_motion_frames = 120 + self.lazy_load = lazy_load + self.random1024 = random1024 + + # speficify mid_index to handle + if not isinstance(mid_indices, list): + mid_indices = [mid_indices] + self.mid_indices = mid_indices + assert all([m in self.MIDINDEX_TO_LOAD for m in mid_indices]) + + super().__init__() + + def _load_dataset(self): + Log.info(f"[BEDLAM] Loading from {self.root}") + tic = time() + # Load mid to valid range + self.mid_to_valid_range = {} + self.mid_to_imgfeat_dir = {} + for m in self.mid_indices: + fn, feat_dir = self.MIDINDEX_TO_LOAD[m] + mid_to_valid_range_ = torch.load(self.root / fn) + self.mid_to_valid_range.update(mid_to_valid_range_) + self.mid_to_imgfeat_dir.update({mid: self.root / feat_dir for mid in mid_to_valid_range_}) + + # Load motionfiles + Log.info(f"[BEDLAM] Start loading motion files") + if self.random1024: # Debug, faster loading + try: + Log.info(f"[BEDLAM] Loading 1024 samples for debugging ...") + self.motion_files = torch.load(self.root / "smplpose_v2_random1024.pth") + except: + Log.info(f"[BEDLAM] Not found, saving 1024 samples to disk ...") + self.motion_files = torch.load(self.root / "smplpose_v2.pth") + keys = list(self.motion_files.keys()) + keys = np.random.choice(keys, 1024, replace=False) + self.motion_files = {k: self.motion_files[k] for k in keys} + torch.save(self.motion_files, self.root / "smplpose_v2_random1024.pth") + self.mid_to_valid_range = {k: v for k, v in self.mid_to_valid_range.items() if k in self.motion_files} + else: + self.motion_files = torch.load(self.root / "smplpose_v2.pth") + Log.info(f"[BEDLAM] Motion files loaded. Elapsed: {time() - tic:.2f}s") + + def _get_idx2meta(self): + # sum_frame = sum([e-s for s, e in self.mid_to_valid_range.values()]) + self.idx2meta = list(self.mid_to_valid_range.keys()) + Log.info(f"[BEDLAM] {len(self.idx2meta)} sequences. ") + + def _load_data(self, idx): + mid = self.idx2meta[idx] + # neutral smplx : "pose": (F, 63), "trans": (F, 3), "beta": (10), + # and : "skeleton": (J, 3) + data = self.motion_files[mid].copy() + + # Random select a subset + range1, range2 = self.mid_to_valid_range[mid] # [range1, range2) + mlength = range2 - range1 + min_motion_len = self.min_motion_frames + max_motion_len = self.max_motion_frames + + if mlength < min_motion_len: # the minimal mlength is 30 when generating data + start = range1 + length = mlength + else: + effect_max_motion_len = min(max_motion_len, mlength) + length = np.random.randint(min_motion_len, effect_max_motion_len + 1) # [low, high) + start = np.random.randint(range1, range2 - length + 1) + end = start + length + data["start_end"] = (start, end) + data["length"] = length + + # Update data to a subset + for k, v in data.items(): + if isinstance(v, torch.Tensor) and len(v.shape) > 1 and k != "skeleton": + data[k] = v[start:end] + + # Load img(as feature) : {mid -> 'features', 'bbx_xys', 'img_wh', 'start_end'} + imgfeat_dir = self.mid_to_imgfeat_dir[mid] + f_img_dict = torch.load(imgfeat_dir / mid2featname(mid)) + + # remap (start, end) + start_mapped = start - f_img_dict["start_end"][0] + end_mapped = end - f_img_dict["start_end"][0] + + data["f_imgseq"] = f_img_dict["features"][start_mapped:end_mapped].float() # (L, 1024) + data["bbx_xys"] = f_img_dict["bbx_xys"][start_mapped:end_mapped].float() # (L, 4) + data["img_wh"] = f_img_dict["img_wh"] # (2) + data["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3) # do not provide kp2d + + return data + + def _process_data(self, data, idx): + length = data["length"] + + # SMPL params in cam + body_pose = data["pose"][:, 3:] # (F, 63) + betas = data["beta"].repeat(length, 1) # (F, 10) + global_orient = data["global_orient_incam"] # (F, 3) + transl = data["trans_incam"] + data["cam_ext"][:, :3, 3] # (F, 3), bedlam convention + smpl_params_c = {"body_pose": body_pose, "betas": betas, "transl": transl, "global_orient": global_orient} + + # SMPL params in world + global_orient_w = data["pose"][:, :3] # (F, 3) + transl_w = data["trans"] # (F, 3) + smpl_params_w = {"body_pose": body_pose, "betas": betas, "transl": transl_w, "global_orient": global_orient_w} + + gravity_vec = torch.tensor([0, -1, 0], dtype=torch.float32) # (3), BEDLAM is ay + T_w2c = get_T_w2c_from_wcparams( + global_orient_w=global_orient_w, + transl_w=transl_w, + global_orient_c=global_orient, + transl_c=transl, + offset=data["skeleton"][0], + ) # (F, 4, 4) + R_c2gv = get_R_c2gv(T_w2c[:, :3, :3], gravity_vec) # (F, 3, 3) + + # cam_angvel (slightly different from WHAM) + cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6) + + # Returns: do not forget to make it batchable! (last lines) + max_len = self.max_motion_frames + return_data = { + "meta": {"data_name": "bedlam", "idx": idx}, + "length": length, + "smpl_params_c": smpl_params_c, + "smpl_params_w": smpl_params_w, + "R_c2gv": R_c2gv, # (F, 3, 3) + "gravity_vec": gravity_vec, # (3) + "bbx_xys": data["bbx_xys"], # (F, 3) + "K_fullimg": data["cam_int"], # (F, 3, 3) + "f_imgseq": data["f_imgseq"], # (F, D) + "kp2d": data["kp2d"], # (F, 17, 3) + "cam_angvel": cam_angvel, # (F, 6) + "mask": { + "valid": get_valid_mask(max_len, length), + "vitpose": False, + "bbx_xys": True, + "f_imgseq": True, + "spv_incam_only": False, + }, + } + + if False: # check transformation, wis3d: sampled motion (global, incam) + wis3d = make_wis3d(name="debug-data-bedlam") + smplx = make_smplx("supermotion") + + # global + smplx_out = smplx(**smpl_params_w) + w_gt_joints = smplx_out.joints + add_motion_as_lines(w_gt_joints, wis3d, name="w-gt_joints") + + # incam + smplx_out = smplx(**smpl_params_c) + c_gt_joints = smplx_out.joints + add_motion_as_lines(c_gt_joints, wis3d, name="c-gt_joints") + + # Check transformation works correctly + print("T_w2c", (apply_T_on_points(w_gt_joints, T_w2c) - c_gt_joints).abs().max()) + R_c, t_c = get_c_rootparam( + smpl_params_w["global_orient"], smpl_params_w["transl"], T_w2c, data["skeleton"][0] + ) + print("transl_c", (t_c - smpl_params_c["transl"]).abs().max()) + R_diff = matrix_to_axis_angle( + (axis_angle_to_matrix(R_c) @ axis_angle_to_matrix(smpl_params_c["global_orient"]).transpose(-1, -2)) + ).norm(dim=-1) + print("global_orient_c", R_diff.abs().max()) # < 1e-6 + + skeleton_beta = smplx.get_skeleton(smpl_params_c["betas"]) + print("Skeleton", (skeleton_beta[0] - data["skeleton"]).abs().max()) # (1.2e-7) + + if False: # cam-overlay + smplx = make_smplx("supermotion") + + # *. original bedlam param + # mid = self.idx2meta[idx] + # video_path = "-".join(mid.replace("bedlam_data/", "inputs/bedlam/").split("-")[:-1]) + # npz_file = "inputs/bedlam/processed_labels/20221024_3-10_100_batch01handhair_static_highSchoolGym.npz" + # params = np.load(npz_file, allow_pickle=True) + # mid2index = {} + # for j in tqdm(range(len(params["video_name"]))): + # k = params["video_name"][j] + "-" + params["sub"][j] + # mid2index[k] = j + # betas = params['shape'][mid2index[mid]][:length] + # global_orient_incam = torch.from_numpy(params['pose_cam'][121][:, :3]) + # body_pose = torch.from_numpy(params['pose_cam'][121][:, 3:66]) + # transl_incam = torch.from_numpy(params["trans_cam"][121]) + smplx_out = smplx(**smpl_params_c) + + # ----- Render Overlay ----- # + mid = self.idx2meta[idx] + images = read_video_np(self.root / "videos" / mid2vname(mid), data["start_end"][0], data["start_end"][1]) + render_dict = { + "K": data["cam_int"][:1], # only support batch-size 1 + "faces": smplx.faces, + "verts": smplx_out.vertices, + "background": images, + } + img_overlay = simple_render_mesh_background(render_dict) + save_video(img_overlay, "tmp.mp4", crf=23) + + # Batchable + return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len) + return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len) + return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len) + return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len) + return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len) + return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len) + return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len) + return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len) + return return_data + + +group_name = "train_datasets/imgfeat_bedlam" +MainStore.store(name="v2", node=builds(BedlamDatasetV2), group=group_name) +MainStore.store(name="v2_random1024", node=builds(BedlamDatasetV2, random1024=True), group=group_name) diff --git a/hmr4d/dataset/bedlam/resource/vname2lwh.pt b/hmr4d/dataset/bedlam/resource/vname2lwh.pt new file mode 100644 index 0000000..c1f7b2f Binary files /dev/null and b/hmr4d/dataset/bedlam/resource/vname2lwh.pt differ diff --git a/hmr4d/dataset/bedlam/utils.py b/hmr4d/dataset/bedlam/utils.py new file mode 100644 index 0000000..ca6b8d5 --- /dev/null +++ b/hmr4d/dataset/bedlam/utils.py @@ -0,0 +1,39 @@ +import torch +import numpy as np +from pathlib import Path + +resource_dir = Path(__file__).parent / "resource" + + +def mid2vname(mid): + """vname = {scene}/{seq}, Note that it ends with .mp4""" + # mid example: "inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008" + # -> vname: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4 + scene = mid.split("/")[-3] + seq = mid.split("/")[-1].split("-")[0] + vname = f"{scene}/{seq}" + return vname + + +def mid2featname(mid): + """featname = {scene}/{seqsubj}, Note that it ends with .pt (extra)""" + # mid example: "inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008" + # -> featname: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4-rp_emma_posed_008.pt + scene = mid.split("/")[-3] + seqsubj = mid.split("/")[-1] + featname = f"{scene}/{seqsubj}.pt" + return featname + + +def featname2mid(featname): + """reverse func of mid2featname, Note that it removes .pt (extra)""" + # featname example: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4-rp_emma_posed_008.pt + # -> mid: inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008 + scene = featname.split("/")[0] + seqsubj = featname.split("/")[1].strip(".pt") + mid = f"inputs/bedlam/bedlam_download/{scene}/mp4/{seqsubj}" + return mid + + +def load_vname2lwh(): + return torch.load(resource_dir / "vname2lwh.pt") diff --git a/hmr4d/dataset/emdb/emdb_motion_test.py b/hmr4d/dataset/emdb/emdb_motion_test.py new file mode 100644 index 0000000..1a2e81c --- /dev/null +++ b/hmr4d/dataset/emdb/emdb_motion_test.py @@ -0,0 +1,167 @@ +from pathlib import Path +import numpy as np +import torch +from torch.utils import data +from hmr4d.utils.pylogger import Log +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines + +from hmr4d.utils.geo_transform import compute_cam_angvel +from pytorch3d.transforms import quaternion_to_matrix +from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K +from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17 + +from .utils import EMDB1_NAMES, EMDB2_NAMES + +VID_PRESETS = {1: EMDB1_NAMES, 2: EMDB2_NAMES} + + +from hmr4d.configs import MainStore, builds + + +class EmdbSmplFullSeqDataset(data.Dataset): + def __init__(self, split=1, flip_test=False): + """ + split: 1 for EMDB-1, 2 for EMDB-2 + flip_test: if True, extra flip data will be returned + """ + super().__init__() + self.dataset_name = "EMDB" + self.split = split + self.dataset_id = f"EMDB_{split}" + Log.info(f"[{self.dataset_name}] Full sequence, split={split}") + + # Load evaluation protocol from WHAM labels + tic = Log.time() + self.emdb_dir = Path("inputs/EMDB/hmr4d_support") + # 'name', 'gender', 'smpl_params', 'mask', 'K_fullimg', 'T_w2c', 'bbx_xys', 'kp2d', 'features' + self.labels = torch.load(self.emdb_dir / "emdb_vit_v4.pt") + self.cam_traj = torch.load(self.emdb_dir / "emdb_dpvo_traj.pt") # estimated with DPVO + + # Setup dataset index + self.idx2meta = [] + for vid in VID_PRESETS[split]: + seq_length = len(self.labels[vid]["mask"]) + self.idx2meta.append((vid, 0, seq_length)) # start=0, end=seq_length + Log.info(f"[{self.dataset_name}] {len(self.idx2meta)} sequences. Elapsed: {Log.time() - tic:.2f}s") + + # If flip_test is enabled, we will return extra data for flipped test + self.flip_test = flip_test + if self.flip_test: + Log.info(f"[{self.dataset_name}] Flip test enabled") + + def __len__(self): + return len(self.idx2meta) + + def _load_data(self, idx): + data = {} + + # [vid, start, end] + vid, start, end = self.idx2meta[idx] + length = end - start + meta = {"dataset_id": self.dataset_id, "vid": vid, "vid-start-end": (start, end)} + data.update({"meta": meta, "length": length}) + + label = self.labels[vid] + + # smpl_params in world + gender = label["gender"] + smpl_params = label["smpl_params"] + mask = label["mask"] + data.update({"smpl_params": smpl_params, "gender": gender, "mask": mask}) + + # camera + # K_fullimg = label["K_fullimg"] # We use estimated K + width_height = (1440, 1920) if vid != "P0_09_outdoor_walk" else (720, 960) + K_fullimg = estimate_K(*width_height) + T_w2c = label["T_w2c"] + data.update({"K_fullimg": K_fullimg, "T_w2c": T_w2c}) + + # R_w2c -> cam_angvel + use_DPVO = False + if use_DPVO: + traj = self.cam_traj[data["meta"]["vid"]] # (L, 7) + R_w2c = quaternion_to_matrix(traj[:, [6, 3, 4, 5]]).mT # (L, 3, 3) + else: # GT + R_w2c = data["T_w2c"][:, :3, :3] # (L, 3, 3) + data["cam_angvel"] = compute_cam_angvel(R_w2c) # (L, 6) + + # image bbx, features + bbx_xys = label["bbx_xys"] + f_imgseq = label["features"] + kp2d = label["kp2d"] + data.update({"bbx_xys": bbx_xys, "f_imgseq": f_imgseq, "kp2d": kp2d}) + + # to render a video + video_path = self.emdb_dir / f"videos/{vid}.mp4" + frame_id = torch.where(mask)[0].long() + resize_factor = 0.5 + width_height_render = torch.tensor(width_height) * resize_factor + K_render = resize_K(K_fullimg, resize_factor) + bbx_xys_render = bbx_xys * resize_factor + data["meta_render"] = { + "split": self.split, + "name": vid, + "video_path": str(video_path), + "resize_factor": resize_factor, + "frame_id": frame_id, + "width_height": width_height_render.int(), + "K": K_render, + "bbx_xys": bbx_xys_render, + "R_cam_type": "DPVO" if use_DPVO else "GtGyro", + } + + # if enable flip_test + if self.flip_test: + imgfeat_dir = self.emdb_dir / "imgfeats/emdb_flip" + f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt") + + flipped_bbx_xys = f_img_dict["bbx_xys"].float() # (L, 3) + flipped_features = f_img_dict["features"].float() # (L, 1024) + width = width_height[0] + flipped_kp2d = flip_kp2d_coco17(kp2d, width) # (L, 17, 3) + + R_flip_x = torch.tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]).float() + flipped_R_w2c = R_flip_x @ R_w2c.clone() + + data_flip = { + "bbx_xys": flipped_bbx_xys, + "f_imgseq": flipped_features, + "kp2d": flipped_kp2d, + "cam_angvel": compute_cam_angvel(flipped_R_w2c), + } + data["flip_test"] = data_flip + + return data + + def _process_data(self, data): + length = data["length"] + data["K_fullimg"] = data["K_fullimg"][None].repeat(length, 1, 1) + return data + + def __getitem__(self, idx): + data = self._load_data(idx) + data = self._process_data(data) + return data + + +# EMDB-1 and EMDB-2 +MainStore.store( + name="v1", + node=builds(EmdbSmplFullSeqDataset, populate_full_signature=True), + group="test_datasets/emdb1", +) +MainStore.store( + name="v1_fliptest", + node=builds(EmdbSmplFullSeqDataset, flip_test=True, populate_full_signature=True), + group="test_datasets/emdb1", +) +MainStore.store( + name="v1", + node=builds(EmdbSmplFullSeqDataset, split=2, populate_full_signature=True), + group="test_datasets/emdb2", +) +MainStore.store( + name="v1_fliptest", + node=builds(EmdbSmplFullSeqDataset, split=2, flip_test=True, populate_full_signature=True), + group="test_datasets/emdb2", +) diff --git a/hmr4d/dataset/emdb/utils.py b/hmr4d/dataset/emdb/utils.py new file mode 100644 index 0000000..24320e9 --- /dev/null +++ b/hmr4d/dataset/emdb/utils.py @@ -0,0 +1,120 @@ +import torch +import pickle +import numpy as np +from pathlib import Path +from tqdm import tqdm +from hmr4d.utils.geo_transform import convert_lurb_to_bbx_xys +from hmr4d.utils.video_io_utils import get_video_lwh + + +def name_to_subfolder(name): + return f"{name[:2]}/{name[3:]}" + + +def name_to_local_pkl_path(name): + return f"{name_to_subfolder(name)}/{name}_data.pkl" + + +def load_raw_pkl(fp): + annot = pickle.load(open(fp, "rb")) + annot["subfolder"] = name_to_subfolder(annot["name"]) + return annot + + +def load_pkl(fp): + annot = pickle.load(open(fp, "rb")) + # ['gender', 'name', 'emdb1', 'emdb2', 'n_frames', 'good_frames_mask', 'camera', 'smpl', 'kp2d', 'bboxes', 'subfolder'] + data = {} + + F = annot["n_frames"] + smpl_params = { + "body_pose": annot["smpl"]["poses_body"], # (F, 69) + "betas": annot["smpl"]["betas"][None].repeat(F, axis=0), # (F, 10) + "global_orient": annot["smpl"]["poses_root"], # (F, 3) + "transl": annot["smpl"]["trans"], # (F, 3) + } + smpl_params = {k: torch.from_numpy(v).float() for k, v in smpl_params.items()} + + data["name"] = annot["name"] + data["gender"] = annot["gender"] + data["smpl_params"] = smpl_params + data["mask"] = torch.from_numpy(annot["good_frames_mask"]).bool() # (L,) + data["K_fullimg"] = torch.from_numpy(annot["camera"]["intrinsics"]).float() # (3, 3) + data["T_w2c"] = torch.from_numpy(annot["camera"]["extrinsics"]).float() # (L, 4, 4) + bbx_lurb = torch.from_numpy(annot["bboxes"]["bboxes"]).float() + data["bbx_xys"] = convert_lurb_to_bbx_xys(bbx_lurb) # (L, 3) + + return data + + +EMDB1_LIST = [ + "P1/14_outdoor_climb/P1_14_outdoor_climb_data.pkl", + "P2/23_outdoor_hug_tree/P2_23_outdoor_hug_tree_data.pkl", + "P3/31_outdoor_workout/P3_31_outdoor_workout_data.pkl", + "P3/32_outdoor_soccer_warmup_a/P3_32_outdoor_soccer_warmup_a_data.pkl", + "P3/33_outdoor_soccer_warmup_b/P3_33_outdoor_soccer_warmup_b_data.pkl", + "P5/42_indoor_dancing/P5_42_indoor_dancing_data.pkl", + "P5/44_indoor_rom/P5_44_indoor_rom_data.pkl", + "P6/49_outdoor_big_stairs_down/P6_49_outdoor_big_stairs_down_data.pkl", # DUPLICATE + "P6/50_outdoor_workout/P6_50_outdoor_workout_data.pkl", + "P6/51_outdoor_dancing/P6_51_outdoor_dancing_data.pkl", + "P7/57_outdoor_rock_chair/P7_57_outdoor_rock_chair_data.pkl", # DUPLICATE + "P7/59_outdoor_rom/P7_59_outdoor_rom_data.pkl", + "P7/60_outdoor_workout/P7_60_outdoor_workout_data.pkl", + "P8/64_outdoor_skateboard/P8_64_outdoor_skateboard_data.pkl", # DUPLICATE + "P8/68_outdoor_handstand/P8_68_outdoor_handstand_data.pkl", + "P8/69_outdoor_cartwheel/P8_69_outdoor_cartwheel_data.pkl", + "P9/76_outdoor_sitting/P9_76_outdoor_sitting_data.pkl", +] +EMDB1_NAMES = ["_".join(p.split("/")[:2]) for p in EMDB1_LIST] + + +EMDB2_LIST = [ + "P0/09_outdoor_walk/P0_09_outdoor_walk_data.pkl", + "P2/19_indoor_walk_off_mvs/P2_19_indoor_walk_off_mvs_data.pkl", + "P2/20_outdoor_walk/P2_20_outdoor_walk_data.pkl", + "P2/24_outdoor_long_walk/P2_24_outdoor_long_walk_data.pkl", + "P3/27_indoor_walk_off_mvs/P3_27_indoor_walk_off_mvs_data.pkl", + "P3/28_outdoor_walk_lunges/P3_28_outdoor_walk_lunges_data.pkl", + "P3/29_outdoor_stairs_up/P3_29_outdoor_stairs_up_data.pkl", + "P3/30_outdoor_stairs_down/P3_30_outdoor_stairs_down_data.pkl", + "P4/35_indoor_walk/P4_35_indoor_walk_data.pkl", + "P4/36_outdoor_long_walk/P4_36_outdoor_long_walk_data.pkl", + "P4/37_outdoor_run_circle/P4_37_outdoor_run_circle_data.pkl", + "P5/40_indoor_walk_big_circle/P5_40_indoor_walk_big_circle_data.pkl", + "P6/48_outdoor_walk_downhill/P6_48_outdoor_walk_downhill_data.pkl", + "P6/49_outdoor_big_stairs_down/P6_49_outdoor_big_stairs_down_data.pkl", # DUPLICATE + "P7/55_outdoor_walk/P7_55_outdoor_walk_data.pkl", + "P7/56_outdoor_stairs_up_down/P7_56_outdoor_stairs_up_down_data.pkl", + "P7/57_outdoor_rock_chair/P7_57_outdoor_rock_chair_data.pkl", # DUPLICATE + "P7/58_outdoor_parcours/P7_58_outdoor_parcours_data.pkl", + "P7/61_outdoor_sit_lie_walk/P7_61_outdoor_sit_lie_walk_data.pkl", + "P8/64_outdoor_skateboard/P8_64_outdoor_skateboard_data.pkl", # DUPLICATE + "P8/65_outdoor_walk_straight/P8_65_outdoor_walk_straight_data.pkl", + "P9/77_outdoor_stairs_up/P9_77_outdoor_stairs_up_data.pkl", + "P9/78_outdoor_stairs_up_down/P9_78_outdoor_stairs_up_down_data.pkl", + "P9/79_outdoor_walk_rectangle/P9_79_outdoor_walk_rectangle_data.pkl", + "P9/80_outdoor_walk_big_circle/P9_80_outdoor_walk_big_circle_data.pkl", +] +EMDB2_NAMES = ["_".join(p.split("/")[:2]) for p in EMDB2_LIST] +EMDB_NAMES = list(sorted(set(EMDB1_NAMES + EMDB2_NAMES))) + + +def _check_annot(emdb_raw_dir=Path("inputs/EMDB/EMDB")): + for pkl_local_path in set(EMDB1_LIST + EMDB2_LIST): + annot = load_raw_pkl(emdb_raw_dir / pkl_local_path) + if any((annot["bboxes"]["invalid_idxs"] != np.where(~annot["good_frames_mask"])[0])): + print(annot["name"]) + + +def _check_length(emdb_raw_dir=Path("inputs/EMDB/EMDB"), emdb_hmr4d_support_dir=Path("inputs/EMDB/hmr4d_support")): + lengths = [] + for local_pkl_path in tqdm(set(EMDB1_LIST + EMDB2_LIST)): + data = load_pkl(emdb_raw_dir / local_pkl_path) + video_path = emdb_hmr4d_support_dir / "videos" / f"{data['name']}.mp4" + length, width, height = get_video_lwh(video_path) + lengths.append(length) + print(sorted(lengths)) + + video_ram = length[-1] * (width / 4) * (height / 4) * 3 / 1e6 + print(f"Video RAM for {lengths[-1]} x {width} x {height}: {video_ram:.2f} MB") diff --git a/hmr4d/dataset/h36m/camera-parameters.json b/hmr4d/dataset/h36m/camera-parameters.json new file mode 100644 index 0000000..6e14797 --- /dev/null +++ b/hmr4d/dataset/h36m/camera-parameters.json @@ -0,0 +1,1452 @@ +{ + "intrinsics": { + "54138969": { + "calibration_matrix": [ + [ + 1145.04940458804, + 0.0, + 512.541504956548 + ], + [ + 0.0, + 1143.78109572365, + 515.4514869776 + ], + [ + 0.0, + 0.0, + 1.0 + ] + ], + "distortion": [ + -0.207098910824901, + 0.247775183068982, + -0.00142447157470321, + -0.000975698859470499, + -0.00307515035078854 + ] + }, + "55011271": { + "calibration_matrix": [ + [ + 1149.67569986785, + 0.0, + 508.848621645943 + ], + [ + 0.0, + 1147.59161666764, + 508.064917088557 + ], + [ + 0.0, + 0.0, + 1.0 + ] + ], + "distortion": [ + -0.194213629607385, + 0.240408539138292, + -0.0027408943961907, + -0.001619026613787, + 0.00681997559022603 + ] + }, + "58860488": { + "calibration_matrix": [ + [ + 1149.14071676148, + 0.0, + 519.815837182153 + ], + [ + 0.0, + 1148.7989685676, + 501.402658888552 + ], + [ + 0.0, + 0.0, + 1.0 + ] + ], + "distortion": [ + -0.208338188251856, + 0.255488007488945, + -0.000759999321030303, + 0.00148438698385668, + -0.00246049749891915 + ] + }, + "60457274": { + "calibration_matrix": [ + [ + 1145.51133842318, + 0.0, + 514.968197319863 + ], + [ + 0.0, + 1144.77392807652, + 501.882018537695 + ], + [ + 0.0, + 0.0, + 1.0 + ] + ], + "distortion": [ + -0.198384093827848, + 0.218323676298049, + -0.00181336200488089, + -0.000587205583421232, + -0.00894780704152122 + ] + } + }, + "extrinsics": { + "S1": { + "54138969": { + "R": [ + [ + -0.9153617321513369, + 0.40180836633680234, + 0.02574754463350265 + ], + [ + 0.051548117060134555, + 0.1803735689384521, + -0.9822464900705729 + ], + [ + -0.399319034032262, + -0.8977836111057917, + -0.185819527201491 + ] + ], + "t": [ + [ + -346.05078140028075 + ], + [ + 546.9807793144001 + ], + [ + 5474.481087434061 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9281683400814921, + 0.3721538354721445, + 0.002248380248018696 + ], + [ + 0.08166409428175585, + -0.1977722953267526, + -0.976840363061605 + ], + [ + -0.3630902204349604, + 0.9068559102440475, + -0.21395758897485287 + ] + ], + "t": [ + [ + 251.42516271750836 + ], + [ + 420.9422103702068 + ], + [ + 5588.195881837821 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9141549520542256, + -0.4027780222811878, + -0.045722952682337906 + ], + [ + -0.04562341383935875, + 0.21430849526487267, + -0.9756999400261069 + ], + [ + 0.40278930937200774, + -0.889854894701693, + -0.214287280609606 + ] + ], + "t": [ + [ + 480.482559565337 + ], + [ + 253.83237471361554 + ], + [ + 5704.2076793704555 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9141562410494211, + -0.40060705854636447, + 0.061905989962380774 + ], + [ + -0.05641000739510571, + -0.2769531972942539, + -0.9592261660183036 + ], + [ + 0.40141783470104664, + 0.8733904688919611, + -0.2757767409202658 + ] + ], + "t": [ + [ + 51.88347637559197 + ], + [ + 378.4208425426766 + ], + [ + 4406.149140878431 + ] + ] + } + }, + "S2": { + "54138969": { + "R": [ + [ + -0.9072826056858586, + 0.4200536513985309, + 0.019829356183203237 + ], + [ + 0.06404223092375372, + 0.18462275321422528, + -0.9807206695353717 + ], + [ + -0.4156162485733534, + -0.8885208882982778, + -0.1944061855483302 + ] + ], + "t": [ + [ + -253.9473271477662 + ], + [ + 543.369692173605 + ], + [ + 5522.981999493327 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9195695689704942, + 0.3926824530407384, + 0.013867187794489123 + ], + [ + 0.09616327770610274, + -0.190692439252443, + -0.9769283584955307 + ], + [ + -0.38097825639298405, + 0.8996871037676718, + -0.21311659595137136 + ] + ], + "t": [ + [ + 123.3506735789221 + ], + [ + 401.02404156275884 + ], + [ + 5743.522551411228 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9231022562305128, + -0.3793547679556717, + -0.06302526930870815 + ], + [ + -0.023520852900409527, + 0.21928184512961552, + -0.9753779994829639 + ], + [ + 0.3838345920067314, + -0.898891223911909, + -0.2113423136836923 + ] + ], + "t": [ + [ + 498.7689000990772 + ], + [ + 278.0695777621727 + ], + [ + 5618.721192968872 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9239917699501332, + -0.37272063182115767, + 0.08554846392108466 + ], + [ + -0.01857104155727153, + -0.2671779087245581, + -0.9634682566151569 + ], + [ + 0.38196115703026423, + 0.8886480156419687, + -0.2537919991167828 + ] + ], + "t": [ + [ + -55.1478742462578 + ], + [ + 424.8747833741909 + ], + [ + 4452.137526291175 + ] + ] + } + }, + "S3": { + "54138969": { + "R": [ + [ + -0.909926063968229, + 0.4142842734534348, + 0.020077322541766036 + ], + [ + 0.06112258570603725, + 0.18181129378483157, + -0.9814319553432596 + ], + [ + -0.41024210855042, + -0.891803338310328, + -0.19075696094942407 + ] + ], + "t": [ + [ + -144.30406670344493 + ], + [ + 546.2767112872957 + ], + [ + 5569.530692348755 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9248703521034336, + 0.3800681977315835, + 0.012767022876799783 + ], + [ + 0.093795468138089, + -0.1954524371286302, + -0.9762175756342618 + ], + [ + -0.3685339088290622, + 0.9040721817938792, + -0.21641683887726407 + ] + ], + "t": [ + [ + -38.93379836342622 + ], + [ + 375.57502666735104 + ], + [ + 5759.402838804998 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9218827889823751, + -0.38260686272952316, + -0.061189149122614306 + ], + [ + -0.02577019492115, + 0.21811470471458455, + -0.9755828374059251 + ], + [ + 0.3866109419452632, + -0.897796170731164, + -0.2109360457310579 + ] + ], + "t": [ + [ + 596.8162203909545 + ], + [ + 282.123966506171 + ], + [ + 5575.726600786697 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9244960445794738, + -0.37161308683612865, + 0.08491629554468147 + ], + [ + -0.018795005038688972, + -0.26693214570791374, + -0.963532032354589 + ], + [ + 0.3807280017840865, + 0.88918555053481, + -0.2537621827176058 + ] + ], + "t": [ + [ + -158.57266932864025 + ], + [ + 433.1881250816 + ], + [ + 4413.555688648984 + ] + ] + } + }, + "S4": { + "54138969": { + "R": [ + [ + -0.906169211683753, + 0.422346184383899, + 0.021933087625945674 + ], + [ + 0.06180306305120707, + 0.18355044391174938, + -0.9810655512947585 + ], + [ + -0.4183751201899252, + -0.8876558652294037, + -0.19243004892662768 + ] + ], + "t": [ + [ + -201.25197932223173 + ], + [ + 537.4605027947064 + ], + [ + 5553.966756732112 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9205073288493492, + 0.39058428754662783, + 0.010496278213208041 + ], + [ + 0.0923916650578188, + -0.1914846595009468, + -0.9771373523735801 + ], + [ + -0.3796446203523497, + 0.900431862773358, + -0.21234976510469855 + ] + ], + "t": [ + [ + 63.12322044876507 + ], + [ + 396.6138950755392 + ], + [ + 5760.7235858284985 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9244800422436603, + -0.37641653359695837, + -0.060392422769829215 + ], + [ + -0.02551533125481826, + 0.2191523935220463, + -0.9753569583924211 + ], + [ + 0.3803756292983513, + -0.9001571094250204, + -0.21220640656557746 + ] + ], + "t": [ + [ + 559.4298619884164 + ], + [ + 278.041710381495 + ], + [ + 5601.2846874450925 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9241606780958346, + -0.3729066880542538, + 0.0828712439019392 + ], + [ + -0.021270464031387784, + -0.2668345784720987, + -0.9635075895349796 + ], + [ + 0.38141133756265905, + 0.8886731174824772, + -0.2545299232755129 + ] + ], + "t": [ + [ + -98.61477305435534 + ], + [ + 432.68486951797627 + ], + [ + 4419.390974448715 + ] + ] + } + }, + "S5": { + "54138969": { + "R": [ + [ + -0.9042074184788829, + 0.42657831374650107, + 0.020973473936051274 + ], + [ + 0.06390493744399675, + 0.18368565260974637, + -0.9809055713959477 + ], + [ + -0.4222855708380685, + -0.8856017859436166, + -0.1933503902128034 + ] + ], + "t": [ + [ + -219.3059666108619 + ], + [ + 544.4787497640639 + ], + [ + 5518.740477016156 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9222116004775194, + 0.38649075753002626, + 0.012274293810989732 + ], + [ + 0.09333184463870337, + -0.19167233853095322, + -0.9770111982052265 + ], + [ + -0.3752531555110883, + 0.902156643264318, + -0.21283434941998647 + ] + ], + "t": [ + [ + 103.90282067751986 + ], + [ + 395.67169468951965 + ], + [ + 5767.97265758172 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9258288614330635, + -0.3728674116124112, + -0.06173178026768599 + ], + [ + -0.023578112500148365, + 0.220000562347259, + -0.9752147584905696 + ], + [ + 0.3772068291381898, + -0.9014264506460582, + -0.21247437993123308 + ] + ], + "t": [ + [ + 520.3272318446208 + ], + [ + 283.3690958234795 + ], + [ + 5591.123958858676 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9222815489764817, + -0.3772688722588351, + 0.0840532119677073 + ], + [ + -0.021177649402562934, + -0.26645871124348197, + -0.9636136478735888 + ], + [ + 0.3859381447632816, + 0.88694303832152, + -0.25373962085111357 + ] + ], + "t": [ + [ + -79.116431351199 + ], + [ + 425.59047114848386 + ], + [ + 4454.481629705836 + ] + ] + } + }, + "S6": { + "54138969": { + "R": [ + [ + -0.9149503344107554, + 0.4034864343564006, + 0.008036345687245266 + ], + [ + 0.07174776353922047, + 0.1822275975157708, + -0.9806351824867137 + ], + [ + -0.3971374371533952, + -0.896655898321083, + -0.19567845056940925 + ] + ], + "t": [ + [ + -239.5182864132218 + ], + [ + 545.8141831785044 + ], + [ + 5523.931578633363 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9197364689900042, + 0.39209901596964664, + 0.018525368698999664 + ], + [ + 0.101478073351267, + -0.19191459963948, + -0.9761511087296542 + ], + [ + -0.37919260045353465, + 0.899681692667386, + -0.21630030892357308 + ] + ], + "t": [ + [ + 169.02510061389722 + ], + [ + 409.6671223380997 + ], + [ + 5714.338002825065 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.916577698818659, + -0.39393483656788014, + -0.06856140726771254 + ], + [ + -0.01984531630322392, + 0.21607069980297702, + -0.9761760169700323 + ], + [ + 0.3993638509543854, + -0.8933805444629346, + -0.20586334624209834 + ] + ], + "t": [ + [ + 521.9864793089763 + ], + [ + 286.28272817103516 + ], + [ + 5643.2724406159 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9182950552949388, + -0.3850769011116475, + 0.09192372735651859 + ], + [ + -0.015534985886560007, + -0.26706146429979655, + -0.9635542737695438 + ], + [ + 0.3955917790277871, + 0.8833990913037544, + -0.25122338635033875 + ] + ], + "t": [ + [ + -56.29675276801464 + ], + [ + 420.29579722027506 + ], + [ + 4499.322693551688 + ] + ] + } + }, + "S7": { + "54138969": { + "R": [ + [ + -0.9055764231419416, + 0.42392653746206904, + 0.014752378956221508 + ], + [ + 0.06862812683752326, + 0.18074371881263407, + -0.9811329615890764 + ], + [ + -0.41859469903024304, + -0.8874784498483331, + -0.19277053457045695 + ] + ], + "t": [ + [ + -323.9118424584857 + ], + [ + 541.7715234126381 + ], + [ + 5506.569132699328 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9212640765077017, + 0.3886011826562522, + 0.01617473877914905 + ], + [ + 0.09922277503271489, + -0.1946115441987536, + -0.9758489574618522 + ], + [ + -0.3760682680727248, + 0.9006194910741931, + -0.21784671226815075 + ] + ], + "t": [ + [ + 178.6238708832376 + ], + [ + 403.59193467821774 + ], + [ + 5694.8801003668095 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9245069728829368, + -0.37555597339631824, + -0.06515034871105972 + ], + [ + -0.018955014220249332, + 0.21601110989507338, + -0.9762068980691586 + ], + [ + 0.38069353097569036, + -0.9012751584550871, + -0.20682244613440448 + ] + ], + "t": [ + [ + 441.1064712697594 + ], + [ + 271.91614362573955 + ], + [ + 5660.120611352617 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9228353966173104, + -0.3744001545228767, + 0.09055029013436408 + ], + [ + -0.014982084363704698, + -0.269786590656035, + -0.9628035794752281 + ], + [ + 0.3849030629889691, + 0.8871525910436372, + -0.25457791009093983 + ] + ], + "t": [ + [ + 25.768533743836343 + ], + [ + 431.05581759025813 + ], + [ + 4461.872981411145 + ] + ] + } + }, + "S8": { + "54138969": { + "R": [ + [ + -0.9115694669712032, + 0.4106494283805017, + 0.020202818036194434 + ], + [ + 0.060907749548984036, + 0.1834736632003901, + -0.9811359034082424 + ], + [ + -0.40660958293025334, + -0.8931430243150293, + -0.19226072190306673 + ] + ], + "t": [ + [ + -82.70216069652597 + ], + [ + 552.1896311377282 + ], + [ + 5557.353609418419 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.931016282525616, + 0.3647626932499711, + 0.01252434769597448 + ], + [ + 0.08939715221301257, + -0.19463753190599434, + -0.9767929055586687 + ], + [ + -0.35385990285476776, + 0.9105297407479727, + -0.2138194574051759 + ] + ], + "t": [ + [ + -209.06289992510443 + ], + [ + 375.0691429434037 + ], + [ + 5818.276676972416 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9209075762929309, + -0.3847355178017309, + -0.0625125368875214 + ], + [ + -0.02568138180824641, + 0.21992027027623712, + -0.9751797482259595 + ], + [ + 0.38893405939143305, + -0.8964450100611084, + -0.21240678280563546 + ] + ], + "t": [ + [ + 623.0985110132146 + ], + [ + 290.9053651845054 + ], + [ + 5534.379001592981 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.927667052235436, + -0.3636062759574404, + 0.08499597802942535 + ], + [ + -0.01666268768012713, + -0.26770413351564454, + -0.9633570738505596 + ], + [ + 0.37303645269074087, + 0.8922583555131325, + -0.2543989622245125 + ] + ], + "t": [ + [ + -178.36705625795474 + ], + [ + 423.4669232560848 + ], + [ + 4421.6448791590965 + ] + ] + } + }, + "S9": { + "54138969": { + "R": [ + [ + -0.9033486204435297, + 0.4269119782787646, + 0.04132109321984796 + ], + [ + 0.04153061098352977, + 0.182951140059007, + -0.9822444139329296 + ], + [ + -0.4268916470184284, + -0.8855930460167476, + -0.18299857527497945 + ] + ], + "t": [ + [ + -321.2078335720134 + ], + [ + 467.13452033013084 + ], + [ + 5514.330338522134 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9315720471487059, + 0.36348288012373176, + -0.007329176497134756 + ], + [ + 0.06810069482701912, + -0.19426747906725159, + -0.9785818524481906 + ], + [ + -0.35712157080642226, + 0.911120377575769, + -0.20572758986325015 + ] + ], + "t": [ + [ + 19.193095609487138 + ], + [ + 404.22842728571936 + ], + [ + 5702.169280033924 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9269344193869241, + -0.3732303525241731, + -0.03862235247246717 + ], + [ + -0.04725991098820678, + 0.218240494552814, + -0.9747500127472326 + ], + [ + 0.37223525218497616, + -0.901704048173249, + -0.21993345934341726 + ] + ], + "t": [ + [ + 455.40107288876885 + ], + [ + 273.3589338272866 + ], + [ + 5657.814488280711 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.915460708083783, + -0.39734606500700814, + 0.06362229623477154 + ], + [ + -0.04940628468469528, + -0.26789167566119776, + -0.9621814117644814 + ], + [ + 0.39936288133525055, + 0.8776959352388969, + -0.26487569589663096 + ] + ], + "t": [ + [ + -69.271255294384 + ], + [ + 422.1843366088847 + ], + [ + 4457.893374979773 + ] + ] + } + }, + "S10": { + "54138969": { + "R": [ + [ + -0.9199955359932982, + 0.39133749168985454, + 0.021521648410310328 + ], + [ + 0.0555185840851712, + 0.18448351869097226, + -0.9812662829999691 + ], + [ + -0.3879766752957989, + -0.9015657485337887, + -0.19145051709809383 + ] + ], + "t": [ + [ + -181.4625993368258 + ], + [ + 543.5199110634021 + ], + [ + 5582.194377534298 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9152587269115653, + 0.40266346194010966, + 0.01279059148853104 + ], + [ + 0.09843295287698457, + -0.1927270179143742, + -0.9763028476624197 + ], + [ + -0.3906563919867918, + 0.8948287171209015, + -0.21603043863220686 + ] + ], + "t": [ + [ + -22.5707386911355 + ], + [ + 383.7773845053516 + ], + [ + 5727.149101385447 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9117691356172892, + -0.4060874893594546, + -0.06140027948988781 + ], + [ + -0.03165845257462336, + 0.21854812554171174, + -0.9753124931029975 + ], + [ + 0.4094811176553588, + -0.887315990956964, + -0.21212153703897946 + ] + ], + "t": [ + [ + 579.9870562891809 + ], + [ + 276.09388439709664 + ], + [ + 5616.656671116378 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.9374925472123639, + -0.3377263929586908, + 0.08395598501825681 + ], + [ + -0.009787543064644189, + -0.2667415901136707, + -0.9637183863060765 + ], + [ + 0.3478676873784507, + 0.9026570819545717, + -0.25337376437829157 + ] + ], + "t": [ + [ + -72.59483557976097 + ], + [ + 445.63607020105314 + ], + [ + 4402.73689876101 + ] + ] + } + }, + "S11": { + "54138969": { + "R": [ + [ + -0.9059013006181885, + 0.4217144115102914, + 0.038727105014486805 + ], + [ + 0.044493184429779696, + 0.1857199061874203, + -0.9815948619389944 + ], + [ + -0.4211450938543295, + -0.8875049698848251, + -0.1870073216538954 + ] + ], + "t": [ + [ + -234.7208032216618 + ], + [ + 464.34018262882194 + ], + [ + 5536.652631113797 + ] + ] + }, + "55011271": { + "R": [ + [ + 0.9216646531492915, + 0.3879848687925067, + -0.0014172943441045224 + ], + [ + 0.07721054863099915, + -0.18699239961454955, + -0.979322405373477 + ], + [ + -0.3802272982247548, + 0.9024974149959955, + -0.20230080971229314 + ] + ], + "t": [ + [ + -11.934348472090557 + ], + [ + 449.4165893644565 + ], + [ + 5541.113551868937 + ] + ] + }, + "58860488": { + "R": [ + [ + -0.9063540572469627, + -0.42053101768163204, + -0.04093880896680188 + ], + [ + -0.0603212197838846, + 0.22468715090881142, + -0.9725620980997899 + ], + [ + 0.4181909532208387, + -0.8790161246439863, + -0.2290130547809762 + ] + ], + "t": [ + [ + 781.127357651581 + ], + [ + 235.3131620173424 + ], + [ + 5576.37044019807 + ] + ] + }, + "60457274": { + "R": [ + [ + 0.91754082476548, + -0.39226322025776267, + 0.06517975852741943 + ], + [ + -0.04531905395586976, + -0.26600517028098103, + -0.9629057236990188 + ], + [ + 0.395050652748768, + 0.8805514269006645, + -0.2618476013752581 + ] + ], + "t": [ + [ + -155.13650339749012 + ], + [ + 422.16256306729633 + ], + [ + 4435.416222660868 + ] + ] + } + } + } + } \ No newline at end of file diff --git a/hmr4d/dataset/h36m/h36m.py b/hmr4d/dataset/h36m/h36m.py new file mode 100644 index 0000000..5370b0d --- /dev/null +++ b/hmr4d/dataset/h36m/h36m.py @@ -0,0 +1,205 @@ +import torch +import numpy as np +from pathlib import Path +from hmr4d.configs import MainStore, builds + +from hmr4d.utils.pylogger import Log +from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle +from hmr4d.utils import matrix +from hmr4d.utils.smplx_utils import make_smplx +from tqdm import tqdm + +from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points +from hmr4d.utils.geo.hmr_global import get_tgtcoord_rootparam, get_T_w2c_from_wcparams, get_c_rootparam, get_R_c2gv + +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.vis.renderer import Renderer +import imageio +from hmr4d.utils.video_io_utils import read_video_np +from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict + + +class H36mSmplDataset(ImgfeatMotionDatasetBase): + def __init__( + self, + root="inputs/H36M/hmr4d_support", + original_coord="az", + motion_frames=120, # H36M's videos are 25fps and very long + lazy_load=False, + ): + # Path + self.root = Path(root) + + # Coord + self.original_coord = original_coord + + # Setting + self.motion_frames = motion_frames + self.lazy_load = lazy_load + + super().__init__() + + def _load_dataset(self): + # smplpose + tic = Log.time() + fn = self.root / "smplxpose_v1.pt" + self.smpl_model = make_smplx("supermotion") + Log.info(f"[H36M] Loading from {fn} ...") + self.motion_files = torch.load(fn) + # Dict of { + # "smpl_params_glob": {'body_pose', 'global_orient', 'transl', 'betas'}, FxC + # "cam_Rt": tensor(F, 3), + # "cam_K": tensor(1, 10), + # } + self.seqs = list(self.motion_files.keys()) + Log.info(f"[H36M] {len(self.seqs)} sequences. Elapsed: {Log.time() - tic:.2f}s") + + # img(as feature) + # vid -> (features, vid, meta {bbx_xys, K_fullimg}) + if not self.lazy_load: + tic = Log.time() + fn = self.root / "vitfeat_h36m.pt" + Log.info(f"[H36M] Fully Loading to RAM ViT-Feat: {fn}") + self.f_img_dicts = torch.load(fn) + Log.info(f"[H36M] Finished. Elapsed: {Log.time() - tic:.2f}s") + else: + raise NotImplementedError # "Check BEDLAM-SMPL for lazy_load" + + def _get_idx2meta(self): + # We expect to see the entire sequence during one epoch, + # so each sequence will be sampled max(SeqLength // MotionFrames, 1) times + seq_lengths = [] + self.idx2meta = [] + for vid in self.f_img_dicts: + seq_length = self.f_img_dicts[vid]["bbx_xys"].shape[0] + num_samples = max(seq_length // self.motion_frames, 1) + seq_lengths.append(seq_length) + self.idx2meta.extend([vid] * num_samples) + hours = sum(seq_lengths) / 25 / 3600 + Log.info(f"[H36M] has {hours:.1f} hours motion -> Resampled to {len(self.idx2meta)} samples.") + + def _load_data(self, idx): + sampled_motion = {} + vid = self.idx2meta[idx] + motion = self.motion_files[vid] + seq_length = self.f_img_dicts[vid]["bbx_xys"].shape[0] # this is a better choice + sampled_motion["vid"] = vid + + # Random select a subset + target_length = self.motion_frames + if target_length > seq_length: # this should not happen + start = 0 + length = seq_length + Log.info(f"[H36M] ({idx}) target length < sequence length: {target_length} <= {seq_length}") + else: + start = np.random.randint(0, seq_length - target_length) + length = target_length + end = start + length + sampled_motion["length"] = length + sampled_motion["start_end"] = (start, end) + + # Select motion subset + # body_pose, global_orient, transl, betas + sampled_motion["smpl_params_global"] = {k: v[start:end] for k, v in motion["smpl_params_glob"].items()} + + # Image as feature + f_img_dict = self.f_img_dicts[vid] + sampled_motion["f_imgseq"] = f_img_dict["features"][start:end].float() # (L, 1024) + sampled_motion["bbx_xys"] = f_img_dict["bbx_xys"][start:end] + sampled_motion["K_fullimg"] = f_img_dict["K_fullimg"] + # sampled_motion["kp2d"] = self.vitpose[vid][start:end].float() # (L, 17, 3) + sampled_motion["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3) + + # Camera + sampled_motion["T_w2c"] = motion["cam_Rt"] # (4, 4) + + return sampled_motion + + def _process_data(self, data, idx): + length = data["length"] + + # SMPL params in world + smpl_params_w = data["smpl_params_global"].copy() # in az + + # SMPL params in cam + T_w2c = data["T_w2c"] # (4, 4) + offset = self.smpl_model.get_skeleton(smpl_params_w["betas"][0])[0] # (3) + global_orient_c, transl_c = get_c_rootparam( + smpl_params_w["global_orient"], + smpl_params_w["transl"], + T_w2c, + offset, + ) + smpl_params_c = { + "body_pose": smpl_params_w["body_pose"].clone(), # (F, 63) + "betas": smpl_params_w["betas"].clone(), # (F, 10) + "global_orient": global_orient_c, # (F, 3) + "transl": transl_c, # (F, 3) + } + + # World params + gravity_vec = torch.tensor([0, 0, -1]).float() # (3), H36M is az + T_w2c = T_w2c.repeat(length, 1, 1) # (F, 4, 4) + R_c2gv = get_R_c2gv(T_w2c[..., :3, :3], axis_gravity_in_w=gravity_vec) # (F, 3, 3) + + # Image + bbx_xys = data["bbx_xys"] # (F, 3) + K_fullimg = data["K_fullimg"].repeat(length, 1, 1) # (F, 3, 3) + f_imgseq = data["f_imgseq"] # (F, 1024) + cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6) slightly different from WHAM + + # Returns: do not forget to make it batchable! (last lines) + max_len = self.motion_frames + return_data = { + "meta": {"data_name": "h36m", "idx": idx, "vid": data["vid"]}, + "length": length, + "smpl_params_c": smpl_params_c, + "smpl_params_w": smpl_params_w, + "R_c2gv": R_c2gv, # (F, 3, 3) + "gravity_vec": gravity_vec, # (3) + "bbx_xys": bbx_xys, # (F, 3) + "K_fullimg": K_fullimg, # (F, 3, 3) + "f_imgseq": f_imgseq, # (F, D) + "kp2d": data["kp2d"], # (F, 17, 3) + "cam_angvel": cam_angvel, # (F, 6) + "mask": { + "valid": get_valid_mask(max_len, length), + "vitpose": False, + "bbx_xys": True, + "f_imgseq": True, + "spv_incam_only": False, + }, + } + + if False: # Render to image to check + smplx_out = self.smplx(**smpl_params_c) + # ----- Overlay ----- # + mid = return_data["meta"]["mid"] + video_path = self.root / f"videos/{mid}.mp4" + images = read_video_np(video_path, data["start_end"][0], data["start_end"][1]) + render_dict = { + "K": K_fullimg[:1], # only support batch size 1 + "faces": self.smplx.faces, + "verts": smplx_out.vertices, + "background": images, + } + img_overlay = simple_render_mesh_background(render_dict) + save_video(img_overlay, f"tmp.mp4") + + # Batchable + return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len) + return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len) + return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len) + return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len) + return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len) + return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len) + return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len) + return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len) + + return return_data + + +group_name = "train_datasets/imgfeat_h36m" +node_v1 = builds(H36mSmplDataset) +MainStore.store(name="v1", node=node_v1, group=group_name) diff --git a/hmr4d/dataset/h36m/utils.py b/hmr4d/dataset/h36m/utils.py new file mode 100644 index 0000000..8709126 --- /dev/null +++ b/hmr4d/dataset/h36m/utils.py @@ -0,0 +1,82 @@ +import json +import numpy as np +from pathlib import Path +from collections import defaultdict +import pickle +import torch + +RESOURCE_FOLDER = Path(__file__).resolve().parent / "resource" + +camera_idx_to_name = {0: "54138969", 1: "55011271", 2: "58860488", 3: "60457274"} + + +def get_vid(pkl_path, cam_id): + """.../S6/Posing 1.pkl, 54138969 -> S6@Posing_1@54138969""" + sub_id, fn = pkl_path.split("/")[-2:] + vid = f"{sub_id}@{fn.split('.')[0].replace(' ', '_')}@{cam_id}" + return vid + + +def get_raw_pkl_paths(h36m_raw_root): + smpl_param_dir = h36m_raw_root / "neutrSMPL_H3.6" + pkl_paths = [] + for train_sub in ["S1", "S5", "S6", "S7", "S8"]: + for pth in (smpl_param_dir / train_sub).glob("*.pkl"): + if "aligned" not in str(pth): # Use world sequence only + pkl_paths.append(str(pth)) + + return pkl_paths + + +def get_cam_KRts(): + """ + Returns: + Ks (torch.Tensor): {cam_id: 3x3} + Rts (torch.Tensor): {subj_id: {cam_id: 4x4}} + """ + # this file is copied from https://github.com/karfly/human36m-camera-parameters + cameras_path = RESOURCE_FOLDER / "camera-parameters.json" + with open(cameras_path, "r") as f: + cameras = json.load(f) + + # 4 camera ids: '54138969', '55011271', '58860488', '60457274' + Ks = {} + for cam in cameras["intrinsics"]: + Ks[cam] = torch.tensor(cameras["intrinsics"][cam]["calibration_matrix"]).float() + + # extrinsics + extrinsics = cameras["extrinsics"] + Rts = defaultdict(dict) + for subj in extrinsics: + for cam in extrinsics[subj]: + Rt = torch.eye(4) + Rt[:3, :3] = torch.tensor(extrinsics[subj][cam]["R"]) + Rt[:3, [3]] = torch.tensor(extrinsics[subj][cam]["t"]) / 1000 + Rts[subj][cam] = Rt.float() + + return Ks, Rts + + +def parse_raw_pkl(pkl_path, to_50hz=True): + """ + raw_pkl @ 200Hz, where video @ 50Hz. + the frames should be divided by 4, and mannually align with the video. + """ + with open(str(pkl_path), "rb") as f: + data = pickle.load(f, encoding="bytes") + poses = torch.from_numpy(data[b"poses"]).float() + betas = torch.from_numpy(data[b"betas"]).float() + trans = torch.from_numpy(data[b"trans"]).float() + assert poses.shape[0] == trans.shape[0] + if to_50hz: + poses = poses[::4] + trans = trans[::4] + + seq_length = poses.shape[0] # 50FPS + smpl_params = { + "body_pose": poses[:, 3:], + "betas": betas[None].expand(seq_length, -1), + "global_orient": poses[:, :3], + "transl": trans, + } + return smpl_params diff --git a/hmr4d/dataset/imgfeat_motion/base_dataset.py b/hmr4d/dataset/imgfeat_motion/base_dataset.py new file mode 100644 index 0000000..4df66df --- /dev/null +++ b/hmr4d/dataset/imgfeat_motion/base_dataset.py @@ -0,0 +1,32 @@ +import torch +from torch.utils import data +import numpy as np +from pathlib import Path +from hmr4d.utils.pylogger import Log + + +class ImgfeatMotionDatasetBase(data.Dataset): + def __init__(self): + super().__init__() + self._load_dataset() + self._get_idx2meta() # -> Set self.idx2meta + + def __len__(self): + return len(self.idx2meta) + + def _load_dataset(self): + raise NotImplemented + + def _get_idx2meta(self): + raise NotImplemented + + def _load_data(self, idx): + raise NotImplemented + + def _process_data(self, data, idx): + raise NotImplemented + + def __getitem__(self, idx): + data = self._load_data(idx) + data = self._process_data(data, idx) + return data diff --git a/hmr4d/dataset/pure_motion/amass.py b/hmr4d/dataset/pure_motion/amass.py new file mode 100644 index 0000000..5070bcc --- /dev/null +++ b/hmr4d/dataset/pure_motion/amass.py @@ -0,0 +1,119 @@ +import torch +import torch.nn.functional as F +import numpy as np + +from tqdm import tqdm +from pathlib import Path +from hmr4d.utils.pylogger import Log +from hmr4d.configs import MainStore, builds + +from .base_dataset import BaseDataset +from .utils import * +from hmr4d.utils.geo.hmr_global import get_tgtcoord_rootparam +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, convert_motion_as_line_mesh + + +class AmassDataset(BaseDataset): + def __init__( + self, + motion_frames=120, + l_factor=1.5, # speed augmentation + skip_moyo=True, # not contained in the ICCV19 released version + cam_augmentation="v11", + random1024=False, # DEBUG + limit_size=None, + ): + self.root = Path("inputs/AMASS/hmr4d_support") + self.motion_frames = motion_frames + self.l_factor = l_factor + self.random1024 = random1024 + self.skip_moyo = skip_moyo + self.dataset_name = "AMASS" + super().__init__(cam_augmentation, limit_size) + + def _load_dataset(self): + filename = self.root / "smplxpose_v2.pth" + Log.info(f"[{self.dataset_name}] Loading from {filename} ...") + tic = Log.time() + if self.random1024: # Debug, faster loading + try: + Log.info(f"[{self.dataset_name}] Loading 1024 samples for debugging ...") + self.motion_files = torch.load(self.root / "smplxpose_v2_random1024.pth") + except: + Log.info(f"[{self.dataset_name}] Not found! Saving 1024 samples for debugging ...") + self.motion_files = torch.load(filename) + keys = list(self.motion_files.keys()) + keys = np.random.choice(keys, 1024, replace=False) + self.motion_files = {k: self.motion_files[k] for k in keys} + torch.save(self.motion_files, self.root / "smplxpose_v2_random1024.pth") + else: + self.motion_files = torch.load(filename) + self.seqs = list(self.motion_files.keys()) + Log.info(f"[{self.dataset_name}] {len(self.seqs)} sequences. Elapsed: {Log.time() - tic:.2f}s") + + def _get_idx2meta(self): + # We expect to see the entire sequence during one epoch, + # so each sequence will be sampled max(SeqLength // MotionFrames, 1) times + seq_lengths = [] + self.idx2meta = [] + + # Skip too-long idle-prefix + motion_start_id = {} + for vid in self.motion_files: + if self.skip_moyo and "moyo_smplxn" in vid: + continue + seq_length = self.motion_files[vid]["pose"].shape[0] + start_id = motion_start_id[vid] if vid in motion_start_id else 0 + seq_length = seq_length - start_id + if seq_length < 25: # Skip clips that are too short + continue + num_samples = max(seq_length // self.motion_frames, 1) + seq_lengths.append(seq_length) + self.idx2meta.extend([(vid, start_id)] * num_samples) + hours = sum(seq_lengths) / 30 / 3600 + Log.info(f"[{self.dataset_name}] has {hours:.1f} hours motion -> Resampled to {len(self.idx2meta)} samples.") + + def _load_data(self, idx): + """ + - Load original data + - Augmentation: speed-augmentation to L frames + """ + # Load original data + mid, start_id = self.idx2meta[idx] + raw_data = self.motion_files[mid] + raw_len = raw_data["pose"].shape[0] - start_id + data = { + "body_pose": raw_data["pose"][start_id:, 3:], # (F, 63) + "betas": raw_data["beta"].repeat(raw_len, 1), # (10) + "global_orient": raw_data["pose"][start_id:, :3], # (F, 3) + "transl": raw_data["trans"][start_id:], # (F, 3) + } + + # Get {tgt_len} frames from data + # Random select a subset with speed augmentation [start, end) + tgt_len = self.motion_frames + raw_subset_len = np.random.randint(int(tgt_len / self.l_factor), int(tgt_len * self.l_factor)) + if raw_subset_len <= raw_len: + start = np.random.randint(0, raw_len - raw_subset_len + 1) + end = start + raw_subset_len + else: # interpolation will use all possible frames (results in a slow motion) + start = 0 + end = raw_len + data = {k: v[start:end] for k, v in data.items()} + + # Interpolation (vec + r6d) + data_interpolated = interpolate_smpl_params(data, tgt_len) + + # AZ -> AY + data_interpolated["global_orient"], data_interpolated["transl"], _ = get_tgtcoord_rootparam( + data_interpolated["global_orient"], + data_interpolated["transl"], + tsf="az->ay", + ) + + data_interpolated["data_name"] = "amass" + return data_interpolated + + +group_name = "train_datasets/pure_motion_amass" +MainStore.store(name="v11", node=builds(AmassDataset, cam_augmentation="v11"), group=group_name) diff --git a/hmr4d/dataset/pure_motion/base_dataset.py b/hmr4d/dataset/pure_motion/base_dataset.py new file mode 100644 index 0000000..29def08 --- /dev/null +++ b/hmr4d/dataset/pure_motion/base_dataset.py @@ -0,0 +1,182 @@ +import torch +from torch.utils.data import Dataset +from pathlib import Path + +from .utils import * +from .cam_traj_utils import CameraAugmentorV11 +from hmr4d.utils.geo.hmr_cam import create_camera_sensor +from hmr4d.utils.geo.hmr_global import get_c_rootparam, get_R_c2gv +from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict +from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points, project_p2d, cvt_p2d_from_i_to_c + +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, convert_motion_as_line_mesh +from hmr4d.utils.smplx_utils import make_smplx + + +class BaseDataset(Dataset): + def __init__(self, cam_augmentation, limit_size=None): + super().__init__() + self.cam_augmentation = cam_augmentation + self.limit_size = limit_size + self.smplx = make_smplx("supermotion") + self.smplx_lite = make_smplx("supermotion_smpl24") + + self._load_dataset() + self._get_idx2meta() + + def _load_dataset(self): + NotImplementedError("_load_dataset is not implemented") + + def _get_idx2meta(self): + self.idx2meta = None + NotImplementedError("_get_idx2meta is not implemented") + + def __len__(self): + if self.limit_size is not None: + return min(self.limit_size, len(self.idx2meta)) + return len(self.idx2meta) + + def _load_data(self, idx): + NotImplementedError("_load_data is not implemented") + + def _process_data(self, data, idx): + """ + Args: + data: dict { + "body_pose": (F, 63), + "betas": (F, 10), + "global_orient": (F, 3), in the AY coordinates + "transl": (F, 3), in the AY coordinates + } + """ + data_name = data["data_name"] + length = data["body_pose"].shape[0] + # Augmentation: betas, SMPL (gravity-axis) + body_pose = data["body_pose"] + betas = augment_betas(data["betas"], std=0.1) + global_orient_w, transl_w = rotate_around_axis(data["global_orient"], data["transl"], axis="y") + del data + + # SMPL_params in world + smpl_params_w = { + "body_pose": body_pose, # (F, 63) + "betas": betas, # (F, 10) + "global_orient": global_orient_w, # (F, 3) + "transl": transl_w, # (F, 3) + } + + # Camera trajectory augmentation + if self.cam_augmentation == "v11": + # interleave repeat to original length (faster) + N = 10 + w_j3d = self.smplx_lite( + smpl_params_w["body_pose"][::N], + smpl_params_w["betas"][::N], + smpl_params_w["global_orient"][::N], + None, + ) + w_j3d = w_j3d.repeat_interleave(N, dim=0) + smpl_params_w["transl"][:, None] # (F, 24, 3) + + if False: + wis3d = make_wis3d(name="debug_amass") + add_motion_as_lines(w_j3d, wis3d, "w_j3d") + + width, height, K_fullimg = create_camera_sensor(1000, 1000, 43.3) # WHAM + focal_length = K_fullimg[0, 0] + wham_cam_augmentor = CameraAugmentorV11() + T_w2c = wham_cam_augmentor(w_j3d, length) # (F, 4, 4) + + else: + raise NotImplementedError + + if False: # render + for idx_render in range(10): + T_w2c = wham_cam_augmentor(smpl_params_w["transl"]) + + # targets + w_j3d = self.smplx(**smpl_params_w).joints[:, :22] + c_j3d = apply_T_on_points(w_j3d, T_w2c) + verts, faces, vertex_colors = convert_motion_as_line_mesh(c_j3d) + vertex_colors = vertex_colors[None] / 255.0 + bg = np.ones((height, width, 3), dtype=np.uint8) * 255 + + # render + renderer = Renderer(width, height, device="cuda", faces=faces, K=K_fullimg) + vname = f"{idx_render:02d}" + out_fn = Path(f"outputs/dump_render_wham_cam/{vname}.mp4") + out_fn.parent.mkdir(exist_ok=True, parents=True) + writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1) + for i in tqdm(range(len(verts)), desc=f"Rendering {vname}"): + # incam + # img_overlay_pred = renderer.render_mesh(verts[i].cuda(), bg, [0.8, 0.8, 0.8], VI=1) + img_overlay_pred = renderer.render_mesh(verts[i].cuda(), bg, vertex_colors, VI=1) + # if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines + # bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy() + # lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int) + # rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int) + # img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2) + + # write + writer.append_data(img_overlay_pred) + writer.close() + pass + + # SMPL params in cam + offset = self.smplx.get_skeleton(smpl_params_w["betas"][0])[0] # (3) + global_orient_c, transl_c = get_c_rootparam( + smpl_params_w["global_orient"], + smpl_params_w["transl"], + T_w2c, + offset, + ) + smpl_params_c = { + "body_pose": smpl_params_w["body_pose"].clone(), # (F, 63) + "betas": smpl_params_w["betas"].clone(), # (F, 10) + "global_orient": global_orient_c, # (F, 3) + "transl": transl_c, # (F, 3) + } + + # World params + gravity_vec = torch.tensor([0, -1, 0], dtype=torch.float32) # (3), BEDLAM is ay + R_c2gv = get_R_c2gv(T_w2c[:, :3, :3], gravity_vec) # (F, 3, 3) + + # Image + K_fullimg = K_fullimg.repeat(length, 1, 1) # (F, 3, 3) + cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6) + + # Returns: do not forget to make it batchable! (last lines) + # NOTE: bbx_xys and f_imgseq will be added later + max_len = length + return_data = { + "meta": {"data_name": data_name, "idx": idx, "T_w2c": T_w2c}, + "length": length, + "smpl_params_c": smpl_params_c, + "smpl_params_w": smpl_params_w, + "R_c2gv": R_c2gv, # (F, 3, 3) + "gravity_vec": gravity_vec, # (3) + "bbx_xys": torch.zeros((length, 3)), # (F, 3) # NOTE: a placeholder + "K_fullimg": K_fullimg, # (F, 3, 3) + "f_imgseq": torch.zeros((length, 1024)), # (F, D) # NOTE: a placeholder + "kp2d": torch.zeros(length, 17, 3), # (F, 17, 3) + "cam_angvel": cam_angvel, # (F, 6) + "mask": { + "valid": get_valid_mask(length, length), + "vitpose": False, + "bbx_xys": False, + "f_imgseq": False, + "spv_incam_only": False, + }, + } + + # Batchable + return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len) + return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len) + return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len) + return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len) + return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len) + return return_data + + def __getitem__(self, idx): + data = self._load_data(idx) + data = self._process_data(data, idx) + return data diff --git a/hmr4d/dataset/pure_motion/cam_traj_utils.py b/hmr4d/dataset/pure_motion/cam_traj_utils.py new file mode 100644 index 0000000..71b0aff --- /dev/null +++ b/hmr4d/dataset/pure_motion/cam_traj_utils.py @@ -0,0 +1,427 @@ +import torch +import torch.nn.functional as F +import numpy as np +from numpy.random import rand, randn +from pytorch3d.transforms import ( + axis_angle_to_matrix, + matrix_to_axis_angle, + matrix_to_rotation_6d, + rotation_6d_to_matrix, +) +from einops import rearrange +from hmr4d.utils.geo.hmr_cam import create_camera_sensor +from hmr4d.utils.geo_transform import transform_mat, apply_T_on_points +from hmr4d.utils.geo.transforms import axis_rotate_to_matrix +import hmr4d.utils.matrix as matrix + +halfpi = np.pi / 2 +R_y_upsidedown = torch.tensor([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]).float() + + +def noisy_interpolation(x, length, step_noise_perc=0.2): + """Non-linear interpolation with noise, although with noise, the jittery is very small + Args: + x: (2, C) + length: scalar + step_noise_perc: [x0, x1 +-(step_noise_perc * step), x2], where step = x1-x0 + """ + assert x.shape[0] == 2 and len(x.shape) == 2 + dim = x.shape[-1] + output = np.zeros((length, dim)) + + # Use linsapce(0, 1) +- noise as reference + linspace = np.repeat(np.linspace(0, 1, length)[None], dim, axis=0) # (D, L) + noise = (linspace[0, 1] - linspace[0, 0]) * step_noise_perc + space_noise = np.random.uniform(-noise, noise, (dim, length - 2)) # (D, L-2) + linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise + + # Do 1d interp + for i in range(dim): + output[:, i] = np.interp(linspace[i], np.array([0.0, 1.0]), x[:, i]) + return output + + +def noisy_impluse_interpolation(data1, data2, step_noise_perc=0.2): + """Non-linear interpolation of impluse with noise""" + + dim = data1.shape[-1] + L = data1.shape[0] + + linspace1 = np.stack([np.linspace(0, 1, L // 2) for _ in range(dim)]) + linspace2 = np.stack([np.linspace(0, 1, L // 2)[::-1] for _ in range(dim)]) + linspace = np.concatenate([linspace1, linspace2], axis=-1) + noise = (linspace[0, 1] - linspace[0, 0]) * step_noise_perc + space_noise = np.stack([np.random.uniform(-noise, noise, L - 2) for _ in range(dim)]) + + linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise + linspace = linspace.T + output = data1 * (1 - linspace) + data2 * linspace + return output + + +def create_camera(w_root, cfg): + """Create static camera pose + Args: + w_root: (3,), y-up coordinates + Returns: + R_w2c: (3, 3) + t_w2c: (3) + """ + # Parse + pitch_std = cfg["pitch_std"] + pitch_mean = cfg["pitch_mean"] + roll_std = cfg["roll_std"] + tz_range1_prob = cfg["tz_range1_prob"] + tz_range1 = cfg["tz_range1"] + tz_range2 = cfg["tz_range2"] + f = cfg["f"] + w = cfg["w"] + + # algo + yaw = rand() * 2 * np.pi # Look at any direction in xz-plane + pitch = np.clip(randn() * pitch_std + pitch_mean, -halfpi, halfpi) + roll = np.clip(randn() * roll_std, -halfpi, halfpi) # Normal-dist + + # Note we use OpenCV's camera system by first applying R_y_upsidedown + yaw_rm = axis_rotate_to_matrix(yaw, axis="y") + pitch_rm = axis_rotate_to_matrix(pitch, axis="x") + roll_rm = axis_rotate_to_matrix(roll, axis="z") + R_w2c = (roll_rm @ pitch_rm @ yaw_rm @ R_y_upsidedown).squeeze(0) # (3, 3) + + # Place people in the scene + if rand() < tz_range1_prob: + tz = rand() * (tz_range1[1] - tz_range1[0]) + tz_range1[0] + max_dist_in_fov = (w / 2) / f * tz + tx = (rand() * 2 - 1) * 0.7 * max_dist_in_fov + ty = (rand() * 2 - 1) * 0.5 * max_dist_in_fov + + else: + tz = rand() * (tz_range2[1] - tz_range2[0]) + tz_range2[0] + max_dist_in_fov = (w / 2) / f * tz + max_dist_in_fov *= 0.9 # add a threshold + tx = torch.randn(1) * 1.6 + tx = torch.clamp(tx, -max_dist_in_fov, max_dist_in_fov) + ty = torch.randn(1) * 0.8 + ty = torch.clamp(ty, -max_dist_in_fov, max_dist_in_fov) + + dist = torch.tensor([tx, ty, tz], dtype=torch.float) + t_w2c = dist - torch.matmul(R_w2c, w_root) + + return R_w2c, t_w2c + + +def create_rotation_move(R, length, r_xyz_w_std=[np.pi / 8, np.pi / 4, np.pi / 8]): + """Create rotational move for the camera + Args: + R: (3, 3) + Return: + R_move: (L, 3, 3) + """ + # Create final camera pose + assert len(R.size()) == 2 + r_xyz = (2 * rand(3) - 1) * r_xyz_w_std + Rf = R @ axis_angle_to_matrix(torch.from_numpy(r_xyz).float()) + + # Inbetweening two poses + Rs = torch.stack((R, Rf)) # (2, 3, 3) + rs = matrix_to_rotation_6d(Rs).numpy() # (2, 6) + rs_move = noisy_interpolation(rs, length) # (L, 6) + R_move = rotation_6d_to_matrix(torch.from_numpy(rs_move).float()) + + return R_move + + +def create_translation_move(R_w2c, t_w2c, length, t_xyz_w_std=[1.0, 0.25, 1.0]): + """Create translational move for the camera + Args: + R_w2c: (3, 3), + t_w2c: (3,), + """ + # Create subject final displacement + subj_start_final = np.array([[0, 0, 0], randn(3) * t_xyz_w_std]) + subj_move = noisy_interpolation(subj_start_final, length) + subj_move = torch.from_numpy(subj_move).float() # (L, 3) + + # Equal to camera move + t_move = t_w2c + torch.einsum("ij,lj->li", R_w2c, subj_move) + + return t_move + + +class CameraAugmentorV11: + cfg_create_camera = { + "pitch_mean": np.pi / 36, + "pitch_std": np.pi / 8, + "roll_std": np.pi / 24, + "tz_range1_prob": 0.4, + "tz_range1": [1.0, 6.0], # uniform sample + "tz_range2": [4.0, 12.0], + "tx_scale": 0.7, + "ty_scale": 0.3, + } + + # r_xyz_w_std = [np.pi / 8, np.pi / 4, np.pi / 8] # in world coords + r_xyz_w_std = [np.pi / 6, np.pi / 3, np.pi / 6] # in world coords + t_xyz_w_std = [1.0, 0.25, 1.0] # in world coords + r_xyz_w_std_half = [x / 2 for x in r_xyz_w_std] + t_xyz_w_std_half = [x / 2 for x in t_xyz_w_std] + + t_factor = 1.0 + tz_bias_factor = 1.0 + + rotx_impluse_noise = np.pi / 36 + roty_impluse_noise = np.pi / 36 + rotz_impluse_noise = np.pi / 36 + rot_impluse_n = 1 + + tx_step_noise = 0.0025 + ty_step_noise = 0.0025 + tz_step_noise = 0.0025 + + tx_impluse_noise = 0.15 + ty_impluse_noise = 0.15 + tz_impluse_noise = 0.15 + t_impluse_n = 1 + + # === Postprocess === # + height_max = 4.0 + height_min = -2.0 # -1.5 -> -2.0 allow look upside + tz_post_min = 0.5 + + def __init__(self): + self.w = 1000 + self.f = create_camera_sensor(1000, 1000, 24)[2][0, 0] # use 24mm camera + self.half_fov_tol = (self.w / 2) / self.f + + def create_rotation_track(self, cam_mat, root, rx_factor=1.0, ry_factor=1.0, rz_factor=1.0): + """Create rotational move for the camera with rotating human""" + human_mat = matrix.get_TRS(matrix.identity_mat()[None, :3, :3], root) + cam2human_mat = matrix.get_mat_BtoA(human_mat, cam_mat) + R = matrix.get_rotation(cam2human_mat) + + # Create final camera pose + yaw = np.random.normal(scale=ry_factor) + pitch = np.random.normal(scale=rx_factor) + roll = np.random.normal(scale=rz_factor) + + yaw_rm = axis_angle_to_matrix(torch.tensor([0, yaw, 0]).float()) + pitch_rm = axis_angle_to_matrix(torch.tensor([pitch, 0, 0]).float()) + roll_rm = axis_angle_to_matrix(torch.tensor([0, 0, roll]).float()) + Rf = roll_rm @ pitch_rm @ yaw_rm @ R[0] + + # Inbetweening two poses + Rs = torch.stack((R[0], Rf)) + rs = matrix_to_rotation_6d(Rs).numpy() + rs_move = noisy_interpolation(rs, self.l) + R_move = rotation_6d_to_matrix(torch.from_numpy(rs_move).float()) + R_move = torch.inverse(R_move) + return R_move + + def create_translation_track(self, cam_mat, root, t_factor=1.0, tz_bias_factor=0.0): + """Create translational move for the camera with tracking human""" + delta_T0 = matrix.get_position(cam_mat)[0] - root[0] + T_new = matrix.get_position(cam_mat) + + tz_bias = delta_T0.norm(dim=-1) * tz_bias_factor * np.clip(1 + np.random.normal(scale=0.1), 0.67, 1.5) + + T_new[1:] = root[1:] + delta_T0 + cam_mat = matrix.get_TRS(matrix.get_rotation(cam_mat), T_new) + w2c = torch.inverse(cam_mat) + T_new = matrix.get_position(w2c) + + # Create final camera position + tx = np.random.normal(scale=t_factor) + ty = np.random.normal(scale=t_factor) + tz = np.random.normal(scale=t_factor) + tz_bias + Ts = np.array([[0, 0, 0], [tx, ty, tz]]) + + T_move = noisy_interpolation(Ts, self.l) + T_move = torch.from_numpy(T_move).float() + return T_move + T_new + + def add_stepnoise(self, R, T): + w2c = matrix.get_TRS(R, T) + cam_mat = torch.inverse(w2c) + R_new = matrix.get_rotation(cam_mat) + T_new = matrix.get_position(cam_mat) + + L = R_new.shape[0] + window = 10 + + def add_impulse_rot(R_new): + N = np.random.randint(1, self.rot_impluse_n + 1) + rx = np.random.normal(scale=self.rotx_impluse_noise, size=N) + ry = np.random.normal(scale=self.roty_impluse_noise, size=N) + rz = np.random.normal(scale=self.rotz_impluse_noise, size=N) + R_impluse_noise = axis_angle_to_matrix(torch.from_numpy(np.array([rx, ry, rz])).float().transpose(0, 1)) + R_noise = R_new.clone() + last_i = 0 + for i in range(N): + n_i = np.random.randint(last_i + window, L - (N - i) * window * 2) + + # make impluse smooth + window_R = R_noise[n_i - window : n_i + window].clone() + window_r = matrix_to_rotation_6d(window_R).numpy() + impluse_R = R_impluse_noise[i] @ window_R[window] + window_impluse_R = window_R.clone() + window_impluse_R[:] = impluse_R[None] + window_impluse_r = matrix_to_rotation_6d(window_impluse_R).numpy() + + window_new_r = noisy_impluse_interpolation(window_r, window_impluse_r) + window_new_R = rotation_6d_to_matrix(torch.from_numpy(window_new_r).float()) + R_noise[n_i - window : n_i + window] = window_new_R + last_i = n_i + R_new = R_noise + return R_new + + def add_impulse_t(T_new): + N = np.random.randint(1, self.t_impluse_n + 1) + tx = np.random.normal(scale=self.tx_impluse_noise, size=N) + ty = np.random.normal(scale=self.ty_impluse_noise, size=N) + tz = np.random.normal(scale=self.tz_impluse_noise, size=N) + T_impluse_noise = torch.from_numpy(np.array([tx, ty, tz])).float().transpose(0, 1) + T_noise = T_new.clone() + last_i = 0 + for i in range(N): + n_i = np.random.randint(last_i + window, L - N * window * 2) + + # make impluse smooth + window_T = T_noise[n_i - window : n_i + window].clone() + window_impluse_T = window_T.clone() + window_impluse_T += T_impluse_noise[i : i + 1] + window_impluse_T = window_impluse_T.numpy() + window_T = window_T.numpy() + + window_new_T = noisy_impluse_interpolation(window_T, window_impluse_T) + window_new_T = torch.from_numpy(window_new_T).float() + T_noise[n_i - window : n_i + window] = window_new_T + last_i = n_i + T_new = T_noise + return T_new + + impulse_type_prob = { + "t": 0.2, + "r": 0.2, + "both": 0.1, + "pass": 0.5, + } + impulse_type = np.random.choice(list(impulse_type_prob.keys()), p=list(impulse_type_prob.values())) + if impulse_type == "t": + # impluse translation only + T_new = add_impulse_t(T_new) + elif impulse_type == "r": + # impluse rotation only + R_new = add_impulse_rot(R_new) + elif impulse_type == "both": + # impluse rotation and translation + R_new = add_impulse_rot(R_new) + T_new = add_impulse_t(T_new) + else: + assert impulse_type == "pass" + + cam_mat_new = matrix.get_TRS(R_new, T_new) + w2c_new = torch.inverse(cam_mat_new) + R_new = matrix.get_rotation(w2c_new) + T_new = matrix.get_position(w2c_new) + tx = np.random.normal(scale=self.tx_step_noise, size=L) + ty = np.random.normal(scale=self.ty_step_noise, size=L) + tz = np.random.normal(scale=self.tz_step_noise, size=L) + T_new = T_new + torch.from_numpy(np.array([tx, ty, tz])).float().transpose(0, 1) + + return R_new, T_new + + def __call__(self, w_j3d, length=120): + """ + Args: + w_j3d: (L, J, 3) + length: scalar + """ + # Check + self.l = length + assert w_j3d.size(0) == self.l, "currently, only support fixed length" + + # Setup + w_j3d = w_j3d.clone() + w_root = w_j3d[:, 0] # (L, 3) + + # Simulate a static camera pose + cfg_camera0 = {**self.cfg_create_camera, "w": self.w, "f": self.f} + R0_w2c, t0_w2c = create_camera(w_root[0], cfg_camera0) # (3, 3) and (3,) + + # Move camera + camera_type_prob = { + "random": 0.25, + "track": 0.15, + "trackrotate": 0.10, + "trackpush": 0.05, + "trackpull": 0.05, + "static": 0.4, + } + camera_type = np.random.choice(list(camera_type_prob.keys()), p=list(camera_type_prob.values())) + if camera_type == "random": # random move + add noise on cam + R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std) + t_w2c = create_translation_move(R0_w2c, t0_w2c, length, self.t_xyz_w_std) + R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c) + + elif camera_type == "track": # track human + R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half) + cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4) + t_w2c = self.create_translation_track(cam_mat, w_root, 0.5) + R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c) + + elif camera_type == "trackrotate": # track human and rotate + cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4) + t_w2c = self.create_translation_track(cam_mat, w_root, 0.5) + cam_mat = matrix.get_TRS(matrix.get_rotation(cam_mat), t_w2c) + R_w2c = self.create_rotation_track(cam_mat, w_root, np.pi / 16, np.pi, np.pi / 16) + R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c) + + elif camera_type == "trackpush": # track human and push close to human + R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half) + # [1/tz_bias_factor, 1] * dist + cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4) + t_w2c = self.create_translation_track(cam_mat, w_root, 0.5, (1.0 / (1 + self.tz_bias_factor) - 1)) + R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c) + + elif camera_type == "trackpull": # track human and pull far from human + R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half) + # [1, (tz_bias_factor + 1)] * dist + cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4) + t_w2c = self.create_translation_track(cam_mat, w_root, 0.5, self.tz_bias_factor) + R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c) + + else: + assert camera_type == "static" + R_w2c = R0_w2c.repeat(length, 1, 1) # (F, 3, 3) + t_w2c = t0_w2c.repeat(length, 1) # (F, 3) + + # Recompute t_w2c for better camera height + # cam_w = torch.einsum("lji,lj->li", R_w2c, -t_w2c) # (L, 3), camera center in world: cam_w = - R_w2c^t_w2c @ t + # height = cam_w[..., 1] - w_root[:, 1] + # height = torch.clamp(height, self.height_min, self.height_max) + # new_pos = cam_w.clone() + # new_pos[:, 1] = w_root[:, 1] + height + # t_w2c = torch.einsum("lij,lj->li", R_w2c, -new_pos) # (L, 3), new t = -R_w2c @ cam_w + + # Recompute t_w2c for better depth and FoV + c_j3d = torch.einsum("lij,lkj->lki", R_w2c, w_j3d) + t_w2c[:, None] # (L, J, 3) + delta = torch.zeros_like(t_w2c) # (L, 3) this will be later added to t_w2c + # - If the person is too close to the camera, push away the person in the z direction + c_j3d_min = c_j3d[..., 2].min() # scalar + if c_j3d_min < self.tz_post_min: + push_away = self.tz_post_min - c_j3d_min + delta[..., 2] += push_away + c_j3d[..., 2] += push_away + # - If the person is not in the FoV, push away the person in the z direction + c_root = c_j3d[:, 0] # (L, 3) + half_fov = torch.div(c_root[:, :2], c_root[:, 2:]).abs() # (L, 2), [x/z, y/z] + if half_fov.max() > self.half_fov_tol: + max_idx1, max_idx2 = torch.where(torch.max(half_fov) == half_fov) + max_idx1, max_idx2 = max_idx1[0], max_idx2[0] + z_trg = c_root[max_idx1, max_idx2].abs() / self.half_fov_tol # extreme fitted z in the fov + push_away = z_trg - c_root[max_idx1, 2] + delta[..., 2] += push_away + t_w2c += delta + + T_w2c = transform_mat(R_w2c, t_w2c) # (F, 4, 4) + return T_w2c diff --git a/hmr4d/dataset/pure_motion/utils.py b/hmr4d/dataset/pure_motion/utils.py new file mode 100644 index 0000000..9466296 --- /dev/null +++ b/hmr4d/dataset/pure_motion/utils.py @@ -0,0 +1,66 @@ +import torch +import torch.nn.functional as F +from pytorch3d.transforms import ( + axis_angle_to_matrix, + matrix_to_axis_angle, + matrix_to_rotation_6d, + rotation_6d_to_matrix, +) +from einops import rearrange + + +def aa_to_r6d(x): + return matrix_to_rotation_6d(axis_angle_to_matrix(x)) + + +def r6d_to_aa(x): + return matrix_to_axis_angle(rotation_6d_to_matrix(x)) + + +def interpolate_smpl_params(smpl_params, tgt_len): + """ + smpl_params['body_pose'] (L, 63) + tgt_len: L->L' + """ + betas = smpl_params["betas"] + body_pose = smpl_params["body_pose"] + global_orient = smpl_params["global_orient"] # (L, 3) + transl = smpl_params["transl"] # (L, 3) + + # Interpolate + body_pose = rearrange(aa_to_r6d(body_pose.reshape(-1, 21, 3)), "l j c -> c j l") + body_pose = F.interpolate(body_pose, tgt_len, mode="linear", align_corners=True) + body_pose = r6d_to_aa(rearrange(body_pose, "c j l -> l j c")).reshape(-1, 63) + + # although this should be the same as above, we do it for consistency + betas = rearrange(betas, "l c -> c 1 l") + betas = F.interpolate(betas, tgt_len, mode="linear", align_corners=True) + betas = rearrange(betas, "c 1 l -> l c") + + global_orient = rearrange(aa_to_r6d(global_orient.reshape(-1, 1, 3)), "l j c -> c j l") + global_orient = F.interpolate(global_orient, tgt_len, mode="linear", align_corners=True) + global_orient = r6d_to_aa(rearrange(global_orient, "c j l -> l j c")).reshape(-1, 3) + + transl = rearrange(transl, "l c -> c 1 l") + transl = F.interpolate(transl, tgt_len, mode="linear", align_corners=True) + transl = rearrange(transl, "c 1 l -> l c") + + return {"body_pose": body_pose, "betas": betas, "global_orient": global_orient, "transl": transl} + + +def rotate_around_axis(global_orient, transl, axis="y"): + """Global coordinate augmentation. Random rotation around y-axis""" + angle = torch.rand(1) * 2 * torch.pi + if axis == "y": + aa = torch.tensor([0.0, angle, 0.0]).float().unsqueeze(0) + rmat = axis_angle_to_matrix(aa) + + global_orient = matrix_to_axis_angle(rmat @ axis_angle_to_matrix(global_orient)) + transl = (rmat.squeeze(0) @ transl.T).T + return global_orient, transl + + +def augment_betas(betas, std=0.1): + noise = torch.normal(mean=torch.zeros(10), std=torch.ones(10) * std) + betas_aug = betas + noise[None] + return betas_aug diff --git a/hmr4d/dataset/rich/resource/cam2params.pt b/hmr4d/dataset/rich/resource/cam2params.pt new file mode 100644 index 0000000..f919328 Binary files /dev/null and b/hmr4d/dataset/rich/resource/cam2params.pt differ diff --git a/hmr4d/dataset/rich/resource/seqname2imgrange.json b/hmr4d/dataset/rich/resource/seqname2imgrange.json new file mode 100644 index 0000000..2bff25c --- /dev/null +++ b/hmr4d/dataset/rich/resource/seqname2imgrange.json @@ -0,0 +1 @@ +{"ParkingLot1_002_burpee3": [1, 351], "ParkingLot1_002_overfence1": [1, 268], "ParkingLot1_002_overfence2": [1, 270], "ParkingLot1_002_stretching1": [1, 327], "ParkingLot1_002_pushup1": [1, 220], "ParkingLot1_004_pushup2": [1, 347], "ParkingLot1_004_burpeejump1": [1, 296], "ParkingLot1_004_eating1": [1, 522], "ParkingLot1_004_takingphotos1": [1, 593], "ParkingLot1_004_phonetalk1": [1, 724], "ParkingLot1_005_burpeejump2": [1, 270], "ParkingLot1_005_overfence1": [1, 301], "ParkingLot1_005_pushup2": [1, 262], "ParkingLot1_005_pushup3": [1, 243], "ParkingLot1_004_005_greetingchattingeating1": [275, 849], "ParkingLot1_007_overfence2": [1, 263], "ParkingLot1_007_eating1": [1, 426], "ParkingLot1_007_eating2": [1, 498], "ParkingLot2_008_phonetalk1": [171, 1215], "ParkingLot2_008_burpeejump1": [78, 505], "ParkingLot2_008_overfence1": [161, 459], "ParkingLot2_008_pushup1": [165, 459], "ParkingLot2_008_pushup2": [107, 719], "ParkingLot2_008_overfence2": [138, 632], "ParkingLot2_008_overfence3": [100, 661], "ParkingLot2_008_eating1": [180, 1332], "ParkingLot2_014_pushup2": [80, 420], "ParkingLot2_014_burpeejump1": [50, 348], "ParkingLot2_014_burpeejump2": [50, 248], "ParkingLot2_014_phonetalk2": [121, 1141], "ParkingLot2_014_takingphotos2": [91, 906], "ParkingLot2_014_overfence3": [40, 502], "ParkingLot2_015_overfence1": [170, 692], "ParkingLot2_015_burpeejump2": [344, 678], "ParkingLot2_015_pushup1": [190, 817], "ParkingLot2_015_eating2": [31, 835], "ParkingLot2_016_burpeejump2": [100, 793], "ParkingLot2_016_overfence2": [100, 720], "ParkingLot2_016_pushup1": [61, 680], "ParkingLot2_016_pushup2": [100, 570], "ParkingLot2_016_stretching1": [100, 691], "Pavallion_000_yoga2": [1, 1643], "Pavallion_000_plankjack": [1, 900], "Pavallion_000_phonesiteat": [1, 1157], "Pavallion_000_sidebalancerun": [1, 1091], "Pavallion_002_plankjack": [110, 699], "Pavallion_002_phonesiteat": [1, 1030], "Pavallion_003_plankjack": [1, 764], "Pavallion_003_phonesiteat": [75, 838], "Pavallion_003_sidebalancerun": [1, 942], "Pavallion_006_phonesiteat": [130, 841], "Pavallion_006_sidebalancerun": [1, 798], "Pavallion_006_plankjack": [1, 615], "Pavallion_013_phonesiteat": [1, 1254], "Pavallion_013_plankjack": [1, 641], "Pavallion_013_yoga2": [1, 884], "Pavallion_003_018_tossball": [230, 949], "LectureHall_018_wipingchairs1": [1, 1166], "LectureHall_018_wipingspray1": [1, 904], "LectureHall_020_wipingtable1": [1, 897], "BBQ_001_juggle": [0, 297], "BBQ_001_guitar": [0, 381], "ParkingLot1_002_stretching2": [240, 240], "ParkingLot1_002_burpee1": [1, 286], "ParkingLot1_002_burpee2": [1, 203], "ParkingLot1_004_pushup1": [1, 354], "ParkingLot1_004_eating2": [1, 516], "ParkingLot1_004_phonetalk2": [1, 960], "ParkingLot1_004_takingphotos2": [1, 571], "ParkingLot1_004_stretching2": [1, 399], "ParkingLot1_005_overfence2": [1, 298], "ParkingLot1_005_pushup1": [1, 476], "ParkingLot1_005_burpeejump1": [1, 252], "ParkingLot1_007_burpee2": [1, 349], "ParkingLot2_008_eating2": [160, 1100], "ParkingLot2_008_burpeejump2": [129, 492], "ParkingLot2_014_overfence1": [95, 547], "ParkingLot2_014_eating2": [101, 986], "ParkingLot2_016_phonetalk5": [170, 1259], "Pavallion_002_sidebalancerun": [1, 655], "Pavallion_013_sidebalancerun": [1, 810], "Pavallion_018_sidebalancerun": [1, 873], "LectureHall_018_wipingtable1": [1, 1280], "LectureHall_020_wipingchairs1": [1, 1163], "LectureHall_003_wipingchairs1": [1, 724], "Pavallion_000_yoga1": [1, 1757], "Pavallion_002_yoga1": [1, 613], "Pavallion_003_yoga1": [1, 792], "Pavallion_006_yoga1": [1, 930], "Pavallion_018_yoga1": [1, 880], "ParkingLot2_017_burpeejump2": [118, 612], "ParkingLot2_017_burpeejump1": [40, 817], "ParkingLot2_017_overfence1": [110, 661], "ParkingLot2_017_overfence2": [90, 944], "ParkingLot2_017_eating1": [97, 895], "ParkingLot2_017_pushup1": [191, 719], "ParkingLot2_017_pushup2": [74, 811], "ParkingLot2_009_burpeejump1": [200, 1085], "ParkingLot2_009_burpeejump2": [150, 399], "ParkingLot2_009_overfence1": [140, 601], "ParkingLot2_009_overfence2": [150, 559], "LectureHall_009_sidebalancerun1": [1, 673], "LectureHall_010_plankjack1": [1, 532], "LectureHall_010_sidebalancerun1": [1, 919], "LectureHall_021_plankjack1": [1, 507], "LectureHall_021_sidebalancerun1": [1, 855], "LectureHall_019_wipingchairs1": [1, 978], "LectureHall_009_021_reparingprojector1": [1, 499], "ParkingLot2_009_spray1": [145, 1242], "ParkingLot2_009_impro1": [100, 990], "ParkingLot2_009_impro2": [100, 1140], "ParkingLot2_009_impro5": [100, 649], "Gym_010_pushup1": [1, 475], "Gym_010_pushup2": [1, 407], "Gym_011_pushup1": [1, 346], "Gym_011_pushup2": [1, 540], "Gym_011_burpee2": [1, 479], "Gym_012_pushup2": [1, 291], "Gym_010_mountainclimber1": [0, 0], "Gym_010_mountainclimber2": [1, 471], "Gym_013_dips1": [1, 503], "Gym_013_dips2": [1, 333], "Gym_013_dips3": [1, 502], "Gym_013_lunge1": [1, 690], "Gym_013_lunge2": [1, 834], "Gym_013_pushup1": [1, 861], "Gym_013_pushup2": [1, 477], "Gym_013_burpee4": [1, 320], "Gym_010_lunge1": [1, 337], "Gym_010_lunge2": [1, 312], "Gym_010_dips1": [1, 572], "Gym_010_dips2": [1, 603], "Gym_010_cooking1": [1, 779], "Gym_011_cooking1": [1, 1141], "Gym_011_cooking2": [1, 1145], "Gym_011_dips1": [1, 494], "Gym_011_dips4": [1, 495], "Gym_011_dips3": [1, 320], "Gym_011_dips2": [1, 382], "Gym_012_lunge1": [1, 225], "Gym_012_lunge2": [1, 318], "Gym_012_cooking2": [1, 993]} \ No newline at end of file diff --git a/hmr4d/dataset/rich/resource/test.txt b/hmr4d/dataset/rich/resource/test.txt new file mode 100644 index 0000000..69e157e --- /dev/null +++ b/hmr4d/dataset/rich/resource/test.txt @@ -0,0 +1,54 @@ +sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id +ParkingLot2_017_burpeejump2 ParkingLot2 scan_camcoord 017 V female V V X 0,2,3 +ParkingLot2_017_burpeejump1 ParkingLot2 scan_camcoord 017 V female V V X 0,1,5 +ParkingLot2_017_overfence1 ParkingLot2 scan_camcoord 017 V female V V X 0,3,4 +ParkingLot2_017_overfence2 ParkingLot2 scan_camcoord 017 V female V V X 0,1,4 +ParkingLot2_017_eating1 ParkingLot2 scan_camcoord 017 V female V V X 0,2,4 +ParkingLot2_017_pushup1 ParkingLot2 scan_camcoord 017 X female V V X 0,1,4,5 +ParkingLot2_017_pushup2 ParkingLot2 scan_camcoord 017 V female V V X 0,4,5 +ParkingLot2_009_burpeejump1 ParkingLot2 scan_camcoord 009 X female V V X 0,1,2,3 +ParkingLot2_009_burpeejump2 ParkingLot2 scan_camcoord 009 X female V V X 0,2,3,4 +ParkingLot2_009_overfence1 ParkingLot2 scan_camcoord 009 X female V V X 0,3,4,5 +ParkingLot2_009_overfence2 ParkingLot2 scan_camcoord 009 X female V V X 0,1,4,5 +LectureHall_009_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 009 X female V V X 0,1,4,5 +LectureHall_010_plankjack1 LectureHall scan_yoga_scene_camcoord 010 X female V V X 0,2,4,6 +LectureHall_010_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 010 X female V V X 0,1,2,4 +LectureHall_021_plankjack1 LectureHall scan_yoga_scene_camcoord 021 X female V V X 0,3,5,6 +LectureHall_021_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 021 X female V V X 0,4,5,6 +LectureHall_019_wipingchairs1 LectureHall scan_chair_scene_camcoord 019 X female V V X 0,1,2,3 +LectureHall_009_021_reparingprojector1 LectureHall scan_yoga_scene_camcoord 009 X female V X X 0,3,4,5 +LectureHall_009_021_reparingprojector1 LectureHall scan_yoga_scene_camcoord 021 X female V X X 0,3,4,5 +ParkingLot2_009_spray1 ParkingLot2 scan_camcoord 009 X female V X X 0,1,2,3 +ParkingLot2_009_impro1 ParkingLot2 scan_camcoord 009 X female V X X 0,2,3,4 +ParkingLot2_009_impro2 ParkingLot2 scan_camcoord 009 X female V X X 0,3,4,5 +ParkingLot2_009_impro5 ParkingLot2 scan_camcoord 009 X female V X X 0,2,4,5 +Gym_010_pushup1 Gym scan_camcoord 010 X female X V X 3,4,5,6 +Gym_010_pushup2 Gym scan_camcoord 010 X female X V X 2,3,4,5 +Gym_011_pushup1 Gym scan_camcoord 011 X male X V X 2,3,4,5 +Gym_011_pushup2 Gym scan_camcoord 011 X male X V X 2,3,4,5 +Gym_011_burpee2 Gym scan_camcoord 011 X male X V X 2,3,4,5 +Gym_012_pushup2 Gym scan_camcoord 012 X female X V X 3,4,5,6 +Gym_010_mountainclimber1 Gym scan_camcoord 010 X female X V X 3,4,5,6 +Gym_010_mountainclimber2 Gym scan_camcoord 010 X female X V X 3,4,5,6 +Gym_013_dips1 Gym scan_camcoord 013 X female X X V 0,3,4,5 +Gym_013_dips2 Gym scan_camcoord 013 X female X X V 1,2,4,5 +Gym_013_dips3 Gym scan_camcoord 013 X female X X V 1,2,4,5 +Gym_013_lunge1 Gym scan_camcoord 013 X female X X V 1,4,5,6 +Gym_013_lunge2 Gym scan_camcoord 013 X female X X V 0,4,5,6 +Gym_013_pushup1 Gym scan_camcoord 013 X female X V V 0,3,4,5 +Gym_013_pushup2 Gym scan_camcoord 013 X female X V V 1,2,4,5 +Gym_013_burpee4 Gym scan_camcoord 013 X female X V V 0,4,5,6 +Gym_010_lunge1 Gym scan_camcoord 010 X female X X X 1,4,5,6 +Gym_010_lunge2 Gym scan_camcoord 010 X female X X X 0,2,4,5 +Gym_010_dips1 Gym scan_camcoord 010 X female X X X 0,4,5,6 +Gym_010_dips2 Gym scan_camcoord 010 X female X X X 1,2,4,5 +Gym_010_cooking1 Gym scan_table_camcoord 010 X female X X X 1,3,4,5 +Gym_011_cooking1 Gym scan_table_camcoord 011 V male X X X 4,5,6 +Gym_011_cooking2 Gym scan_table_camcoord 011 V male X X X 2,4,5 +Gym_011_dips1 Gym scan_camcoord 011 X male X X X 1,3,4,5 +Gym_011_dips4 Gym scan_camcoord 011 X male X X X 0,2,4,5 +Gym_011_dips3 Gym scan_camcoord 011 X male X X X 0,3,4,5 +Gym_011_dips2 Gym scan_camcoord 011 X male X X X 1,3,4,5 +Gym_012_lunge1 Gym scan_camcoord 012 X female X X X 0,3,4,5 +Gym_012_lunge2 Gym scan_camcoord 012 X female X X X 0,4,5,6 +Gym_012_cooking2 Gym scan_table_camcoord 012 V female X X X 3,4,5 \ No newline at end of file diff --git a/hmr4d/dataset/rich/resource/train.txt b/hmr4d/dataset/rich/resource/train.txt new file mode 100644 index 0000000..875c79e --- /dev/null +++ b/hmr4d/dataset/rich/resource/train.txt @@ -0,0 +1,65 @@ +sequence_name capture_name scan_name id moving_cam gender view_id +ParkingLot1_002_burpee3 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7 +ParkingLot1_002_overfence1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7 +ParkingLot1_002_overfence2 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7 +ParkingLot1_002_stretching1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7 +ParkingLot1_002_pushup1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_pushup2 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_burpeejump1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_eating1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_takingphotos1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_phonetalk1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_005_burpeejump2 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7 +ParkingLot1_005_overfence1 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7 +ParkingLot1_005_pushup2 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7 +ParkingLot1_005_pushup3 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_005_greetingchattingeating1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7 +ParkingLot1_004_005_greetingchattingeating1 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7 +ParkingLot1_007_overfence2 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7 +ParkingLot1_007_eating1 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7 +ParkingLot1_007_eating2 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7 +ParkingLot2_008_phonetalk1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_burpeejump1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_overfence1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_pushup1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_pushup2 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_overfence2 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_overfence3 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_008_eating1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5 +ParkingLot2_014_pushup2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_014_burpeejump1 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_014_burpeejump2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_014_phonetalk2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_014_takingphotos2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_014_overfence3 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5 +ParkingLot2_015_overfence1 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5 +ParkingLot2_015_burpeejump2 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5 +ParkingLot2_015_pushup1 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5 +ParkingLot2_015_eating2 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5 +ParkingLot2_016_burpeejump2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5 +ParkingLot2_016_overfence2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5 +ParkingLot2_016_pushup1 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5 +ParkingLot2_016_pushup2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5 +ParkingLot2_016_stretching1 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5 +Pavallion_000_yoga2 Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6 +Pavallion_000_plankjack Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6 +Pavallion_000_phonesiteat Pavallion scan_camcoord 000 X male 0,1,3,4,6 +Pavallion_000_sidebalancerun Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6 +Pavallion_002_plankjack Pavallion scan_camcoord 002 V male 0,1,2,3,4,5,6 +Pavallion_002_phonesiteat Pavallion scan_camcoord 002 V male 0,1,3,4,6 +Pavallion_003_plankjack Pavallion scan_camcoord 003 V male 0,1,2,3,4,5,6 +Pavallion_003_phonesiteat Pavallion scan_camcoord 003 V male 0,1,3,4,6 +Pavallion_003_sidebalancerun Pavallion scan_camcoord 003 V male 0,1,2,3,4,5,6 +Pavallion_006_phonesiteat Pavallion scan_camcoord 006 V male 0,1,3,4,6 +Pavallion_006_sidebalancerun Pavallion scan_camcoord 006 V male 0,1,2,3,4,5,6 +Pavallion_006_plankjack Pavallion scan_camcoord 006 V male 0,1,2,3,4,5,6 +Pavallion_013_phonesiteat Pavallion scan_camcoord 013 X female 0,1,3,4,6 +Pavallion_013_plankjack Pavallion scan_camcoord 013 X female 0,1,2,3,4,5,6 +Pavallion_013_yoga2 Pavallion scan_camcoord 013 V female 0,1,2,3,4,5,6 +Pavallion_003_018_tossball Pavallion scan_camcoord 003 X male 0,1,2,3,4,5,6 +Pavallion_003_018_tossball Pavallion scan_camcoord 018 X female 0,1,2,3,4,5,6 +LectureHall_018_wipingchairs1 LectureHall scan_chair_scene_camcoord 018 X female 0,1,2,3,4,5,6 +LectureHall_018_wipingspray1 LectureHall scan_chair_scene_camcoord 018 X female 2,3,4 +LectureHall_020_wipingtable1 LectureHall scan_chair_scene_camcoord 020 X male 0,2,4,5,6 +BBQ_001_juggle BBQ scan_camcoord 001 X male 0,1,2,3,4,5,6,7 +BBQ_001_guitar BBQ scan_camcoord 001 X male 0,1,2,3,4,5,6,7 \ No newline at end of file diff --git a/hmr4d/dataset/rich/resource/val.txt b/hmr4d/dataset/rich/resource/val.txt new file mode 100644 index 0000000..714d8ff --- /dev/null +++ b/hmr4d/dataset/rich/resource/val.txt @@ -0,0 +1,29 @@ +sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id +ParkingLot1_002_stretching2 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_002_burpee1 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_002_burpee2 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_004_pushup1 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_004_eating2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_004_phonetalk2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_004_takingphotos2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_004_stretching2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_005_overfence2 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_005_pushup1 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_005_burpeejump1 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot1_007_burpee2 ParkingLot1 scan_camcoord 007 X male V V V 0,1,2,3,4,5,6,7 +ParkingLot2_008_eating2 ParkingLot2 scan_camcoord 008 V male V V V 0,1,2,3,4,5 +ParkingLot2_008_burpeejump2 ParkingLot2 scan_camcoord 008 V male V V V 0,1,2,3,4,5 +ParkingLot2_014_overfence1 ParkingLot2 scan_camcoord 014 X male V V V 0,1,2,3,4,5 +ParkingLot2_014_eating2 ParkingLot2 scan_camcoord 014 X male V V V 0,1,2,3,4,5 +ParkingLot2_016_phonetalk5 ParkingLot2 scan_camcoord 016 V female V V V 0,1,2,3,4,5 +Pavallion_002_sidebalancerun Pavallion scan_camcoord 002 V male V V V 0,1,2,3,4,5,6 +Pavallion_013_sidebalancerun Pavallion scan_camcoord 013 X female V V V 0,1,2,3,4,5,6 +Pavallion_018_sidebalancerun Pavallion scan_camcoord 018 V female V V V 0,1,2,3,4,5,6 +LectureHall_018_wipingtable1 LectureHall scan_chair_scene_camcoord 018 X female V V V 0,2,4,5,6 +LectureHall_020_wipingchairs1 LectureHall scan_chair_scene_camcoord 020 X male V V V 0,1,2,3,4,5,6 +LectureHall_003_wipingchairs1 LectureHall scan_chair_scene_camcoord 003 X male V V V 0,1,2,3,4,5,6 +Pavallion_000_yoga1 Pavallion scan_camcoord 000 X male V X V 0,1,2,3,4,5,6 +Pavallion_002_yoga1 Pavallion scan_camcoord 002 V male V X V 0,1,2,3,4,5,6 +Pavallion_003_yoga1 Pavallion scan_camcoord 003 V male V X V 0,1,2,3,4,5,6 +Pavallion_006_yoga1 Pavallion scan_camcoord 006 V male V X V 0,1,2,3,4,5,6 +Pavallion_018_yoga1 Pavallion scan_camcoord 018 V female V X V 0,1,2,3,4,5,6 \ No newline at end of file diff --git a/hmr4d/dataset/rich/resource/w2az_sahmr.json b/hmr4d/dataset/rich/resource/w2az_sahmr.json new file mode 100755 index 0000000..7dec2ac --- /dev/null +++ b/hmr4d/dataset/rich/resource/w2az_sahmr.json @@ -0,0 +1 @@ +{"BBQ_scan_camcoord": [[0.9989829107564298, 0.03367618890797693, -0.029984301180211045, 0.0008183751635392625], [0.03414262169451401, -0.1305975871406019, 0.9908473906797644, -0.005059823133706893], [0.02945208652127451, -0.9908633531086326, -0.13161455111748036, 1.4054905296083466], [0.0, 0.0, 0.0, 1.0]], "Gym_scan_camcoord": [[0.9932599733260449, -0.07628732032461205, 0.0872632233306122, -0.047601130084306706], [-0.10233962102690007, -0.22374853741942266, 0.9692590953768503, -0.04091804681182174], [-0.05441716049582774, -0.9716567484252654, -0.23004768176013274, 1.537911791136788], [0.0, 0.0, 0.0, 1.0]], "Gym_scan_table_camcoord": [[0.9974451989415423, -0.06250743213795668, 0.03458172980064169, 0.02231858470834599], [-0.04804912583358893, -0.22882402250236075, 0.972281259838159, 0.039081886755815726], [-0.05286167435026744, -0.9714588965331274, -0.2312428501197992, 1.5421821446346522], [0.0, 0.0, 0.0, 1.0]], "LectureHall_scan_chair_scene_camcoord": [[0.9992930513998263, 0.030087515976743376, -0.0225419343977731, 0.001998908749589632], [0.030705594681969043, -0.30721111058653017, 0.9511458878570781, -0.025811963513866963], [0.021692484396004613, -0.9511656401040444, -0.307917783192506, 2.060346184503773], [0.0, 0.0, 0.0, 1.0]], "LectureHall_scan_yoga_scene_camcoord": [[0.9993358324246812, 0.03030060260429296, -0.020242715082476024, -0.003510046042036605], [0.028600729415016745, -0.3079667078507395, 0.9509671419836329, -0.01748548118379142], [0.022580795137075255, -0.9509144968594153, -0.3086287856852993, 2.0424701474796567], [0.0, 0.0, 0.0, 1.0]], "ParkingLot1_scan_camcoord": [[0.9989627324729327, -0.03724260727951709, 0.02620013994738054, 0.0070941466745699025], [-0.03091587075252664, -0.13228243926883107, 0.9907298144280939, -0.0274920377236923], [-0.03343154297742938, -0.9905121627037764, -0.13329661462331338, 1.3859200914120975], [0.0, 0.0, 0.0, 1.0]], "ParkingLot2_scan_camcoord": [[0.9989532636786039, -0.04044665659892979, 0.021364572447267097, 0.01646827411554571], [-0.026687287930043047, -0.13600581518076985, 0.9903485279940424, 0.030197722289598695], [-0.03715058073335097, -0.9898820567153364, -0.13694286452455984, 1.4372015171546513], [0.0, 0.0, 0.0, 1.0]], "Pavallion_scan_camcoord": [[0.9971864096076799, 0.05693557331723671, -0.048760690979605295, 0.0012478238054067193], [0.05746407703876882, -0.16289761936471214, 0.9849681443861059, -0.006002953831755452], [0.04813672552068054, -0.9849988355812122, -0.16571104235928033, 1.7638454838942128], [0.0, 0.0, 0.0, 1.0]]} \ No newline at end of file diff --git a/hmr4d/dataset/rich/rich_motion_test.py b/hmr4d/dataset/rich/rich_motion_test.py new file mode 100644 index 0000000..e21a96b --- /dev/null +++ b/hmr4d/dataset/rich/rich_motion_test.py @@ -0,0 +1,185 @@ +from pathlib import Path +import numpy as np +import torch +from torch.utils import data +from hmr4d.utils.pylogger import Log + +from .rich_utils import ( + get_cam2params, + get_w2az_sahmr, + parse_seqname_info, + get_cam_key_wham_vid, +) +from hmr4d.utils.geo_transform import apply_T_on_points, transform_mat, compute_cam_angvel +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.smplx_utils import make_smplx +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle +from hmr4d.utils.geo.hmr_cam import resize_K + + +from hmr4d.configs import MainStore, builds + + +VID_PRESETS = { + "easytohard": [ + "test/Gym_013_burpee4/cam_06", + "test/Gym_011_pushup1/cam_02", + "test/LectureHall_019_wipingchairs1/cam_03", + "test/ParkingLot2_009_overfence1/cam_04", + "test/LectureHall_021_sidebalancerun1/cam_00", + "test/Gym_010_dips2/cam_05", + ], +} + + +class RichSmplFullSeqDataset(data.Dataset): + def __init__(self, vid_presets=None): + """ + Args: + vid_presets is a key in VID_PRESETS + """ + super().__init__() + self.dataset_name = "RICH" + self.dataset_id = "RICH" + Log.info(f"[{self.dataset_name}] Full sequence, Test") + tic = Log.time() + + # Load evaluation protocol from WHAM labels + self.rich_dir = Path("inputs/RICH/hmr4d_support") + self.labels = torch.load(self.rich_dir / "rich_test_labels.pt") + self.preproc_data = torch.load(self.rich_dir / "rich_test_preproc.pt") + vids = select_subset(self.labels, vid_presets) + + # Setup dataset index + self.idx2meta = [] + for vid in vids: + seq_length = len(self.labels[vid]["frame_id"]) + self.idx2meta.append((vid, 0, seq_length)) # start=0, end=seq_length + # print(sum([end - start for _, _, start, end in self.idx2meta])) + + # Prepare ground truth motion in ay-coordinate + self.w2az = get_w2az_sahmr() # scan_name -> T_w2az, w-coordinate refers to cam-1-coordinate + self.cam2params = get_cam2params() # cam_key -> (T_w2c, K) + seqname_info = parse_seqname_info(skip_multi_persons=True) # {k: (scan_name, subject_id, gender, cam_ids)} + self.seqname_to_scanname = {k: v[0] for k, v in seqname_info.items()} + + Log.info(f"[RICH] {len(self.idx2meta)} sequences. Elapsed: {Log.time() - tic:.2f}s") + + def __len__(self): + return len(self.idx2meta) + + def _load_data(self, idx): + data = {} + + # [start, end), when loading data from labels + vid, start, end = self.idx2meta[idx] + label = self.labels[vid] + preproc_data = self.preproc_data[vid] + + length = end - start + meta = {"dataset_id": "RICH", "vid": vid, "vid-start-end": (start, end)} + data.update({"meta": meta, "length": length}) + + # SMPLX + data.update({"gt_smpl_params": label["gt_smplx_params"], "gender": label["gender"]}) + + # camera + cam_key = get_cam_key_wham_vid(vid) + scan_name = self.seqname_to_scanname[vid.split("/")[1]] + T_w2c, K = self.cam2params[cam_key] # (4, 4) (3, 3) + T_w2az = self.w2az[scan_name] + data.update({"T_w2c": T_w2c, "T_w2az": T_w2az, "K": K}) + + # image features + data.update( + { + "f_imgseq": preproc_data["f_imgseq"], + "bbx_xys": preproc_data["bbx_xys"], + "img_wh": preproc_data["img_wh"], + "kp2d": preproc_data["kp2d"], + } + ) + + # to render a video + video_path = self.rich_dir / "video" / vid / "video.mp4" + frame_id = label["frame_id"] # (F,) + width, height = data["img_wh"] / 4 # Video saved has been downsampled 1/4 + K_render = resize_K(K, 0.25) + bbx_xys_render = data["bbx_xys"] / 4 + data["meta_render"] = { + "name": vid.replace("/", "@"), + "video_path": str(video_path), + "frame_id": frame_id, + "width_height": (width, height), + "K": K_render, + "bbx_xys": bbx_xys_render, + } + + return data + + def _process_data(self, data): + # T_w2az is pre-computed by using floor clue. az2zy uses a rotation along x-axis. + R_az2ay = axis_angle_to_matrix(torch.tensor([1.0, 0.0, 0.0]) * -torch.pi / 2) # (3, 3) + T_w2ay = transform_mat(R_az2ay, R_az2ay.new([0, 0, 0])) @ data["T_w2az"] # (4, 4) + + if False: # Visualize groundtruth and observation + self.rich_smplx = { + "male": make_smplx("rich-smplx", gender="male"), + "female": make_smplx("rich-smplx", gender="female"), + } + wis3d = make_wis3d(name="debug-rich-smpl_dataset") + rich_smplx = make_smplx("rich-smplx", gender=data["gender"]) + smplx_out = rich_smplx(**data["gt_smpl_params"]) + smplx_verts_ay = apply_T_on_points(smplx_out.vertices, T_w2ay) + for i in range(400): + wis3d.set_scene_id(i) + wis3d.add_mesh(smplx_out.vertices[i], rich_smplx.bm.faces, name=f"gt-smplx") + wis3d.add_mesh(smplx_verts_ay[i], rich_smplx.bm.faces, name=f"gt-smplx-ay") + + # process img feature with xys + length = data["length"] + f_imgseq = data["f_imgseq"] # (F, 1024) + R_w2c = data["T_w2c"][:3, :3].repeat(length, 1, 1) # (L, 4, 4) + cam_angvel = compute_cam_angvel(R_w2c) # (L, 6) + + # Return + data = { + # --- not batched + "task": "CAP-Seq", + "meta": data["meta"], + "meta_render": data["meta_render"], + # --- we test on single sequence, so set kv manually + "length": length, + "f_imgseq": f_imgseq, + "cam_angvel": cam_angvel, + "bbx_xys": data["bbx_xys"], # (F, 3) + "K_fullimg": data["K"][None].expand(length, -1, -1), # (F, 3, 3) + "kp2d": data["kp2d"], # (F, 17, 3) + # --- dataset specific + "model": "smplx", + "gender": data["gender"], + "gt_smpl_params": data["gt_smpl_params"], + "T_w2ay": T_w2ay, # (4, 4) + "T_w2c": data["T_w2c"], # (4, 4) + } + return data + + def __getitem__(self, idx): + data = self._load_data(idx) + data = self._process_data(data) + return data + + +def select_subset(labels, vid_presets): + vids = list(labels.keys()) + if vid_presets != None: # Use a subset of the videos + vids = VID_PRESETS[vid_presets] + return vids + + +# +group_name = "test_datasets/rich" +base_node = builds(RichSmplFullSeqDataset, vid_presets=None, populate_full_signature=True) +MainStore.store(name="all", node=base_node, group=group_name) +MainStore.store(name="easy_to_hard", node=base_node(vid_presets="easytohard"), group=group_name) +MainStore.store(name="postproc", node=base_node(vid_presets="postproc"), group=group_name) diff --git a/hmr4d/dataset/rich/rich_utils.py b/hmr4d/dataset/rich/rich_utils.py new file mode 100644 index 0000000..868c8b9 --- /dev/null +++ b/hmr4d/dataset/rich/rich_utils.py @@ -0,0 +1,370 @@ +import torch +import cv2 +import numpy as np +from hmr4d.utils.geo_transform import apply_T_on_points, project_p2d +from pathlib import Path +import json +import time + +# ----- Meta sample utils ----- # + + +def sample_idx2meta(idx2meta, sample_interval): + """ + 1. remove frames that < 45 + 2. sample frames by sample_interval + 3. sorted + """ + idx2meta = [ + v + for k, v in idx2meta.items() + if int(v["frame_name"]) > 45 and (int(v["frame_name"]) + int(v["cam_id"])) % sample_interval == 0 + ] + idx2meta = sorted(idx2meta, key=lambda meta: meta["img_key"]) + return idx2meta + + +def remove_bbx_invisible_frame(idx2meta, img2gtbbx): + raw_img_lu = np.array([0.0, 0.0]) + raw_img_rb_type1 = np.array([4112.0, 3008.0]) - 1 # horizontal + raw_img_rb_type2 = np.array([3008.0, 4112.0]) - 1 # vertical + + idx2meta_new = [] + for meta in idx2meta: + gtbbx_center = np.array([img2gtbbx[meta["img_key"]][[0, 2]].mean(), img2gtbbx[meta["img_key"]][[1, 3]].mean()]) + if (gtbbx_center < raw_img_lu).any(): + continue + raw_img_rb = raw_img_rb_type1 if meta["cam_key"] not in ["Pavallion_3", "Pavallion_5"] else raw_img_rb_type2 + if (gtbbx_center > raw_img_rb).any(): + continue + idx2meta_new.append(meta) + return idx2meta_new + + +def remove_extra_rules(idx2meta): + multi_person_seqs = ["LectureHall_009_021_reparingprojector1"] + idx2meta = [meta for meta in idx2meta if meta["seq_name"] not in multi_person_seqs] + return idx2meta + + +# ----- Image utils ----- # + + +def compute_bbx(dataset, data): + """ + Use gt_smplh_params to compute bbx (w.r.t. original image resolution) + Args: + dataset: rich_pose.RichPose + data: dict + + # This function need extra scripts to run + from hmr4d.utils.smplx_utils import make_smplx + self.smplh_male = make_smplx("rich-smplh", gender="male") + self.smplh_female = make_smplx("rich-smplh", gender="female") + self.smplh = { + "male": self.smplh_male, + "female": self.smplh_female, + } + """ + gender = data["meta"]["gender"] + smplh_params = {k: v.reshape(1, -1) for k, v in data["gt_smplh_params"].items()} + smplh_opt = dataset.smplh[gender](**smplh_params) + verts_3d_w = smplh_opt.vertices + T_w2c, K = data["T_w2c"], data["K"] + verts_3d_c = apply_T_on_points(verts_3d_w, T_w2c[None]) + verts_2d = project_p2d(verts_3d_c, K[None])[0] + min_2d = verts_2d.T.min(-1)[0] + max_2d = verts_2d.T.max(-1)[0] + bbx = torch.stack([min_2d, max_2d]).reshape(-1).numpy() + return bbx + + +def get_2d(dataset, data): + gender = data["meta"]["gender"] + smplh_params = {k: v.reshape(1, -1) for k, v in data["gt_smplh_params"].items()} + smplh_opt = dataset.smplh[gender](**smplh_params) + joints_3d_w = smplh_opt.joints + T_w2c, K = data["T_w2c"], data["K"] + joints_3d_c = apply_T_on_points(joints_3d_w, T_w2c[None]) + joints_2d = project_p2d(joints_3d_c, K[None])[0] + conf = torch.ones((73, 1)) + keypoints = torch.cat([joints_2d, conf], dim=1) + return keypoints + + +def squared_crop_and_resize(dataset, img, bbx_lurb, dst_size=224, state=None): + if state is not None: + np.random.set_state(state) + center_rand = dataset.BBX_CENTER * (np.random.random(2) * 2 - 1) + center_x = (bbx_lurb[0] + bbx_lurb[2]) / 2 + center_rand[0] + center_y = (bbx_lurb[1] + bbx_lurb[3]) / 2 + center_rand[1] + ori_half_size = max(bbx_lurb[2] - bbx_lurb[0], bbx_lurb[3] - bbx_lurb[1]) / 2 + ori_half_size *= 1 + 0.15 + dataset.BBX_ZOOM * np.random.random() # zoom + + src = np.array( + [ + [center_x - ori_half_size, center_y - ori_half_size], + [center_x + ori_half_size, center_y - ori_half_size], + [center_x, center_y], + ], + dtype=np.float32, + ) + dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32) + + A = cv2.getAffineTransform(src, dst) + img_crop = cv2.warpAffine(img, A, (dst_size, dst_size), flags=cv2.INTER_LINEAR) + bbx_new = np.array( + [center_x - ori_half_size, center_y - ori_half_size, center_x + ori_half_size, center_y + ori_half_size], + dtype=bbx_lurb.dtype, + ) + return img_crop, bbx_new, A + + +# Augment bbx +def get_augmented_square_bbx(bbx_lurb, per_shift=0.1, per_zoomout=0.2, base_zoomout=0.15, state=None): + """ + Args: + per_shift: in percent, maximum random shift + per_zoomout: in percent, maximum random zoom + """ + if state is not None: + np.random.set_state(state) + maxsize_bbx = max(bbx_lurb[2] - bbx_lurb[0], bbx_lurb[3] - bbx_lurb[1]) + # shift of center + shift = maxsize_bbx * per_shift * (np.random.random(2) * 2 - 1) + center_x = (bbx_lurb[0] + bbx_lurb[2]) / 2 + shift[0] + center_y = (bbx_lurb[1] + bbx_lurb[3]) / 2 + shift[1] + # zoomout of half-size + halfsize_bbx = maxsize_bbx / 2 + halfsize_bbx *= 1 + base_zoomout + per_zoomout * np.random.random() + + bbx_lurb = np.array( + [ + center_x - halfsize_bbx, + center_y - halfsize_bbx, + center_x + halfsize_bbx, + center_y + halfsize_bbx, + ] + ) + return bbx_lurb + + +def get_squared_bbx_region_and_resize(frames, bbx_xys, dst_size=224): + """ + Args: + frames: (F, H, W, 3) + bbx_xys: (F, 3), xys + """ + frames_np = frames.numpy() if isinstance(frames, torch.Tensor) else frames + bbx_xys = bbx_xys if isinstance(bbx_xys, torch.Tensor) else torch.tensor(bbx_xys) # use tensor + srcs = torch.stack( + [ + torch.stack([bbx_xys[:, 0] - bbx_xys[:, 2] / 2, bbx_xys[:, 1] - bbx_xys[:, 2] / 2], dim=-1), + torch.stack([bbx_xys[:, 0] + bbx_xys[:, 2] / 2, bbx_xys[:, 1] - bbx_xys[:, 2] / 2], dim=-1), + bbx_xys[:, :2], + ], + dim=1, + ) # (F, 3, 2) + dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32) + As = np.stack([cv2.getAffineTransform(src, dst) for src in srcs.numpy()]) + + img_crops = np.stack( + [cv2.warpAffine(frames_np[i], As[i], (dst_size, dst_size), flags=cv2.INTER_LINEAR) for i in range(len(As))] + ) + img_crops = torch.from_numpy(img_crops) + As = torch.from_numpy(As) + return img_crops, As + + +# ----- Camera utils ----- # + + +def extract_cam_xml(xml_path="", dtype=torch.float32): + import xml.etree.ElementTree as ET + + tree = ET.parse(xml_path) + + extrinsics_mat = [float(s) for s in tree.find("./CameraMatrix/data").text.split()] + intrinsics_mat = [float(s) for s in tree.find("./Intrinsics/data").text.split()] + distortion_vec = [float(s) for s in tree.find("./Distortion/data").text.split()] + + return { + "ext_mat": torch.tensor(extrinsics_mat).float(), + "int_mat": torch.tensor(intrinsics_mat).float(), + "dis_vec": torch.tensor(distortion_vec).float(), + } + + +def get_cam2params(scene_info_root=None): + """ + Args: + scene_info_root: this could be repalced by path to scan_calibration + """ + if scene_info_root is not None: + cam_params = {} + cam_xml_files = Path(scene_info_root).glob("*/calibration/*.xml") + for cam_xml_file in cam_xml_files: + cam_param = extract_cam_xml(cam_xml_file) + T_w2c = cam_param["ext_mat"].reshape(3, 4) + T_w2c = torch.cat([T_w2c, torch.tensor([[0, 0, 0, 1.0]])], dim=0) # (4, 4) + K = cam_param["int_mat"].reshape(3, 3) + cap_name = cam_xml_file.parts[-3] + cam_id = int(cam_xml_file.stem) + cam_key = f"{cap_name}_{cam_id}" + cam_params[cam_key] = (T_w2c, K) + else: + cam_params = torch.load(Path(__file__).parent / "resource/cam2params.pt") + return cam_params + + +# ----- Parse Raw Resource ----- # + + +def get_w2az_sahmr(): + """ + Returns: + w2az_sahmr: dict, {scan_name: Tw2az}, Tw2az is a tensor of (4,4) + """ + fn = Path(__file__).parent / "resource/w2az_sahmr.json" + with open(fn, "r") as f: + kvs = json.load(f).items() + w2az_sahmr = {k: torch.tensor(v) for k, v in kvs} + return w2az_sahmr + + +def has_multi_persons(seq_name): + """ + Args: + seq_name: e.g. LectureHall_009_021_reparingprojector1 + """ + return len(seq_name.split("_")) != 3 + + +def parse_seqname_info(skip_multi_persons=True): + """ + This function will skip multi-person sequences. + Returns: + sname_to_info: scan_name, subject_id, gender, cam_ids + """ + fns = [Path(__file__).parent / f"resource/{split}.txt" for split in ["train", "val", "test"]] + # Train / Val&Test Header: + # sequence_name capture_name scan_name id moving_cam gender view_id + # sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id + sname_to_info = {} + for fn in fns: + with open(fn, "r") as f: + for line in f.readlines()[1:]: + raw_values = line.strip().split() + seq_name = raw_values[0] + if skip_multi_persons and has_multi_persons(seq_name): + continue + scan_name = f"{raw_values[1]}_{raw_values[2]}" + subject_id = int(raw_values[3]) + gender = raw_values[5] + cam_ids = [int(c) for c in raw_values[-1].split(",")] + sname_to_info[seq_name] = (scan_name, subject_id, gender, cam_ids) + return sname_to_info + + +def get_seqnames_of_split(splits=["train"], skip_multi_persons=True): + if not isinstance(splits, list): + splits = [splits] + fns = [Path(__file__).parent / f"resource/{split}.txt" for split in splits] + seqnames = [] + for fn in fns: + with open(fn, "r") as f: + for line in f.readlines()[1:]: + seq_name = line.strip().split()[0] + if skip_multi_persons and has_multi_persons(seq_name): + continue + seqnames.append(seq_name) + return seqnames + + +def get_seqname_to_imgrange(): + """Each sequence has a different range of image ids.""" + from tqdm import tqdm + + split_seqnames = {split: get_seqnames_of_split(split) for split in ["train", "val", "test"]} + seqname_to_imgrange = {} + for split in ["train", "val", "test"]: + for seqname in tqdm(split_seqnames[split]): + img_root = Path("inputs/RICH") / "images_ds4" / split # compressed (not original) + img_dir = img_root / seqname + img_names = sorted([n.name for n in img_dir.glob("**/*.jpeg")]) + if len(img_names) == 0: + img_range = (0, 0) + else: + img_range = (int(img_names[0].split("_")[0]), int(img_names[-1].split("_")[0])) + seqname_to_imgrange[seqname] = img_range + return seqname_to_imgrange + + +# ----- Compose keys ----- # + + +def get_img_key(seq_name, cam_id, f_id): + assert len(seq_name.split("_")) == 3 + subject_id = int(seq_name.split("_")[1]) + return f"{seq_name}_{int(cam_id)}_{int(f_id):05d}_{subject_id}" + + +def get_seq_cam_fn(img_root, seq_name, cam_id): + """ + Args: + img_root: "inputs/RICH/images_ds4/train" + """ + img_root = Path(img_root) + cam_id = int(cam_id) + return str(img_root / f"{seq_name}/cam_{cam_id:02d}") + + +def get_img_fn(img_root, seq_name, cam_id, f_id): + """ + Args: + img_root: "inputs/RICH/images_ds4/train" + """ + img_root = Path(img_root) + cam_id = int(cam_id) + f_id = int(f_id) + return str(img_root / f"{seq_name}/cam_{cam_id:02d}" / f"{f_id:05d}_{cam_id:02d}.jpeg") + + +# ----- WHAM ----- # + + +def get_cam_key_wham_vid(vid): + _, sname, cname = vid.split("/") + scene = sname.split("_")[0] + cid = int(cname.split("_")[1]) + cam_key = f"{scene}_{cid}" + return cam_key + + +def get_K_wham_vid(vid): + cam_key = get_cam_key_wham_vid(vid) + cam2params = get_cam2params() + K = cam2params[cam_key][1] + return K + + +class RichVid2Tc2az: + def __init__(self) -> None: + self.w2az = get_w2az_sahmr() # scan_name: tensor 4,4 + seqname_info = parse_seqname_info(skip_multi_persons=True) # {k: (scan_name, subject_id, gender, cam_ids)} + self.seqname_to_scanname = {k: v[0] for k, v in seqname_info.items()} + self.cam2params = get_cam2params() # cam_key -> (T_w2c, K) + + def __call__(self, vid): + cam_key = get_cam_key_wham_vid(vid) + scan_name = self.seqname_to_scanname[vid.split("/")[1]] + T_w2c, K = self.cam2params[cam_key] # (4, 4) (3, 3) + T_w2az = self.w2az[scan_name] + T_c2az = T_w2az @ T_w2c.inverse() + return T_c2az + + def get_T_w2az(self, vid): + cam_key = get_cam_key_wham_vid(vid) + scan_name = self.seqname_to_scanname[vid.split("/")[1]] + T_w2az = self.w2az[scan_name] + return T_w2az diff --git a/hmr4d/dataset/threedpw/threedpw_motion_test.py b/hmr4d/dataset/threedpw/threedpw_motion_test.py new file mode 100644 index 0000000..469648a --- /dev/null +++ b/hmr4d/dataset/threedpw/threedpw_motion_test.py @@ -0,0 +1,153 @@ +import torch +from torch.utils import data +from pathlib import Path + +from hmr4d.utils.pylogger import Log +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.geo_transform import compute_cam_angvel +from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K +from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17 + +from hmr4d.configs import MainStore, builds + +VID_HARD = [] +# VID_HARD = ["downtown_bar_00_1"] + + +class ThreedpwSmplFullSeqDataset(data.Dataset): + def __init__(self, flip_test=False, skip_invalid=False): + super().__init__() + self.dataset_name = "3DPW" + self.skip_invalid = skip_invalid + Log.info(f"[{self.dataset_name}] Full sequence") + + # Load evaluation protocol from WHAM labels + self.threedpw_dir = Path("inputs/3DPW/hmr4d_support") + # ['vname', 'K_fullimg', 'T_w2c', 'smpl_params', 'gender', 'mask_raw', 'mask_wham', 'img_wh'] + self.labels = torch.load(self.threedpw_dir / "test_3dpw_gt_labels.pt") + self.vid2bbx = torch.load(self.threedpw_dir / "preproc_test_bbx.pt") + self.vid2kp2d = torch.load(self.threedpw_dir / "preproc_test_kp2d_v0.pt") + + # Setup dataset index + self.idx2meta = list(self.labels) + if len(VID_HARD) > 0: # Pick subsets for fast testing + self.idx2meta = VID_HARD + Log.info(f"[{self.dataset_name}] {len(self.idx2meta)} sequences.") + + # If flip_test is enabled, we will return extra data for flipped test + self.flip_test = flip_test + if self.flip_test: + Log.info(f"[{self.dataset_name}] Flip test enabled") + + def __len__(self): + return len(self.idx2meta) + + def _load_data(self, idx): + data = {} + vid = self.idx2meta[idx] + meta = {"dataset_id": self.dataset_name, "vid": vid} + data.update({"meta": meta}) + + # Add useful data + label = self.labels[vid] + mask = label["mask_wham"] + width_height = label["img_wh"] + data.update( + { + "length": len(mask), # F + "smpl_params": label["smpl_params"], # world + "gender": label["gender"], # str + "T_w2c": label["T_w2c"], # (F, 4, 4) + "mask": mask, # (F) + } + ) + K_fullimg = label["K_fullimg"] # (3, 3) + if False: + K_fullimg = estimate_K(*width_height) + data["K_fullimg"] = K_fullimg + + # Preprocessed: bbx, kp2d, image as feature + bbx_xys = self.vid2bbx[vid]["bbx_xys"] # (F, 3) + kp2d = self.vid2kp2d[vid] # (F, 17, 3) + cam_angvel = compute_cam_angvel(data["T_w2c"][:, :3, :3]) # (L, 6) + data.update({"bbx_xys": bbx_xys, "kp2d": kp2d, "cam_angvel": cam_angvel}) + + imgfeat_dir = self.threedpw_dir / "imgfeats/3dpw_test" + f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt") + f_imgseq = f_img_dict["features"].float() + data["f_imgseq"] = f_imgseq # (F, 1024) + + # to render a video + vname = label["vname"] + video_path = self.threedpw_dir / f"videos/{vname}.mp4" + frame_id = torch.where(mask)[0].long() + ds = 0.5 + K_render = resize_K(K_fullimg, ds) + bbx_xys_render = bbx_xys * ds + kp2d_render = kp2d.clone() + kp2d_render[..., :2] *= ds + data["meta_render"] = { + "name": vid, + "video_path": str(video_path), + "ds": ds, + "frame_id": frame_id, + "K": K_render, + "bbx_xys": bbx_xys_render, + "kp2d": kp2d_render, + } + + if self.flip_test: + imgfeat_dir = self.threedpw_dir / "imgfeats/3dpw_test_flip" + f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt") + flipped_bbx_xys = f_img_dict["bbx_xys"].float() # (L, 3) + flipped_features = f_img_dict["features"].float() # (L, 1024) + flipped_kp2d = flip_kp2d_coco17(kp2d, width_height[0]) # (L, 17, 3) + + R_flip_x = torch.tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]).float() + flipped_R_w2c = R_flip_x @ data["T_w2c"][:, :3, :3].clone() + + data_flip = { + "bbx_xys": flipped_bbx_xys, + "f_imgseq": flipped_features, + "kp2d": flipped_kp2d, + "cam_angvel": compute_cam_angvel(flipped_R_w2c), + } + data["flip_test"] = data_flip + return data + + def _process_data(self, data): + length = data["length"] + data["K_fullimg"] = data["K_fullimg"][None].repeat(length, 1, 1) + + if self.skip_invalid: # Drop all invalid frames + mask = data["mask"].clone() + data["length"] = sum(mask) + data["smpl_params"] = {k: v[mask].clone() for k, v in data["smpl_params"].items()} + data["T_w2c"] = data["T_w2c"][mask].clone() + data["mask"] = data["mask"][mask].clone() + data["K_fullimg"] = data["K_fullimg"][mask].clone() + data["bbx_xys"] = data["bbx_xys"][mask].clone() + data["kp2d"] = data["kp2d"][mask].clone() + data["cam_angvel"] = data["cam_angvel"][mask].clone() + data["f_imgseq"] = data["f_imgseq"][mask].clone() + data["flip_test"] = {k: v[mask].clone() for k, v in data["flip_test"].items()} + + return data + + def __getitem__(self, idx): + data = self._load_data(idx) + data = self._process_data(data) + return data + + +# 3DPW +MainStore.store( + name="fliptest", + node=builds(ThreedpwSmplFullSeqDataset, flip_test=True), + group="test_datasets/3dpw", +) +MainStore.store( + name="v1", + node=builds(ThreedpwSmplFullSeqDataset, flip_test=False), + group="test_datasets/3dpw", +) diff --git a/hmr4d/dataset/threedpw/threedpw_motion_train.py b/hmr4d/dataset/threedpw/threedpw_motion_train.py new file mode 100644 index 0000000..2c803fd --- /dev/null +++ b/hmr4d/dataset/threedpw/threedpw_motion_train.py @@ -0,0 +1,164 @@ +import torch +from torch.utils import data +from pathlib import Path +import numpy as np + +from hmr4d.utils.pylogger import Log +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.geo_transform import compute_cam_angvel +from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K +from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17 +from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase +from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict +from hmr4d.utils.smplx_utils import make_smplx +from hmr4d.utils.video_io_utils import get_video_lwh, read_video_np, save_video +from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background + +from hmr4d.configs import MainStore, builds + + +class ThreedpwSmplDataset(ImgfeatMotionDatasetBase): + def __init__(self): + # Path + self.hmr4d_support_dir = Path("inputs/3DPW/hmr4d_support") + self.dataset_name = "3DPW" + + # Setting + self.min_motion_frames = 60 + self.max_motion_frames = 120 + super().__init__() + + def _load_dataset(self): + self.train_labels = torch.load(self.hmr4d_support_dir / "train_3dpw_gt_labels.pt") + self.refit_smplx = torch.load(self.hmr4d_support_dir / "train_refit_smplx.pt") + if True: # Remove clips that have obvious error + update_list = { + "courtyard_basketball_00_1": [(0, 300), (340, 468)], + "courtyard_laceShoe_00_0": [(0, 620), (780, 931)], + "courtyard_rangeOfMotions_00_1": [(0, 370), (410, 601)], + "courtyard_shakeHands_00_1": [(0, 100), (120, 391)], + } + for k, v in update_list.items(): + self.refit_smplx[k]["valid_range_list"] = v + + self.f_img_folder = self.hmr4d_support_dir / "imgfeats/3dpw_train_smplx_refit" + Log.info(f"[{self.dataset_name}] Train") + + def _get_idx2meta(self): + # We expect to see the entire sequence during one epoch, + # so each sequence will be sampled max(SeqLength // MotionFrames, 1) times + seq_lengths = [] + self.idx2meta = [] + for vid in self.refit_smplx: + valid_range_list = self.refit_smplx[vid]["valid_range_list"] + for start, end in valid_range_list: + seq_length = end - start + num_samples = max(seq_length // self.max_motion_frames, 1) + seq_lengths.append(seq_length) + self.idx2meta.extend([(vid, start, end)] * num_samples) + minutes = sum(seq_lengths) / 25 / 60 + Log.info( + f"[{self.dataset_name}] has {minutes:.1f} minutes motion -> Resampled to {len(self.idx2meta)} samples." + ) + + def _load_data(self, idx): + data = {} + vid, range1, range2 = self.idx2meta[idx] + + # Random select a subset + mlength = range2 - range1 + min_motion_len = self.min_motion_frames + max_motion_len = self.max_motion_frames + + if mlength < min_motion_len: # this may happen, the minimal mlength is around 30 + start = range1 + length = mlength + else: + effect_max_motion_len = min(max_motion_len, mlength) + length = np.random.randint(min_motion_len, effect_max_motion_len + 1) # [low, high) + start = np.random.randint(range1, range2 - length + 1) + end = start + length + data["length"] = length + data["meta"] = {"data_name": self.dataset_name, "idx": idx, "vid": vid, "start_end": (start, end)} + + # Select motion subset + data["smplx_params_incam"] = {k: v[start:end] for k, v in self.refit_smplx[vid]["smplx_params_incam"].items()} + data["K_fullimg"] = self.train_labels[vid]["K_fullimg"] + data["T_w2c"] = self.train_labels[vid]["T_w2c"][start:end] + + # Img (as feature): + f_img_dict = torch.load(self.f_img_folder / f"{vid}.pt") + data["bbx_xys"] = f_img_dict["bbx_xys"][start:end] # (F, 3) + data["f_imgseq"] = f_img_dict["features"][start:end].float() # (F, 3) + data["img_wh"] = f_img_dict["img_wh"] # (2) + data["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3) # do not provide kp2d + + return data + + def _process_data(self, data, idx): + length = data["length"] + + smpl_params_c = data["smplx_params_incam"] + smpl_params_w_zero = {k: torch.zeros_like(v) for k, v in smpl_params_c.items()} + K_fullimg = data["K_fullimg"][None].repeat(length, 1, 1) + cam_angvel = compute_cam_angvel(data["T_w2c"][:, :3, :3]) + + max_len = self.max_motion_frames + return_data = { + "meta": data["meta"], + "length": length, + "smpl_params_c": smpl_params_c, + "smpl_params_w": smpl_params_w_zero, + "R_c2gv": torch.zeros(length, 3, 3), # (F, 3, 3) + "gravity_vec": torch.zeros(3), # (3) + "bbx_xys": data["bbx_xys"], # (F, 3) + "K_fullimg": K_fullimg, # (F, 3, 3) + "f_imgseq": data["f_imgseq"], # (F, D) + "kp2d": data["kp2d"], # (F, 17, 3) + "cam_angvel": cam_angvel, # (F, 6) + "mask": { + "valid": get_valid_mask(max_len, length), + "vitpose": False, + "bbx_xys": True, + "f_imgseq": True, + "spv_incam_only": True, + }, + } + + if False: # Debug, render incam + start, end = data["meta"]["start_end"] + vid = data["meta"]["vid"] + + ds = 0.5 + faces = smplx.faces + smplx = make_smplx("supermotion") + smplx_c_verts = smplx(**return_data["smpl_params_c"]).vertices + K_render = resize_K(K_fullimg, ds) + + video_path = self.hmr4d_support_dir / f"videos/{vid[:-2]}.mp4" + images = read_video_np(video_path, scale=ds, start_frame=start, end_frame=end) + + render_dict = { + "K": K_render[:1], # only support batch size 1 + "faces": faces, + "verts": smplx_c_verts, + "background": images, + } + img_overlay = simple_render_mesh_background(render_dict, VI=10) + save_video(img_overlay, f"tmp.mp4", crf=28) + + # Batchable + return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len) + return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len) + return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len) + return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len) + return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len) + return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len) + return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len) + return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len) + + return return_data + + +# 3DPW +MainStore.store(name="v1", node=builds(ThreedpwSmplDataset), group="train_datasets/imgfeat_3dpw") diff --git a/hmr4d/dataset/threedpw/utils.py b/hmr4d/dataset/threedpw/utils.py new file mode 100644 index 0000000..ca0ac36 --- /dev/null +++ b/hmr4d/dataset/threedpw/utils.py @@ -0,0 +1,81 @@ +import json +import numpy as np +from pathlib import Path +from collections import defaultdict +import pickle +import torch +import joblib + +RESOURCE_FOLDER = Path(__file__).resolve().parent / "resource" + + +def read_raw_pkl(pkl_path): + with open(pkl_path, "rb") as f: + data = pickle.load(f, encoding="bytes") + + num_subjects = len(data[b"poses"]) + F = data[b"poses"][0].shape[0] + smpl_params = [] + for i in range(num_subjects): + smpl_params.append( + { + "body_pose": torch.from_numpy(data[b"poses"][i][:, 3:72]).float(), # (F, 69) + "betas": torch.from_numpy(data[b"betas"][i][:10]).repeat(F, 1).float(), # (F, 10) + "global_orient": torch.from_numpy(data[b"poses"][i][:, :3]).float(), # (F, 3) + "transl": torch.from_numpy(data[b"trans"][i]).float(), # (F, 3) + } + ) + genders = ["male" if g == "m" else "female" for g in data[b"genders"]] + campose_valid = [torch.from_numpy(v).bool() for v in data[b"campose_valid"]] + + seq_name = data[b"sequence"] + K_fullimg = torch.from_numpy(data[b"cam_intrinsics"]).float() + T_w2c = torch.from_numpy(data[b"cam_poses"]).float() + + return_data = { + "sequence": seq_name, # 'courtyard_bodyScannerMotions_00' + "K_fullimg": K_fullimg, # (3, 3), not 55FoV + "T_w2c": T_w2c, # (F, 4, 4) + "smpl_params": smpl_params, # list of dict + "genders": genders, # list of str + "campose_valid": campose_valid, # list of bool-array + # "jointPositions": data[b'jointPositions'], # SMPL, 24x3 + # "poses2d": data[b"poses2d"], # COCO, 3x18(?) + } + return return_data + + +def load_and_convert_wham_pth(pth): + """ + Convert to {vid: DataDict} style, Add smpl_params_incam + """ + # load + wham_labels_raw = joblib.load(pth) + # convert it to {vid: DataDict} style + wham_labels = {} + for i, vid in enumerate(wham_labels_raw["vid"]): + wham_labels[vid] = {k: wham_labels_raw[k][i] for k in wham_labels_raw} + + # convert pose and betas as smpl_params_incam (without transl) + for vid in wham_labels: + pose = wham_labels[vid]["pose"] + global_orient = pose[:, :3] # (F, 3) + body_pose = pose[:, 3:] # (F, 69) + betas = wham_labels[vid]["betas"] # (F, 10), all frames are the same + wham_labels[vid]["smpl_params_incam"] = { + "body_pose": body_pose.float(), # (F, 69) + "betas": betas.float(), # (F, 10) + "global_orient": global_orient.float(), # (F, 3) + } + + return wham_labels + + +# Neural-Annot utils + + +def na_cam_param_to_K_fullimg(cam_param): + K = torch.eye(3) + K[[0, 1], [0, 1]] = torch.tensor(cam_param["focal"]) + K[[0, 1], [2, 2]] = torch.tensor(cam_param["princpt"]) + return K diff --git a/hmr4d/model/common_utils/optimizer.py b/hmr4d/model/common_utils/optimizer.py new file mode 100644 index 0000000..b03513f --- /dev/null +++ b/hmr4d/model/common_utils/optimizer.py @@ -0,0 +1,17 @@ +from torch.optim import AdamW, Adam +from hmr4d.configs import MainStore, builds + + +optimizer_cfgs = { + "adam_1e-3": builds(Adam, lr=1e-3, zen_partial=True), + "adam_2e-4": builds(Adam, lr=2e-4, zen_partial=True), + "adamw_2e-4": builds(AdamW, lr=2e-4, zen_partial=True), + "adamw_1e-4": builds(AdamW, lr=1e-4, zen_partial=True), + "adamw_5e-5": builds(AdamW, lr=5e-5, zen_partial=True), + "adamw_1e-5": builds(AdamW, lr=1e-5, zen_partial=True), + # zero-shot text-to-image generation + "adamw_1e-3_dalle": builds(AdamW, lr=1e-3, weight_decay=1e-4, zen_partial=True), +} + +for name, cfg in optimizer_cfgs.items(): + MainStore.store(name=name, node=cfg, group=f"optimizer") diff --git a/hmr4d/model/common_utils/scheduler.py b/hmr4d/model/common_utils/scheduler.py new file mode 100644 index 0000000..8d8bc54 --- /dev/null +++ b/hmr4d/model/common_utils/scheduler.py @@ -0,0 +1,29 @@ +import torch +from bisect import bisect_right + + +class WarmupMultiStepLR(torch.optim.lr_scheduler.LRScheduler): + def __init__(self, optimizer, milestones, warmup=0, gamma=0.1, last_epoch=-1, verbose="deprecated"): + """Assume optimizer does not change lr; Scheduler is called epoch-based""" + self.milestones = milestones + self.warmup = warmup + assert warmup < milestones[0] + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + base_lrs = self.base_lrs # base lr for each groups + n_groups = len(base_lrs) + comming_epoch = self.last_epoch # the lr will be set for the comming epoch, starts from 0 + + # add extra warmup + if comming_epoch < self.warmup: + # e.g. comming_epoch [0, 1, 2] for warmup == 3 + # lr should be base_lr * (last_epoch+1) / (warmup + 1), e.g. [0.25, 0.5, 0.75] * base_lr + lr_factor = (self.last_epoch + 1) / (self.warmup + 1) + return [base_lrs[i] * lr_factor for i in range(n_groups)] + else: + # bisect_right([3,5,7], 0) -> 0; bisect_right([3,5,7], 5) -> 2 + p = bisect_right(self.milestones, comming_epoch) + lr_factor = self.gamma**p + return [base_lrs[i] * lr_factor for i in range(n_groups)] diff --git a/hmr4d/model/common_utils/scheduler_cfg.py b/hmr4d/model/common_utils/scheduler_cfg.py new file mode 100644 index 0000000..fd16eee --- /dev/null +++ b/hmr4d/model/common_utils/scheduler_cfg.py @@ -0,0 +1,47 @@ +from omegaconf import DictConfig, ListConfig +from hmr4d.configs import MainStore, builds + +# do not perform scheduling +default = DictConfig({"scheduler": None}) +MainStore.store(name="default", node=default, group=f"scheduler_cfg") + + +# epoch-based +def epoch_half_by(milestones=[100, 200, 300]): + return DictConfig( + { + "scheduler": { + "_target_": "torch.optim.lr_scheduler.MultiStepLR", + "milestones": milestones, + "gamma": 0.5, + }, + "interval": "epoch", + "frequency": 1, + } + ) + + +MainStore.store(name="epoch_half_100_200_300", node=epoch_half_by([100, 200, 300]), group=f"scheduler_cfg") +MainStore.store(name="epoch_half_100_200", node=epoch_half_by([100, 200]), group=f"scheduler_cfg") +MainStore.store(name="epoch_half_200_350", node=epoch_half_by([200, 350]), group=f"scheduler_cfg") +MainStore.store(name="epoch_half_300", node=epoch_half_by([300]), group=f"scheduler_cfg") + + +# epoch-based +def warmup_epoch_half_by(warmup=10, milestones=[100, 200, 300]): + return DictConfig( + { + "scheduler": { + "_target_": "hmr4d.model.common_utils.scheduler.WarmupMultiStepLR", + "milestones": milestones, + "warmup": warmup, + "gamma": 0.5, + }, + "interval": "epoch", + "frequency": 1, + } + ) + + +MainStore.store(name="warmup_5_epoch_half_200_350", node=warmup_epoch_half_by(5, [200, 350]), group=f"scheduler_cfg") +MainStore.store(name="warmup_10_epoch_half_200_350", node=warmup_epoch_half_by(10, [200, 350]), group=f"scheduler_cfg") diff --git a/hmr4d/model/gvhmr/callbacks/metric_3dpw.py b/hmr4d/model/gvhmr/callbacks/metric_3dpw.py new file mode 100644 index 0000000..6af1750 --- /dev/null +++ b/hmr4d/model/gvhmr/callbacks/metric_3dpw.py @@ -0,0 +1,186 @@ +import torch +import pytorch_lightning as pl +import numpy as np +from pathlib import Path +from einops import einsum, rearrange + +from hmr4d.configs import MainStore, builds +from hmr4d.utils.pylogger import Log +from hmr4d.utils.comm.gather import all_gather +from hmr4d.utils.eval.eval_utils import compute_camcoord_metrics, as_np_array +from hmr4d.utils.smplx_utils import make_smplx +from hmr4d.utils.vis.cv2_utils import cv2, draw_bbx_xys_on_image_batch, draw_coco17_skeleton_batch +from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background +from hmr4d.utils.video_io_utils import read_video_np, get_video_lwh, save_video +from hmr4d.utils.geo_transform import apply_T_on_points +from hmr4d.utils.seq_utils import rearrange_by_mask + + +class MetricMocap(pl.Callback): + def __init__(self): + super().__init__() + # vid->result + self.metric_aggregator = { + "pa_mpjpe": {}, + "mpjpe": {}, + "pve": {}, + "accel": {}, + } + + # SMPLX and SMPL + self.smplx = make_smplx("supermotion_EVAL3DPW") + self.smpl = {"male": make_smplx("smpl", gender="male"), "female": make_smplx("smpl", gender="female")} + self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_3dpw14_J_regressor_sparse.pt").to_dense() + self.J_regressor24 = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt") + self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt") + self.faces_smplx = self.smplx.faces + self.faces_smpl = self.smpl["male"].faces + + # The metrics are calculated similarly for val/test/predict + self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end + + # Only validation record the metrics with logger + self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end + + # ================== Batch-based Computation ================== # + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + """The behaviour is the same for val/test/predict""" + assert batch["B"] == 1 + dataset_id = batch["meta"][0]["dataset_id"] + if dataset_id != "3DPW": + return + + # Move to cuda if not + self.smplx = self.smplx.cuda() + for g in ["male", "female"]: + self.smpl[g] = self.smpl[g].cuda() + self.J_regressor = self.J_regressor.cuda() + self.J_regressor24 = self.J_regressor24.cuda() + self.smplx2smpl = self.smplx2smpl.cuda() + + vid = batch["meta"][0]["vid"] + seq_length = batch["length"][0].item() + gender = batch["gender"][0] + T_w2c = batch["T_w2c"][0] + mask = batch["mask"][0] + + # Groundtruth (cam) + target_w_params = {k: v[0] for k, v in batch["smpl_params"].items()} + target_w_output = self.smpl[gender](**target_w_params) + target_w_verts = target_w_output.vertices + target_c_verts = apply_T_on_points(target_w_verts, T_w2c) + target_c_j3d = torch.matmul(self.J_regressor, target_c_verts) + + # + Prediction -> Metric + smpl_out = self.smplx(**outputs["pred_smpl_params_incam"]) + pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices]) + pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i") + del smpl_out # Prevent OOM + + # Metric of current sequence + batch_eval = { + "pred_j3d": pred_c_j3d, + "target_j3d": target_c_j3d, + "pred_verts": pred_c_verts, + "target_verts": target_c_verts, + } + camcoord_metrics = compute_camcoord_metrics(batch_eval, mask=mask, pelvis_idxs=[2, 3]) + for k in camcoord_metrics: + self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k]) + + if False: # Render incam (simple) + meta_render = batch["meta_render"][0] + images = read_video_np(meta_render["video_path"], scale=meta_render["ds"]) + render_dict = { + "K": meta_render["K"][None], # only support batch size 1 + "faces": self.smpl["male"].faces, + "verts": pred_c_verts, + "background": images, + } + img_overlay = simple_render_mesh_background(render_dict) + output_fn = Path("outputs/3DPW_render_pred_flip") / f"{vid}.mp4" + save_video(img_overlay, output_fn, crf=28) + + if False: # Render incam (with details) + meta_render = batch["meta_render"][0] + images = read_video_np(meta_render["video_path"], scale=meta_render["ds"]) + render_dict = { + "K": meta_render["K"][None], # only support batch size 1 + "faces": self.smpl["male"].faces, + "verts": pred_c_verts, + "background": images, + } + img_overlay = simple_render_mesh_background(render_dict) + + # Add COCO17 and bbx to image + bbx_xys_render = meta_render["bbx_xys"] + kp2d_render = meta_render["kp2d"] + img_overlay = draw_coco17_skeleton_batch(img_overlay, kp2d_render, conf_thr=0.5) + img_overlay = draw_bbx_xys_on_image_batch(bbx_xys_render, img_overlay, mask) + + # Add metric + metric_all = rearrange_by_mask(torch.tensor(camcoord_metrics["pa_mpjpe"]), mask) + for i in range(len(img_overlay)): + m = metric_all[i] + if m == 0: # a not evaluated frame + continue + text = f"PA-MPJPE: {m:.1f}" + color = (244, 10, 20) if m > 45 else (0, 205, 0) # red or green + cv2.putText(img_overlay[i], text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) + + output_dir = Path("tmp_pred_details") + output_dir.mkdir(exist_ok=True, parents=True) + save_video(img_overlay, output_dir / f"{vid}.mp4", crf=24) + + # ================== Epoch Summary ================== # + def on_predict_epoch_end(self, trainer, pl_module): + """Without logger""" + local_rank, world_size = trainer.local_rank, trainer.world_size + monitor_metric = "pa_mpjpe" + + # Reduce metric_aggregator across all processes + metric_keys = list(self.metric_aggregator.keys()) + with torch.inference_mode(False): # allow in-place operation of all_gather + metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict + for metric_key in metric_keys: + for d in metric_aggregator_gathered: + self.metric_aggregator[metric_key].update(d[metric_key]) + + if False: # debug to make sure the all_gather is correct + print(f"[RANK {local_rank}/{world_size}]: {self.metric_aggregator[monitor_metric].keys()}") + + total = len(self.metric_aggregator[monitor_metric]) + Log.info(f"{total} sequences evaluated in {self.__class__.__name__}") + if total == 0: + return + + # print monitored metric per sequence + mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()} + if len(mm_per_seq) > 0: + sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True) + n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq) + if local_rank == 0: + Log.info( + f"monitored metric {monitor_metric} per sequence\n" + + "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]]) + + "\n------" + ) + + # average over all batches + metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()} + if local_rank == 0: + Log.info(f"[Metrics] 3DPW:\n" + "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items()) + "\n------") + + # save to logger if available + if pl_module.logger is not None: + cur_epoch = pl_module.current_epoch + for k, v in metrics_avg.items(): + pl_module.logger.log_metrics({f"val_metric_3DPW/{k}": v}, step=cur_epoch) + + # reset + for k in self.metric_aggregator: + self.metric_aggregator[k] = {} + + +node_3dpw = builds(MetricMocap) +MainStore.store(name="metric_3dpw", node=node_3dpw, group="callbacks", package="callbacks.metric_3dpw") diff --git a/hmr4d/model/gvhmr/callbacks/metric_emdb.py b/hmr4d/model/gvhmr/callbacks/metric_emdb.py new file mode 100644 index 0000000..36631a7 --- /dev/null +++ b/hmr4d/model/gvhmr/callbacks/metric_emdb.py @@ -0,0 +1,323 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from hmr4d.configs import MainStore, builds + +from hmr4d.utils.comm.gather import all_gather +from hmr4d.utils.pylogger import Log + +from hmr4d.utils.eval.eval_utils import ( + compute_camcoord_metrics, + compute_global_metrics, + compute_camcoord_perjoint_metrics, + rearrange_by_mask, + as_np_array, +) +from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay +from hmr4d.utils.smplx_utils import make_smplx +from einops import einsum, rearrange + +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static +from hmr4d.utils.geo.hmr_cam import estimate_focal_length +from hmr4d.utils.video_io_utils import read_video_np, save_video +import imageio +from tqdm import tqdm +from pathlib import Path +import numpy as np +import cv2 + + +class MetricMocap(pl.Callback): + def __init__(self, emdb_split=1): + """ + Args: + emdb_split: 1 to evaluate incam, 2 to evaluate global + """ + super().__init__() + # vid->result + if emdb_split == 1: + self.target_dataset_id = "EMDB_1" + self.metric_aggregator = { + "pa_mpjpe": {}, + "mpjpe": {}, + "pve": {}, + "accel": {}, + } + elif emdb_split == 2: + self.target_dataset_id = "EMDB_2" + self.metric_aggregator = { + "wa2_mpjpe": {}, + "waa_mpjpe": {}, + "rte": {}, + "jitter": {}, + "fs": {}, + } + else: + raise ValueError(f"Unknown emdb_split: {emdb_split}") + + # SMPL + self.smplx = make_smplx("supermotion") + self.smpl_model = {"male": make_smplx("smpl", gender="male"), "female": make_smplx("smpl", gender="female")} + + self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt") + self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt") + self.faces_smpl = self.smpl_model["male"].faces + self.faces_smplx = self.smplx.faces + + # The metrics are calculated similarly for val/test/predict + self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end + + # Only validation record the metrics with logger + self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end + + # ================== Batch-based Computation ================== # + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + """The behaviour is the same for val/test/predict""" + assert batch["B"] == 1 + dataset_id = batch["meta"][0]["dataset_id"] + if dataset_id != self.target_dataset_id: + return + + # Move to cuda if not + self.smplx = self.smplx.cuda() + for g in ["male", "female"]: + self.smpl_model[g] = self.smpl_model[g].cuda() + self.J_regressor = self.J_regressor.cuda() + self.smplx2smpl = self.smplx2smpl.cuda() + + vid = batch["meta"][0]["vid"] + seq_length = batch["length"][0].item() + gender = batch["gender"][0] + T_w2c = batch["T_w2c"][0] + mask = batch["mask"][0] + + # Groundtruth (world, cam) + target_w_params = {k: v[0] for k, v in batch["smpl_params"].items()} + target_w_output = self.smpl_model[gender](**target_w_params) + target_w_verts = target_w_output.vertices + target_w_j3d = torch.matmul(self.J_regressor, target_w_verts) + target_c_verts = apply_T_on_points(target_w_verts, T_w2c) + target_c_j3d = apply_T_on_points(target_w_j3d, T_w2c) + + # + Prediction -> Metric + if self.target_dataset_id == "EMDB_1": # in camera metrics + # 1. cam + pred_smpl_params_incam = outputs["pred_smpl_params_incam"] + smpl_out = self.smplx(**pred_smpl_params_incam) + pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices]) + pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i") + del smpl_out # Prevent OOM + + batch_eval = { + "pred_j3d": pred_c_j3d, + "target_j3d": target_c_j3d, + "pred_verts": pred_c_verts, + "target_verts": target_c_verts, + } + camcoord_metrics = compute_camcoord_metrics(batch_eval, mask=mask) + for k in camcoord_metrics: + self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k]) + + elif self.target_dataset_id == "EMDB_2": # global metrics + # 2. global (align-y axis) + pred_smpl_params_global = outputs["pred_smpl_params_global"] + smpl_out = self.smplx(**pred_smpl_params_global) + pred_ay_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices]) + pred_ay_j3d = einsum(self.J_regressor, pred_ay_verts, "j v, l v i -> l j i") + del smpl_out # Prevent OOM + + batch_eval = { + "pred_j3d_glob": pred_ay_j3d, + "target_j3d_glob": target_w_j3d, + "pred_verts_glob": pred_ay_verts, + "target_verts_glob": target_w_verts, + } + global_metrics = compute_global_metrics(batch_eval, mask=mask) + for k in global_metrics: + self.metric_aggregator[k][vid] = as_np_array(global_metrics[k]) + + if False: # wis3d debug + wis3d = make_wis3d(name="debug-emdb-incam") + pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, [0]] # (L, J, 3) + target_cr_j3d = target_c_j3d - target_c_j3d[:, [0]] # (L, J, 3) + add_motion_as_lines(pred_cr_j3d, wis3d, name="pred_cr_j3d", const_color="blue") + add_motion_as_lines(target_cr_j3d, wis3d, name="target_cr_j3d", const_color="green") + + if False: # Dump wis3d + vid = batch["meta"][0]["vid"] + split = batch["meta_render"][0]["split"] + wis3d = make_wis3d(name=f"dump_emdb{split}-{vid}") + R_cam_type = batch["meta_render"][0]["R_cam_type"] + + pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, [0]] # (L, J, 3) + target_cr_j3d = target_c_j3d - target_c_j3d[:, [0]] # (L, J, 3) + add_motion_as_lines(pred_cr_j3d, wis3d, name="pred_cr_j3d", const_color="blue") + add_motion_as_lines(target_cr_j3d, wis3d, name="target_cr_j3d", const_color="green") + add_motion_as_lines(pred_ay_j3d, wis3d, name=f"pred_ay_j3d@{R_cam_type}") + # add_motion_as_lines(target_w_j3d, wis3d, name="target_ay_j3d") + + if False: # Render incam + # -- rendering code -- # + vname = batch["meta_render"][0]["name"] + video_path = batch["meta_render"][0]["video_path"] + width, height = batch["meta_render"][0]["width_height"] + K = batch["meta_render"][0]["K"] + faces = self.faces_smpl + split = batch["meta_render"][0]["split"] + + out_fn = f"outputs/dump_render_emdb{split}/{vname}.mp4" + Path(out_fn).parent.mkdir(exist_ok=True, parents=True) + + # renderer + renderer = Renderer(width, height, device="cuda", faces=faces, K=K) + # not skipping invalid frames + resize_factor = 0.25 + images = read_video_np(video_path, scale=resize_factor) # (F, H, W, 3), uint8, numpy + frame_id = batch["meta_render"][0]["frame_id"] + bbx_xys_render = batch["meta_render"][0]["bbx_xys"] + metric_vis = rearrange_by_mask(torch.from_numpy(self.metric_aggregator["mpjpe"][vid]), mask) + + # -- render mesh -- # + verts_incam = pred_c_verts + output_images = [] + for i in tqdm(range(len(images)), desc=f"Rendering {vname}"): + img = renderer.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8]) + # bbx + bbx_xys_ = bbx_xys_render[i].cpu().numpy() + lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int) + rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int) + img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2) + + if metric_vis[i] > 0: + text = f"pred mpjpe: {metric_vis[i]:.1f}" + text_color = (244, 10, 20) if metric_vis[i] > 80 else (0, 205, 0) # red or green + cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.75, text_color, 2) + + output_images.append(img) + save_video(output_images, out_fn, quality=5) + + if False: # Visualize incam + global results + + def move_to_start_point_face_z(verts): + "XZ to origin, Start from the ground, Face-Z" + verts = verts.clone() # (L, V, 3) + xz_mean = verts[0].mean(0)[[0, 2]] + y_min = verts[0, :, [1]].min() + offset = torch.tensor([[[xz_mean[0], y_min, xz_mean[1]]]]).to(verts) + verts = verts - offset + + T_ay2ayfz = compute_T_ayfz2ay(einsum(self.J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True) + verts = apply_T_on_points(verts, T_ay2ayfz) + return verts + + verts_incam = pred_c_verts.clone() + # verts_glob = move_to_start_point_face_z(target_ay_verts) # gt + verts_glob = move_to_start_point_face_z(pred_ay_verts) + global_R, global_T, global_lights = get_global_cameras_static(verts_glob.cpu()) + + # -- rendering code (global version FOV=55) -- # + vname = batch["meta_render"][0]["name"] + width, height = batch["meta_render"][0]["width_height"] + K = batch["meta_render"][0]["K"] + faces = self.faces_smpl + out_fn = f"outputs/dump_render_global/{vname}.mp4" + Path(out_fn).parent.mkdir(exist_ok=True, parents=True) + writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1) + + # two renderers + renderer_incam = Renderer(width, height, device="cuda", faces=faces, K=K) + renderer_glob = Renderer(width, height, estimate_focal_length(width, height), device="cuda", faces=faces) + + # imgs + video_path = batch["meta_render"][0]["video_path"] + frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy() + images = read_video_np(video_path, frame_id=frame_id) # (F, H/4, W/4, 3), uint8, numpy + + # Actual rendering + cx, cz = (verts_glob.mean(1).max(0)[0] + verts_glob.mean(1).min(0)[0])[[0, 2]] / 2.0 + scale = (verts_glob.mean(1).max(0)[0] - verts_glob.mean(1).min(0)[0])[[0, 2]].max() * 1.5 + renderer_glob.set_ground(scale, cx.item(), cz.item()) + color = torch.ones(3).float().cuda() * 0.8 + + for i in tqdm(range(seq_length), desc=f"Rendering {vname}"): + # incam + img_overlay_pred = renderer_incam.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8]) + if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines + bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy() + lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int) + rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int) + img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2) + pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i] + text = f"pred mpjpe: {pred_mpjpe_:.1f}" + cv2.putText(img_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2) + + # glob + cameras = renderer_glob.create_camera(global_R[i], global_T[i]) + img_glob = renderer_glob.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights) + + # write + img = np.concatenate([img_overlay_pred, img_glob], axis=1) + writer.append_data(img) + writer.close() + pass + + # ================== Epoch Summary ================== # + def on_predict_epoch_end(self, trainer, pl_module): + """Without logger""" + local_rank, world_size = trainer.local_rank, trainer.world_size + if "mpjpe" in self.metric_aggregator: + monitor_metric = "mpjpe" + else: + monitor_metric = list(self.metric_aggregator.keys())[0] + + # Reduce metric_aggregator across all processes + metric_keys = list(self.metric_aggregator.keys()) + with torch.inference_mode(False): # allow in-place operation of all_gather + metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict + for metric_key in metric_keys: + for d in metric_aggregator_gathered: + self.metric_aggregator[metric_key].update(d[metric_key]) + + total = len(self.metric_aggregator[monitor_metric]) + Log.info(f"{total} sequences evaluated in {self.__class__.__name__}") + if total == 0: + return + + # print monitored metric per sequence + mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()} + if len(mm_per_seq) > 0: + sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True) + n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq) + if local_rank == 0: + Log.info( + f"monitored metric {monitor_metric} per sequence\n" + + "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]]) + + "\n------" + ) + + # average over all batches + metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()} + if local_rank == 0: + Log.info( + f"[Metrics] {self.target_dataset_id}:\n" + + "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items()) + + "\n------" + ) + + # save to logger if available + if pl_module.logger is not None: + cur_epoch = pl_module.current_epoch + for k, v in metrics_avg.items(): + pl_module.logger.log_metrics({f"val_metric_{self.target_dataset_id}/{k}": v}, step=cur_epoch) + + # reset + for k in self.metric_aggregator: + self.metric_aggregator[k] = {} + + +emdb1_node = builds(MetricMocap, emdb_split=1) +emdb2_node = builds(MetricMocap, emdb_split=2) +MainStore.store(name="metric_emdb1", node=emdb1_node, group="callbacks", package="callbacks.metric_emdb1") +MainStore.store(name="metric_emdb2", node=emdb2_node, group="callbacks", package="callbacks.metric_emdb2") diff --git a/hmr4d/model/gvhmr/callbacks/metric_rich.py b/hmr4d/model/gvhmr/callbacks/metric_rich.py new file mode 100644 index 0000000..f60c9b5 --- /dev/null +++ b/hmr4d/model/gvhmr/callbacks/metric_rich.py @@ -0,0 +1,389 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from hmr4d.configs import MainStore, builds + +from hmr4d.utils.comm.gather import all_gather +from hmr4d.utils.pylogger import Log + +from hmr4d.utils.eval.eval_utils import ( + compute_camcoord_metrics, + compute_global_metrics, + compute_camcoord_perjoint_metrics, + as_np_array, +) +from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay +from hmr4d.utils.smplx_utils import make_smplx +from einops import einsum, rearrange + +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, get_colors_by_conf +from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points +from hmr4d.utils.geo.hmr_cam import estimate_focal_length +from hmr4d.utils.video_io_utils import read_video_np, save_video, get_writer +import imageio +from tqdm import tqdm +from pathlib import Path +import numpy as np +import cv2 + +from smplx.joint_names import JOINT_NAMES +from hmr4d.utils.net_utils import repeat_to_max_len, gaussian_smooth +from hmr4d.utils.geo.hmr_global import rollout_vel, get_static_joint_mask + + +class MetricMocap(pl.Callback): + def __init__(self): + super().__init__() + # vid->result + self.metric_aggregator = { + "pa_mpjpe": {}, + "mpjpe": {}, + "pve": {}, + "accel": {}, + "wa2_mpjpe": {}, + "waa_mpjpe": {}, + "rte": {}, + "jitter": {}, + "fs": {}, + } + + self.perjoint_metrics = False + if self.perjoint_metrics: + body_joint_names = JOINT_NAMES[:22] + ["left_hand", "right_hand"] + self.body_joint_names = body_joint_names + self.perjoint_metric_aggregator = { + "mpjpe": {k: {} for k in body_joint_names}, + } + self.perjoint_obs_metric_aggregator = { + "mpjpe": {k: {} for k in body_joint_names}, + } + + # SMPL + self.smplx_model = { + "male": make_smplx("rich-smplx", gender="male"), + "female": make_smplx("rich-smplx", gender="female"), + "neutral": make_smplx("rich-smplx", gender="neutral"), + } + self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt") + self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt") + self.faces_smpl = make_smplx("smpl").faces + self.faces_smplx = self.smplx_model["neutral"].faces + + # The metrics are calculated similarly for val/test/predict + self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end + + # Only validation record the metrics with logger + self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end + + # ================== Batch-based Computation ================== # + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + """The behaviour is the same for val/test/predict""" + assert batch["B"] == 1 + dataset_id = batch["meta"][0]["dataset_id"] + if dataset_id != "RICH": + return + + # Move to cuda if not + for g in ["male", "female", "neutral"]: + self.smplx_model[g] = self.smplx_model[g].cuda() + self.J_regressor = self.J_regressor.cuda() + self.smplx2smpl = self.smplx2smpl.cuda() + + vid = batch["meta"][0]["vid"] + seq_length = batch["length"][0].item() + gender = batch["gender"][0] + T_w2ay = batch["T_w2ay"][0] + T_w2c = batch["T_w2c"][0] + + # Groundtruth (world, cam) + target_w_params = {k: v[0] for k, v in batch["gt_smpl_params"].items()} + target_w_output = self.smplx_model[gender](**target_w_params) + target_w_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in target_w_output.vertices]) + target_c_verts = apply_T_on_points(target_w_verts, T_w2c) + target_c_j3d = torch.matmul(self.J_regressor, target_c_verts) + offset = target_c_j3d[..., [1, 2], :].mean(-2, keepdim=True) # (L, 1, 3) + target_cr_j3d = target_c_j3d - offset + target_cr_verts = target_c_verts - offset + # optional: ay for visual comparison + target_ay_verts = apply_T_on_points(target_w_verts, T_w2ay) + target_ay_j3d = torch.matmul(self.J_regressor, target_ay_verts) + + # + Prediction -> Metric + # 1. cam + pred_smpl_params_incam = outputs["pred_smpl_params_incam"] + smpl_out = self.smplx_model["neutral"](**pred_smpl_params_incam) + pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices]) + pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i") + offset = pred_c_j3d[..., [1, 2], :].mean(-2, keepdim=True) # (L, 1, 3) + + # 2. ay + pred_smpl_params_global = outputs["pred_smpl_params_global"] + smpl_out = self.smplx_model["neutral"](**pred_smpl_params_global) + pred_ay_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices]) + pred_ay_j3d = einsum(self.J_regressor, pred_ay_verts, "j v, l v i -> l j i") + + # Metric of current sequence + batch_eval = { + "pred_j3d": pred_c_j3d, + "target_j3d": target_c_j3d, + "pred_verts": pred_c_verts, + "target_verts": target_c_verts, + } + camcoord_metrics = compute_camcoord_metrics(batch_eval) + for k in camcoord_metrics: + self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k]) + + batch_eval = { + "pred_j3d_glob": pred_ay_j3d, + "target_j3d_glob": target_ay_j3d, + "pred_verts_glob": pred_ay_verts, + "target_verts_glob": target_ay_verts, + } + global_metrics = compute_global_metrics(batch_eval) + for k in global_metrics: + self.metric_aggregator[k][vid] = as_np_array(global_metrics[k]) + + if False: # global wi3d debug + wis3d = make_wis3d(name="debug-metric-global") + add_motion_as_lines(pred_ay_j3d, wis3d, name="pred_ay_j3d") + add_motion_as_lines(target_ay_j3d, wis3d, name="target_ay_j3d") + + if False: # incam visualize debug + # Print per-sequence error + Log.info( + f"seq {vid} metrics:\n" + + "\n".join( + f"{k}: {self.metric_aggregator[k][vid].mean():.1f} (obs:{camcoord_metrics[k].mean():.1f})" + for k in camcoord_metrics.keys() + ) + + "\n------\n" + ) + if self.perjoint_metrics: + Log.info( + f"\n".join( + f"{k}-{j}: {self.perjoint_metric_aggregator[k][j][vid].mean():.1f} (obs:{self.perjoint_obs_metric_aggregator[k][j][vid].mean():.1f})" + for j in self.body_joint_names + for k in self.perjoint_obs_metric_aggregator.keys() + ) + + "\n------" + ) + + # -- metric -- # + pred_mpjpe = self.metric_aggregator["mpjpe"][vid].mean() + obs_mpjpe = camcoord_metrics["mpjpe"].mean() + + # -- render mesh -- # + vertices_gt = target_c_verts + vertices_cr_gt = target_cr_verts + target_cr_verts.new([0, 0, 3.0]) # move forward +z + vertices_pred = pred_c_verts + vertices_cr_obs = obs_cr_verts + obs_cr_verts.new([0, 0, 3.0]) # move forward +z + vertices_cr_pred = pred_cr_verts + pred_cr_verts.new([0, 0, 3.0]) # move forward +z + + # -- rendering code -- # + vname = batch["meta_render"][0]["name"] + K = batch["meta_render"][0]["K"] + width, height = batch["meta_render"][0]["width_height"] + faces = self.faces_smpl + + renderer = Renderer(width, height, device="cuda", faces=faces, K=K) + out_fn = f"outputs/dump_render/{vname}.mp4" + Path(out_fn).parent.mkdir(exist_ok=True, parents=True) + writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1) + + # imgs + video_path = batch["meta_render"][0]["video_path"] + frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy() + vr = decord.VideoReader(video_path) + images = vr.get_batch(list(frame_id)).numpy() # (F, H/4, W/4, 3), uint8, numpy + + for i in tqdm(range(seq_length), desc=f"Rendering {vname}"): + img_overlay_gt = renderer.render_mesh(vertices_gt[i].cuda(), images[i], [39, 194, 128]) + if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines + bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy() + lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int) + rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int) + img_overlay_gt = cv2.rectangle(img_overlay_gt, lu_point, rd_point, (255, 178, 102), 2) + + img_overlay_pred = renderer.render_mesh(vertices_pred[i].cuda(), images[i]) + # img_overlay_pred = renderer.render_mesh(vertices_pred[i].cuda(), np.zeros_like(images[i])) + img = np.concatenate([img_overlay_gt, img_overlay_pred], axis=0) + + ####### overlay gt cr first, then overlay pred cr with error color ######## + # overlay gt cr first with blue color + black_overlay_obs = renderer.render_mesh( + vertices_cr_gt[i].cuda(), np.zeros_like(images[i]), colors=[39, 194, 128] + ) + black_overlay_pred = renderer.render_mesh( + vertices_cr_gt[i].cuda(), np.zeros_like(images[i]), colors=[39, 194, 128] + ) + + # get error color + obs_error = (vertices_cr_gt[i] - vertices_cr_obs[i]).norm(dim=-1) + pred_error = (vertices_cr_gt[i] - vertices_cr_pred[i]).norm(dim=-1) + max_error = max(obs_error.max(), pred_error.max()) + obs_error_color = torch.stack( + [obs_error / max_error, torch.ones_like(obs_error) * 0.6, torch.ones_like(obs_error) * 0.6], + dim=-1, + ) + obs_error_color = torch.clip(obs_error_color, 0, 1) + pred_error_color = torch.stack( + [pred_error / max_error, torch.ones_like(pred_error) * 0.6, torch.ones_like(pred_error) * 0.6], + dim=-1, + ) + pred_error_color = torch.clip(pred_error_color, 0, 1) + + # overlay cr with error color + black_overlay_obs = renderer.render_mesh( + vertices_cr_obs[i].cuda(), black_overlay_obs, colors=obs_error_color[None] + ) + black_overlay_pred = renderer.render_mesh( + vertices_cr_pred[i].cuda(), black_overlay_pred, colors=pred_error_color[None] + ) + + # write mpjpe on the img + obs_mpjpe_ = camcoord_metrics["mpjpe"][i] + text = f"obs mpjpe: {obs_mpjpe_:.1f} ({obs_mpjpe:.1f})" + cv2.putText(black_overlay_obs, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 200, 200), 2) + pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i] + text = f"pred mpjpe: {pred_mpjpe_:.1f} ({pred_mpjpe:.1f})" + if pred_mpjpe_ > obs_mpjpe_: + # large error -> purple + cv2.putText(black_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2) + else: + # small error -> yellow + cv2.putText(black_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 200, 100), 2) + black = np.concatenate([black_overlay_obs, black_overlay_pred], axis=0) + ########################################### + + img = np.concatenate([img, black], axis=1) + + writer.append_data(img) + writer.close() + + if False: # Visualize incam + global results + + def move_to_start_point_face_z(verts): + "XZ to origin, Start from the ground, Face-Z" + # position + verts = verts.clone() # (L, V, 3) + offset = einsum(self.J_regressor, verts[0], "j v, v i -> j i")[0] # (3) + offset[1] = verts[:, :, [1]].min() + verts = verts - offset + # face direction + T_ay2ayfz = compute_T_ayfz2ay(einsum(self.J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True) + verts = apply_T_on_points(verts, T_ay2ayfz) + return verts + + verts_incam = pred_c_verts.clone() + # verts_glob = move_to_start_point_face_z(target_ay_verts) # gt + verts_glob = move_to_start_point_face_z(pred_ay_verts) + joints_glob = einsum(self.J_regressor, verts_glob, "j v, l v i -> l j i") # (L, J, 3) + global_R, global_T, global_lights = get_global_cameras_static( + verts_glob.cpu(), + beta=4.0, + cam_height_degree=20, + target_center_height=1.0, + vec_rot=-45, + ) + + # -- rendering code (global version FOV=55) -- # + vname = batch["meta_render"][0]["name"] + width, height = batch["meta_render"][0]["width_height"] + K = batch["meta_render"][0]["K"] + faces = self.faces_smpl + out_fn = f"outputs/dump_render_global/{vname}.mp4" + Path(out_fn).parent.mkdir(exist_ok=True, parents=True) + + # two renderers + renderer_incam = Renderer(width, height, device="cuda", faces=faces, K=K) + renderer_glob = Renderer(width, height, estimate_focal_length(width, height), device="cuda", faces=faces) + + # imgs + video_path = batch["meta_render"][0]["video_path"] + frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy() + images = read_video_np(video_path)[frame_id] # (F, H/4, W/4, 3), uint8, numpy + + # Actual rendering + scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob) + renderer_glob.set_ground(scale * 1.5, cx, cz) + color = torch.ones(3).float().cuda() * 0.8 + + writer = get_writer(out_fn, fps=30, crf=23) + for i in tqdm(range(seq_length), desc=f"Rendering {vname}"): + # incam + img_overlay_pred = renderer_incam.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8]) + # if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines + # bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy() + # lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int) + # rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int) + # img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2) + # pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i] + # text = f"pred mpjpe: {pred_mpjpe_:.1f}" + # cv2.putText(img_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2) + + # glob + cameras = renderer_glob.create_camera(global_R[i], global_T[i]) + # img_glob = renderer_glob.render_with_ground(verts_glob[[i]], color_[None], cameras, global_lights) + img_glob = renderer_glob.render_with_ground( + verts_glob[[i]], color.clone()[None], cameras, global_lights + ) + + # write + img = np.concatenate([img_overlay_pred, img_glob], axis=1) + writer.write_frame(img) + writer.close() + + # ================== Epoch Summary ================== # + def on_predict_epoch_end(self, trainer, pl_module): + """Without logger""" + local_rank, world_size = trainer.local_rank, trainer.world_size + monitor_metric = "mpjpe" + + # Reduce metric_aggregator across all processes + metric_keys = list(self.metric_aggregator.keys()) + with torch.inference_mode(False): # allow in-place operation of all_gather + metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict + for metric_key in metric_keys: + for d in metric_aggregator_gathered: + self.metric_aggregator[metric_key].update(d[metric_key]) + + if False: # debug to make sure the all_gather is correct + print(f"[RANK {local_rank}/{world_size}]: {self.metric_aggregator[monitor_metric].keys()}") + + total = len(self.metric_aggregator[monitor_metric]) + Log.info(f"{total} sequences evaluated in {self.__class__.__name__}") + if total == 0: + return + + # print monitored metric per sequence + mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()} + if len(mm_per_seq) > 0: + sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True) + n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq) + if local_rank == 0: + Log.info( + f"monitored metric {monitor_metric} per sequence\n" + + "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]]) + + "\n------" + ) + + # average over all batches + metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()} + if local_rank == 0: + Log.info(f"[Metrics] RICH:\n" + "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items()) + "\n------") + + # save to logger if available + if pl_module.logger is not None: + cur_epoch = pl_module.current_epoch + for k, v in metrics_avg.items(): + pl_module.logger.log_metrics({f"val_metric_RICH/{k}": v}, step=cur_epoch) + + # reset + for k in self.metric_aggregator: + self.metric_aggregator[k] = {} + + +rich_node = builds(MetricMocap) +MainStore.store(name="metric_rich", node=rich_node, group="callbacks", package="callbacks.metric_rich") diff --git a/hmr4d/model/gvhmr/gvhmr_pl.py b/hmr4d/model/gvhmr/gvhmr_pl.py new file mode 100644 index 0000000..af9d3a2 --- /dev/null +++ b/hmr4d/model/gvhmr/gvhmr_pl.py @@ -0,0 +1,324 @@ +from typing import Any, Dict +import numpy as np +from pathlib import Path +import torch +import pytorch_lightning as pl +from hydra.utils import instantiate +from hmr4d.utils.pylogger import Log +from einops import rearrange, einsum +from hmr4d.configs import MainStore, builds + +from hmr4d.utils.geo_transform import compute_T_ayfz2ay, apply_T_on_points +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.smplx_utils import make_smplx +from hmr4d.utils.geo.augment_noisy_pose import ( + get_wham_aug_kp3d, + get_visible_mask, + get_invisible_legs_mask, + randomly_occlude_lower_half, + randomly_modify_hands_legs, +) +from hmr4d.utils.geo.hmr_cam import perspective_projection, normalize_kp2d, safely_render_x3d_K, get_bbx_xys + +from hmr4d.utils.video_io_utils import save_video +from hmr4d.utils.vis.cv2_utils import draw_bbx_xys_on_image_batch +from hmr4d.utils.geo.flip_utils import flip_smplx_params, avg_smplx_aa +from hmr4d.model.gvhmr.utils.postprocess import pp_static_joint, pp_static_joint_cam, process_ik + + +class GvhmrPL(pl.LightningModule): + def __init__( + self, + pipeline, + optimizer=None, + scheduler_cfg=None, + ignored_weights_prefix=["smplx", "pipeline.endecoder"], + ): + super().__init__() + self.pipeline = instantiate(pipeline, _recursive_=False) + self.optimizer = instantiate(optimizer) + self.scheduler_cfg = scheduler_cfg + + # Options + self.ignored_weights_prefix = ignored_weights_prefix + + # The test step is the same as validation + self.test_step = self.predict_step = self.validation_step + + # SMPLX + self.smplx = make_smplx("supermotion_v437coco17") + + def training_step(self, batch, batch_idx): + B, F = batch["smpl_params_c"]["body_pose"].shape[:2] + + # Create augmented noisy-obs : gt_j3d(coco17) + with torch.no_grad(): + gt_verts437, gt_j3d = self.smplx(**batch["smpl_params_c"]) + root_ = gt_j3d[:, :, [11, 12], :].mean(-2, keepdim=True) + batch["gt_j3d"] = gt_j3d + batch["gt_cr_coco17"] = gt_j3d - root_ + batch["gt_c_verts437"] = gt_verts437 + batch["gt_cr_verts437"] = gt_verts437 - root_ + + # bbx_xys + i_x2d = safely_render_x3d_K(gt_verts437, batch["K_fullimg"], thr=0.3) + bbx_xys = get_bbx_xys(i_x2d, do_augment=True) + if False: # trust image bbx_xys seems better + batch["bbx_xys"] = bbx_xys + else: + mask_bbx_xys = batch["mask"]["bbx_xys"] + batch["bbx_xys"][~mask_bbx_xys] = bbx_xys[~mask_bbx_xys] + if False: # visualize bbx_xys from an iPhone view + render_w, render_h = 120, 160 # iphone main-lens 24mm 3:4 + ratio = render_w / 1528 + offset = torch.tensor([764 - 500, 1019 - 500]).to(i_x2d) + i_x2d_render = (i_x2d + offset).clone() + i_x2d_render = (i_x2d_render * ratio).long().clone() + torch.clamp_(i_x2d_render[..., 0], 0, render_w - 1) + torch.clamp_(i_x2d_render[..., 1], 0, render_h - 1) + bbx_xys_render = bbx_xys.clone() + bbx_xys_render[..., :2] += offset + bbx_xys_render *= ratio + + output_dir = Path("outputs/simulated_bbx_xys") + output_dir.mkdir(parents=True, exist_ok=True) + video_list = [] + for bid in range(B): + images = torch.zeros(F, render_h, render_w, 3, device=i_x2d.device) + for fid in range(F): + images[fid, i_x2d_render[bid, fid, :, 1], i_x2d_render[bid, fid, :, 0]] = 255 + + images = draw_bbx_xys_on_image_batch(bbx_xys_render[bid].cpu().numpy(), images.cpu().numpy()) + images = np.stack(images).astype("uint8") # (L, H, W, 3) + images[:, 0, :] = np.array([255, 255, 255]) + images[:, -1, :] = np.array([255, 255, 255]) + images[:, :, 0] = np.array([255, 255, 255]) + images[:, :, -1] = np.array([255, 255, 255]) + video_list.append(images) + + # stack videos + video_output = [] + for i in range(0, len(video_list), 4): + if i + 4 <= len(video_list): + video_output.append(np.concatenate(video_list[i : i + 4], axis=2)) + video_output = np.concatenate(video_output, axis=1) + save_video(video_output, output_dir / f"{batch_idx}.mp4", fps=30, quality=5) + + # noisy_j3d -> project to i_j2d -> compute a bbx -> normalized kp2d [-1, 1] + noisy_j3d = gt_j3d + get_wham_aug_kp3d(gt_j3d.shape[:2]) + if True: + noisy_j3d = randomly_modify_hands_legs(noisy_j3d) + obs_i_j2d = perspective_projection(noisy_j3d, batch["K_fullimg"]) # (B, L, J, 2) + j2d_visible_mask = get_visible_mask(gt_j3d.shape[:2]).cuda() # (B, L, J) + j2d_visible_mask[noisy_j3d[..., 2] < 0.3] = False # Set close-to-image-plane points as invisible + if True: # Set both legs as invisible for a period + legs_invisible_mask = get_invisible_legs_mask(gt_j3d.shape[:2]).cuda() # (B, L, J) + j2d_visible_mask[legs_invisible_mask] = False + obs_kp2d = torch.cat([obs_i_j2d, j2d_visible_mask[:, :, :, None].float()], dim=-1) # (B, L, J, 3) + obs = normalize_kp2d(obs_kp2d, batch["bbx_xys"]) # (B, L, J, 3) + obs[~j2d_visible_mask] = 0 # if not visible, set to (0,0,0) + batch["obs"] = obs + + if True: # Use some detected vitpose (presave data) + prob = 0.5 + mask_real_vitpose = (torch.rand(B).to(obs_kp2d) < prob) * batch["mask"]["vitpose"] + batch["obs"][mask_real_vitpose] = normalize_kp2d(batch["kp2d"], batch["bbx_xys"])[mask_real_vitpose] + + # Set untrusted frames to False + batch["obs"][~batch["mask"]["valid"]] = 0 + + if False: # wis3d + wis3d = make_wis3d(name="debug-aug-kp3d") + add_motion_as_lines(gt_j3d[0], wis3d, name="gt_j3d", skeleton_type="coco17") + add_motion_as_lines(noisy_j3d[0], wis3d, name="noisy_j3d", skeleton_type="coco17") + + # f_imgseq: apply random aug on offline extracted features + # f_imgseq = batch["f_imgseq"] + torch.randn_like(batch["f_imgseq"]) * 0.1 + # f_imgseq[~batch["mask"]["f_imgseq"]] = 0 + # batch["f_imgseq"] = f_imgseq.clone() + + # Forward and get loss + outputs = self.pipeline.forward(batch, train=True) + + # Log + log_kwargs = { + "on_epoch": True, + "prog_bar": True, + "logger": True, + "batch_size": B, + "sync_dist": True, + } + self.log("train/loss", outputs["loss"], **log_kwargs) + for k, v in outputs.items(): + if "_loss" in k: + self.log(f"train/{k}", v, **log_kwargs) + + return outputs + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + # Options & Check + do_postproc = self.trainer.state.stage == "test" # Only apply postproc in test + do_flip_test = "flip_test" in batch + do_postproc_not_flip_test = do_postproc and not do_flip_test # later pp when flip_test + assert batch["B"] == 1, "Only support batch size 1 in evalution." + + # ROPE inference + obs = normalize_kp2d(batch["kp2d"], batch["bbx_xys"]) + if "mask" in batch: + obs[0, ~batch["mask"][0]] = 0 + + batch_ = { + "length": batch["length"], + "obs": obs, + "bbx_xys": batch["bbx_xys"], + "K_fullimg": batch["K_fullimg"], + "cam_angvel": batch["cam_angvel"], + "f_imgseq": batch["f_imgseq"], + } + outputs = self.pipeline.forward(batch_, train=False, postproc=do_postproc_not_flip_test) + outputs["pred_smpl_params_global"] = {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()} + outputs["pred_smpl_params_incam"] = {k: v[0] for k, v in outputs["pred_smpl_params_incam"].items()} + + if do_flip_test: + flip_test = batch["flip_test"] + obs = normalize_kp2d(flip_test["kp2d"], flip_test["bbx_xys"]) + if "mask" in batch: + obs[0, ~batch["mask"][0]] = 0 + + batch_ = { + "length": batch["length"], + "obs": obs, + "bbx_xys": flip_test["bbx_xys"], + "K_fullimg": batch["K_fullimg"], + "cam_angvel": flip_test["cam_angvel"], + "f_imgseq": flip_test["f_imgseq"], + } + flipped_outputs = self.pipeline.forward(batch_, train=False) + + # First update incam results + flipped_outputs["pred_smpl_params_incam"] = { + k: v[0] for k, v in flipped_outputs["pred_smpl_params_incam"].items() + } + smpl_params1 = outputs["pred_smpl_params_incam"] + smpl_params2 = flip_smplx_params(flipped_outputs["pred_smpl_params_incam"]) + + smpl_params_avg = smpl_params1.copy() + smpl_params_avg["betas"] = (smpl_params1["betas"] + smpl_params2["betas"]) / 2 + smpl_params_avg["body_pose"] = avg_smplx_aa(smpl_params1["body_pose"], smpl_params2["body_pose"]) + smpl_params_avg["global_orient"] = avg_smplx_aa( + smpl_params1["global_orient"], smpl_params2["global_orient"] + ) + outputs["pred_smpl_params_incam"] = smpl_params_avg + + # Then update global results + outputs["pred_smpl_params_global"]["betas"] = smpl_params_avg["betas"] + outputs["pred_smpl_params_global"]["body_pose"] = smpl_params_avg["body_pose"] + + # Finally, apply postprocess + if do_postproc: + # temporarily recover the original batch-dim + outputs["pred_smpl_params_global"] = {k: v[None] for k, v in outputs["pred_smpl_params_global"].items()} + outputs["pred_smpl_params_global"]["transl"] = pp_static_joint(outputs, self.pipeline.endecoder) + body_pose = process_ik(outputs, self.pipeline.endecoder) + outputs["pred_smpl_params_global"] = {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()} + + outputs["pred_smpl_params_global"]["body_pose"] = body_pose[0] + # outputs["pred_smpl_params_incam"]["body_pose"] = body_pose[0] + + if False: # wis3d + wis3d = make_wis3d(name="debug-rich-cap") + smplx_model = make_smplx("rich-smplx", gender="neutral").cuda() + gender = batch["gender"][0] + T_w2ay = batch["T_w2ay"][0] + + # Prediction + # add_motion_as_lines(outputs_window["pred_ayfz_motion"][bid], wis3d, name="pred_ayfz_motion") + + smplx_out = smplx_model(**pred_smpl_params_global) + for i in range(len(smplx_out.vertices)): + wis3d.set_scene_id(i) + wis3d.add_mesh(smplx_out.vertices[i], smplx_model.bm.faces, name=f"pred-smplx-global") + + # GT (w) + smplx_models = { + "male": make_smplx("rich-smplx", gender="male").cuda(), + "female": make_smplx("rich-smplx", gender="female").cuda(), + } + gt_smpl_params = {k: v[0, windows[0]] for k, v in batch["gt_smpl_params"].items()} + gt_smplx_out = smplx_models[gender](**gt_smpl_params) + + # GT (ayfz) + smplx_verts_ay = apply_T_on_points(gt_smplx_out.vertices, T_w2ay) + smplx_joints_ay = apply_T_on_points(gt_smplx_out.joints, T_w2ay) + T_ay2ayfz = compute_T_ayfz2ay(smplx_joints_ay[:1], inverse=True)[0] # (4, 4) + smplx_verts_ayfz = apply_T_on_points(smplx_verts_ay, T_ay2ayfz) # (F, 22, 3) + + for i in range(len(smplx_verts_ayfz)): + wis3d.set_scene_id(i) + wis3d.add_mesh(smplx_verts_ayfz[i], smplx_models[gender].bm.faces, name=f"gt-smplx-ayfz") + + breakpoint() + + if False: # o3d + prog_keys = [ + "pred_smpl_progress", + "pred_localjoints_progress", + "pred_incam_localjoints_progress", + ] + for k in prog_keys: + if k in outputs_window: + seq_out = torch.cat( + [v[:, :l] for v, l in zip(outputs_window[k], length)], dim=1 + ) # (B, P, L, J, 3) -> (P, L, J, 3) -> (P, CL, J, 3) + outputs[k] = seq_out[None] + + return outputs + + def configure_optimizers(self): + params = [] + for k, v in self.pipeline.named_parameters(): + if v.requires_grad: + params.append(v) + optimizer = self.optimizer(params=params) + + if self.scheduler_cfg["scheduler"] is None: + return optimizer + + scheduler_cfg = dict(self.scheduler_cfg) + scheduler_cfg["scheduler"] = instantiate(scheduler_cfg["scheduler"], optimizer=optimizer) + return [optimizer], [scheduler_cfg] + + # ============== Utils ================= # + def on_save_checkpoint(self, checkpoint) -> None: + for ig_keys in self.ignored_weights_prefix: + for k in list(checkpoint["state_dict"].keys()): + if k.startswith(ig_keys): + # Log.info(f"Remove key `{ig_keys}' from checkpoint.") + checkpoint["state_dict"].pop(k) + + def load_pretrained_model(self, ckpt_path): + """Load pretrained checkpoint, and assign each weight to the corresponding part.""" + Log.info(f"[PL-Trainer] Loading ckpt: {ckpt_path}") + + state_dict = torch.load(ckpt_path, "cpu")["state_dict"] + missing, unexpected = self.load_state_dict(state_dict, strict=False) + real_missing = [] + for k in missing: + ignored_when_saving = any(k.startswith(ig_keys) for ig_keys in self.ignored_weights_prefix) + if not ignored_when_saving: + real_missing.append(k) + + if len(real_missing) > 0: + Log.warn(f"Missing keys: {real_missing}") + if len(unexpected) > 0: + Log.warn(f"Unexpected keys: {unexpected}") + + +gvhmr_pl = builds( + GvhmrPL, + pipeline="${pipeline}", + optimizer="${optimizer}", + scheduler_cfg="${scheduler_cfg}", + populate_full_signature=True, # Adds all the arguments to the signature +) +MainStore.store(name="gvhmr_pl", node=gvhmr_pl, group="model/gvhmr") diff --git a/hmr4d/model/gvhmr/gvhmr_pl_demo.py b/hmr4d/model/gvhmr/gvhmr_pl_demo.py new file mode 100644 index 0000000..3cd49db --- /dev/null +++ b/hmr4d/model/gvhmr/gvhmr_pl_demo.py @@ -0,0 +1,60 @@ +import torch +import pytorch_lightning as pl +from hydra.utils import instantiate +from hmr4d.utils.pylogger import Log +from hmr4d.configs import MainStore, builds + +from hmr4d.utils.geo.hmr_cam import normalize_kp2d + + +class DemoPL(pl.LightningModule): + def __init__(self, pipeline): + super().__init__() + self.pipeline = instantiate(pipeline, _recursive_=False) + + @torch.no_grad() + def predict(self, data, static_cam=False): + """auto add batch dim + data: { + "length": int, or Torch.Tensor, + "kp2d": (F, 3) + "bbx_xys": (F, 3) + "K_fullimg": (F, 3, 3) + "cam_angvel": (F, 3) + "f_imgseq": (F, 3, 256, 256) + } + + """ + # ROPE inference + batch = { + "length": data["length"][None], + "obs": normalize_kp2d(data["kp2d"], data["bbx_xys"])[None], + "bbx_xys": data["bbx_xys"][None], + "K_fullimg": data["K_fullimg"][None], + "cam_angvel": data["cam_angvel"][None], + "f_imgseq": data["f_imgseq"][None], + } + batch = {k: v.cuda() for k, v in batch.items()} + outputs = self.pipeline.forward(batch, train=False, postproc=True, static_cam=static_cam) + + pred = { + "smpl_params_global": {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()}, + "smpl_params_incam": {k: v[0] for k, v in outputs["pred_smpl_params_incam"].items()}, + "K_fullimg": data["K_fullimg"], + "net_outputs": outputs, # intermediate outputs + } + return pred + + def load_pretrained_model(self, ckpt_path): + """Load pretrained checkpoint, and assign each weight to the corresponding part.""" + Log.info(f"[PL-Trainer] Loading ckpt type: {ckpt_path}") + + state_dict = torch.load(ckpt_path, "cpu")["state_dict"] + missing, unexpected = self.load_state_dict(state_dict, strict=False) + if len(missing) > 0: + Log.warn(f"Missing keys: {missing}") + if len(unexpected) > 0: + Log.warn(f"Unexpected keys: {unexpected}") + + +MainStore.store(name="gvhmr_pl_demo", node=builds(DemoPL, pipeline="${pipeline}"), group="model/gvhmr") diff --git a/hmr4d/model/gvhmr/pipeline/gvhmr_pipeline.py b/hmr4d/model/gvhmr/pipeline/gvhmr_pipeline.py new file mode 100644 index 0000000..9e99b22 --- /dev/null +++ b/hmr4d/model/gvhmr/pipeline/gvhmr_pipeline.py @@ -0,0 +1,384 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast +import numpy as np +from einops import einsum, rearrange, repeat +from hydra.utils import instantiate +from hmr4d.utils.pylogger import Log +from hmr4d.utils.net_utils import gaussian_smooth + +from hmr4d.model.gvhmr.utils.endecoder import EnDecoder +from hmr4d.model.gvhmr.utils.postprocess import ( + pp_static_joint, + process_ik, + pp_static_joint_cam, +) +from hmr4d.model.gvhmr.utils import stats_compose + +from pytorch3d.transforms import ( + matrix_to_rotation_6d, + rotation_6d_to_matrix, + axis_angle_to_matrix, + matrix_to_axis_angle, +) +from hmr4d.utils.geo.hmr_cam import compute_bbox_info_bedlam, compute_transl_full_cam, get_a_pred_cam, project_to_bi01 +from hmr4d.utils.geo.hmr_global import ( + rollout_local_transl_vel, + get_static_joint_mask, + get_tgtcoord_rootparam, +) +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines +from hmr4d.utils.smplx_utils import make_smplx + + +class Pipeline(nn.Module): + def __init__(self, args, args_denoiser3d, **kwargs): + super().__init__() + self.args = args + self.weights = args.weights # loss weights + + # Networks + self.denoiser3d = instantiate(args_denoiser3d, _recursive_=False) + # Log.info(self.denoiser3d) + + # Normalizer + self.endecoder: EnDecoder = instantiate(args.endecoder_opt, _recursive_=False) + if self.args.normalize_cam_angvel: + cam_angvel_stats = stats_compose.cam_angvel["manual"] + self.register_buffer("cam_angvel_mean", torch.tensor(cam_angvel_stats["mean"]), persistent=False) + self.register_buffer("cam_angvel_std", torch.tensor(cam_angvel_stats["std"]), persistent=False) + + # ========== Training ========== # + + def forward(self, inputs, train=False, postproc=False, static_cam=False): + outputs = dict() + length = inputs["length"] # (B,) effective length of each sample + + # *. Conditions + cliff_cam = compute_bbox_info_bedlam(inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 3) + f_cam_angvel = inputs["cam_angvel"] + if self.args.normalize_cam_angvel: + f_cam_angvel = (f_cam_angvel - self.cam_angvel_mean) / self.cam_angvel_std + f_condition = { + "obs": inputs["obs"], # (B, L, J, 3) + "f_cliffcam": cliff_cam, # (B, L, 3) + "f_cam_angvel": f_cam_angvel, # (B, L, C=6) + "f_imgseq": inputs["f_imgseq"], # (B, L, C=1024) + } + if train: + f_condition = randomly_set_null_condition(f_condition, 0.1) + + # Forward & output + model_output = self.denoiser3d(length=length, **f_condition) # pred_x, pred_cam, static_conf_logits + decode_dict = self.endecoder.decode(model_output["pred_x"]) # (B, L, C) -> dict + outputs.update({"model_output": model_output, "decode_dict": decode_dict}) + + # Post-processing + outputs["pred_smpl_params_incam"] = { + "body_pose": decode_dict["body_pose"], # (B, L, 63) + "betas": decode_dict["betas"], # (B, L, 10) + "global_orient": decode_dict["global_orient"], # (B, L, 3) + "transl": compute_transl_full_cam(model_output["pred_cam"], inputs["bbx_xys"], inputs["K_fullimg"]), + } + if not train: + pred_smpl_params_global = get_smpl_params_w_Rt_v2( # This function has for-loop + global_orient_gv=decode_dict["global_orient_gv"], + local_transl_vel=decode_dict["local_transl_vel"], + global_orient_c=decode_dict["global_orient"], + cam_angvel=inputs["cam_angvel"], + ) + outputs["pred_smpl_params_global"] = { + "body_pose": decode_dict["body_pose"], + "betas": decode_dict["betas"], + **pred_smpl_params_global, + } + outputs["static_conf_logits"] = model_output["static_conf_logits"] + + if postproc: # apply post-processing + if static_cam: # extra post-processing to utilize static camera prior + outputs["pred_smpl_params_global"]["transl"] = pp_static_joint_cam(outputs, self.endecoder) + else: + outputs["pred_smpl_params_global"]["transl"] = pp_static_joint(outputs, self.endecoder) + body_pose = process_ik(outputs, self.endecoder) + decode_dict["body_pose"] = body_pose + outputs["pred_smpl_params_global"]["body_pose"] = body_pose + outputs["pred_smpl_params_incam"]["body_pose"] = body_pose + + return outputs + + # ========== Compute Loss ========== # + total_loss = 0 + mask = inputs["mask"]["valid"] # (B, L) + + # 1. Simple loss: MSE + pred_x = model_output["pred_x"] # (B, L, C) + target_x = self.endecoder.encode(inputs) # (B, L, C) + simple_loss = F.mse_loss(pred_x, target_x, reduction="none") + mask_simple = mask[:, :, None].expand(-1, -1, pred_x.size(2)).clone() # (B, L, C) + mask_simple[inputs["mask"]["spv_incam_only"], :, 142:] = False # 3dpw training + simple_loss = (simple_loss * mask_simple).mean() + total_loss += simple_loss + outputs["simple_loss"] = simple_loss + + # 2. Extra loss + extra_funcs = [ + compute_extra_incam_loss, + compute_extra_global_loss, + ] + for extra_func in extra_funcs: + extra_loss, extra_loss_dict = extra_func(inputs, outputs, self) + total_loss += extra_loss + outputs.update(extra_loss_dict) + + outputs["loss"] = total_loss + return outputs + + +def randomly_set_null_condition(f_condition, uncond_prob=0.1): + """Conditions are in shape (B, L, *)""" + keys = list(f_condition.keys()) + for k in keys: + if f_condition[k] is None: + continue + f_condition[k] = f_condition[k].clone() + mask = torch.rand(f_condition[k].shape[:2]) < uncond_prob + f_condition[k][mask] = 0.0 + return f_condition + + +def compute_extra_incam_loss(inputs, outputs, ppl): + model_output = outputs["model_output"] + decode_dict = outputs["decode_dict"] + endecoder = ppl.endecoder + weights = ppl.weights + args = ppl.args + + extra_loss_dict = {} + extra_loss = 0 + mask = inputs["mask"]["valid"] # effective length mask + mask_reproj = ~inputs["mask"]["spv_incam_only"] # do not supervise reproj for 3DPW + + # Incam FK + # prediction + pred_c_j3d = endecoder.fk_v2(**outputs["pred_smpl_params_incam"]) + pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, :, :1] # (B, L, J, 3) + + # gt + gt_c_j3d = endecoder.fk_v2(**inputs["smpl_params_c"]) # (B, L, J, 3) + gt_cr_j3d = gt_c_j3d - gt_c_j3d[:, :, :1] # (B, L, J, 3) + + # Root aligned C-MPJPE Loss + if weights.cr_j3d > 0.0: + cr_j3d_loss = F.mse_loss(pred_cr_j3d, gt_cr_j3d, reduction="none") + cr_j3d_loss = (cr_j3d_loss * mask[..., None, None]).mean() + extra_loss += cr_j3d_loss * weights.cr_j3d + extra_loss_dict["cr_j3d_loss"] = cr_j3d_loss + + # Reprojection (to align with image) + if weights.transl_c > 0.0: + # pred_transl = decode_dict["transl"] # (B, L, 3) + # gt_transl = inputs["smpl_params_c"]["transl"] + # transl_c_loss = F.l1_loss(pred_transl, gt_transl, reduction="none") + # transl_c_loss = (transl_c_loss * mask[..., None]).mean() + + # Instead of supervising transl, we convert gt to pred_cam (prevent divide 0) + pred_cam = model_output["pred_cam"] # (B, L, 3) + gt_transl = inputs["smpl_params_c"]["transl"] # (B, L, 3) + gt_pred_cam = get_a_pred_cam(gt_transl, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 3) + gt_pred_cam[gt_pred_cam.isinf()] = -1 # this will be handled by valid_mask + # (compute_transl_full_cam(gt_pred_cam, inputs["bbx_xys"], inputs["K_fullimg"]) - gt_transl).abs().max() + + # Skip gts that are not good during random construction + gt_j3d_z_min = inputs["gt_j3d"][..., 2].min(dim=-1)[0] + valid_mask = ( + (gt_j3d_z_min > 0.3) + * (gt_pred_cam[..., 0] > 0.3) + * (gt_pred_cam[..., 0] < 5.0) + * (gt_pred_cam[..., 1] > -3.0) + * (gt_pred_cam[..., 1] < 3.0) + * (gt_pred_cam[..., 2] > -3.0) + * (gt_pred_cam[..., 2] < 3.0) + * (inputs["bbx_xys"][..., 2] > 0) + )[..., None] + transl_c_loss = F.mse_loss(pred_cam, gt_pred_cam, reduction="none") + transl_c_loss = (transl_c_loss * mask[..., None] * valid_mask).mean() + + extra_loss_dict["transl_c_loss"] = transl_c_loss + extra_loss += transl_c_loss * weights.transl_c + + if weights.j2d > 0.0: + # prevent divide 0 or small value to overflow(fp16) + reproj_z_thr = 0.3 + pred_c_j3d_z0_mask = pred_c_j3d[..., 2].abs() <= reproj_z_thr + pred_c_j3d[pred_c_j3d_z0_mask] = reproj_z_thr + gt_c_j3d_z0_mask = gt_c_j3d[..., 2].abs() <= reproj_z_thr + gt_c_j3d[gt_c_j3d_z0_mask] = reproj_z_thr + + pred_j2d_01 = project_to_bi01(pred_c_j3d, inputs["bbx_xys"], inputs["K_fullimg"]) + gt_j2d_01 = project_to_bi01(gt_c_j3d, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, J, 2) + + valid_mask = ( + (gt_c_j3d[..., 2] > reproj_z_thr) + * (pred_c_j3d[..., 2] > reproj_z_thr) # Be safe + * (gt_j2d_01[..., 0] > 0.0) + * (gt_j2d_01[..., 0] < 1.0) + * (gt_j2d_01[..., 1] > 0.0) + * (gt_j2d_01[..., 1] < 1.0) + )[..., None] + valid_mask[~mask_reproj] = False # Do not supervise on 3dpw + j2d_loss = F.mse_loss(pred_j2d_01, gt_j2d_01, reduction="none") + j2d_loss = (j2d_loss * mask[..., None, None] * valid_mask).mean() + + extra_loss += j2d_loss * weights.j2d + extra_loss_dict["j2d_loss"] = j2d_loss + + if weights.cr_verts > 0: + # SMPL forward + pred_c_verts437, pred_c_j17 = endecoder.smplx_model(**outputs["pred_smpl_params_incam"]) + root_ = pred_c_j17[:, :, [11, 12], :].mean(-2, keepdim=True) + pred_cr_verts437 = pred_c_verts437 - root_ + + gt_cr_verts437 = inputs["gt_cr_verts437"] # (B, L, 437, 3) + cr_vert_loss = F.mse_loss(pred_cr_verts437, gt_cr_verts437, reduction="none") + cr_vert_loss = (cr_vert_loss * mask[:, :, None, None]).mean() + extra_loss += cr_vert_loss * weights.cr_verts + extra_loss_dict["cr_vert_loss"] = cr_vert_loss + + if weights.verts2d > 0: + gt_c_verts437 = inputs["gt_c_verts437"] # (B, L, 437, 3) + + # prevent divide 0 or small value to overflow(fp16) + reproj_z_thr = 0.3 + pred_c_verts437_z0_mask = pred_c_verts437[..., 2].abs() <= reproj_z_thr + pred_c_verts437[pred_c_verts437_z0_mask] = reproj_z_thr + gt_c_verts437_z0_mask = gt_c_verts437[..., 2].abs() <= reproj_z_thr + gt_c_verts437[gt_c_verts437_z0_mask] = reproj_z_thr + + pred_verts2d_01 = project_to_bi01(pred_c_verts437, inputs["bbx_xys"], inputs["K_fullimg"]) + gt_verts2d_01 = project_to_bi01(gt_c_verts437, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 437, 2) + + valid_mask = ( + (gt_c_verts437[..., 2] > reproj_z_thr) + * (pred_c_verts437[..., 2] > reproj_z_thr) # Be safe + * (gt_verts2d_01[..., 0] > 0.0) + * (gt_verts2d_01[..., 0] < 1.0) + * (gt_verts2d_01[..., 1] > 0.0) + * (gt_verts2d_01[..., 1] < 1.0) + )[..., None] + valid_mask[~mask_reproj] = False # Do not supervise on 3dpw + verts2d_loss = F.mse_loss(pred_verts2d_01, gt_verts2d_01, reduction="none") + verts2d_loss = (verts2d_loss * mask[..., None, None] * valid_mask).mean() + + extra_loss += verts2d_loss * weights.verts2d + extra_loss_dict["verts2d_loss"] = verts2d_loss + + return extra_loss, extra_loss_dict + + +def compute_extra_global_loss(inputs, outputs, ppl): + decode_dict = outputs["decode_dict"] + endecoder = ppl.endecoder + weights = ppl.weights + args = ppl.args + + extra_loss_dict = {} + extra_loss = 0 + mask = inputs["mask"]["valid"].clone() # (B, L) + mask[inputs["mask"]["spv_incam_only"]] = False + + if weights.transl_w > 0: + # compute pred_transl_w by rollout + gt_transl_w = inputs["smpl_params_w"]["transl"] + gt_global_orient_w = inputs["smpl_params_w"]["global_orient"] + local_transl_vel = decode_dict["local_transl_vel"] + pred_transl_w = rollout_local_transl_vel(local_transl_vel, gt_global_orient_w, gt_transl_w[:, [0]]) + + trans_w_loss = F.l1_loss(pred_transl_w, gt_transl_w, reduction="none") + trans_w_loss = (trans_w_loss * mask[..., None]).mean() + extra_loss += trans_w_loss * weights.transl_w + extra_loss_dict["transl_w_loss"] = trans_w_loss + + # Static-Conf loss + if weights.static_conf_bce > 0: + # Compute gt by thresholding velocity + vel_thr = args.static_conf.vel_thr + assert vel_thr > 0 + joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist] + gt_w_j3d = endecoder.fk_v2(**inputs["smpl_params_w"]) # (B, L, J=22, 3) + static_gt = get_static_joint_mask(gt_w_j3d, vel_thr=vel_thr, repeat_last=True) # (B, L, J) + static_gt = static_gt[:, :, joint_ids].float() # (B, L, J') + pred_static_conf_logits = outputs["model_output"]["static_conf_logits"] + + static_conf_loss = F.binary_cross_entropy_with_logits(pred_static_conf_logits, static_gt, reduction="none") + static_conf_loss = (static_conf_loss * mask[..., None]).mean() + extra_loss += static_conf_loss * weights.static_conf_bce + extra_loss_dict["static_conf_loss"] = static_conf_loss + + return extra_loss, extra_loss_dict + + +@autocast(enabled=False) +def get_smpl_params_w_Rt_v2( + global_orient_gv, + local_transl_vel, + global_orient_c, + cam_angvel, +): + """Get global R,t in GV0(ay) + Args: + cam_angvel: (B, L, 6), defined as R @ R_{w2c}^{t} = R_{w2c}^{t+1} + """ + + # Get R_ct_to_c0 from cam_angvel + def as_identity(R): + is_I = matrix_to_axis_angle(R).norm(dim=-1) < 1e-5 + R[is_I] = torch.eye(3)[None].expand(is_I.sum(), -1, -1).to(R) + return R + + B = cam_angvel.shape[0] + R_t_to_tp1 = rotation_6d_to_matrix(cam_angvel) # (B, L, 3, 3) + R_t_to_tp1 = as_identity(R_t_to_tp1) + + # Get R_c2gv + R_gv = axis_angle_to_matrix(global_orient_gv) # (B, L, 3, 3) + R_c = axis_angle_to_matrix(global_orient_c) # (B, L, 3, 3) + + # Camera view direction in GV coordinate: Rc2gv @ [0,0,1] + R_c2gv = R_gv @ R_c.mT + view_axis_gv = R_c2gv[:, :, :, 2] # (B, L, 3) Rc2gv is estimated, so the x-axis is not accurate, i.e. != 0 + + # Rotate axis use camera relative rotation + R_cnext2gv = R_c2gv @ R_t_to_tp1.mT + view_axis_gv_next = R_cnext2gv[..., 2] + + vec1_xyz = view_axis_gv.clone() + vec1_xyz[..., 1] = 0 + vec1_xyz = F.normalize(vec1_xyz, dim=-1) + vec2_xyz = view_axis_gv_next.clone() + vec2_xyz[..., 1] = 0 + vec2_xyz = F.normalize(vec2_xyz, dim=-1) + + aa_tp1_to_t = vec2_xyz.cross(vec1_xyz, dim=-1) + aa_tp1_to_t_angle = torch.acos(torch.clamp((vec1_xyz * vec2_xyz).sum(dim=-1, keepdim=True), -1.0, 1.0)) + aa_tp1_to_t = F.normalize(aa_tp1_to_t, dim=-1) * aa_tp1_to_t_angle + + aa_tp1_to_t = gaussian_smooth(aa_tp1_to_t, dim=-2) # Smooth + R_tp1_to_t = axis_angle_to_matrix(aa_tp1_to_t).mT # (B, L, 3) + + # Get R_t_to_0 + R_t_to_0 = [torch.eye(3)[None].expand(B, -1, -1).to(R_t_to_tp1)] + for i in range(1, R_t_to_tp1.shape[1]): + R_t_to_0.append(R_t_to_0[-1] @ R_tp1_to_t[:, i]) + R_t_to_0 = torch.stack(R_t_to_0, dim=1) # (B, L, 3, 3) + R_t_to_0 = as_identity(R_t_to_0) + + global_orient = matrix_to_axis_angle(R_t_to_0 @ R_gv) + + # Rollout to global transl + # Start from transl0, in gv0 -> flip y-axis of gv0 + transl = rollout_local_transl_vel(local_transl_vel, global_orient) + global_orient, transl, _ = get_tgtcoord_rootparam(global_orient, transl, tsf="any->ay") + + smpl_params_w_Rt = {"global_orient": global_orient, "transl": transl} + return smpl_params_w_Rt diff --git a/hmr4d/model/gvhmr/utils/endecoder.py b/hmr4d/model/gvhmr/utils/endecoder.py new file mode 100644 index 0000000..223cb6d --- /dev/null +++ b/hmr4d/model/gvhmr/utils/endecoder.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +from pytorch3d.transforms import ( + rotation_6d_to_matrix, + matrix_to_axis_angle, + axis_angle_to_matrix, + matrix_to_rotation_6d, + matrix_to_quaternion, + quaternion_to_matrix, +) +from hmr4d.configs import MainStore, builds +from hmr4d.utils.geo.augment_noisy_pose import gaussian_augment +import hmr4d.utils.matrix as matrix +from hmr4d.utils.pylogger import Log +from hmr4d.utils.geo.hmr_global import get_local_transl_vel, rollout_local_transl_vel +from hmr4d.utils.smplx_utils import make_smplx +from . import stats_compose + + +class EnDecoder(nn.Module): + def __init__(self, stats_name="DEFAULT_01", noise_pose_k=10): + super().__init__() + # Load mean, std + stats = getattr(stats_compose, stats_name) + Log.info(f"[EnDecoder] Use {stats_name} for statistics!") + self.register_buffer("mean", torch.tensor(stats["mean"]).float(), False) + self.register_buffer("std", torch.tensor(stats["std"]).float(), False) + + # option + self.noise_pose_k = noise_pose_k + + # smpl + self.smplx_model = make_smplx("supermotion_v437coco17") + parents = self.smplx_model.parents[:22] + self.register_buffer("parents_tensor", parents, False) + self.parents = parents.tolist() + + def get_noisyobs(self, data, return_type="r6d"): + """ + Noisy observation contains local pose with noise + Args: + data (dict): + body_pose: (B, L, J*3) or (B, L, J, 3) + Returns: + noisy_bosy_pose: (B, L, J, 6) or (B, L, J, 3) or (B, L, 3, 3) depends on return_type + """ + body_pose = data["body_pose"] # (B, L, 63) + B, L, _ = body_pose.shape + body_pose = body_pose.reshape(B, L, -1, 3) + + # (B, L, J, C) + return_mapping = {"R": 0, "r6d": 1, "aa": 2} + return_id = return_mapping[return_type] + noisy_bosy_pose = gaussian_augment(body_pose, self.noise_pose_k, to_R=True)[return_id] + return noisy_bosy_pose + + def normalize_body_pose_r6d(self, body_pose_r6d): + """body_pose_r6d: (B, L, {J*6}/{J, 6}) -> (B, L, J*6)""" + B, L = body_pose_r6d.shape[:2] + body_pose_r6d = body_pose_r6d.reshape(B, L, -1) + if self.mean.shape[-1] == 1: # no mean, std provided + return body_pose_r6d + body_pose_r6d = (body_pose_r6d - self.mean[:126]) / self.std[:126] # (B, L, C) + return body_pose_r6d + + def fk_v2(self, body_pose, betas, global_orient=None, transl=None, get_intermediate=False): + """ + Args: + body_pose: (B, L, 63) + betas: (B, L, 10) + global_orient: (B, L, 3) + Returns: + joints: (B, L, 22, 3) + """ + B, L = body_pose.shape[:2] + if global_orient is None: + global_orient = torch.zeros((B, L, 3), device=body_pose.device) + aa = torch.cat([global_orient, body_pose], dim=-1).reshape(B, L, -1, 3) + rotmat = axis_angle_to_matrix(aa) # (B, L, 22, 3, 3) + + skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] # (B, L, 22, 3) + local_skeleton = skeleton - skeleton[:, :, self.parents_tensor] + local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2) + + if transl is not None: + local_skeleton[..., 0, :] += transl # B, L, 22, 3 + + mat = matrix.get_TRS(rotmat, local_skeleton) # B, L, 22, 4, 4 + fk_mat = matrix.forward_kinematics(mat, self.parents) # B, L, 22, 4, 4 + joints = matrix.get_position(fk_mat) # B, L, 22, 3 + if not get_intermediate: + return joints + else: + return joints, mat, fk_mat + + def get_local_pos(self, betas): + skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] # (B, L, 22, 3) + local_skeleton = skeleton - skeleton[:, :, self.parents_tensor] + local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2) + return local_skeleton + + def encode(self, inputs): + """ + definition: { + body_pose_r6d, # (B, L, (J-1)*6) -> 0:126 + betas, # (B, L, 10) -> 126:136 + global_orient_r6d, # (B, L, 6) -> 136:142 incam + global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv + local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord + } + """ + B, L = inputs["smpl_params_c"]["body_pose"].shape[:2] + # cam + smpl_params_c = inputs["smpl_params_c"] + body_pose = smpl_params_c["body_pose"].reshape(B, L, 21, 3) + body_pose_r6d = matrix_to_rotation_6d(axis_angle_to_matrix(body_pose)).flatten(-2) + betas = smpl_params_c["betas"] + global_orient_R = axis_angle_to_matrix(smpl_params_c["global_orient"]) + global_orient_r6d = matrix_to_rotation_6d(global_orient_R) + + # global + R_c2gv = inputs["R_c2gv"] # (B, L, 3, 3) + global_orient_gv_r6d = matrix_to_rotation_6d(R_c2gv @ global_orient_R) + + # local_transl_vel + smpl_params_w = inputs["smpl_params_w"] + local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"]) + if False: # debug + transl_recover = rollout_local_transl_vel( + local_transl_vel, smpl_params_w["global_orient"], smpl_params_w["transl"][:, [0]] + ) + print((transl_recover - smpl_params_w["transl"]).abs().max()) + + # returns + x = torch.cat([body_pose_r6d, betas, global_orient_r6d, global_orient_gv_r6d, local_transl_vel], dim=-1) + x_norm = (x - self.mean) / self.std + return x_norm + + def encode_translw(self, inputs): + """ + definition: { + body_pose_r6d, # (B, L, (J-1)*6) -> 0:126 + betas, # (B, L, 10) -> 126:136 + global_orient_r6d, # (B, L, 6) -> 136:142 incam + global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv + local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord + } + """ + # local_transl_vel + smpl_params_w = inputs["smpl_params_w"] + local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"]) + + # returns + x = local_transl_vel + x_norm = (x - self.mean[-3:]) / self.std[-3:] + return x_norm + + def decode_translw(self, x_norm): + return x_norm * self.std[-3:] + self.mean[-3:] + + def decode(self, x_norm): + """x_norm: (B, L, C)""" + B, L, C = x_norm.shape + x = (x_norm * self.std) + self.mean + + body_pose_r6d = x[:, :, :126] + betas = x[:, :, 126:136] + global_orient_r6d = x[:, :, 136:142] + global_orient_gv_r6d = x[:, :, 142:148] + local_transl_vel = x[:, :, 148:151] + + body_pose = matrix_to_axis_angle(rotation_6d_to_matrix(body_pose_r6d.reshape(B, L, -1, 6))) + body_pose = body_pose.flatten(-2) + global_orient_c = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_r6d)) + global_orient_gv = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_gv_r6d)) + + output = { + "body_pose": body_pose, + "betas": betas, + "global_orient": global_orient_c, + "global_orient_gv": global_orient_gv, + "local_transl_vel": local_transl_vel, + } + + return output + + +group_name = "endecoder/gvhmr" +cfg_base = builds(EnDecoder, populate_full_signature=True) +MainStore.store(name="v1_no_stdmean", node=cfg_base, group=group_name) +MainStore.store(name="v1", node=cfg_base(stats_name="MM_V1"), group=group_name) +MainStore.store( + name="v1_amass_local_bedlam_cam", + node=cfg_base(stats_name="MM_V1_AMASS_LOCAL_BEDLAM_CAM"), + group=group_name, +) + +MainStore.store(name="v2", node=cfg_base(stats_name="MM_V2"), group=group_name) +MainStore.store(name="v2_1", node=cfg_base(stats_name="MM_V2_1"), group=group_name) diff --git a/hmr4d/model/gvhmr/utils/postprocess.py b/hmr4d/model/gvhmr/utils/postprocess.py new file mode 100644 index 0000000..9605c52 --- /dev/null +++ b/hmr4d/model/gvhmr/utils/postprocess.py @@ -0,0 +1,168 @@ +import torch +from torch.cuda.amp import autocast +from pytorch3d.transforms import ( + matrix_to_rotation_6d, + rotation_6d_to_matrix, + axis_angle_to_matrix, + matrix_to_axis_angle, +) + +import hmr4d.utils.matrix as matrix +from hmr4d.utils.ik.ccd_ik import CCD_IK +from hmr4d.utils.geo_transform import get_sequence_cammat, transform_mat, apply_T_on_points +from hmr4d.utils.net_utils import gaussian_smooth +from hmr4d.model.gvhmr.utils.endecoder import EnDecoder + +from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines + + +@autocast(enabled=False) +def pp_static_joint(outputs, endecoder: EnDecoder): + # Global FK + pred_w_j3d = endecoder.fk_v2(**outputs["pred_smpl_params_global"]) + L = pred_w_j3d.shape[1] + joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist] + pred_j3d_static = pred_w_j3d.clone()[:, :, joint_ids] # (B, L, J, 3) + + ######## update overall movement with static info, and make displacement ~[0,0,0] + pred_j_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, J, 3) + + static_conf_logits = outputs["static_conf_logits"][:, :-1].clone() + static_label_ = static_conf_logits > 0 # (B, L-1, J) # avoid non-contact frame + static_conf_logits = static_conf_logits.float() - (~static_label_ * 1e6) # fp16 cannot go through softmax + is_static = static_label_.sum(dim=-1) > 0 # (B, L-1) + + pred_disp = pred_j_disp * static_conf_logits[..., None].softmax(dim=-2) # (B, L-1, J, 3) + pred_disp = pred_disp * is_static[..., None, None] # (B, L-1, J, 3) + pred_disp = pred_disp.sum(-2) # (B, L-1, 3) + #################### + + # Overwrite results: + if False: # for-loop + post_w_transl = outputs["pred_smpl_params_global"]["transl"].clone() # (B, L, 3) + for i in range(1, L): + post_w_transl[:, i:] -= pred_disp[:, i - 1 : i] + else: # vectorized + pred_w_transl = outputs["pred_smpl_params_global"]["transl"].clone() # (B, L, 3) + pred_w_disp = pred_w_transl[:, 1:] - pred_w_transl[:, :-1] # (B, L-1, 3) + pred_w_disp_new = pred_w_disp - pred_disp + post_w_transl = torch.cumsum(torch.cat([pred_w_transl[:, :1], pred_w_disp_new], dim=1), dim=1) + post_w_transl[..., 0] = gaussian_smooth(post_w_transl[..., 0], dim=-1) + post_w_transl[..., 2] = gaussian_smooth(post_w_transl[..., 2], dim=-1) + + # Put the sequence on the ground by -min(y), this does not consider foot height, for o3d vis + post_w_j3d = pred_w_j3d - pred_w_transl.unsqueeze(-2) + post_w_transl.unsqueeze(-2) + ground_y = post_w_j3d[..., 1].flatten(-2).min(dim=-1)[0] # (B,) Minimum y value + post_w_transl[..., 1] -= ground_y + + return post_w_transl + + +@autocast(enabled=False) +def pp_static_joint_cam(outputs, endecoder: EnDecoder): + """Use static joint and static camera assumption to postprocess the global transl""" + # input + pred_smpl_params_incam = outputs["pred_smpl_params_incam"].copy() + pred_smpl_params_global = outputs["pred_smpl_params_global"] + static_conf_logits = outputs["static_conf_logits"].clone()[:, :-1] # (B, L-1, J) + joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist] + B, L = pred_smpl_params_incam["transl"].shape[:2] + assert B == 1 + + # FK + pred_w_j3d = endecoder.fk_v2(**pred_smpl_params_global) # (B, L, J, 3) + # smooth incam results, as this could be noisy + pred_smpl_params_incam["transl"] = gaussian_smooth(pred_smpl_params_incam["transl"], sigma=5, dim=-2) + pred_c_j3d = endecoder.fk_v2(**pred_smpl_params_incam) # (B, L, J, 3) + + # compute T_c2w (static) from first frame + R_gv = axis_angle_to_matrix(pred_smpl_params_global["global_orient"][:, 0]) # (B, 3, 3) + R_c = axis_angle_to_matrix(pred_smpl_params_incam["global_orient"][:, 0]) # (B, 3, 3) + R_c2w = R_gv @ R_c.mT # (B, 3, 3) + t_c2w = pred_w_j3d[:, 0, 0] - torch.einsum("bij,bj->bi", R_c2w, pred_c_j3d[:, 0, 0]) # (B, 3) + T_c2w = transform_mat(R_c2w, t_c2w) # (B, 4, 4) + pred_c_j3d_in_w = apply_T_on_points(pred_c_j3d, T_c2w[:, None]) + + # 1. Make transl similar to incam + post_w_transl = pred_smpl_params_global["transl"].clone() # (B, L, 3) + post_w_j3d = pred_w_j3d.clone() # (B, L, J, 3) + cp_thr = torch.tensor([0.25, 0.25, 0.25]).to(post_w_j3d) # Only update very bad pred + for i in range(1, L): + cp_diff = post_w_j3d[:, i, 0] - pred_c_j3d_in_w[:, i, 0] # (B, 3) + cp_diff = cp_diff * ~((cp_diff > -cp_thr) * (cp_diff < cp_thr)) + cp_diff = torch.clamp(cp_diff, -0.02, 0.02) + post_w_transl[:, i:] -= cp_diff + post_w_j3d[:, i:] -= (cp_diff)[:, None, None] + + # 1. Make stationary joint stay stationary + # pred_j3d_static = pred_w_j3d.clone()[:, :, joint_ids] # (B, L, J, 3) + pred_j3d_static = post_w_j3d[:, :, joint_ids] # (B, L, J, 3) + pred_j_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, J, 3) + + static_label = static_conf_logits.sigmoid() > 0.8 # (B, L-1, J) + static_label_sumJ = static_label.sum(-1, keepdim=True) # (B, L-1, 1) + static_label_sumJ = torch.clamp_min(static_label_sumJ, 1) # replace 0 with 1 + pred_disp_sumJ = (pred_j_disp * static_label[..., None]).sum(-2) # (B, L-1, 3) + pred_disp = pred_disp_sumJ / static_label_sumJ # (B, L-1, 3) + pred_disp[:, :, 1] = 0 # do not modify y + + # Overwrite results (for-loop) + for i in range(1, L): + post_w_transl[:, i:] -= pred_disp[:, [i - 1]] + post_w_j3d[:, i:] -= pred_disp[:, [i - 1], None] + + # Put the sequence on the ground by -min(y), this does not consider foot height, for o3d vis + ground_y = post_w_j3d[..., 1].flatten(-2).min(dim=-1)[0] # (B,) Minimum y value + post_w_transl[..., 1] -= ground_y + + return post_w_transl + + +@autocast(enabled=False) +def process_ik(outputs, endecoder): + static_conf = outputs["static_conf_logits"].sigmoid() # (B, L, J) + post_w_j3d, local_mat, post_w_mat = endecoder.fk_v2(**outputs["pred_smpl_params_global"], get_intermediate=True) + + # sebas rollout merge + joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist] + post_target_j3d = post_w_j3d.clone() + for i in range(1, post_w_j3d.size(1)): + prev = post_target_j3d[:, i - 1, joint_ids] + this = post_w_j3d[:, i, joint_ids] + c_prev = static_conf[:, i - 1, :, None] + post_target_j3d[:, i, joint_ids] = prev * c_prev + this * (1 - c_prev) + + # ik + global_rot = matrix.get_rotation(post_w_mat) + parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] + left_leg_chain = [0, 1, 4, 7, 10] + right_leg_chain = [0, 2, 5, 8, 11] + left_hand_chain = [9, 13, 16, 18, 20] + right_hand_chain = [9, 14, 17, 19, 21] + + def ik(local_mat, target_pos, target_rot, target_ind, chain): + local_mat = local_mat.clone() + IK_solver = CCD_IK( + local_mat, + parents, + target_ind, + target_pos, + target_rot, + kinematic_chain=chain, + max_iter=2, + ) + + chain_local_mat = IK_solver.solve() + chain_rotmat = matrix.get_rotation(chain_local_mat) + local_mat[:, :, chain[1:], :-1, :-1] = chain_rotmat[:, :, 1:] # (B, L, J, 3, 3) + return local_mat + + local_mat = ik(local_mat, post_target_j3d[:, :, [7, 10]], global_rot[:, :, [7, 10]], [3, 4], left_leg_chain) + local_mat = ik(local_mat, post_target_j3d[:, :, [8, 11]], global_rot[:, :, [8, 11]], [3, 4], right_leg_chain) + local_mat = ik(local_mat, post_target_j3d[:, :, [20]], global_rot[:, :, [20]], [4], left_hand_chain) + local_mat = ik(local_mat, post_target_j3d[:, :, [21]], global_rot[:, :, [21]], [4], right_hand_chain) + + body_pose = matrix_to_axis_angle(matrix.get_rotation(local_mat[:, :, 1:])) # (B, L, J-1, 3, 3) + body_pose = body_pose.flatten(2) # (B, L, (J-1)*3) + + return body_pose diff --git a/hmr4d/model/gvhmr/utils/stats_compose.py b/hmr4d/model/gvhmr/utils/stats_compose.py new file mode 100644 index 0000000..0883d5e --- /dev/null +++ b/hmr4d/model/gvhmr/utils/stats_compose.py @@ -0,0 +1,220 @@ +# fmt:off +body_pose_r6d = { + "bedlam": { + "count": 5417929, + "mean": [ 0.9772, -0.0925, 0.0028, 0.1058, 0.9111, 0.1373, 0.9796, 0.0711, + -0.0193, -0.0816, 0.8910, 0.1953, 0.9935, 0.0072, 0.0270, -0.0046, + 0.9200, -0.2511, 0.9752, 0.0477, -0.0990, -0.0613, 0.8242, -0.2730, + 0.9836, -0.0400, 0.0067, 0.0148, 0.7836, -0.3471, 0.9931, -0.0300, + -0.0469, 0.0244, 0.9825, -0.0513, 0.9777, 0.0206, 0.1444, -0.0470, + 0.9603, 0.1521, 0.9804, -0.0362, -0.0902, 0.0500, 0.9546, 0.1337, + 0.9969, -0.0105, 0.0076, 0.0090, 0.9914, 0.0150, 0.9953, -0.0607, + 0.0089, 0.0602, 0.9942, 0.0146, 0.9934, -0.0682, -0.0171, 0.0680, + 0.9932, -0.0017, 0.9790, 0.0294, 0.0065, -0.0338, 0.9706, -0.0456, + 0.9056, 0.2457, -0.1029, -0.2279, 0.9262, 0.0145, 0.9233, -0.1301, + 0.1550, 0.1140, 0.9476, 0.0534, 0.9769, -0.0572, -0.0095, 0.0569, + 0.9690, 0.0472, 0.6782, 0.5746, -0.2378, -0.5546, 0.7212, 0.0917, + 0.6489, -0.5955, 0.2424, 0.5821, 0.6797, 0.0563, 0.5562, -0.1252, + -0.5860, 0.0937, 0.9176, -0.1287, 0.4453, 0.1421, 0.6119, -0.1427, + 0.8996, -0.1136, 0.9186, -0.0881, -0.1463, 0.1087, 0.8692, 0.0845, + 0.9175, 0.0257, 0.0663, -0.0385, 0.8603, 0.1020], + "std": [0.0429, 0.1392, 0.1236, 0.1323, 0.1645, 0.3086, 0.0375, 0.1406, 0.1172, + 0.1275, 0.1934, 0.3280, 0.0119, 0.0835, 0.0716, 0.0741, 0.1528, 0.2484, + 0.0349, 0.0947, 0.1633, 0.0924, 0.3469, 0.3370, 0.0273, 0.1009, 0.1411, + 0.0680, 0.3876, 0.3323, 0.0103, 0.0735, 0.0712, 0.0690, 0.0246, 0.1617, + 0.0216, 0.1097, 0.1016, 0.0924, 0.0509, 0.2035, 0.0245, 0.1188, 0.1212, + 0.1056, 0.0634, 0.2308, 0.0054, 0.0579, 0.0517, 0.0575, 0.0124, 0.1158, + 0.0076, 0.0654, 0.0367, 0.0644, 0.0118, 0.0592, 0.0116, 0.0829, 0.0361, + 0.0832, 0.0124, 0.0422, 0.0343, 0.1060, 0.1680, 0.1075, 0.0473, 0.2023, + 0.0701, 0.2344, 0.2213, 0.2632, 0.0589, 0.1318, 0.0767, 0.2456, 0.2009, + 0.2666, 0.0542, 0.1106, 0.0347, 0.1080, 0.1718, 0.1117, 0.0459, 0.2025, + 0.1882, 0.2769, 0.2032, 0.3072, 0.1447, 0.2204, 0.2018, 0.2820, 0.2126, + 0.3213, 0.1760, 0.2486, 0.4749, 0.1677, 0.2791, 0.2239, 0.0963, 0.2705, + 0.5540, 0.1846, 0.2572, 0.2411, 0.1287, 0.2878, 0.1151, 0.2993, 0.1557, + 0.2812, 0.1880, 0.3334, 0.1286, 0.3355, 0.1553, 0.3216, 0.1880, 0.3306] + }, + "amass": { + "count": 7114038, + "mean": [ 9.6969e-01, -5.9719e-02, -3.7700e-02, 5.8256e-02, 9.0800e-01, + 1.0972e-01, 9.7636e-01, 4.3401e-02, 4.3110e-03, -4.3032e-02, + 9.0261e-01, 1.4478e-01, 9.9288e-01, 3.5673e-03, 1.6264e-02, + -2.2260e-03, 9.3470e-01, -2.3495e-01, 9.7147e-01, 5.2553e-02, + -9.3666e-02, -5.4550e-02, 8.3321e-01, -2.4246e-01, 9.7971e-01, + -3.8429e-02, 5.3575e-03, 1.5537e-02, 8.1449e-01, -3.0926e-01, + 9.9532e-01, -9.4398e-03, -3.8328e-02, 8.5141e-03, 9.8880e-01, + 1.9976e-04, 9.5602e-01, -3.9528e-02, 2.0017e-01, 1.0363e-02, + 9.5965e-01, 1.3770e-01, 9.6223e-01, -4.6278e-02, -1.5177e-01, + 6.6705e-02, 9.5545e-01, 1.2519e-01, 9.9767e-01, -1.2616e-02, + -2.5442e-04, 1.1661e-02, 9.9376e-01, -3.6222e-02, 9.9511e-01, + -1.0583e-02, 1.2130e-02, 7.6461e-03, 9.9137e-01, 2.0029e-02, + 9.9295e-01, 7.2917e-03, 4.9454e-03, -8.0286e-03, 9.9137e-01, + 2.3707e-03, 9.7698e-01, 1.9943e-02, 1.3808e-03, -2.2006e-02, + 9.7375e-01, -6.7936e-02, 9.2804e-01, 2.5005e-01, -5.7167e-02, + -2.4047e-01, 9.4246e-01, 2.5863e-02, 9.2957e-01, -2.1329e-01, + 1.1112e-01, 2.0741e-01, 9.4876e-01, 2.9901e-02, 9.7683e-01, + -4.1210e-02, 2.3248e-03, 4.0967e-02, 9.7365e-01, 5.7309e-03, + 6.4513e-01, 6.1999e-01, -2.5469e-01, -6.2342e-01, 6.8177e-01, + 3.5524e-02, 6.6192e-01, -5.9341e-01, 2.7136e-01, 5.9269e-01, + 6.8966e-01, 3.1309e-02, 6.8946e-01, -1.1676e-01, -4.9859e-01, + 4.0969e-02, 9.3656e-01, -1.4875e-01, 6.2787e-01, 1.3793e-01, + 5.4289e-01, -9.1946e-02, 9.2868e-01, -1.1927e-01, 9.3012e-01, + -8.3810e-02, -1.1951e-01, 9.7211e-02, 8.9118e-01, 5.9887e-02, + 9.3033e-01, 7.1047e-02, 7.5264e-02, -8.0679e-02, 8.8562e-01, + 4.8960e-02], + "std": [0.0612, 0.1390, 0.1779, 0.1415, 0.1826, 0.3268, 0.0440, 0.1382, 0.1542, + 0.1348, 0.1930, 0.3272, 0.0132, 0.0801, 0.0855, 0.0729, 0.1255, 0.2238, + 0.0554, 0.1088, 0.1727, 0.0939, 0.3294, 0.3559, 0.0532, 0.1082, 0.1554, + 0.0768, 0.3446, 0.3407, 0.0120, 0.0650, 0.0584, 0.0632, 0.0198, 0.1335, + 0.0631, 0.1250, 0.1574, 0.1047, 0.0730, 0.2091, 0.0759, 0.1241, 0.1667, + 0.1112, 0.0831, 0.2185, 0.0060, 0.0441, 0.0502, 0.0441, 0.0102, 0.0946, + 0.0237, 0.0722, 0.0610, 0.0738, 0.0479, 0.0949, 0.0369, 0.0943, 0.0610, + 0.0966, 0.0498, 0.0729, 0.0425, 0.1001, 0.1824, 0.0972, 0.0408, 0.1887, + 0.0594, 0.1842, 0.1884, 0.2020, 0.0457, 0.1018, 0.0640, 0.1990, 0.1854, + 0.2133, 0.0467, 0.0910, 0.0392, 0.1049, 0.1776, 0.1037, 0.0413, 0.1945, + 0.1733, 0.2612, 0.1905, 0.2963, 0.1512, 0.1861, 0.1710, 0.2663, 0.1896, + 0.3135, 0.1568, 0.2219, 0.3976, 0.1594, 0.2810, 0.1855, 0.0845, 0.2398, + 0.4398, 0.1629, 0.2685, 0.1990, 0.0998, 0.2556, 0.1137, 0.2837, 0.1419, + 0.2761, 0.1678, 0.2973, 0.1172, 0.3010, 0.1394, 0.2910, 0.1724, 0.3039] + } +} + +betas = { + "bedlam": { + "count": 37855, # so many subjects? + "mean": [ 0.0378, -0.3562, 0.1185, 0.2245, 0.0204, 0.0929, 0.0537, 0.1006, + -0.1180, 0.0936], + "std":[0.8070, 1.3480, 0.8964, 0.7390, 0.6433, 0.6089, 0.5374, 0.6984, 0.7263, + 0.5395], + }, + "amass": { + "count": 18086, + "mean": [ 0.2310, 0.1750, 0.2931, -0.1859, -1.1163, -1.1028, -0.2573, 0.3555, + 0.3732, 0.2852], + "std": [0.8831, 0.7965, 1.0899, 1.1788, 1.2128, 1.1081, 0.9780, 1.1434, 0.8498, + 1.1462], + } +} + +global_orient_c_r6d = { + "bedlam": { + "count": 5417929, + "mean": [-4.9862e-03, -8.7136e-04, -1.4187e-03, 1.4825e-02, -9.4419e-01, + -5.1653e-02], + "std": [0.7048, 0.1713, 0.6884, 0.1548, 0.1546, 0.2403], + }, +} + +global_orient_gv_r6d = { + "bedlam": { + "count": 5134187, + "mean": [ 3.6018e-04, -2.2327e-04, 2.2316e-03, -4.4879e-02, -9.7435e-01, + 1.0021e-01], + "std": [0.6070, 0.5355, 0.5873, 0.6285, 0.2336, 0.7675], + }, +} + +local_transl_vel = { + "none":{ + "mean": [0., 0., 0.], + "std": [1., 1., 1.] + }, + "1e-2":{ + "mean": [0., 0., 0.], + "std": [1e-2, 1e-2, 1e-2] + }, + "bedlam": { + "count": 5417929, + "mean": [7.3057e-05, -2.2142e-04, 3.2444e-03], + "std": [0.0065, 0.0091, 0.0114], + }, + "amass": { + "count": 7113068, + "mean": [-0.0002, -0.0006, 0.0069], + "std": [0.0064, 0.0070, 0.0138], + }, + "alignhead":{ + "count": 7113068, + "mean":[-2.0822e-04, -1.7966e-06, 6.9816e-03], + "std":[0.0065, 0.0066, 0.0139], + }, + "alignhead_absy":{ + "count": 7113068, + "mean":[-0.0002, -0.0316, 0.0070], + "std":[0.0065, 0.1351, 0.0139], + }, + "alignhead_absgy":{ + "count": 7113068, + "mean":[[-2.0822e-04, 1.2627e+00, 6.9816e-03]], + "std":[0.0065, 0.1516, 0.0139], + } + +} + +pred_cam = { + "bedlam": { + "count": 5096332, + "mean": [1.0606, -0.0027, 0.2702], + "std": [0.1784, 0.0956, 0.0764], + } +} + +vitfeat = { + "bedlam": { + "count": 5546332, + "mean": [-1.3772, 0.2490, 0.0602, -0.1834, 0.2458, 0.5372, 0.3343, -0.3476, -0.1017, -0.0362, -0.0678, 0.2150, -0.2534, 0.1029, 0.8199, -0.4676, 0.6259, -0.3350, 0.0549, -0.4469, 0.2751, -0.1763, 0.1114, -0.2115, -0.0264, 0.5294, 0.8212, -0.4562, 0.4147, -0.0256, -0.1019, 0.2798, 0.9284, 0.4652, 0.6365, 0.6785, -0.0765, 0.0337, -0.2566, -0.0335, -0.1799, 0.7426, 0.2810, -0.7121, -0.0893, 0.1608, -0.2483, 1.5094, -1.4395, -0.3682, -0.4157, -0.0032, -0.0376, -0.0043, 0.2092, 0.3038, -0.2077, -0.4868, -0.1534, 0.2668, 1.2773, 0.2838, -0.4863, -1.2300, 0.0581, -0.3041, 0.1518, 0.7955, -0.4293, 1.4666, 0.3077, 0.3918, 0.1418, 0.1590, 0.8671, -0.3527, 0.5629, 0.1414, 0.0964, -0.1094, -0.0211, -0.0937, 0.1606, -0.7900, 0.0397, 0.0570, 0.7083, -0.5732, 0.1430, -0.2571, 0.5275, 0.6603, 0.3265, 0.4574, -0.3361, -0.1267, 0.3841, 0.1758, -0.6207, -0.3673, 0.8914, 0.4297, -0.8118, 0.2229, -0.2876, 0.2460, 0.4856, -0.1446, -0.2416, 0.1229, 0.2865, 0.7023, -0.2883, 0.3940, -1.5496, 0.4456, 0.6445, 0.2058, -0.4265, 0.3724, 0.1557, -1.4208, -0.1246, 0.1237, -0.3965, 0.0105, -0.0780, 0.6448, -0.1132, 0.8500, -0.2828, 0.4447, 0.6257, -0.2664, -0.8384, -1.8091, -0.2769, 0.1866, 0.6051, -0.2548, 0.9823, -0.2985, -0.2773, -0.4383, 0.1886, 0.2411, 0.2546, 0.2195, -0.0041, 0.1038, -0.6804, 1.2364, 0.5393, 0.0351, 0.4537, -0.8044, -0.1993, -2.1097, -0.8458, 0.1497, 1.6042, 0.6458, -0.5455, 0.0778, 0.0504, -0.5242, -0.3215, -0.0199, 1.1461, -0.3355, -0.3421, -0.3951, 0.0184, -0.0261, 0.2048, 0.0080, 0.6553, -1.3221, 0.5140, 0.5958, -0.2523, 0.9434, -0.0727, 0.1978, 1.1105, -0.4992, 0.3990, 0.2074, 0.3843, -0.0444, 0.0624, -0.8442, -0.0724, -0.5328, 1.1723, 0.8043, 0.6674, 1.5283, 4.2502, 0.0935, 0.3733, 0.1569, 0.0154, 0.0674, 0.0862, -0.2744, -0.4537, 0.1588, -1.9156, 0.0149, -1.0498, -0.0790, 0.0851, -0.5007, 0.3323, -0.1065, 0.0782, 0.0725, -0.5921, -0.1876, 0.0094, -0.3631, 0.0951, 0.1318, 0.0936, 0.5668, -0.0875, -0.4576, -0.4306, 0.5458, 1.0761, 1.1740, -0.0337, 1.3718, -0.2913, -0.3433, 0.5338, -0.4577, -0.4966, 0.2704, 0.3236, 0.4053, 0.0360, 1.1616, -0.2012, 0.7373, 0.0779, -0.0280, -0.4426, 0.0450, 0.2923, 0.0161, -0.4788, 0.1924, -0.3012, 0.0298, -0.7776, -0.2215, 0.4494, -0.1677, 0.2214, 0.0762, -0.3088, 0.4230, 0.0673, -1.0233, 0.0748, -0.4358, -0.2497, -0.0066, 0.1679, -0.1077, -0.4290, 2.5254, -0.8819, -0.8073, 0.2535, 2.0680, -0.4715, 0.3614, -2.9281, 3.1536, 0.3118, -0.0239, 0.7064, -0.6935, -1.1070, -0.1715, -0.0920, -0.2133, -1.0173, 0.0084, -0.1721, 0.2605, -0.6607, -0.0788, -0.3479, -0.2187, 1.0605, 0.2857, 0.7464, 0.9612, -1.1332, 1.5708, -1.0264, 0.6070, 0.4103, -0.1950, -0.0629, -0.0958, -0.2199, -0.2198, -0.4019, 0.2478, -0.3576, 0.0191, -5.8435, 0.0145, -0.2312, 0.9872, 1.1159, 0.3775, 0.1960, -0.5968, -0.2611, -0.0634, -0.1003, 0.7411, -0.8298, -0.1743, 1.8418, 0.3692, -0.4321, 0.0613, -1.9046, 0.5812, 0.2805, 0.1703, -0.2212, -0.0740, -0.2737, -0.3084, 2.9787, -0.1392, 0.3347, 0.0866, -0.8654, -0.4564, -0.7839, 0.1033, -0.0204, 0.1558, -0.1469, 0.2850, -0.1139, 0.8253, 0.7352, -0.6132, 0.0566, 0.3087, -0.1189, 0.1640, 0.2511, 0.5230, -0.0972, -0.5621, -2.5404, 0.3529, -0.2543, -0.6757, 0.2045, -0.0511, -0.2204, 0.1023, 0.0143, 0.4191, -0.3946, -1.0912, 0.8555, 1.0751, -0.0184, -0.3162, 0.1910, 0.6522, -0.5801, 0.2091, -0.8254, -0.3425, 0.3368, -0.0384, -0.4570, 2.5288, -0.3513, -0.1630, 0.1096, -0.5936, 1.5303, -0.4135, -0.2418, -0.0564, -2.6344, -0.1054, 0.8866, -0.2946, -0.4564, -0.6220, 0.2672, -0.9012, 0.3535, 0.2344, -0.0718, 0.0782, 0.0133, 0.2032, -1.2768, 0.1271, -0.5114, -0.0584, -0.8219, -0.1069, 1.5577, -0.1432, -0.6794, 0.9101, 0.6390, 0.3547, -0.6126, -0.1885, 0.2462, -1.1864, 0.0653, -0.7940, 0.5204, 0.5372, 0.5353, -0.4268, -0.2003, -0.2496, -0.0405, 0.3615, -0.1635, 0.1908, -0.0467, 0.7167, 0.1465, 0.4621, 0.1190, -1.6899, 0.6512, 1.3150, -0.1273, 0.0507, 0.2058, -0.1855, 0.1316, 0.1280, 0.5049, 0.0262, -0.0329, 2.0327, -0.6410, 0.4536, 0.0609, 0.1883, -0.5454, -0.5247, 0.1856, 0.7238, 1.4886, -0.1068, 1.7239, -0.8228, -0.2155, 0.5159, 0.2941, -0.0782, -0.0159, 0.1844, -0.1808, -0.1132, 0.4861, 4.0106, 0.0130, 0.2455, -0.1101, 0.0792, 0.4720, -0.1022, 2.0154, -0.4013, 0.5604, 1.3600, -0.5614, 0.3793, -0.1245, 0.2444, 0.1657, 1.7616, 0.6198, 0.1761, -0.6036, -0.1931, 0.4449, 0.2574, -0.2360, 1.1118, 0.0804, 1.1533, 0.2549, 0.3386, 0.2463, 0.0930, -0.6093, -0.1464, 0.2889, 0.2294, -0.5943, 0.1323, 0.5119, 0.1093, -1.0178, 0.4735, 0.3068, 0.3213, -0.0585, -0.3682, -0.6105, -0.7776, 0.1999, 0.9439, -0.4209, 0.1488, 1.3119, -0.4679, -0.3882, 0.2677, -0.1673, -0.5921, -1.2811, -1.0972, 0.3873, 0.0798, -0.0538, 0.0659, -0.1439, -1.3106, -0.5175, 0.4538, -1.0376, -0.9015, 0.7454, -0.0714, -0.4641, 0.2083, 0.0596, -2.9637, 0.3057, 0.2121, -0.2399, 0.6963, 0.1400, 1.7446, 0.9707, -0.3118, -0.3371, 0.0130, 1.0006, -0.2740, 0.1100, -0.9666, 0.7636, 1.2002, -0.0018, -0.3380, 0.1262, 0.5829, -0.0374, 0.0689, 0.2022, -2.0056, -0.2051, -0.4549, 0.0519, 0.4217, -0.7413, 0.0601, 0.4385, 2.8503, -2.7656, 1.2281, -0.1280, 0.6028, 0.4995, 0.0638, -0.3376, 0.2527, -0.1572, -0.4385, -0.6372, 0.2569, 0.4115, 0.4507, 0.6063, -0.1051, 1.2529, 0.2453, -0.7905, -0.3797, -0.2674, 0.2662, 1.5347, -0.3908, 0.8839, -0.6054, -0.4827, -0.3495, 1.2107, -0.4419, -0.6177, 0.1054, 1.0132, -0.3246, -0.1776, 1.1740, -0.0252, 0.0368, -0.7937, -0.9988, -0.0228, 0.0742, -2.4925, 0.5785, 2.3900, 1.2726, -0.3682, -0.8625, -0.3299, 0.3934, 1.4045, -0.6200, -0.0024, 0.2348, -0.1827, -0.5913, -0.6982, 0.2648, 0.2601, 0.9986, 0.1636, 0.8982, -0.4269, 1.7454, -1.9136, -0.9865, -0.0451, 0.2851, -0.5938, -0.3066, 0.0910, -0.3150, -0.4002, 0.4789, 0.0337, -0.6997, -0.2555, -0.6602, -3.0103, 0.2491, -1.0346, 0.3651, 0.2319, 1.0224, -0.2613, 1.6970, 0.7515, 2.1477, 0.1310, 0.2060, 0.1372, 1.0049, -0.8758, -0.3804, -2.1513, 0.8010, -0.2271, -0.2108, 0.3728, -1.7321, -1.0250, -0.2584, -0.2513, 0.2418, -0.7641, 0.2084, -1.3560, 0.5803, 0.1556, -0.3612, 1.3099, -0.2673, 0.4371, -0.8022, 0.1776, -0.5019, 0.1880, -0.2093, 0.0750, -0.7228, -1.3950, 0.1944, -1.5994, -0.2832, 0.0507, 0.1917, 1.2954, 0.0471, 0.3115, -2.2382, -0.3891, -0.0704, 0.3897, 0.0347, 0.9186, -0.8407, 0.9456, 0.5629, 0.3474, -0.4869, 0.4696, -0.4438, 0.0860, -0.8313, -0.0383, 0.2055, 0.4822, -0.1455, -0.1719, -0.2346, -0.4606, 0.8018, 0.3767, -0.0613, 1.9429, -0.6558, -0.0772, -0.1592, -0.1413, 0.4759, -0.0686, 0.9243, -0.2413, -0.1084, -0.2248, -0.0776, 1.4193, -0.0605, 0.1305, -0.2055, 0.0917, 0.6884, -0.0152, 0.1215, 0.2920, -0.0781, -0.0256, 0.3789, -0.1933, 0.1759, 2.3899, 1.0915, -0.7082, -0.4519, -0.2648, -1.2404, -0.2485, 1.0713, 0.1662, -0.1268, 0.3338, -0.0319, 0.1692, -0.5161, 0.9351, 0.1996, -0.2743, 0.0492, -0.0171, 0.1546, 0.2533, -0.0102, 0.6147, 0.0035, -0.2468, -0.2116, -1.7912, 0.2735, 0.4147, 0.4458, 0.6123, 0.0860, 0.2098, -0.3691, -0.2297, -0.6086, -1.0407, -0.7736, -0.3087, -0.0900, -0.1007, -0.3801, -0.3408, -0.4853, -0.3101, -0.8812, 0.0187, -0.9697, -0.2393, 0.1129, -0.5682, 0.4349, 0.1017, 0.2173, -0.0644, -0.9307, 0.9754, 0.2189, 0.2966, -0.4089, -0.2471, -0.7549, 0.3300, 0.7856, 0.1262, 0.2097, -0.5872, 0.9896, 0.5100, 1.0608, -0.7974, 0.1549, -0.1020, 0.4286, 0.0603, -0.6836, -0.4662, -1.2350, -0.0858, -0.5552, 0.0383, 0.2145, -0.4324, -0.5896, 0.9709, -0.0827, -0.2574, 0.2436, -0.1460, 0.5862, 0.4329, -1.2421, 0.0497, -0.0034, 0.2385, -0.1346, 2.0652, 0.8790, -0.2033, -2.6427, 0.3654, -0.1929, -0.0753, -0.9107, 0.9437, 0.3717, -0.7058, -0.2487, -1.0937, -0.7612, 0.9516, -0.7426, -0.0736, 1.2167, 0.6336, 0.2707, -0.7666, -0.1272, -0.8960, 0.3748, 0.7344, 0.7257, 0.3686, -0.5036, -0.2829, 0.0548, 0.3034, -0.2335, -0.3215, 0.0566, -0.2733, -0.3644, 0.0467, -0.0924, -0.5145, -1.7089, 0.4896, 0.0074, 0.2840, 0.1140, -0.0409, -0.3251, 1.0805, 3.0856, -0.3409, 1.2684, -0.0245, -0.0636, -0.0090, 0.1293, -0.3410, -0.0482, 0.1482, 0.2027, 0.5623, 0.0566, 0.6453, -0.0126, 0.0720, -0.0277, 0.0531, 0.1860, -0.1044, -0.6973, 0.3026, 0.4733, -0.1590, 0.4727, 0.8486, 0.4478, 0.1814, 1.0862, 0.0478, 0.2437, -0.5269, -0.0796, -0.4291, 0.4937, -0.0407, -0.6961, -0.0412, 0.6865, 0.0457, 0.1085, -0.4717, -0.1339, 0.8600, 0.6718, -0.3542, -0.5655, 1.3711, 0.0034, 0.3077, 0.0903, 0.3618, 0.3287, -0.1007, 0.0332, -0.3841, -0.3981, 0.1079, -0.4399, 0.1836, 0.0939, -0.1425, -0.2531, -1.2103, 0.0234, -1.3023, -0.0570, -0.0587, 1.1733, 0.0079, 1.0809, 0.4697, -0.1427, 3.3793, -0.1503, 0.4354, 0.0274, 0.3112, -0.3816, 0.0187, -0.1282, -0.4136, 0.3684, 0.6930, 1.3605, 0.4949, 0.4162, -2.2398, 0.4104, 0.6839, 0.4519, 0.0546, -0.0816, 0.0357, 0.1977, -0.8450, 0.1481, 0.1588, -0.1392, -0.3304, -0.3499, -0.8669, 0.1510, 0.1127, 0.9853, -0.3019, -0.3493, -0.0783, -0.8491, 0.0696, 0.7295, -1.0612, 0.1232], + "std": [0.9277, 0.7470, 0.6154, 0.8520, 0.8682, 0.7121, 0.7048, 0.6865, 0.7543, 0.6952, 0.6186, 0.4204, 0.4614, 0.4731, 0.4421, 0.4068, 0.6927, 0.6540, 0.4717, 0.4993, 0.5945, 0.5480, 0.4898, 0.6438, 0.5551, 0.5686, 0.7287, 0.6033, 0.5590, 0.3768, 0.5304, 0.6748, 0.5559, 0.5265, 0.6214, 0.6490, 0.4639, 0.6465, 0.5575, 0.6202, 0.5369, 1.2466, 0.7340, 0.5462, 0.6508, 0.5766, 0.5405, 0.5581, 0.5687, 0.7549, 0.5743, 0.4748, 0.6308, 0.6292, 0.6391, 0.6284, 0.4202, 0.5970, 0.5587, 0.5364, 0.4655, 0.5201, 0.7140, 0.6220, 0.4978, 0.4479, 0.5452, 0.7489, 0.5866, 0.4592, 0.7493, 0.6548, 0.5497, 0.4658, 0.8663, 0.4574, 0.5351, 0.5595, 0.4579, 0.5141, 0.4824, 0.5504, 0.5468, 0.5726, 0.5155, 0.6679, 0.8433, 0.5278, 0.5666, 0.7699, 0.5682, 0.9431, 0.5344, 0.6562, 0.4749, 0.5241, 0.6869, 0.4117, 0.5839, 0.5115, 0.8811, 0.5335, 0.6476, 0.4883, 0.6034, 0.5778, 0.4764, 0.8787, 0.8589, 0.5168, 0.4548, 0.8146, 0.5860, 0.6087, 0.6758, 0.7049, 0.8292, 0.6547, 0.6043, 0.7242, 0.6158, 0.6435, 0.5219, 0.6148, 0.7738, 0.4871, 0.7944, 0.7605, 0.6120, 0.5482, 0.6107, 0.6106, 0.4295, 0.4549, 0.4167, 0.6142, 0.6368, 0.5432, 0.5412, 0.6568, 0.9641, 0.6413, 0.6634, 0.4222, 0.6917, 0.5664, 0.5554, 0.4098, 0.6949, 0.5890, 0.4995, 0.5475, 0.6446, 0.5599, 0.6439, 0.6220, 0.5761, 0.5862, 0.5126, 0.6037, 0.5377, 0.5817, 0.6216, 0.5986, 0.4834, 0.6929, 0.5819, 0.6781, 0.6088, 0.5425, 0.7211, 0.6253, 0.5408, 0.6826, 0.5454, 0.7614, 0.9767, 0.8721, 0.7527, 0.4022, 0.5061, 0.5921, 0.5945, 0.6048, 0.7206, 0.5533, 0.5506, 0.6816, 0.6116, 0.6424, 0.7484, 0.6350, 0.5953, 0.4941, 0.7675, 0.8244, 0.6885, 0.5751, 0.9304, 0.5252, 0.5741, 0.4537, 0.5610, 0.9873, 0.5155, 0.7180, 0.4421, 0.5171, 0.5343, 0.5225, 0.7952, 0.6149, 0.6401, 0.5667, 0.6946, 0.8172, 0.5188, 0.5082, 0.6298, 0.6904, 0.4820, 0.5600, 0.5584, 0.5600, 0.4776, 0.5008, 0.7215, 0.6071, 0.5571, 0.6174, 0.4049, 0.7368, 0.5996, 0.7888, 0.7609, 0.5913, 0.8778, 0.4462, 0.7460, 0.7240, 0.5705, 0.6267, 0.5684, 0.5707, 0.6560, 0.5310, 0.5278, 0.6833, 0.6420, 0.6696, 0.8815, 0.4767, 0.7171, 0.4826, 0.6736, 0.5483, 0.4913, 0.5840, 0.5242, 0.4310, 0.5846, 0.4389, 0.5164, 0.6203, 0.5625, 0.8495, 0.5091, 0.6904, 0.5490, 0.5467, 0.4746, 0.8446, 0.6030, 0.6563, 1.0108, 0.5633, 0.6324, 0.6339, 0.6269, 1.2128, 0.6877, 0.5998, 0.4763, 0.4979, 0.7968, 0.6549, 1.0234, 0.5385, 0.6164, 0.5485, 0.8526, 0.5776, 0.5292, 0.5716, 0.5458, 0.5332, 0.5264, 0.6239, 0.6668, 0.7481, 0.3929, 0.5932, 0.5741, 0.4433, 0.7519, 0.4940, 0.7438, 0.5315, 0.3895, 0.5528, 0.6656, 0.6665, 0.9897, 0.8098, 0.6000, 0.5226, 1.2953, 0.5624, 0.6416, 0.5880, 0.5828, 0.4779, 0.6721, 0.6273, 0.7918, 0.5498, 0.5262, 0.6396, 0.6185, 0.6117, 0.8871, 0.5688, 0.5335, 0.6402, 0.5994, 0.9472, 0.5072, 0.7688, 0.6257, 0.6548, 0.6070, 0.7646, 0.5362, 0.5151, 0.6852, 0.4533, 0.6976, 0.6170, 0.5700, 0.5819, 0.4350, 0.5755, 0.4902, 0.9396, 0.5110, 0.5461, 0.6380, 1.0192, 0.5009, 0.8211, 0.6223, 0.5970, 0.5465, 0.8314, 0.4997, 0.5066, 0.5824, 0.6241, 0.4910, 0.4849, 0.5292, 0.5357, 0.4856, 0.6120, 0.4212, 0.6712, 0.4599, 0.4625, 0.7568, 0.8765, 0.8095, 0.7385, 0.5748, 0.7405, 0.6474, 0.6466, 0.6481, 0.5660, 0.6876, 0.9852, 0.5923, 0.6319, 0.6818, 0.4716, 0.6599, 0.5343, 0.5384, 0.9786, 0.4421, 0.5543, 1.0386, 0.5640, 0.5990, 0.5060, 0.6141, 0.3880, 0.6767, 0.5753, 0.4797, 0.4623, 0.5802, 0.6813, 0.5792, 0.4790, 0.6855, 0.5186, 0.4890, 0.5740, 0.6117, 0.5177, 0.5032, 0.6367, 0.4555, 0.6749, 0.6680, 0.6878, 0.7425, 0.8106, 0.5460, 1.0575, 0.5022, 0.7639, 0.5132, 0.5433, 0.7702, 0.4572, 0.4274, 0.6779, 0.5277, 0.5634, 0.4814, 0.5491, 0.5790, 0.5750, 0.5573, 0.4652, 0.5240, 0.6244, 0.6247, 0.7397, 0.7107, 0.5964, 0.4891, 0.7089, 0.6531, 0.6979, 0.4630, 0.5348, 0.4308, 0.8983, 0.5416, 0.4521, 0.6261, 0.4931, 0.7247, 0.5689, 0.5254, 0.4913, 0.6307, 0.5586, 0.5804, 0.5692, 0.5211, 0.6549, 0.6069, 0.5216, 0.4617, 0.7538, 0.4234, 0.4868, 0.7661, 1.1726, 0.8879, 0.4984, 0.6142, 0.4203, 0.5944, 0.6758, 0.5682, 0.6554, 0.7316, 0.5552, 0.7454, 0.3907, 0.7559, 0.4752, 0.5638, 0.7824, 0.7995, 0.5728, 0.8546, 0.5663, 0.5545, 0.4785, 1.0497, 0.7177, 0.5461, 0.5134, 0.5432, 0.5964, 0.5879, 0.7046, 0.7501, 0.5707, 0.9907, 0.9337, 0.5682, 0.4887, 0.5970, 0.6229, 0.6501, 0.7529, 0.7062, 0.6775, 0.7286, 0.6250, 0.4521, 0.5357, 0.5479, 0.7957, 0.4596, 0.6440, 0.8665, 0.6024, 0.7485, 0.6478, 0.6483, 0.5785, 0.5500, 0.4802, 0.4465, 0.6829, 0.6890, 0.6180, 0.8767, 0.7419, 0.6193, 0.3918, 0.5888, 0.5440, 0.5146, 0.4297, 0.4410, 0.4894, 0.4422, 0.9614, 0.6290, 0.6717, 0.5415, 0.5442, 0.5862, 0.4967, 0.7102, 1.1356, 0.4818, 0.4557, 0.6403, 0.4971, 0.7491, 0.8534, 0.8754, 0.5308, 0.5591, 0.6415, 0.7715, 0.8137, 0.4898, 0.5460, 0.5476, 0.9199, 0.6195, 0.5949, 0.7990, 0.4444, 0.6199, 0.5166, 0.4646, 0.9060, 0.6261, 0.5149, 0.6533, 0.7420, 0.4830, 0.5314, 0.5503, 0.5777, 0.6284, 0.7288, 0.5743, 0.6041, 0.5674, 0.4661, 0.6211, 0.6172, 0.4094, 0.5787, 0.8089, 0.6061, 0.5882, 0.5498, 0.7239, 0.6387, 0.7910, 0.5267, 0.5569, 0.6382, 0.5492, 0.5444, 0.6476, 0.8666, 0.9807, 0.5594, 0.6814, 0.5467, 0.8900, 0.5321, 0.5516, 1.0188, 0.7193, 0.5044, 0.5717, 0.9741, 0.7856, 0.6849, 0.5604, 1.0236, 0.8399, 0.5065, 0.6475, 0.4055, 0.7975, 0.4454, 0.5726, 0.4489, 0.6851, 0.6504, 0.4737, 0.5995, 0.6226, 0.5917, 0.5394, 0.5240, 0.7863, 0.6008, 0.5330, 0.4760, 0.6163, 0.4679, 0.5712, 0.7180, 0.4908, 1.0175, 0.5942, 0.5170, 0.7534, 0.5569, 0.8764, 0.7314, 0.5474, 0.9083, 0.6677, 0.6286, 0.6759, 0.5397, 0.5748, 0.6215, 0.4800, 0.5206, 0.5591, 0.5884, 0.6291, 0.6633, 0.7693, 0.5104, 0.6564, 0.5489, 0.6270, 0.5935, 0.6236, 0.6108, 0.4794, 0.5974, 0.7061, 0.6686, 0.6512, 0.4998, 0.5933, 0.4956, 0.6610, 0.7542, 0.5869, 0.8418, 0.9938, 0.9021, 0.6323, 0.5777, 0.4343, 0.6098, 0.5338, 0.5906, 0.7783, 0.7423, 0.6426, 0.6236, 0.9643, 0.5780, 1.0100, 1.1266, 0.7556, 0.5229, 0.8272, 0.6900, 0.5175, 0.4124, 0.5741, 0.4516, 0.6266, 0.5630, 0.5275, 0.5692, 0.5075, 0.7549, 0.6359, 0.5804, 0.6680, 0.7558, 0.6250, 0.4314, 0.6496, 0.5479, 0.7524, 0.7088, 0.6644, 0.7214, 0.6450, 0.4467, 0.7789, 0.5168, 0.6297, 0.6242, 0.4410, 0.8372, 0.5758, 0.4997, 0.8915, 0.6473, 0.5974, 0.5293, 0.7941, 0.4605, 0.9110, 0.5919, 0.5139, 0.5003, 0.4500, 0.6182, 0.5807, 0.4562, 0.5618, 0.6794, 0.7201, 0.6143, 0.8797, 0.8171, 0.6225, 0.7453, 0.7611, 0.4696, 1.0906, 0.8825, 0.7207, 0.5523, 0.7120, 0.5194, 0.5321, 1.0233, 0.5618, 0.5410, 0.4300, 0.7191, 0.5373, 0.4795, 0.4450, 0.6546, 0.7965, 0.7454, 0.6264, 0.5576, 0.7710, 0.5527, 0.6586, 0.5177, 0.4858, 0.5005, 0.5372, 0.5766, 0.4508, 0.5238, 0.8275, 0.4104, 0.5535, 0.8077, 0.4460, 0.7125, 0.7166, 0.6107, 0.4561, 0.6620, 0.4635, 0.6397, 0.4391, 0.6880, 0.6801, 0.5627, 0.8076, 0.7918, 1.0309, 0.5832, 0.6152, 0.7971, 0.4539, 0.5846, 0.7248, 0.4455, 0.6318, 0.6118, 0.4552, 0.6757, 0.5354, 0.6566, 0.6728, 0.4383, 0.6899, 1.0565, 0.6028, 0.6937, 0.5518, 0.8039, 0.4296, 0.6068, 0.5736, 0.4923, 0.7643, 0.7391, 0.4975, 0.5006, 0.5674, 0.5170, 0.4835, 0.4286, 0.5667, 0.6109, 0.6465, 0.6281, 0.7791, 0.5174, 0.5058, 0.6196, 0.6593, 0.5999, 0.5012, 0.5414, 0.7151, 0.6546, 0.6790, 0.5412, 0.4801, 0.6561, 1.0082, 0.5567, 0.6362, 0.4540, 0.8812, 0.6893, 0.6420, 0.6078, 0.5117, 0.7079, 0.8240, 0.7587, 0.6344, 0.6848, 0.4633, 0.5352, 0.6077, 0.5436, 0.7223, 0.5001, 0.9734, 0.5155, 0.5549, 0.4711, 0.9038, 0.5415, 1.0173, 0.5001, 0.5290, 0.5228, 0.5619, 0.9670, 0.7854, 0.5350, 0.5183, 0.9770, 0.5547, 0.9710, 0.5050, 0.4584, 0.6438, 0.4854, 0.5949, 0.6611, 0.4676, 0.4815, 0.8837, 0.6425, 0.6257, 0.6896, 0.4465, 0.7492, 0.6293, 0.7096, 0.5578, 0.5117, 0.4909, 0.5773, 0.4800, 0.5488, 0.6336, 0.6863, 0.5035, 0.6682, 0.7245, 0.5524, 0.4594, 0.5816, 0.5698, 0.6140, 0.5816, 0.5242, 0.4088, 0.4358, 0.6426, 0.4777, 0.6115, 0.4383, 0.5957, 0.8423, 0.5353, 0.5407, 0.8497, 0.6962, 0.7542, 0.5981, 0.5121, 0.6232, 0.5306, 0.5416, 0.5217, 0.5437, 0.5349, 0.5111, 0.8627, 0.6092, 0.5850, 0.5851, 0.7203, 0.3688, 0.5063, 0.5650, 0.5444, 0.5657, 0.7461, 0.4447, 0.7153, 0.4738, 0.5730, 0.4605, 0.4905, 0.6253, 0.8114, 0.8273, 0.5052, 0.6180, 0.6496, 0.4037, 0.5635, 0.5212, 0.7652, 0.4872, 0.5764, 0.7834, 0.6888, 0.5313, 0.5379, 0.5710, 0.7474, 0.6535, 0.9660, 0.5257, 0.7157, 0.7150, 0.5430, 0.5331, 0.6820, 0.6872, 0.4904, 0.6592, 0.6256, 0.6107, 0.4939, 0.5986, 0.5172, 0.4583], + }, + "emdb": { + "count": 62707, + "mean": [-1.1869, 0.1485, 0.1933, -0.6247, 0.0793, 0.5762, 0.1835, -0.2564, 0.1285, 0.3221, 0.0577, 0.1154, -0.0818, -0.2512, 0.9673, -0.5680, 0.5968, -0.2124, -0.0112, -0.5576, 0.5339, -0.1490, 0.3102, -0.4012, -0.0570, 0.6416, 0.9359, -0.2932, 0.8544, 0.1719, -0.4534, 0.1316, 0.8625, 0.3806, 0.4884, 1.0853, -0.3872, -0.2403, -0.4274, 0.1319, -0.3334, 0.6352, 0.5748, -0.8850, -0.4331, 0.3662, -0.3324, 1.3993, -1.5142, -0.3082, -0.5491, -0.1847, 0.0145, -0.0726, 0.0015, -0.0358, -0.2815, -0.4356, -0.3842, 0.1150, 1.1513, 0.6343, -0.7336, -1.1613, 0.1020, -0.1291, 0.1560, 0.4854, -0.4191, 1.6794, 0.4274, 0.4792, 0.3570, 0.0811, 1.0886, 0.0670, 0.5227, 0.1891, 0.1121, 0.1495, -0.2090, -0.2156, -0.2512, -0.9291, 0.1287, -0.0481, 0.6701, -0.4579, 0.2352, -0.1056, 0.5551, 0.4357, 0.8168, 0.6344, -0.6445, -0.1965, 0.5587, 0.3860, -0.2466, -0.1542, 0.6825, 0.5875, -0.5208, 0.1500, -0.3980, 0.2157, 0.8368, -0.1356, -0.3387, 0.1747, 0.1467, 0.2282, -0.1412, 0.6216, -1.8406, 0.0150, 0.2891, 0.0280, 0.0461, 0.8558, 0.2929, -1.3753, -0.5792, 0.2089, -0.3524, -0.1849, -0.0157, 0.4454, -0.5306, 0.8238, -0.3160, 0.3760, 0.8978, -0.1943, -0.9474, -1.7321, -0.0149, 0.2338, 0.6087, -0.4851, 0.5210, -0.4042, -0.5368, -0.6220, 0.1245, 0.3112, 0.6360, -0.1522, 0.0540, -0.2380, -0.8354, 1.7591, 0.5687, 0.1732, 0.7923, -0.5383, -0.3271, -2.0050, -0.5563, 0.2979, 1.6609, 0.7108, -1.0155, 0.3591, 0.0136, -0.4743, -0.5401, -0.0176, 1.3333, -0.2973, -0.1114, -0.1616, 0.1160, 0.1152, 0.0057, 0.2067, 0.3876, -1.5311, 0.0636, 0.4566, -0.2653, 1.0534, -0.4638, 0.2166, 0.8686, -0.1447, 0.5605, -0.3841, 0.7015, 0.0418, 0.0811, -0.6406, -0.2929, -0.6821, 1.3678, 0.7574, 0.8315, 2.0377, 4.9034, -0.0097, 0.0165, 0.3248, 0.2994, 0.0210, 0.2276, -0.6580, -0.6899, 0.1981, -2.3205, 0.0059, -0.9412, -0.3191, 0.0389, -0.4170, 0.3391, -0.1346, 0.1567, 0.1838, -0.4176, -0.2758, 0.1495, -0.2977, 0.0929, 0.7186, 0.1230, 0.8780, -0.1240, -0.7370, -0.7551, 0.3830, 1.0824, 1.4500, -0.1040, 1.4225, 0.0929, 0.4612, 0.5167, -0.7093, -0.4729, 0.2321, 0.4156, -0.0696, -0.0626, 1.3341, -0.2398, 0.8453, 0.4048, 0.1690, 0.0074, -0.0474, 0.4134, 0.2043, -0.5962, 0.1643, -0.3821, 0.3012, -0.5690, 0.0133, 0.1876, -0.0727, 0.2896, 0.3253, 0.0313, 0.5141, -0.0055, -1.2889, -0.0983, -0.3212, -0.4173, -0.0804, 0.2591, -0.4160, -0.4815, 2.2822, -1.0033, -0.9814, 0.5290, 1.7943, -0.4217, -0.0373, -3.3970, 3.3067, 0.1174, -0.1369, 0.3847, -0.6960, -0.8867, -0.3825, -0.0134, -0.4367, -1.0273, -0.0623, 0.1520, 0.3816, -0.6543, -0.0118, -0.3019, -0.1190, 1.0490, 0.6255, 0.8503, 0.9500, -1.1942, 1.6886, -1.3958, 0.9389, 0.2318, -0.0460, 0.1140, -0.2352, -0.5648, 0.0363, -0.5636, 0.0661, -0.8680, -0.1223, -6.5336, 0.2139, -0.2734, 1.1739, 0.6003, 0.2183, 0.2154, -0.5902, -0.2916, -0.2748, 0.0787, 0.9065, -0.9764, -0.2278, 1.6248, 0.7941, -0.5014, 0.2422, -2.1474, 0.7818, 0.4370, 0.1361, -0.3936, -0.7724, 0.0941, -0.5762, 3.2182, -0.1101, 0.2677, -0.0101, -1.1798, -0.0122, -0.8163, 0.1115, -0.1697, -0.1466, -0.3549, 0.5360, -0.5183, 0.7519, 0.7093, -0.5946, 0.2787, 0.4822, -0.2680, 0.0934, 0.1483, 0.6706, -0.1150, -0.1945, -2.6643, 0.2194, -0.5014, -0.5869, 0.1022, 0.1988, -0.2558, 0.3732, -0.0644, 0.6440, -0.7403, -1.0228, 0.8158, 0.9543, -0.1226, -0.0929, 0.2716, 0.7962, -0.5293, 0.1538, -1.2074, -0.5093, 0.2037, 0.2156, -0.4407, 2.6976, -0.3653, 0.0458, -0.0899, -0.7584, 1.8329, -0.5082, -0.4776, -0.0265, -2.9437, -0.1675, 1.2358, 0.1571, -0.5022, -0.6370, 0.4087, -0.9664, 0.3533, 0.0928, -0.5308, 0.4462, 0.2476, 0.0976, -1.8347, 0.0468, -0.9309, -0.3712, -0.8578, -0.0568, 1.7377, -0.1299, -0.7187, 0.9764, 0.6858, 0.4272, -0.9588, 0.1038, 0.2520, -1.3775, 0.1491, -0.8507, 0.7052, 0.6483, 0.2818, -0.3305, -0.5913, -0.0907, -0.2438, -0.1932, -0.0564, -0.0777, -0.0748, 0.6530, 0.2393, 0.4476, 0.3941, -1.7061, 0.8876, 1.1888, 0.1423, 0.1737, 0.1330, 0.1115, 0.1525, -0.3715, 0.4657, -0.4010, -0.3089, 2.0455, -0.9555, 0.5093, 0.1502, -0.0865, -0.7851, -0.5175, 0.1613, 0.8113, 1.1943, 0.0612, 1.7087, -1.1616, -0.3204, 0.4428, 0.6120, -0.2282, 0.0174, -0.3141, -0.0045, 0.2204, 0.3966, 4.1174, -0.1531, 0.4325, -0.0245, -0.0310, 0.6541, 0.2904, 1.9309, -0.5405, 0.8576, 1.0352, -0.3592, -0.1056, -0.0047, 0.7218, 0.2350, 1.8817, 0.7558, -0.1575, -0.0544, 0.0234, 0.5841, 0.0996, -0.0503, 1.4150, 0.2260, 0.9152, 0.0688, 0.5286, 0.5885, 0.4606, -0.9186, 0.0441, 0.5233, 0.5305, -0.9086, 0.3728, 0.6752, 0.5453, -1.1360, 0.0613, -0.2365, 0.8856, -0.0512, -0.2589, -0.7055, -0.8111, 0.1787, 1.0393, -0.2469, -0.0922, 1.1790, -0.3284, 0.0402, 0.0746, -0.1033, -0.7248, -1.3859, -1.0511, 0.2797, 0.2777, -0.0877, 0.0271, 0.0740, -1.5863, -0.7014, 0.3677, -1.6786, -1.0769, 0.5594, 0.2428, -0.2664, 0.3454, -0.0490, -3.3762, 0.2004, 0.1913, -0.6461, 0.7643, -0.1239, 1.6487, 0.4942, -0.3305, -0.5069, -0.2183, 1.1533, -0.4380, 0.0219, -0.6319, 0.6743, 1.0648, 0.0587, -0.0989, -0.0995, 0.3757, 0.1813, 0.2854, 0.4345, -2.2154, 0.3601, -0.6406, -0.1099, 0.3583, -0.3726, 0.2892, 0.5897, 3.4282, -2.8781, 0.8985, 0.1550, 0.1102, 0.8008, -0.0811, -0.4199, 0.3145, -0.3236, -0.2425, -0.4502, 0.2431, 0.8504, 0.4597, 0.6396, 0.0902, 1.3885, 0.1297, -1.1721, -0.3227, -0.4472, 0.2575, 1.6201, -0.5444, 0.8665, -0.9622, 0.0035, -0.5908, 1.6270, 0.0351, -0.3419, 0.0039, 1.1001, -0.3767, -0.2270, 1.3332, 0.3555, 0.0667, -0.5392, -1.3500, -0.0842, 0.2591, -2.8862, 0.3166, 2.3757, 1.1254, -0.5208, -0.7074, -0.8110, 0.3715, 1.3720, -0.7236, -0.0665, 0.2772, -0.2840, -0.3515, -0.4777, 0.3030, 0.5417, 0.7752, -0.0182, 1.1569, -0.1614, 1.6521, -2.2844, -0.9332, -0.1472, 0.6151, -0.5020, -0.0719, 0.3361, -0.2722, -0.1500, 0.5092, -0.0348, -0.6530, -0.4159, -0.6603, -3.6738, 0.1421, -1.1267, 0.4267, 0.0699, 1.6415, 0.1451, 1.3309, 0.7792, 2.1801, -0.0886, 0.4233, 0.2828, 1.3708, -1.2021, -0.2627, -2.1505, 0.7701, -0.0167, -0.0247, 0.4665, -1.5951, -0.9997, -0.1568, -0.1108, 0.1543, -1.0055, 0.0001, -1.0355, 0.8421, -0.0485, -0.3064, 1.2358, -0.0448, 0.4038, -0.7671, 0.3624, -0.6197, 0.7966, -0.2266, 0.1130, -0.5302, -1.5468, 0.0700, -1.1711, -0.3307, 0.0086, -0.0416, 1.2763, -0.0574, 0.0121, -2.6334, -0.3180, -0.1954, 0.3944, 0.0076, 1.2025, -0.5634, 0.9271, 0.4198, 0.3251, -0.0041, 0.5236, -0.5314, 0.0639, -0.8840, -0.2680, 0.4958, 0.7804, 0.2942, -0.1935, -0.1405, -0.5670, 0.9489, 0.5726, -0.2529, 1.8878, -0.7204, -0.0050, -0.2448, 0.1725, 0.4253, 0.0058, 1.0247, -0.2908, -0.3978, -0.0963, 0.2107, 1.3576, 0.3074, 0.5527, -0.0927, 0.1521, 0.6300, -0.1377, -0.0497, 0.0425, -0.2248, -0.1534, 0.5778, 0.0033, 0.1789, 2.4935, 1.3225, -0.8038, -0.8864, 0.1176, -1.0532, -0.2375, 1.4582, -0.1168, 0.0548, 0.4221, -0.3585, 0.4043, -0.4371, 1.3289, -0.3674, -0.4286, -0.1730, 0.0535, 0.1441, 0.2703, 0.3826, 0.5123, -0.0401, -0.1230, -0.3143, -1.7583, 0.2582, 0.3484, 0.5722, 0.8621, 0.4420, 0.4442, -0.2445, 0.0532, -0.8102, -1.4058, -0.6382, -0.5799, -0.2456, -0.0906, -0.3191, -0.3395, -0.4364, -0.5810, -0.7970, 0.0831, -1.1570, -0.2573, -0.0644, -0.7106, 0.1313, 0.1944, -0.2329, 0.1409, -1.2096, 1.0822, 0.5523, 0.2151, -0.1106, -0.1034, -0.4873, 0.6932, 1.0196, -0.0521, 0.0569, -0.8759, 1.0084, 0.6800, 1.0768, -1.2878, -0.1161, 0.0447, 0.1888, -0.2371, -1.0470, -0.4027, -1.4363, 0.1606, -0.8026, -0.0244, -0.2893, -0.4938, -0.6921, 1.0140, -0.4158, -0.5957, 0.3313, -0.2462, 0.7703, 0.3403, -1.5113, -0.1231, -0.3776, 0.3326, 0.1634, 2.1520, 0.7302, -0.0300, -2.8234, 0.4553, -0.4652, -0.3331, -1.0286, 1.2882, -0.2797, -0.4759, 0.1470, -1.0253, -0.8175, 0.6936, -0.3728, -0.4594, 1.0876, 0.6229, -0.0461, -0.4342, -0.1686, -1.3960, 0.5283, 0.4002, 0.8179, 0.4787, -0.7147, -0.5052, -0.2552, 0.2817, -0.4022, -0.5289, 0.0815, -0.4814, -0.5451, -0.1384, -0.4303, -0.4506, -1.9036, 0.6884, 0.1361, 0.2678, -0.0052, 0.0119, -0.1882, 1.0507, 3.1094, -0.5746, 1.3087, -0.1831, -0.1917, 0.0633, 0.5083, -0.1448, -0.0134, 0.5002, 0.2579, 0.7755, 0.1579, 0.4157, -0.2610, -0.4953, 0.1709, 0.4063, 0.2068, 0.2666, -0.7872, 0.5325, 0.4910, -0.1599, 0.4387, 0.9262, 0.9245, 0.5763, 0.9292, -0.4531, -0.5367, -0.4911, 0.2302, -0.4182, 0.7188, 0.0342, -0.2079, 0.1310, 0.5718, -0.0331, 0.1861, -0.1287, -0.0427, 0.8478, 0.7278, -0.5664, -0.5335, 1.3976, 0.1697, 0.6063, -0.0220, 0.4921, -0.1349, -0.0531, -0.2408, -0.3858, -0.2741, 0.2285, -0.5532, 0.2704, -0.2687, -0.2161, -0.1179, -1.5228, -0.3683, -1.3004, 0.2431, -0.3305, 1.6118, -0.0328, 1.1503, 0.5712, -0.0423, 3.4830, -0.2760, 0.6307, -0.0419, 0.1553, -0.5602, 0.2106, -0.2213, -0.4543, 0.3034, 0.9189, 1.5738, 0.5071, 0.2238, -2.2069, 0.4104, 0.6224, 0.2836, -0.1620, -0.3043, -0.4012, 0.2410, -0.6261, -0.2435, 0.0211, -0.2227, -0.2392, -0.3634, -0.9207, 0.2260, 0.0929, 0.8206, -0.3214, -0.2296, 0.1274, -0.8615, 0.2329, 1.1085, -1.0565, 0.2258], + "std": [0.9963, 0.6391, 0.4956, 0.6280, 0.7591, 0.5610, 0.8236, 0.7139, 0.7494, 0.5686, 0.5042, 0.3464, 0.4228, 0.4171, 0.3526, 0.3710, 0.6288, 0.4674, 0.4413, 0.4741, 0.6553, 0.4882, 0.3697, 0.5507, 0.4961, 0.3683, 0.5604, 0.5302, 0.6027, 0.3023, 0.4882, 0.5746, 0.5314, 0.5031, 0.6145, 0.5994, 0.4285, 0.6399, 0.5362, 0.5403, 0.4677, 1.2902, 0.6126, 0.4145, 0.5068, 0.4667, 0.4825, 0.4275, 0.4381, 0.6758, 0.4866, 0.4136, 0.5262, 0.5698, 0.6550, 0.6492, 0.3450, 0.5948, 0.4219, 0.4973, 0.4483, 0.4336, 0.7440, 0.4595, 0.4366, 0.3634, 0.4430, 0.6587, 0.5073, 0.3533, 0.7036, 0.7039, 0.5312, 0.4701, 0.7512, 0.4102, 0.4227, 0.4488, 0.4158, 0.4676, 0.4521, 0.4560, 0.3917, 0.4757, 0.4348, 0.6013, 0.6715, 0.5179, 0.4834, 0.7451, 0.4845, 0.8893, 0.4188, 0.5963, 0.4306, 0.4551, 0.6417, 0.2886, 0.5378, 0.4316, 0.7568, 0.4818, 0.5494, 0.4736, 0.5841, 0.5043, 0.4265, 0.6994, 0.7652, 0.4344, 0.3931, 0.7198, 0.4169, 0.5794, 0.6720, 0.5694, 0.8603, 0.5307, 0.5893, 0.5763, 0.5292, 0.5228, 0.4156, 0.4901, 0.8334, 0.4574, 0.7241, 0.5346, 0.4063, 0.4147, 0.4979, 0.6599, 0.4173, 0.3715, 0.3828, 0.4492, 0.5576, 0.4060, 0.4353, 0.5315, 0.9834, 0.5548, 0.5679, 0.3506, 0.5419, 0.4256, 0.4187, 0.3570, 0.6316, 0.5870, 0.4832, 0.4862, 0.6072, 0.6781, 0.6152, 0.6708, 0.5008, 0.4435, 0.4229, 0.4973, 0.4301, 0.5363, 0.5478, 0.5388, 0.3952, 0.5961, 0.4721, 0.6389, 0.4450, 0.4841, 0.5594, 0.5234, 0.5224, 0.6326, 0.4469, 0.7397, 0.9551, 0.8426, 0.7576, 0.3893, 0.4382, 0.5222, 0.5234, 0.6035, 0.5764, 0.4043, 0.4741, 0.5471, 0.4229, 0.5962, 0.7127, 0.6205, 0.5671, 0.3766, 0.7455, 0.7315, 0.5891, 0.5372, 0.5957, 0.5342, 0.4010, 0.4453, 0.4609, 0.8789, 0.4353, 0.6297, 0.4126, 0.4149, 0.4597, 0.4859, 0.6733, 0.6096, 0.5719, 0.4494, 0.6353, 0.7537, 0.4643, 0.4577, 0.6485, 0.6069, 0.3603, 0.5821, 0.4807, 0.5192, 0.5329, 0.4153, 0.7329, 0.5444, 0.5742, 0.4593, 0.4003, 0.6770, 0.5428, 0.6781, 0.7920, 0.5037, 0.7615, 0.4537, 0.5931, 0.7333, 0.4880, 0.5469, 0.4698, 0.4917, 0.6256, 0.4947, 0.3974, 0.7559, 0.5916, 0.6547, 0.7502, 0.4682, 0.4517, 0.4888, 0.6472, 0.4755, 0.3927, 0.5845, 0.4135, 0.4091, 0.5860, 0.4544, 0.4051, 0.5547, 0.5322, 0.7200, 0.4595, 0.5484, 0.4758, 0.5259, 0.4137, 0.7149, 0.5638, 0.6221, 0.9309, 0.5637, 0.5657, 0.5711, 0.5651, 1.0484, 0.4435, 0.4587, 0.3716, 0.4108, 0.8114, 0.5531, 1.0675, 0.5825, 0.3841, 0.4500, 0.7335, 0.4767, 0.4162, 0.5679, 0.4880, 0.4614, 0.5118, 0.5198, 0.5619, 0.6869, 0.3536, 0.5128, 0.4722, 0.3722, 0.7705, 0.4556, 0.5365, 0.4999, 0.3254, 0.5268, 0.7580, 0.5932, 0.9908, 0.6171, 0.4912, 0.4439, 0.9135, 0.4658, 0.6566, 0.5500, 0.5423, 0.4725, 0.5415, 0.5550, 0.7519, 0.4220, 0.6024, 0.4821, 0.5268, 0.4583, 0.7421, 0.5200, 0.4541, 0.5197, 0.4562, 0.8381, 0.4423, 0.7400, 0.6578, 0.6459, 0.5316, 0.6877, 0.5362, 0.4215, 0.6455, 0.4363, 0.6716, 0.5795, 0.5587, 0.5234, 0.4456, 0.4991, 0.4244, 0.8959, 0.4744, 0.4440, 0.4437, 0.8485, 0.4237, 0.6907, 0.5582, 0.4315, 0.5458, 0.7341, 0.4731, 0.5065, 0.6181, 0.5643, 0.4407, 0.4353, 0.4732, 0.3769, 0.4162, 0.5028, 0.3689, 0.6656, 0.4598, 0.3735, 0.6801, 0.7902, 0.7101, 0.6292, 0.5732, 0.7452, 0.6803, 0.5065, 0.5261, 0.4644, 0.5021, 0.6714, 0.5226, 0.4455, 0.7599, 0.4380, 0.5468, 0.4595, 0.5308, 0.8445, 0.4413, 0.5196, 0.9241, 0.5414, 0.5018, 0.3832, 0.4950, 0.3185, 0.5330, 0.4844, 0.4481, 0.4517, 0.5104, 0.6092, 0.5712, 0.4164, 0.6590, 0.4888, 0.3930, 0.5419, 0.5486, 0.5165, 0.4390, 0.5542, 0.3883, 0.4074, 0.6213, 0.6185, 0.7711, 0.6565, 0.4925, 1.0624, 0.4690, 0.7498, 0.5333, 0.5290, 0.6258, 0.4473, 0.3862, 0.6571, 0.4873, 0.5240, 0.4127, 0.4445, 0.5094, 0.4754, 0.5769, 0.4786, 0.4510, 0.5130, 0.4897, 0.7568, 0.7398, 0.5718, 0.4229, 0.4929, 0.7470, 0.5901, 0.3772, 0.4914, 0.4074, 0.9471, 0.4967, 0.4323, 0.5259, 0.3591, 0.7202, 0.6012, 0.4573, 0.4296, 0.5578, 0.5218, 0.4640, 0.4522, 0.4029, 0.8071, 0.6086, 0.4832, 0.4202, 0.6781, 0.3862, 0.3920, 0.7543, 1.0257, 0.8849, 0.4181, 0.4722, 0.4069, 0.4854, 0.5405, 0.4676, 0.5547, 0.6282, 0.4275, 0.8011, 0.3308, 0.7135, 0.4315, 0.4915, 0.6616, 0.7376, 0.5742, 0.7461, 0.5443, 0.4749, 0.4906, 1.0020, 0.6306, 0.4435, 0.4559, 0.4360, 0.4047, 0.5802, 0.6109, 0.7836, 0.5163, 0.9777, 0.9272, 0.4618, 0.3534, 0.5218, 0.4479, 0.6498, 0.7145, 0.6224, 0.5671, 0.5042, 0.3885, 0.4079, 0.4481, 0.5406, 0.6944, 0.3744, 0.5942, 0.6770, 0.5934, 0.7417, 0.5662, 0.4753, 0.5063, 0.5003, 0.4510, 0.4358, 0.6455, 0.7740, 0.4780, 0.8687, 0.5533, 0.5700, 0.3518, 0.4868, 0.4154, 0.4798, 0.3266, 0.3536, 0.3789, 0.3805, 0.7909, 0.5760, 0.5784, 0.4993, 0.5787, 0.5324, 0.4496, 0.8483, 1.0794, 0.4820, 0.4135, 0.6231, 0.4668, 0.6684, 0.7052, 0.7616, 0.4881, 0.4150, 0.5793, 0.8068, 0.7793, 0.4721, 0.5230, 0.4810, 0.9577, 0.5537, 0.5583, 0.6645, 0.4334, 0.6398, 0.5011, 0.4081, 0.6255, 0.5372, 0.4846, 0.6125, 0.6509, 0.4413, 0.4762, 0.4917, 0.5940, 0.4950, 0.6753, 0.6653, 0.5210, 0.5599, 0.4678, 0.4868, 0.5985, 0.4160, 0.4874, 0.8380, 0.5382, 0.5701, 0.5448, 0.6131, 0.5674, 0.7120, 0.4070, 0.4434, 0.5725, 0.4919, 0.4805, 0.5997, 0.7108, 0.9824, 0.4765, 0.7575, 0.4452, 0.8892, 0.4639, 0.4962, 1.0346, 0.7584, 0.4312, 0.4835, 0.8968, 0.4799, 0.6864, 0.5641, 1.0694, 0.6750, 0.4288, 0.5159, 0.3649, 0.7699, 0.4386, 0.4449, 0.3923, 0.6499, 0.5612, 0.4541, 0.6261, 0.5444, 0.4369, 0.4124, 0.4174, 0.6129, 0.5005, 0.4779, 0.3929, 0.4865, 0.4338, 0.4114, 0.6266, 0.3669, 1.0147, 0.4856, 0.4867, 0.6250, 0.5368, 0.6699, 0.6411, 0.5296, 0.7614, 0.5643, 0.5843, 0.6846, 0.3923, 0.3928, 0.4964, 0.4490, 0.4755, 0.4104, 0.5468, 0.6040, 0.5808, 0.6283, 0.4316, 0.6127, 0.4635, 0.5303, 0.4261, 0.4668, 0.6121, 0.4063, 0.5571, 0.6130, 0.5874, 0.4987, 0.4113, 0.5401, 0.4028, 0.6598, 0.7740, 0.5384, 0.7890, 0.9379, 0.8801, 0.6222, 0.5356, 0.3990, 0.4802, 0.4107, 0.5475, 0.6936, 0.6865, 0.4776, 0.5211, 0.8844, 0.6517, 1.0729, 0.9252, 0.6953, 0.4177, 0.7587, 0.6628, 0.3629, 0.3685, 0.3758, 0.4439, 0.5236, 0.4905, 0.5290, 0.4184, 0.3940, 0.6498, 0.5411, 0.5662, 0.5519, 0.6107, 0.6385, 0.4127, 0.6277, 0.5255, 0.5926, 0.5653, 0.6570, 0.6034, 0.5312, 0.4128, 0.7292, 0.3620, 0.5067, 0.5314, 0.3908, 0.7561, 0.4494, 0.4501, 0.7682, 0.4939, 0.4198, 0.5256, 0.6339, 0.5123, 0.9018, 0.5054, 0.4879, 0.4567, 0.4145, 0.6046, 0.3835, 0.4289, 0.5254, 0.6191, 0.6610, 0.5933, 0.7890, 0.7817, 0.6299, 0.5977, 0.7094, 0.3737, 1.0318, 0.7045, 0.7785, 0.5376, 0.5861, 0.4233, 0.5538, 1.0604, 0.5690, 0.5249, 0.3747, 0.6036, 0.4707, 0.3617, 0.3665, 0.6184, 0.4878, 0.6193, 0.5311, 0.6187, 0.6748, 0.4493, 0.6137, 0.4601, 0.3855, 0.4183, 0.4986, 0.4832, 0.4192, 0.4416, 0.7202, 0.3724, 0.4899, 0.6939, 0.4272, 0.7122, 0.6950, 0.5565, 0.4417, 0.6186, 0.4753, 0.5919, 0.3763, 0.5643, 0.5347, 0.5454, 0.9336, 0.6594, 0.9747, 0.4970, 0.4725, 0.7820, 0.4113, 0.4942, 0.6699, 0.4159, 0.6766, 0.6564, 0.3947, 0.5381, 0.3874, 0.6686, 0.5628, 0.3904, 0.6647, 0.9821, 0.4343, 0.5455, 0.4879, 0.8165, 0.4153, 0.5544, 0.5179, 0.3821, 0.6678, 0.7883, 0.3372, 0.4702, 0.5044, 0.4584, 0.4769, 0.3787, 0.4377, 0.5435, 0.5899, 0.5378, 0.5986, 0.4887, 0.5390, 0.5464, 0.6330, 0.5010, 0.4244, 0.5249, 0.6770, 0.6314, 0.6404, 0.4605, 0.3649, 0.6489, 1.0657, 0.5497, 0.5357, 0.3651, 0.8484, 0.8126, 0.4873, 0.6711, 0.4401, 0.6181, 0.8585, 0.6000, 0.5654, 0.5416, 0.3504, 0.4671, 0.5499, 0.4409, 0.7650, 0.4980, 0.9734, 0.3568, 0.6037, 0.4361, 0.7880, 0.4726, 0.9902, 0.5020, 0.5178, 0.5065, 0.4543, 0.9039, 0.8296, 0.4451, 0.4436, 0.8518, 0.5201, 0.8668, 0.5122, 0.3412, 0.5849, 0.4815, 0.5795, 0.5664, 0.4384, 0.4593, 0.7974, 0.6570, 0.6522, 0.5490, 0.4195, 0.6821, 0.6133, 0.5692, 0.4780, 0.4574, 0.5090, 0.4488, 0.4269, 0.4153, 0.5143, 0.6560, 0.4480, 0.5482, 0.6997, 0.4377, 0.4166, 0.6103, 0.4671, 0.4449, 0.5672, 0.3296, 0.3898, 0.3778, 0.6572, 0.5555, 0.4047, 0.3720, 0.5728, 0.6867, 0.5435, 0.5001, 0.6808, 0.6373, 0.6849, 0.4826, 0.4767, 0.3736, 0.5070, 0.4442, 0.4302, 0.4339, 0.4614, 0.4735, 0.7977, 0.5657, 0.4047, 0.5261, 0.6204, 0.3413, 0.3996, 0.4236, 0.3303, 0.4193, 0.6074, 0.3941, 0.4802, 0.4114, 0.3880, 0.3460, 0.3767, 0.6491, 0.6893, 0.8560, 0.4244, 0.4307, 0.5702, 0.3635, 0.5170, 0.3975, 0.6187, 0.5012, 0.4976, 0.7149, 0.7001, 0.4834, 0.3844, 0.5179, 0.6909, 0.5862, 1.0062, 0.5099, 0.6410, 0.7432, 0.4219, 0.4655, 0.6067, 0.6674, 0.4618, 0.7115, 0.5300, 0.5284, 0.4208, 0.4955, 0.4561, 0.3723], + } +} + +cam_angvel = { + "emdb_none_test": { + "count": 42622, + "mean": [1., 0., 0., 0., 1., 0.], + "std": [5.5702e-05, 3.2200e-03, 5.6530e-03, 3.2191e-03, 2.4738e-05, 3.3406e-03], + }, + "manual": { + "mean": [1., 0., 0., 0., 1., 0.], + "std": [0.001, 0.1, 0.1, 0.1, 0.001, 0.1], # manually + } +} +# fmt:on + +# ====== Compose ====== # + + +def compose(targets, sources): + if len(sources) == 1: + sources = sources * len(targets) + mean = [] + std = [] + for t, s in zip(targets, sources): + mean.extend(t[s]["mean"]) + std.extend(t[s]["std"]) + return {"mean": mean, "std": std} + + +DEFAULT_01 = {"mean": [0.0], "std": [1.0]} + +MM_V1 = compose( + [body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel], + ["bedlam"] * 5, +) +MM_V1_AMASS_LOCAL_BEDLAM_CAM = compose( + [body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel], + ["amass", "amass", "bedlam", "bedlam", "amass"], +) + +MM_V2 = compose( + [body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel], + ["bedlam", "bedlam", "bedlam", "bedlam", "none"], +) + +MM_V2_1 = compose( + [body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel], + ["bedlam", "bedlam", "bedlam", "bedlam", "1e-2"], +) diff --git a/hmr4d/network/base_arch/embeddings/rotary_embedding.py b/hmr4d/network/base_arch/embeddings/rotary_embedding.py new file mode 100644 index 0000000..78c8f20 --- /dev/null +++ b/hmr4d/network/base_arch/embeddings/rotary_embedding.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from einops import repeat, rearrange +from torch.cuda.amp import autocast + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +@autocast(enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim=-1) + + +def get_encoding(d_model, max_seq_len=4096): + """Return: (L, D)""" + t = torch.arange(max_seq_len).float() + freqs = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)) + freqs = torch.einsum("i, j -> i j", t, freqs) + freqs = repeat(freqs, "i j -> i (j r)", r=2) + return freqs + + +class ROPE(nn.Module): + """Minimal impl of a lang-style positional encoding.""" + + def __init__(self, d_model, max_seq_len=4096): + super().__init__() + self.d_model = d_model + self.max_seq_len = max_seq_len + + # Pre-cache a freqs tensor + encoding = get_encoding(d_model, max_seq_len) + self.register_buffer("encoding", encoding, False) + + def rotate_queries_or_keys(self, x): + """ + Args: + x : (B, H, L, D) + Returns: + rotated_x: (B, H, L, D) + """ + + seq_len, d_model = x.shape[-2:] + assert d_model == self.d_model + + # encoding: (L, D)s + if seq_len > self.max_seq_len: + encoding = get_encoding(d_model, seq_len).to(x) + else: + encoding = self.encoding[:seq_len] + + # encoding: (L, D) + # x: (B, H, L, D) + rotated_x = apply_rotary_emb(encoding, x, seq_dim=-2) + + return rotated_x diff --git a/hmr4d/network/base_arch/transformer/encoder_rope.py b/hmr4d/network/base_arch/transformer/encoder_rope.py new file mode 100644 index 0000000..7f16660 --- /dev/null +++ b/hmr4d/network/base_arch/transformer/encoder_rope.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from timm.models.vision_transformer import Mlp +from typing import Optional, Tuple +from einops import einsum, rearrange, repeat +from hmr4d.network.base_arch.embeddings.rotary_embedding import ROPE + + +class RoPEAttention(nn.Module): + def __init__(self, embed_dim, num_heads, dropout=0.1): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.rope = ROPE(self.head_dim, max_seq_len=4096) + + self.query = nn.Linear(embed_dim, embed_dim) + self.key = nn.Linear(embed_dim, embed_dim) + self.value = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(dropout) + self.proj = nn.Linear(embed_dim, embed_dim) + + def forward(self, x, attn_mask=None, key_padding_mask=None): + # x: (B, L, C) + # attn_mask: (L, L) + # key_padding_mask: (B, L) + B, L, _ = x.shape + xq, xk, xv = self.query(x), self.key(x), self.value(x) + + xq = xq.reshape(B, L, self.num_heads, -1).transpose(1, 2) + xk = xk.reshape(B, L, self.num_heads, -1).transpose(1, 2) + xv = xv.reshape(B, L, self.num_heads, -1).transpose(1, 2) + + xq = self.rope.rotate_queries_or_keys(xq) # B, N, L, C + xk = self.rope.rotate_queries_or_keys(xk) # B, N, L, C + + attn_score = einsum(xq, xk, "b n i c, b n j c -> b n i j") / math.sqrt(self.head_dim) + if attn_mask is not None: + attn_mask = attn_mask.reshape(1, 1, L, L).expand(B, self.num_heads, -1, -1) + attn_score = attn_score.masked_fill(attn_mask, float("-inf")) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.reshape(B, 1, 1, L).expand(-1, self.num_heads, L, -1) + attn_score = attn_score.masked_fill(key_padding_mask, float("-inf")) + + attn_score = torch.softmax(attn_score, dim=-1) + attn_score = self.dropout(attn_score) + output = einsum(attn_score, xv, "b n i j, b n j c -> b n i c") # B, N, L, C + output = output.transpose(1, 2).reshape(B, L, -1) # B, L, C + output = self.proj(output) # B, L, C + return output + + +class EncoderRoPEBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, dropout=0.1, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = RoPEAttention(hidden_size, num_heads, dropout) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=dropout) + + self.gate_msa = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.gate_mlp = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + # Zero-out adaLN modulation layers + nn.init.constant_(self.gate_msa, 0) + nn.init.constant_(self.gate_mlp, 0) + + def forward(self, x, attn_mask=None, tgt_key_padding_mask=None): + x = x + self.gate_msa * self._sa_block( + self.norm1(x), attn_mask=attn_mask, key_padding_mask=tgt_key_padding_mask + ) + x = x + self.gate_mlp * self.mlp(self.norm2(x)) + return x + + def _sa_block(self, x, attn_mask=None, key_padding_mask=None): + # x: (B, L, C) + x = self.attn(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + return x diff --git a/hmr4d/network/base_arch/transformer/layer.py b/hmr4d/network/base_arch/transformer/layer.py new file mode 100644 index 0000000..1d4e75c --- /dev/null +++ b/hmr4d/network/base_arch/transformer/layer.py @@ -0,0 +1,12 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module diff --git a/hmr4d/network/gvhmr/relative_transformer.py b/hmr4d/network/gvhmr/relative_transformer.py new file mode 100644 index 0000000..e21df80 --- /dev/null +++ b/hmr4d/network/gvhmr/relative_transformer.py @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import einsum, rearrange, repeat +from hmr4d.configs import MainStore, builds + +from hmr4d.network.base_arch.transformer.encoder_rope import EncoderRoPEBlock +from hmr4d.network.base_arch.transformer.layer import zero_module + +from hmr4d.utils.net_utils import length_to_mask +from timm.models.vision_transformer import Mlp + + +class NetworkEncoderRoPE(nn.Module): + def __init__( + self, + # x + output_dim=151, + max_len=120, + # condition + cliffcam_dim=3, + cam_angvel_dim=6, + imgseq_dim=1024, + # intermediate + latent_dim=512, + num_layers=12, + num_heads=8, + mlp_ratio=4.0, + # output + pred_cam_dim=3, + static_conf_dim=6, + # training + dropout=0.1, + # other + avgbeta=True, + ): + super().__init__() + + # input + self.output_dim = output_dim + self.max_len = max_len + + # condition + self.cliffcam_dim = cliffcam_dim + self.cam_angvel_dim = cam_angvel_dim + self.imgseq_dim = imgseq_dim + + # intermediate + self.latent_dim = latent_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + + # ===== build model ===== # + # Input (Kp2d) + # Main token: map d_obs 2 to 32 + self.learned_pos_linear = nn.Linear(2, 32) + self.learned_pos_params = nn.Parameter(torch.randn(17, 32), requires_grad=True) + self.embed_noisyobs = Mlp( + 17 * 32, hidden_features=self.latent_dim * 2, out_features=self.latent_dim, drop=dropout + ) + + self._build_condition_embedder() + + # Transformer + self.blocks = nn.ModuleList( + [ + EncoderRoPEBlock(self.latent_dim, self.num_heads, mlp_ratio=mlp_ratio, dropout=dropout) + for _ in range(self.num_layers) + ] + ) + + # Output heads + self.final_layer = Mlp(self.latent_dim, out_features=self.output_dim) + self.pred_cam_head = pred_cam_dim > 0 # keep extra_output for easy-loading old ckpt + if self.pred_cam_head: + self.pred_cam_head = Mlp(self.latent_dim, out_features=pred_cam_dim) + self.register_buffer("pred_cam_mean", torch.tensor([1.0606, -0.0027, 0.2702]), False) + self.register_buffer("pred_cam_std", torch.tensor([0.1784, 0.0956, 0.0764]), False) + + self.static_conf_head = static_conf_dim > 0 + if self.static_conf_head: + self.static_conf_head = Mlp(self.latent_dim, out_features=static_conf_dim) + + self.avgbeta = avgbeta + + def _build_condition_embedder(self): + latent_dim = self.latent_dim + dropout = self.dropout + self.cliffcam_embedder = nn.Sequential( + nn.Linear(self.cliffcam_dim, latent_dim), + nn.SiLU(), + nn.Dropout(dropout), + zero_module(nn.Linear(latent_dim, latent_dim)), + ) + if self.cam_angvel_dim > 0: + self.cam_angvel_embedder = nn.Sequential( + nn.Linear(self.cam_angvel_dim, latent_dim), + nn.SiLU(), + nn.Dropout(dropout), + zero_module(nn.Linear(latent_dim, latent_dim)), + ) + if self.imgseq_dim > 0: + self.imgseq_embedder = nn.Sequential( + nn.LayerNorm(self.imgseq_dim), + zero_module(nn.Linear(self.imgseq_dim, latent_dim)), + ) + + def forward(self, length, obs=None, f_cliffcam=None, f_cam_angvel=None, f_imgseq=None): + """ + Args: + x: None we do not use it + timesteps: (B,) + length: (B), valid length of x, if None then use x.shape[2] + f_imgseq: (B, L, C) + f_cliffcam: (B, L, 3), CLIFF-Cam parameters (bbx-detection in the full-image) + f_noisyobs: (B, L, C), nosiy pose observation + f_cam_angvel: (B, L, 6), Camera angular velocity + """ + B, L, J, C = obs.shape + assert J == 17 and C == 3 + + # Main token from observation (2D pose) + obs = obs.clone() + visible_mask = obs[..., [2]] > 0.5 # (B, L, J, 1) + obs[~visible_mask[..., 0]] = 0 # set low-conf to all zeros + f_obs = self.learned_pos_linear(obs[..., :2]) # (B, L, J, 32) + f_obs = f_obs * visible_mask + self.learned_pos_params.repeat(B, L, 1, 1) * ~visible_mask + x = self.embed_noisyobs(f_obs.view(B, L, -1)) # (B, L, J*32) -> (B, L, C) + + # Condition + f_to_add = [] + f_to_add.append(self.cliffcam_embedder(f_cliffcam)) + if hasattr(self, "cam_angvel_embedder"): + f_to_add.append(self.cam_angvel_embedder(f_cam_angvel)) + if f_imgseq is not None and hasattr(self, "imgseq_embedder"): + f_to_add.append(self.imgseq_embedder(f_imgseq)) + + for f_delta in f_to_add: + x = x + f_delta + + # Setup length and make padding mask + assert B == length.size(0) + pmask = ~length_to_mask(length, L) # (B, L) + + if L > self.max_len: + attnmask = torch.ones((L, L), device=x.device, dtype=torch.bool) + for i in range(L): + min_ind = max(0, i - self.max_len // 2) + max_ind = min(L, i + self.max_len // 2) + max_ind = max(self.max_len, max_ind) + min_ind = min(L - self.max_len, min_ind) + attnmask[i, min_ind:max_ind] = False + else: + attnmask = None + + # Transformer + for block in self.blocks: + x = block(x, attn_mask=attnmask, tgt_key_padding_mask=pmask) + + # Output + sample = self.final_layer(x) # (B, L, C) + if self.avgbeta: + betas = (sample[..., 126:136] * (~pmask[..., None])).sum(1) / length[:, None] # (B, C) + betas = repeat(betas, "b c -> b l c", l=L) + sample = torch.cat([sample[..., :126], betas, sample[..., 136:]], dim=-1) + + # Output (extra) + pred_cam = None + if self.pred_cam_head: + pred_cam = self.pred_cam_head(x) + pred_cam = pred_cam * self.pred_cam_std + self.pred_cam_mean + torch.clamp_min_(pred_cam[..., 0], 0.25) # min_clamp s to 0.25 (prevent negative prediction) + + static_conf_logits = None + if self.static_conf_head: + static_conf_logits = self.static_conf_head(x) # (B, L, C') + + output = { + "pred_context": x, + "pred_x": sample, + "pred_cam": pred_cam, + "static_conf_logits": static_conf_logits, + } + return output + + +# Add to MainStore +group_name = "network/gvhmr" +MainStore.store( + name="relative_transformer", + node=builds(NetworkEncoderRoPE, populate_full_signature=True), + group=group_name, +) diff --git a/hmr4d/network/hmr2/__init__.py b/hmr4d/network/hmr2/__init__.py new file mode 100644 index 0000000..4be998b --- /dev/null +++ b/hmr4d/network/hmr2/__init__.py @@ -0,0 +1,32 @@ +import torch +from .hmr2 import HMR2 +from pathlib import Path +from .configs import get_config +from hmr4d import PROJ_ROOT + +HMR2A_CKPT = PROJ_ROOT / f"inputs/checkpoints/hmr2/epoch=10-step=25000.ckpt" # this is HMR2.0a, follow WHAM + + +def load_hmr2(checkpoint_path=HMR2A_CKPT): + model_cfg = str((Path(__file__).parent / "configs/model_config.yaml").resolve()) + model_cfg = get_config(model_cfg) + + # Override some config values, to crop bbox correctly + if (model_cfg.MODEL.BACKBONE.TYPE == "vit") and ("BBOX_SHAPE" not in model_cfg.MODEL): + model_cfg.defrost() + assert ( + model_cfg.MODEL.IMAGE_SIZE == 256 + ), f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" + model_cfg.MODEL.BBOX_SHAPE = [192, 256] # (W, H) + model_cfg.freeze() + + # Setup model and Load weights. + # model = HMR2.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg) + model = HMR2(model_cfg) + + state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"] + keys = [k for k in state_dict.keys() if k.split(".")[0] in ["backbone", "smpl_head"]] + state_dict = {k: v for k, v in state_dict.items() if k in keys} + model.load_state_dict(state_dict, strict=True) + + return model diff --git a/hmr4d/network/hmr2/components/__init__.py b/hmr4d/network/hmr2/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hmr4d/network/hmr2/components/pose_transformer.py b/hmr4d/network/hmr2/components/pose_transformer.py new file mode 100644 index 0000000..ac04971 --- /dev/null +++ b/hmr4d/network/hmr2/components/pose_transformer.py @@ -0,0 +1,358 @@ +from inspect import isfunction +from typing import Callable, Optional + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from .t_cond_mlp import ( + AdaptiveLayerNorm1D, + FrequencyEmbedder, + normalization_layer, +) +# from .vit import Attention, FeedForward + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): + super().__init__() + self.norm = normalization_layer(norm, dim, norm_cond_dim) + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + if isinstance(self.norm, AdaptiveLayerNorm1D): + return self.fn(self.norm(x, *args), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + context_dim = default(context_dim, dim) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, context=None): + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q = self.to_q(x) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args): + for attn, ff in self.layers: + x = attn(x, *args) + x + x = ff(x, *args) + x + return x + + +class TransformerCrossAttn(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ca = CrossAttention( + dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args, context=None, context_list=None): + if context_list is None: + context_list = [context] * len(self.layers) + if len(context_list) != len(self.layers): + raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") + + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + x = self_attn(x, *args) + x + x = cross_attn(x, *args, context=context_list[i]) + x + x = ff(x, *args) + x + return x + + +class DropTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() + # TODO: permutation idx for each batch using torch.argsort + if zero_mask.any(): + x = x[:, ~zero_mask, :] + return x + + +class ZeroTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() + # Zero-out the masked tokens + x[zero_mask, :] = 0 + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = "drop", + emb_dropout_loc: str = "token", + norm: str = "layer", + norm_cond_dim: int = -1, + token_pe_numfreq: int = -1, + ): + super().__init__() + if token_pe_numfreq > 0: + token_dim_new = token_dim * (2 * token_pe_numfreq + 1) + self.to_token_embedding = nn.Sequential( + Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), + FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), + Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), + nn.Linear(token_dim_new, dim), + ) + else: + self.to_token_embedding = nn.Linear(token_dim, dim) + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + else: + raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") + self.emb_dropout_loc = emb_dropout_loc + + self.transformer = Transformer( + dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim + ) + + def forward(self, inp: torch.Tensor, *args, **kwargs): + x = inp + + if self.emb_dropout_loc == "input": + x = self.dropout(x) + x = self.to_token_embedding(x) + + if self.emb_dropout_loc == "token": + x = self.dropout(x) + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + if self.emb_dropout_loc == "token_afterpos": + x = self.dropout(x) + x = self.transformer(x, *args) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = 'drop', + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + skip_token_embedding: bool = False, + ): + super().__init__() + if not skip_token_embedding: + self.to_token_embedding = nn.Linear(token_dim, dim) + else: + self.to_token_embedding = nn.Identity() + if token_dim != dim: + raise ValueError( + f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + elif emb_dropout_type == "normal": + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = TransformerCrossAttn( + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + norm=norm, + norm_cond_dim=norm_cond_dim, + context_dim=context_dim, + ) + + def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): + x = self.to_token_embedding(inp) + b, n, _ = x.shape + + x = self.dropout(x) + x += self.pos_embedding[:, :n] + + x = self.transformer(x, *args, context=context, context_list=context_list) + return x + diff --git a/hmr4d/network/hmr2/components/t_cond_mlp.py b/hmr4d/network/hmr2/components/t_cond_mlp.py new file mode 100644 index 0000000..44d5a09 --- /dev/null +++ b/hmr4d/network/hmr2/components/t_cond_mlp.py @@ -0,0 +1,199 @@ +import copy +from typing import List, Optional + +import torch + + +class AdaptiveLayerNorm1D(torch.nn.Module): + def __init__(self, data_dim: int, norm_cond_dim: int): + super().__init__() + if data_dim <= 0: + raise ValueError(f"data_dim must be positive, but got {data_dim}") + if norm_cond_dim <= 0: + raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") + self.norm = torch.nn.LayerNorm( + data_dim + ) # TODO: Check if elementwise_affine=True is correct + self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) + torch.nn.init.zeros_(self.linear.weight) + torch.nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # x: (batch, ..., data_dim) + # t: (batch, norm_cond_dim) + # return: (batch, data_dim) + x = self.norm(x) + alpha, beta = self.linear(t).chunk(2, dim=-1) + + # Add singleton dimensions to alpha and beta + if x.dim() > 2: + alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) + beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) + + return x * (1 + alpha) + beta + + +class SequentialCond(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + for module in self: + if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): + # print(f'Passing on args to {module}', [a.shape for a in args]) + input = module(input, *args, **kwargs) + else: + # print(f'Skipping passing args to {module}', [a.shape for a in args]) + input = module(input) + return input + + +def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): + if norm == "batch": + return torch.nn.BatchNorm1d(dim) + elif norm == "layer": + return torch.nn.LayerNorm(dim) + elif norm == "ada": + assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" + return AdaptiveLayerNorm1D(dim, norm_cond_dim) + elif norm is None: + return torch.nn.Identity() + else: + raise ValueError(f"Unknown norm: {norm}") + + +def linear_norm_activ_dropout( + input_dim: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) + if norm is not None: + layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) + layers.append(copy.deepcopy(activation)) + if dropout > 0.0: + layers.append(torch.nn.Dropout(dropout)) + return SequentialCond(*layers) + + +def create_simple_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) + return SequentialCond(*layers) + + +class ResidualMLPBlock(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, + ): + super().__init__() + if not (input_dim == output_dim == hidden_dim): + raise NotImplementedError( + f"input_dim {input_dim} != output_dim {output_dim} is not implemented" + ) + + layers = [] + prev_dim = input_dim + for i in range(num_hidden_layers): + layers.append( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + self.model = SequentialCond(*layers) + self.skip = torch.nn.Identity() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x + self.model(x, *args, **kwargs) + + +class ResidualMLP(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + num_blocks: int = 1, + norm_cond_dim: int = -1, + ): + super().__init__() + self.input_dim = input_dim + self.model = SequentialCond( + linear_norm_activ_dropout( + input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ), + *[ + ResidualMLPBlock( + hidden_dim, + hidden_dim, + num_hidden_layers, + hidden_dim, + activation, + bias, + norm, + dropout, + norm_cond_dim, + ) + for _ in range(num_blocks) + ], + torch.nn.Linear(hidden_dim, output_dim, bias=bias), + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.model(x, *args, **kwargs) + + +class FrequencyEmbedder(torch.nn.Module): + def __init__(self, num_frequencies, max_freq_log2): + super().__init__() + frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) + self.register_buffer("frequencies", frequencies) + + def forward(self, x): + # x should be of size (N,) or (N, D) + N = x.size(0) + if x.dim() == 1: # (N,) + x = x.unsqueeze(1) # (N, D) where D=1 + x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) + scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) + s = torch.sin(scaled) + c = torch.cos(scaled) + embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( + N, -1 + ) # (N, D * 2 * num_frequencies + D) + return embedded + diff --git a/hmr4d/network/hmr2/configs/__init__.py b/hmr4d/network/hmr2/configs/__init__.py new file mode 100644 index 0000000..68a1856 --- /dev/null +++ b/hmr4d/network/hmr2/configs/__init__.py @@ -0,0 +1,119 @@ +import os +from typing import Dict +from yacs.config import CfgNode as CN +from pathlib import Path + +# CACHE_DIR = os.path.join(os.environ.get("HOME"), "Code/4D-Humans/cache") +# CACHE_DIR_4DHUMANS = os.path.join(CACHE_DIR, "4DHumans") + + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + + +_C = CN(new_allowed=True) + +_C.GENERAL = CN(new_allowed=True) +_C.GENERAL.RESUME = True +_C.GENERAL.TIME_TO_RUN = 3300 +_C.GENERAL.VAL_STEPS = 100 +_C.GENERAL.LOG_STEPS = 100 +_C.GENERAL.CHECKPOINT_STEPS = 20000 +_C.GENERAL.CHECKPOINT_DIR = "checkpoints" +_C.GENERAL.SUMMARY_DIR = "tensorboard" +_C.GENERAL.NUM_GPUS = 1 +_C.GENERAL.NUM_WORKERS = 4 +_C.GENERAL.MIXED_PRECISION = True +_C.GENERAL.ALLOW_CUDA = True +_C.GENERAL.PIN_MEMORY = False +_C.GENERAL.DISTRIBUTED = False +_C.GENERAL.LOCAL_RANK = 0 +_C.GENERAL.USE_SYNCBN = False +_C.GENERAL.WORLD_SIZE = 1 + +_C.TRAIN = CN(new_allowed=True) +_C.TRAIN.NUM_EPOCHS = 100 +_C.TRAIN.BATCH_SIZE = 32 +_C.TRAIN.SHUFFLE = True +_C.TRAIN.WARMUP = False +_C.TRAIN.NORMALIZE_PER_IMAGE = False +_C.TRAIN.CLIP_GRAD = False +_C.TRAIN.CLIP_GRAD_VALUE = 1.0 +_C.LOSS_WEIGHTS = CN(new_allowed=True) + +_C.DATASETS = CN(new_allowed=True) + +_C.MODEL = CN(new_allowed=True) +_C.MODEL.IMAGE_SIZE = 224 + +_C.EXTRA = CN(new_allowed=True) +_C.EXTRA.FOCAL_LENGTH = 5000 + +_C.DATASETS.CONFIG = CN(new_allowed=True) +_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 +_C.DATASETS.CONFIG.ROT_FACTOR = 30 +_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 +_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 +_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 +_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.DO_FLIP = True +_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 + + +def default_config() -> CN: + """ + Get a yacs CfgNode object with the default config values. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + + +def dataset_config(name="datasets_tar.yaml") -> CN: + """ + Get dataset config file + Returns: + CfgNode: Dataset config as a yacs CfgNode object. + """ + cfg = CN(new_allowed=True) + config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name) + cfg.merge_from_file(config_file) + cfg.freeze() + return cfg + + +def dataset_eval_config() -> CN: + return dataset_config("datasets_eval.yaml") + + +def get_config(config_file: str, merge: bool = True) -> CN: + """ + Read a config file and optionally merge it with the default config file. + Args: + config_file (str): Path to config file. + merge (bool): Whether to merge with the default config or not. + Returns: + CfgNode: Config as a yacs CfgNode object. + """ + if merge: + cfg = default_config() + else: + cfg = CN(new_allowed=True) + cfg.merge_from_file(config_file) + + # ---- Update ---- # + cfg.SMPL.MODEL_PATH = cfg.SMPL.MODEL_PATH # Not used + cfg.SMPL.JOINT_REGRESSOR_EXTRA = cfg.SMPL.JOINT_REGRESSOR_EXTRA # Not Used + cfg.SMPL.MEAN_PARAMS = str(Path(__file__).parent / "smpl_mean_params.npz") + # ---------------- # + + cfg.freeze() + return cfg diff --git a/hmr4d/network/hmr2/configs/model_config.yaml b/hmr4d/network/hmr2/configs/model_config.yaml new file mode 100644 index 0000000..1374229 --- /dev/null +++ b/hmr4d/network/hmr2/configs/model_config.yaml @@ -0,0 +1,131 @@ +task_name: train +tags: +- dev +train: true +test: false +ckpt_path: null +seed: null +DATASETS: + TRAIN: + H36M-TRAIN: + WEIGHT: 0.3 + MPII-TRAIN: + WEIGHT: 0.1 + COCO-TRAIN-2014: + WEIGHT: 0.4 + MPI-INF-TRAIN: + WEIGHT: 0.2 + VAL: + COCO-VAL: + WEIGHT: 1.0 + MOCAP: CMU-MOCAP + CONFIG: + SCALE_FACTOR: 0.3 + ROT_FACTOR: 30 + TRANS_FACTOR: 0.02 + COLOR_SCALE: 0.2 + ROT_AUG_RATE: 0.6 + TRANS_AUG_RATE: 0.5 + DO_FLIP: true + FLIP_AUG_RATE: 0.5 + EXTREME_CROP_AUG_RATE: 0.1 +trainer: + _target_: pytorch_lightning.Trainer + default_root_dir: ${paths.output_dir} + accelerator: gpu + devices: 8 + deterministic: false + num_sanity_val_steps: 0 + log_every_n_steps: ${GENERAL.LOG_STEPS} + val_check_interval: ${GENERAL.VAL_STEPS} + precision: 16 + max_steps: ${GENERAL.TOTAL_STEPS} + move_metrics_to_cpu: true + limit_val_batches: 1 + track_grad_norm: 2 + strategy: ddp + num_nodes: 1 + sync_batchnorm: true +paths: + root_dir: ${oc.env:PROJECT_ROOT} + data_dir: ${paths.root_dir}/data/ + log_dir: /fsx/shubham/code/hmr2023/logs_hydra/ + output_dir: ${hydra:runtime.output_dir} + work_dir: ${hydra:runtime.cwd} +extras: + ignore_warnings: false + enforce_tags: true + print_config: true +exp_name: 3001d +SMPL: + MODEL_PATH: data/smpl + GENDER: neutral + NUM_BODY_JOINTS: 23 + JOINT_REGRESSOR_EXTRA: data/SMPL_to_J19.pkl + MEAN_PARAMS: data/smpl_mean_params.npz +EXTRA: + FOCAL_LENGTH: 5000 + NUM_LOG_IMAGES: 4 + NUM_LOG_SAMPLES_PER_IMAGE: 8 + PELVIS_IND: 39 +MODEL: + IMAGE_SIZE: 256 + IMAGE_MEAN: + - 0.485 + - 0.456 + - 0.406 + IMAGE_STD: + - 0.229 + - 0.224 + - 0.225 + BACKBONE: + TYPE: vit + FREEZE: true + NUM_LAYERS: 50 + OUT_CHANNELS: 2048 + ADD_NECK: false + FLOW: + DIM: 144 + NUM_LAYERS: 4 + CONTEXT_FEATURES: 2048 + LAYER_HIDDEN_FEATURES: 1024 + LAYER_DEPTH: 2 + FC_HEAD: + NUM_FEATURES: 1024 + SMPL_HEAD: + TYPE: transformer_decoder + IN_CHANNELS: 2048 + TRANSFORMER_DECODER: + depth: 6 + heads: 8 + mlp_dim: 1024 + dim_head: 64 + dropout: 0.0 + emb_dropout: 0.0 + norm: layer + context_dim: 1280 +GENERAL: + TOTAL_STEPS: 100000 + LOG_STEPS: 100 + VAL_STEPS: 100 + CHECKPOINT_STEPS: 1000 + CHECKPOINT_SAVE_TOP_K: -1 + NUM_WORKERS: 6 + PREFETCH_FACTOR: 2 +TRAIN: + LR: 0.0001 + WEIGHT_DECAY: 0.0001 + BATCH_SIZE: 512 + LOSS_REDUCTION: mean + NUM_TRAIN_SAMPLES: 2 + NUM_TEST_SAMPLES: 64 + POSE_2D_NOISE_RATIO: 0.01 + SMPL_PARAM_NOISE_RATIO: 0.005 +LOSS_WEIGHTS: + KEYPOINTS_3D: 0.05 + KEYPOINTS_2D: 0.01 + GLOBAL_ORIENT: 0.001 + BODY_POSE: 0.001 + BETAS: 0.0005 + ADVERSARIAL: 0.0005 +local: {} diff --git a/hmr4d/network/hmr2/configs/smpl_mean_params.npz b/hmr4d/network/hmr2/configs/smpl_mean_params.npz new file mode 100755 index 0000000..b599137 Binary files /dev/null and b/hmr4d/network/hmr2/configs/smpl_mean_params.npz differ diff --git a/hmr4d/network/hmr2/hmr2.py b/hmr4d/network/hmr2/hmr2.py new file mode 100644 index 0000000..b018fb8 --- /dev/null +++ b/hmr4d/network/hmr2/hmr2.py @@ -0,0 +1,55 @@ +import torch +import pytorch_lightning as pl +from yacs.config import CfgNode +from .vit import ViT +from .smpl_head import SMPLTransformerDecoderHead + +from pytorch3d.transforms import matrix_to_axis_angle +from hmr4d.utils.geo.hmr_cam import compute_transl_full_cam + + +class HMR2(pl.LightningModule): + def __init__(self, cfg: CfgNode): + super().__init__() + self.cfg = cfg + self.backbone = ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ) + self.smpl_head = SMPLTransformerDecoderHead(cfg) + + def forward(self, batch, feat_mode=True): + """this file has been modified + Args: + feat_mode: default True, as we only need the feature token output for the HMR4D project; + when False, the full process of HMR2 will be executed. + """ + # Backbone + x = batch["img"][:, :, :, 32:-32] + vit_feats = self.backbone(x) + + # Output head + if feat_mode: + token_out = self.smpl_head(vit_feats, only_return_token_out=True) # (B, 1024) + return token_out + + # return full process + pred_smpl_params, pred_cam, _, token_out = self.smpl_head(vit_feats, only_return_token_out=False) + output = {} + output["token_out"] = token_out + output["smpl_params"] = { + "body_pose": matrix_to_axis_angle(pred_smpl_params["body_pose"]).flatten(-2), # (B, 23, 3) + "betas": pred_smpl_params["betas"], # (B, 10) + "global_orient": matrix_to_axis_angle(pred_smpl_params["global_orient"])[:, 0], # (B, 3) + "transl": compute_transl_full_cam(pred_cam, batch["bbx_xys"], batch["K_fullimg"]), # (B, 3) + } + + return output diff --git a/hmr4d/network/hmr2/smpl_head.py b/hmr4d/network/hmr2/smpl_head.py new file mode 100644 index 0000000..5af5cac --- /dev/null +++ b/hmr4d/network/hmr2/smpl_head.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import einops + +from .utils.geometry import rot6d_to_rotmat, aa_to_rotmat +from .components.pose_transformer import TransformerDecoder + + +class SMPLTransformerDecoderHead(nn.Module): + """Cross-attention based SMPL Transformer decoder""" + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get("JOINT_REP", "6d") + self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type] + npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS + 1) + self.npose = npose + self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get("TRANSFORMER_INPUT", "zero") == "mean_shape" + transformer_args = dict( + num_tokens=1, + token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1, + dim=1024, + ) + transformer_args.update(**dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER)) + self.transformer = TransformerDecoder(**transformer_args) + dim = transformer_args["dim"] + self.decpose = nn.Linear(dim, npose) + self.decshape = nn.Linear(dim, 10) + self.deccam = nn.Linear(dim, 3) + + if cfg.MODEL.SMPL_HEAD.get("INIT_DECODER_XAVIER", False): + # True by default in MLP. False by default in Transformer + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + mean_params = np.load(cfg.SMPL.MEAN_PARAMS) + init_body_pose = torch.from_numpy(mean_params["pose"].astype(np.float32)).unsqueeze(0) + init_betas = torch.from_numpy(mean_params["shape"].astype("float32")).unsqueeze(0) + init_cam = torch.from_numpy(mean_params["cam"].astype(np.float32)).unsqueeze(0) + self.register_buffer("init_body_pose", init_body_pose) + self.register_buffer("init_betas", init_betas) + self.register_buffer("init_cam", init_cam) + + def forward(self, x, only_return_token_out=False): + batch_size = x.shape[0] + # vit pretrained backbone is channel-first. Change to token-first + x = einops.rearrange(x, "b c h w -> b (h w) c") + + init_body_pose = self.init_body_pose.expand(batch_size, -1) + init_betas = self.init_betas.expand(batch_size, -1) + init_cam = self.init_cam.expand(batch_size, -1) + + # TODO: Convert init_body_pose to aa rep if needed + if self.joint_rep_type == "aa": + raise NotImplementedError + + pred_body_pose = init_body_pose + pred_betas = init_betas + pred_cam = init_cam + pred_body_pose_list = [] + pred_betas_list = [] + pred_cam_list = [] + for i in range(self.cfg.MODEL.SMPL_HEAD.get("IEF_ITERS", 1)): + assert i == 0, "Only support 1 iteration for now" + + # Input token to transformer is zero token + if self.input_is_mean_shape: + token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:, None, :] + else: + token = torch.zeros(batch_size, 1, 1).to(x.device) + + # Pass through transformer + token_out = self.transformer(token, context=x) + token_out = token_out.squeeze(1) # (B, C) + + if only_return_token_out: + return token_out + else: + # Readout from token_out + pred_body_pose = self.decpose(token_out) + pred_body_pose + pred_betas = self.decshape(token_out) + pred_betas + pred_cam = self.deccam(token_out) + pred_cam + pred_body_pose_list.append(pred_body_pose) + pred_betas_list.append(pred_betas) + pred_cam_list.append(pred_cam) + + # Convert self.joint_rep_type -> rotmat + joint_conversion_fn = {"6d": rot6d_to_rotmat, "aa": lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())}[ + self.joint_rep_type + ] + + pred_smpl_params_list = {} + pred_smpl_params_list["body_pose"] = torch.cat( + [joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0 + ) + pred_smpl_params_list["betas"] = torch.cat(pred_betas_list, dim=0) + pred_smpl_params_list["cam"] = torch.cat(pred_cam_list, dim=0) + pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.cfg.SMPL.NUM_BODY_JOINTS + 1, 3, 3) + + pred_smpl_params = { + "global_orient": pred_body_pose[:, [0]], + "body_pose": pred_body_pose[:, 1:], + "betas": pred_betas, + } + return pred_smpl_params, pred_cam, pred_smpl_params_list, token_out diff --git a/hmr4d/network/hmr2/utils/geometry.py b/hmr4d/network/hmr2/utils/geometry.py new file mode 100644 index 0000000..e128ba8 --- /dev/null +++ b/hmr4d/network/hmr2/utils/geometry.py @@ -0,0 +1,118 @@ +from typing import Optional +import torch +from torch.nn import functional as F + + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, + 2 * xy - 2 * wz, + 2 * wy + 2 * xz, + 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, + 2 * xz - 2 * wy, + 2 * wx + 2 * yz, + w2 - x2 - y2 + z2, + ], + dim=1, + ).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def perspective_projection( + points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:, 0, 0] = focal_length[:, 0] + K[:, 1, 1] = focal_length[:, 1] + K[:, 2, 2] = 1.0 + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum("bij,bkj->bki", rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum("bij,bkj->bki", K, projected_points) + + return projected_points[:, :, :-1] diff --git a/hmr4d/network/hmr2/utils/preproc.py b/hmr4d/network/hmr2/utils/preproc.py new file mode 100644 index 0000000..3db9dcf --- /dev/null +++ b/hmr4d/network/hmr2/utils/preproc.py @@ -0,0 +1,52 @@ +import cv2 +import numpy as np +import torch +from pathlib import Path + +IMAGE_MEAN = torch.tensor([0.485, 0.456, 0.406]) +IMAGE_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def expand_to_aspect_ratio(input_shape, target_aspect_ratio=[192, 256]): + """Increase the size of the bounding box to match the target shape.""" + if target_aspect_ratio is None: + return input_shape + + try: + w, h = input_shape + except (ValueError, TypeError): + return input_shape + + w_t, h_t = target_aspect_ratio + if h / w < h_t / w_t: + h_new = max(w * h_t / w_t, h) + w_new = w + else: + h_new = h + w_new = max(h * w_t / h_t, w) + if h_new < h or w_new < w: + breakpoint() + return np.array([w_new, h_new]) + + +def crop_and_resize(img, bbx_xy, bbx_s, dst_size=256, enlarge_ratio=1.2): + """ + Args: + img: (H, W, 3) + bbx_xy: (2,) + bbx_s: scalar + """ + hs = bbx_s * enlarge_ratio / 2 + src = np.stack( + [ + bbx_xy - hs, # left-up corner + bbx_xy + np.array([hs, -hs]), # right-up corner + bbx_xy, # center + ] + ).astype(np.float32) + dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32) + A = cv2.getAffineTransform(src, dst) + + img_crop = cv2.warpAffine(img, A, (dst_size, dst_size), flags=cv2.INTER_LINEAR) + bbx_xys_final = np.array([*bbx_xy, bbx_s * enlarge_ratio]) + return img_crop, bbx_xys_final diff --git a/hmr4d/network/hmr2/utils/smpl_wrapper.py b/hmr4d/network/hmr2/utils/smpl_wrapper.py new file mode 100644 index 0000000..839a83d --- /dev/null +++ b/hmr4d/network/hmr2/utils/smpl_wrapper.py @@ -0,0 +1,45 @@ +import torch +import numpy as np +import pickle +from typing import Optional +import smplx +from smplx.lbs import vertices2joints +from smplx.utils import SMPLOutput + + +class SMPL(smplx.SMPLLayer): + def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs): + """ + Extension of the official SMPL implementation to support more joints. + Args: + Same as SMPLLayer. + joint_regressor_extra (str): Path to extra joint regressor. + """ + super(SMPL, self).__init__(*args, **kwargs) + smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] + + if joint_regressor_extra is not None: + self.register_buffer( + "joint_regressor_extra", + torch.tensor(pickle.load(open(joint_regressor_extra, "rb"), encoding="latin1"), dtype=torch.float32), + ) + self.register_buffer("joint_map", torch.tensor(smpl_to_openpose, dtype=torch.long)) + self.update_hips = update_hips + + def forward(self, *args, **kwargs) -> SMPLOutput: + """ + Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified. + """ + smpl_output = super(SMPL, self).forward(*args, **kwargs) + joints = smpl_output.joints[:, self.joint_map, :] + if self.update_hips: + joints[:, [9, 12]] = ( + joints[:, [9, 12]] + + 0.25 * (joints[:, [9, 12]] - joints[:, [12, 9]]) + + 0.5 * (joints[:, [8]] - 0.5 * (joints[:, [9, 12]] + joints[:, [12, 9]])) + ) + if hasattr(self, "joint_regressor_extra"): + extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices) + joints = torch.cat([joints, extra_joints], dim=1) + smpl_output.joints = joints + return smpl_output diff --git a/hmr4d/network/hmr2/vit.py b/hmr4d/network/hmr2/vit.py new file mode 100644 index 0000000..c56c718 --- /dev/null +++ b/hmr4d/network/hmr2/vit.py @@ -0,0 +1,348 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +def vit(cfg): + return ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ) + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/hmr4d/utils/body_model/README.md b/hmr4d/utils/body_model/README.md new file mode 100644 index 0000000..03397aa --- /dev/null +++ b/hmr4d/utils/body_model/README.md @@ -0,0 +1,3 @@ +# README + +Contents of this folder are modified from HuMoR repository. \ No newline at end of file diff --git a/hmr4d/utils/body_model/__init__.py b/hmr4d/utils/body_model/__init__.py new file mode 100644 index 0000000..c8a3598 --- /dev/null +++ b/hmr4d/utils/body_model/__init__.py @@ -0,0 +1,3 @@ +from .body_model import BodyModel +from .body_model_smplh import BodyModelSMPLH +from .body_model_smplx import BodyModelSMPLX diff --git a/hmr4d/utils/body_model/body_model.py b/hmr4d/utils/body_model/body_model.py new file mode 100644 index 0000000..5f8ac47 --- /dev/null +++ b/hmr4d/utils/body_model/body_model.py @@ -0,0 +1,127 @@ +from turtle import forward +import numpy as np + +import torch +import torch.nn as nn + +from smplx import SMPL, SMPLH, SMPLX +from smplx.vertex_ids import vertex_ids +from smplx.utils import Struct + + +class BodyModel(nn.Module): + """ + Wrapper around SMPLX body model class. + modified by Zehong Shen + """ + + def __init__(self, + bm_path, + num_betas=16, + use_vtx_selector=False, + model_type='smplh'): + super().__init__() + ''' + Creates the body model object at the given path. + + :param bm_path: path to the body model pkl file + :param model_type: one of [smpl, smplh, smplx] + :param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints + ''' + self.use_vtx_selector = use_vtx_selector + cur_vertex_ids = None + if self.use_vtx_selector: + cur_vertex_ids = vertex_ids[model_type] + data_struct = None + if '.npz' in bm_path: + # smplx does not support .npz by default, so have to load in manually + smpl_dict = np.load(bm_path, encoding='latin1') + data_struct = Struct(**smpl_dict) + # print(smpl_dict.files) + if model_type == 'smplh': + data_struct.hands_componentsl = np.zeros((0)) + data_struct.hands_componentsr = np.zeros((0)) + data_struct.hands_meanl = np.zeros((15 * 3)) + data_struct.hands_meanr = np.zeros((15 * 3)) + V, D, B = data_struct.shapedirs.shape + data_struct.shapedirs = np.concatenate([data_struct.shapedirs, np.zeros( + (V, D, SMPL.SHAPE_SPACE_DIM-B))], axis=-1) # super hacky way to let smplh use 16-size beta + kwargs = { + 'model_type': model_type, + 'data_struct': data_struct, + 'num_betas': num_betas, + 'vertex_ids': cur_vertex_ids, + 'use_pca': False, + 'flat_hand_mean': True, + # - enable variable batchsize, since we don't need module variable - # + 'create_body_pose': False, + 'create_betas': False, + 'create_global_orient': False, + 'create_transl': False, + 'create_left_hand_pose': False, + 'create_right_hand_pose': False, + } + assert(model_type in ['smpl', 'smplh', 'smplx']) + if model_type == 'smpl': + self.bm = SMPL(bm_path, **kwargs) + self.num_joints = SMPL.NUM_JOINTS + elif model_type == 'smplh': + self.bm = SMPLH(bm_path, **kwargs) + self.num_joints = SMPLH.NUM_JOINTS + elif model_type == 'smplx': + self.bm = SMPLX(bm_path, **kwargs) + self.num_joints = SMPLX.NUM_JOINTS + + self.model_type = model_type + + def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None, + trans=None, dmpls=None, expression=None, return_dict=False, **kwargs): + ''' + Note dmpls are not supported. + ''' + assert(dmpls is None) + B = pose_body.shape[0] + if pose_hand is None: + pose_hand = torch.zeros((B, 2*SMPLH.NUM_HAND_JOINTS*3), device=pose_body.device) + if len(betas.shape) == 1: + betas = betas.reshape((1, -1)).expand(B, -1) + + out_obj = self.bm( + betas=betas, + global_orient=root_orient, + body_pose=pose_body, + left_hand_pose=pose_hand[:, :(SMPLH.NUM_HAND_JOINTS*3)], + right_hand_pose=pose_hand[:, (SMPLH.NUM_HAND_JOINTS*3):], + transl=trans, + expression=expression, + jaw_pose=pose_jaw, + leye_pose=None if pose_eye is None else pose_eye[:, :3], + reye_pose=None if pose_eye is None else pose_eye[:, 3:], + return_full_pose=True, + **kwargs + ) + + out = { + 'v': out_obj.vertices, + 'f': self.bm.faces_tensor, + 'Jtr': out_obj.joints, + } + + if not self.use_vtx_selector: + # don't need extra joints + out['Jtr'] = out['Jtr'][:, :self.num_joints+1] # add one for the root + + if not return_dict: + out = Struct(**out) + + return out + + def forward_motion(self, **kwargs): + B, W, _ = kwargs['pose_body'].shape + kwargs = {k: v.reshape(B*W, v.shape[-1]) for k, v in kwargs.items()} + + smpl_opt = self.forward(**kwargs) + smpl_opt.v = smpl_opt.v.reshape(B, W, -1, 3) + smpl_opt.Jtr = smpl_opt.Jtr.reshape(B, W, -1, 3) + + return smpl_opt diff --git a/hmr4d/utils/body_model/body_model_smplh.py b/hmr4d/utils/body_model/body_model_smplh.py new file mode 100644 index 0000000..27c30fa --- /dev/null +++ b/hmr4d/utils/body_model/body_model_smplh.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import smplx + +kwargs_disable_member_var = { + "create_body_pose": False, + "create_betas": False, + "create_global_orient": False, + "create_transl": False, + "create_left_hand_pose": False, + "create_right_hand_pose": False, +} + + +class BodyModelSMPLH(nn.Module): + """Support Batch inference""" + + def __init__(self, model_path, **kwargs): + super().__init__() + # enable flexible batchsize, handle missing variable at forward() + kwargs.update(kwargs_disable_member_var) + self.bm = smplx.create(model_path=model_path, **kwargs) + self.faces = self.bm.faces + self.is_smpl = kwargs.get("model_type", "smpl") == "smpl" + if not self.is_smpl: + self.hand_pose_dim = self.bm.num_pca_comps if self.bm.use_pca else 3 * self.bm.NUM_HAND_JOINTS + + # For fast computing of skeleton under beta + shapedirs = self.bm.shapedirs # (V, 3, 10) + J_regressor = self.bm.J_regressor[:22, :] # (22, V) + v_template = self.bm.v_template # (V, 3) + J_template = J_regressor @ v_template # (22, 3) + J_shapedirs = torch.einsum("jv, vcd -> jcd", J_regressor, shapedirs) # (22, 3, 10) + self.register_buffer("J_template", J_template, False) + self.register_buffer("J_shapedirs", J_shapedirs, False) + + def forward( + self, + betas=None, + global_orient=None, + transl=None, + body_pose=None, + left_hand_pose=None, + right_hand_pose=None, + **kwargs + ): + + device, dtype = self.bm.shapedirs.device, self.bm.shapedirs.dtype + + model_vars = [betas, global_orient, body_pose, transl, left_hand_pose, right_hand_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if body_pose is None: + body_pose = ( + torch.zeros(3 * self.bm.NUM_BODY_JOINTS, device=device, dtype=dtype)[None] + .expand(batch_size, -1) + .contiguous() + ) + if not self.is_smpl: + if left_hand_pose is None: + left_hand_pose = ( + torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None] + .expand(batch_size, -1) + .contiguous() + ) + if right_hand_pose is None: + right_hand_pose = ( + torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None] + .expand(batch_size, -1) + .contiguous() + ) + if betas is None: + betas = torch.zeros([batch_size, self.bm.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + bm_out = self.bm( + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + transl=transl, + **kwargs + ) + + return bm_out + + def get_skeleton(self, betas): + """betas: (*, 10) -> skeleton_beta: (*, 22, 3)""" + skeleton_beta = self.J_template + torch.einsum("...d, jcd -> ...jc", betas, self.J_shapedirs) # (22, 3) + return skeleton_beta diff --git a/hmr4d/utils/body_model/body_model_smplx.py b/hmr4d/utils/body_model/body_model_smplx.py new file mode 100644 index 0000000..ccc8aab --- /dev/null +++ b/hmr4d/utils/body_model/body_model_smplx.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import smplx + +kwargs_disable_member_var = { + "create_body_pose": False, + "create_betas": False, + "create_global_orient": False, + "create_transl": False, + "create_left_hand_pose": False, + "create_right_hand_pose": False, + "create_expression": False, + "create_jaw_pose": False, + "create_leye_pose": False, + "create_reye_pose": False, +} + + +class BodyModelSMPLX(nn.Module): + """Support Batch inference""" + + def __init__(self, model_path, **kwargs): + super().__init__() + # enable flexible batchsize, handle missing variable at forward() + kwargs.update(kwargs_disable_member_var) + self.bm = smplx.create(model_path=model_path, **kwargs) + self.faces = self.bm.faces + self.hand_pose_dim = self.bm.num_pca_comps if self.bm.use_pca else 3 * self.bm.NUM_HAND_JOINTS + + # For fast computing of skeleton under beta + shapedirs = self.bm.shapedirs # (V, 3, 10) + J_regressor = self.bm.J_regressor[:22, :] # (22, V) + v_template = self.bm.v_template # (V, 3) + J_template = J_regressor @ v_template # (22, 3) + J_shapedirs = torch.einsum("jv, vcd -> jcd", J_regressor, shapedirs) # (22, 3, 10) + self.register_buffer("J_template", J_template, False) + self.register_buffer("J_shapedirs", J_shapedirs, False) + + def forward( + self, + betas=None, + global_orient=None, + transl=None, + body_pose=None, + left_hand_pose=None, + right_hand_pose=None, + expression=None, + jaw_pose=None, + leye_pose=None, + reye_pose=None, + **kwargs + ): + + device, dtype = self.bm.shapedirs.device, self.bm.shapedirs.dtype + + model_vars = [ + betas, + global_orient, + body_pose, + transl, + expression, + left_hand_pose, + right_hand_pose, + jaw_pose, + leye_pose, + reye_pose, + ] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if body_pose is None: + body_pose = ( + torch.zeros(3 * self.bm.NUM_BODY_JOINTS, device=device, dtype=dtype)[None] + .expand(batch_size, -1) + .contiguous() + ) + if left_hand_pose is None: + left_hand_pose = ( + torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None].expand(batch_size, -1).contiguous() + ) + if right_hand_pose is None: + right_hand_pose = ( + torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None].expand(batch_size, -1).contiguous() + ) + if jaw_pose is None: + jaw_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if leye_pose is None: + leye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if reye_pose is None: + reye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if expression is None: + expression = torch.zeros([batch_size, self.bm.num_expression_coeffs], dtype=dtype, device=device) + if betas is None: + betas = torch.zeros([batch_size, self.bm.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + bm_out = self.bm( + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + transl=transl, + expression=expression, + jaw_pose=jaw_pose, + leye_pose=leye_pose, + reye_pose=reye_pose, + **kwargs + ) + + return bm_out + + def get_skeleton(self, betas): + """betas: (*, 10) -> skeleton_beta: (*, 22, 3)""" + skeleton_beta = self.J_template + torch.einsum("...d, jcd -> ...jc", betas, self.J_shapedirs) # (22, 3) + return skeleton_beta + + def forward_bfc(self, **kwargs): + """Wrap (B, F, C) to (B*F, C) and unwrap (B*F, C) to (B, F, C)""" + for k in kwargs: + assert len(kwargs[k].shape) == 3 + B, F = kwargs["body_pose"].shape[:2] + smplx_out = self.forward(**{k: v.reshape(B * F, -1) for k, v in kwargs.items()}) + smplx_out.vertices = smplx_out.vertices.reshape(B, F, -1, 3) + smplx_out.joints = smplx_out.joints.reshape(B, F, -1, 3) + return smplx_out diff --git a/hmr4d/utils/body_model/coco_aug_dict.pth b/hmr4d/utils/body_model/coco_aug_dict.pth new file mode 100644 index 0000000..e508a4f Binary files /dev/null and b/hmr4d/utils/body_model/coco_aug_dict.pth differ diff --git a/hmr4d/utils/body_model/min_lbs.py b/hmr4d/utils/body_model/min_lbs.py new file mode 100644 index 0000000..8c7fbd9 --- /dev/null +++ b/hmr4d/utils/body_model/min_lbs.py @@ -0,0 +1,106 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch3d.transforms import axis_angle_to_matrix +from smplx.utils import Struct, to_np, to_tensor +from hmr4d.utils.smplx_utils import forward_kinematics_motion + + +class MinimalLBS(nn.Module): + def __init__(self, sp_ids, bm_dir='models/smplh', num_betas=16, model_type='smplh', **kwargs): + super().__init__() + self.num_betas = num_betas + self.sensor_point_vid = torch.tensor(sp_ids) + + # load struct data on predefined sensor-point + self.load_struct_on_sp(f'{bm_dir}/male/model.npz', prefix='male') + self.load_struct_on_sp(f'{bm_dir}/female/model.npz', prefix='female') + + def load_struct_on_sp(self, bm_path, prefix='m'): + """ + Load 4 weights from body-model-struct. + Keep the sensor points only. Use prefix to label different bm. + """ + num_betas = self.num_betas + sp_vid = self.sensor_point_vid + # load data + data_struct = Struct(**np.load(bm_path, encoding='latin1')) + + # v-template + v_template = to_tensor(to_np(data_struct.v_template)) # (V, 3) + v_template_sp = v_template[sp_vid] # (N, 3) + self.register_buffer(f'{prefix}_v_template_sp', v_template_sp, False) + + # shapedirs + shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])) # (V, 3, NB) + shapedirs_sp = shapedirs[sp_vid] + self.register_buffer(f'{prefix}_shapedirs_sp', shapedirs_sp, False) + + # posedirs + posedirs = to_tensor(to_np(data_struct.posedirs)) # (V, 3, 51*9) + posedirs_sp = posedirs[sp_vid] + posedirs_sp = posedirs_sp.reshape(len(sp_vid)*3, -1).T # (51*9, N*3) + self.register_buffer(f'{prefix}_posedirs_sp', posedirs_sp, False) + + # lbs_weights + lbs_weights = to_tensor(to_np(data_struct.weights)) # (V, J+1) + lbs_weights_sp = lbs_weights[sp_vid] + self.register_buffer(f'{prefix}_lbs_weights_sp', lbs_weights_sp, False) + + def forward(self, root_orient=None, pose_body=None, trans=None, betas=None, A=None, recompute_A=False, genders=None, + joints_zero=None): + """ + Args: + root_orient, Optional: (B, T, 3) + pose_body: (B, T, J*3) + trans: (B, T, 3) + betas: (B, T, 16) + A, Optional: (B, T, J+1, 4, 4) + recompute_A: if True, root_orient should be given, otherwise use A + genders, List: ['male', 'female', ...] + joints_zero: (B, J+1, 3), required when recompute_A is True + Returns: + sensor_verts: (B, T, N, 3) + """ + B, T = pose_body.shape[:2] + + v_template = torch.stack([getattr(self, f'{g}_v_template_sp') for g in genders]) # (B, N, 3) + shapedirs = torch.stack([getattr(self, f'{g}_shapedirs_sp') for g in genders]) # (B, N, 3, NB) + posedirs = torch.stack([getattr(self, f'{g}_posedirs_sp') for g in genders]) # (B, 51*9, N*3) + lbs_weights = torch.stack([getattr(self, f'{g}_lbs_weights_sp') for g in genders]) # (B, N, J+1) + + # ===== LBS, handle T ===== # + # 2. Add shape contribution + if betas.shape[1] == 1: + betas = betas.expand(-1, T, -1) + blend_shape = torch.einsum('btl,bmkl->btmk', [betas, shapedirs]) + v_shaped = v_template[:, None] + blend_shape + + # 3. Add pose blend shapes + ident = torch.eye(3).to(pose_body) + aa = pose_body.reshape(B, T, -1, 3) + R = axis_angle_to_matrix(aa) + pose_feature = (R - ident).view(B, T, -1) + dim_pf = pose_feature.shape[-1] + # (B, T, P) @ (B, P, N*3) -> (B, T, N, 3) + pose_offsets = torch.matmul(pose_feature, posedirs[:, :dim_pf]).view(B, T, -1, 3) + v_posed = pose_offsets + v_shaped + + # 4. Compute A + if recompute_A: + _, _, A = forward_kinematics_motion(root_orient, pose_body, trans, joints_zero) + + # 5. Skinning + W = lbs_weights + # (B, 1, N, J+1)) @ (B, T, J+1, 16) + num_joints = A.shape[-3] # 22 + Ts = torch.matmul(W[:, None, :, :num_joints], A.view(B, T, num_joints, 16)) + Ts = Ts.view(B, T, -1, 4, 4) # (B, T, N, 4, 4) + v_posed_homo = F.pad(v_posed, (0, 1), value=1) # (B, T, N, 4) + v_homo = torch.matmul(Ts, torch.unsqueeze(v_posed_homo, dim=-1)) + + # 6. translate + sensor_verts = v_homo[:, :, :, :3, 0] + trans[:, :, None] + + return sensor_verts diff --git a/hmr4d/utils/body_model/seg_part_info.npy b/hmr4d/utils/body_model/seg_part_info.npy new file mode 100644 index 0000000..ddfb247 Binary files /dev/null and b/hmr4d/utils/body_model/seg_part_info.npy differ diff --git a/hmr4d/utils/body_model/smpl_3dpw14_J_regressor_sparse.pt b/hmr4d/utils/body_model/smpl_3dpw14_J_regressor_sparse.pt new file mode 100644 index 0000000..dfddd70 Binary files /dev/null and b/hmr4d/utils/body_model/smpl_3dpw14_J_regressor_sparse.pt differ diff --git a/hmr4d/utils/body_model/smpl_coco17_J_regressor.pt b/hmr4d/utils/body_model/smpl_coco17_J_regressor.pt new file mode 100644 index 0000000..f9e310c Binary files /dev/null and b/hmr4d/utils/body_model/smpl_coco17_J_regressor.pt differ diff --git a/hmr4d/utils/body_model/smpl_lite.py b/hmr4d/utils/body_model/smpl_lite.py new file mode 100644 index 0000000..4e84fc7 --- /dev/null +++ b/hmr4d/utils/body_model/smpl_lite.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from pytorch3d.transforms import axis_angle_to_matrix +from smplx.utils import Struct, to_np, to_tensor +from einops import einsum, rearrange +from time import time + +import pickle + +from .smplx_lite import batch_rigid_transform_v2 + + +class SmplLite(nn.Module): + def __init__( + self, + model_path="inputs/checkpoints/body_models/smpl", + gender="neutral", + num_betas=10, + ): + super().__init__() + + # Load the model + model_path = Path(model_path) + if model_path.is_dir(): + smpl_path = Path(model_path) / f"SMPL_{gender.upper()}.pkl" + else: + smpl_path = model_path + assert smpl_path.exists() + with open(smpl_path, "rb") as smpl_file: + data_struct = Struct(**pickle.load(smpl_file, encoding="latin1")) + self.faces = data_struct.f # (F, 3) + + self.register_smpl_buffers(data_struct, num_betas) + self.register_fast_skeleton_computing_buffers() + + def register_smpl_buffers(self, data_struct, num_betas): + # shapedirs, (V, 3, N_betas), V=10475 for SMPLX + shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])).float() + self.register_buffer("shapedirs", shapedirs, False) + + # v_template, (V, 3) + v_template = to_tensor(to_np(data_struct.v_template)).float() + self.register_buffer("v_template", v_template, False) + + # J_regressor, (J, V), J=55 for SMPLX + J_regressor = to_tensor(to_np(data_struct.J_regressor)).float() + self.register_buffer("J_regressor", J_regressor, False) + + # posedirs, (54*9, V, 3), note that the first global_orient is not included + posedirs = to_tensor(to_np(data_struct.posedirs)).float() # (V, 3, 54*9) + posedirs = rearrange(posedirs, "v c n -> n v c") + self.register_buffer("posedirs", posedirs, False) + + # lbs_weights, (V, J), J=55 + lbs_weights = to_tensor(to_np(data_struct.weights)).float() + self.register_buffer("lbs_weights", lbs_weights, False) + + # parents, (J), long + parents = to_tensor(to_np(data_struct.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer("parents", parents, False) + + def register_fast_skeleton_computing_buffers(self): + # For fast computing of skeleton under beta + J_template = self.J_regressor @ self.v_template # (J, 3) + J_shapedirs = torch.einsum("jv, vcd -> jcd", self.J_regressor, self.shapedirs) # (J, 3, 10) + self.register_buffer("J_template", J_template, False) + self.register_buffer("J_shapedirs", J_shapedirs, False) + + def get_skeleton(self, betas): + return self.J_template + einsum(betas, self.J_shapedirs, "... k, j c k -> ... j c") + + def forward( + self, + body_pose, + betas, + global_orient, + transl, + ): + """ + Args: + body_pose: (B, L, 63) + betas: (B, L, 10) + global_orient: (B, L, 3) + transl: (B, L, 3) + Returns: + vertices: (B, L, V, 3) + """ + # 1. Convert [global_orient, body_pose] to rot_mats + full_pose = torch.cat([global_orient, body_pose], dim=-1) + rot_mats = axis_angle_to_matrix(full_pose.reshape(*full_pose.shape[:-1], full_pose.shape[-1] // 3, 3)) + + # 2. Forward Kinematics + J = self.get_skeleton(betas) # (*, 55, 3) + A = batch_rigid_transform_v2(rot_mats, J, self.parents)[1] + + # 3. Canonical v_posed = v_template + shaped_offsets + pose_offsets + pose_feature = rot_mats[..., 1:, :, :] - rot_mats.new([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + pose_feature = pose_feature.view(*pose_feature.shape[:-3], -1) # (*, 55*3*3) + v_posed = ( + self.v_template + + einsum(betas, self.shapedirs, "... k, v c k -> ... v c") + + einsum(pose_feature, self.posedirs, "... k, k v c -> ... v c") + ) + del pose_feature, rot_mats, full_pose + + # 4. Skinning + T = einsum(self.lbs_weights, A, "v j, ... j c d -> ... v c d") + verts = einsum(T[..., :3, :3], v_posed, "... v c d, ... v d -> ... v c") + T[..., :3, 3] + + # 5. Translation + verts = verts + transl[..., None, :] + return verts + + +class SmplxLiteJ24(SmplLite): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Compute mapping + smpl2j24 = self.J_regressor # (24, 6890) + + jids, smplx_vids = torch.where(smpl2j24 != 0) + interestd = torch.zeros([len(smplx_vids), 24]) + for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)): + interestd[idx, jid] = smpl2j24[jid, smplx_vid] + self.register_buffer("interestd", interestd, False) # (236, 24) + + # Update to vertices of interest + self.v_template = self.v_template[smplx_vids].clone() # (V', 3) + self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K) + self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3) + self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J) + + def forward(self, body_pose, betas, global_orient, transl): + """Returns: joints (*, J, 3). (B, L) or (B,) are both supported.""" + # Use super class's forward to get verts + verts = super().forward(body_pose, betas, global_orient, transl) # (*, 236, 3) + joints = einsum(self.interestd, verts, "v j, ... v c -> ... j c") + return joints diff --git a/hmr4d/utils/body_model/smpl_neutral_J_regressor.pt b/hmr4d/utils/body_model/smpl_neutral_J_regressor.pt new file mode 100644 index 0000000..ba3f5ed Binary files /dev/null and b/hmr4d/utils/body_model/smpl_neutral_J_regressor.pt differ diff --git a/hmr4d/utils/body_model/smpl_vert_segmentation.json b/hmr4d/utils/body_model/smpl_vert_segmentation.json new file mode 100644 index 0000000..b3244cc --- /dev/null +++ b/hmr4d/utils/body_model/smpl_vert_segmentation.json @@ -0,0 +1,7440 @@ +{ + "rightHand": [ + 5442, + 5443, + 5444, + 5445, + 5446, + 5447, + 5448, + 5449, + 5450, + 5451, + 5452, + 5453, + 5454, + 5455, + 5456, + 5457, + 5458, + 5459, + 5460, + 5461, + 5462, + 5463, + 5464, + 5465, + 5466, + 5467, + 5468, + 5469, + 5470, + 5471, + 5472, + 5473, + 5474, + 5475, + 5476, + 5477, + 5478, + 5479, + 5480, + 5481, + 5482, + 5483, + 5484, + 5485, + 5486, + 5487, + 5492, + 5493, + 5494, + 5495, + 5496, + 5497, + 5502, + 5503, + 5504, + 5505, + 5506, + 5507, + 5508, + 5509, + 5510, + 5511, + 5512, + 5513, + 5514, + 5515, + 5516, + 5517, + 5518, + 5519, + 5520, + 5521, + 5522, + 5523, + 5524, + 5525, + 5526, + 5527, + 5530, + 5531, + 5532, + 5533, + 5534, + 5535, + 5536, + 5537, + 5538, + 5539, + 5540, + 5541, + 5542, + 5543, + 5544, + 5545, + 5546, + 5547, + 5548, + 5549, + 5550, + 5551, + 5552, + 5553, + 5554, + 5555, + 5556, + 5557, + 5558, + 5559, + 5560, + 5561, + 5562, + 5569, + 5571, + 5574, + 5575, + 5576, + 5577, + 5578, + 5579, + 5580, + 5581, + 5582, + 5583, + 5588, + 5589, + 5592, + 5593, + 5594, + 5595, + 5596, + 5597, + 5598, + 5599, + 5600, + 5601, + 5602, + 5603, + 5604, + 5605, + 5610, + 5611, + 5612, + 5613, + 5614, + 5621, + 5622, + 5625, + 5631, + 5632, + 5633, + 5634, + 5635, + 5636, + 5637, + 5638, + 5639, + 5640, + 5641, + 5643, + 5644, + 5645, + 5646, + 5649, + 5650, + 5652, + 5653, + 5654, + 5655, + 5656, + 5657, + 5658, + 5659, + 5660, + 5661, + 5662, + 5663, + 5664, + 5667, + 5670, + 5671, + 5672, + 5673, + 5674, + 5675, + 5682, + 5683, + 5684, + 5685, + 5686, + 5687, + 5688, + 5689, + 5690, + 5692, + 5695, + 5697, + 5698, + 5699, + 5700, + 5701, + 5707, + 5708, + 5709, + 5710, + 5711, + 5712, + 5713, + 5714, + 5715, + 5716, + 5717, + 5718, + 5719, + 5720, + 5721, + 5723, + 5724, + 5725, + 5726, + 5727, + 5728, + 5729, + 5730, + 5731, + 5732, + 5735, + 5736, + 5737, + 5738, + 5739, + 5740, + 5745, + 5746, + 5748, + 5749, + 5750, + 5751, + 5752, + 6056, + 6057, + 6066, + 6067, + 6158, + 6159, + 6160, + 6161, + 6162, + 6163, + 6164, + 6165, + 6166, + 6167, + 6168, + 6169, + 6170, + 6171, + 6172, + 6173, + 6174, + 6175, + 6176, + 6177, + 6178, + 6179, + 6180, + 6181, + 6182, + 6183, + 6184, + 6185, + 6186, + 6187, + 6188, + 6189, + 6190, + 6191, + 6192, + 6193, + 6194, + 6195, + 6196, + 6197, + 6198, + 6199, + 6200, + 6201, + 6202, + 6203, + 6204, + 6205, + 6206, + 6207, + 6208, + 6209, + 6210, + 6211, + 6212, + 6213, + 6214, + 6215, + 6216, + 6217, + 6218, + 6219, + 6220, + 6221, + 6222, + 6223, + 6224, + 6225, + 6226, + 6227, + 6228, + 6229, + 6230, + 6231, + 6232, + 6233, + 6234, + 6235, + 6236, + 6237, + 6238, + 6239 + ], + "rightUpLeg": [ + 4320, + 4321, + 4323, + 4324, + 4333, + 4334, + 4335, + 4336, + 4337, + 4338, + 4339, + 4340, + 4356, + 4357, + 4358, + 4359, + 4360, + 4361, + 4362, + 4363, + 4364, + 4365, + 4366, + 4367, + 4383, + 4384, + 4385, + 4386, + 4387, + 4388, + 4389, + 4390, + 4391, + 4392, + 4393, + 4394, + 4395, + 4396, + 4397, + 4398, + 4399, + 4400, + 4401, + 4419, + 4420, + 4421, + 4422, + 4430, + 4431, + 4432, + 4433, + 4434, + 4435, + 4436, + 4437, + 4438, + 4439, + 4440, + 4441, + 4442, + 4443, + 4444, + 4445, + 4446, + 4447, + 4448, + 4449, + 4450, + 4451, + 4452, + 4453, + 4454, + 4455, + 4456, + 4457, + 4458, + 4459, + 4460, + 4461, + 4462, + 4463, + 4464, + 4465, + 4466, + 4467, + 4468, + 4469, + 4470, + 4471, + 4472, + 4473, + 4474, + 4475, + 4476, + 4477, + 4478, + 4479, + 4480, + 4481, + 4482, + 4483, + 4484, + 4485, + 4486, + 4487, + 4488, + 4489, + 4490, + 4491, + 4492, + 4493, + 4494, + 4495, + 4496, + 4497, + 4498, + 4499, + 4500, + 4501, + 4502, + 4503, + 4504, + 4505, + 4506, + 4507, + 4508, + 4509, + 4510, + 4511, + 4512, + 4513, + 4514, + 4515, + 4516, + 4517, + 4518, + 4519, + 4520, + 4521, + 4522, + 4523, + 4524, + 4525, + 4526, + 4527, + 4528, + 4529, + 4530, + 4531, + 4532, + 4623, + 4624, + 4625, + 4626, + 4627, + 4628, + 4629, + 4630, + 4631, + 4632, + 4633, + 4634, + 4645, + 4646, + 4647, + 4648, + 4649, + 4650, + 4651, + 4652, + 4653, + 4654, + 4655, + 4656, + 4657, + 4658, + 4659, + 4660, + 4670, + 4671, + 4672, + 4673, + 4704, + 4705, + 4706, + 4707, + 4708, + 4709, + 4710, + 4711, + 4712, + 4713, + 4745, + 4746, + 4757, + 4758, + 4759, + 4760, + 4801, + 4802, + 4829, + 4834, + 4835, + 4836, + 4837, + 4838, + 4839, + 4840, + 4841, + 4924, + 4925, + 4926, + 4928, + 4929, + 4930, + 4931, + 4932, + 4933, + 4934, + 4935, + 4936, + 4948, + 4949, + 4950, + 4951, + 4952, + 4970, + 4971, + 4972, + 4973, + 4983, + 4984, + 4985, + 4986, + 4987, + 4988, + 4989, + 4990, + 4991, + 4992, + 4993, + 5004, + 5005, + 6546, + 6547, + 6548, + 6549, + 6552, + 6553, + 6554, + 6555, + 6556, + 6873, + 6877 + ], + "leftArm": [ + 626, + 627, + 628, + 629, + 634, + 635, + 680, + 681, + 716, + 717, + 718, + 719, + 769, + 770, + 771, + 772, + 773, + 774, + 775, + 776, + 777, + 778, + 779, + 780, + 784, + 785, + 786, + 787, + 788, + 789, + 790, + 791, + 792, + 793, + 1231, + 1232, + 1233, + 1234, + 1258, + 1259, + 1260, + 1261, + 1271, + 1281, + 1282, + 1310, + 1311, + 1314, + 1315, + 1340, + 1341, + 1342, + 1343, + 1355, + 1356, + 1357, + 1358, + 1376, + 1377, + 1378, + 1379, + 1380, + 1381, + 1382, + 1383, + 1384, + 1385, + 1386, + 1387, + 1388, + 1389, + 1390, + 1391, + 1392, + 1393, + 1394, + 1395, + 1396, + 1397, + 1398, + 1399, + 1400, + 1402, + 1403, + 1405, + 1406, + 1407, + 1408, + 1409, + 1410, + 1411, + 1412, + 1413, + 1414, + 1415, + 1416, + 1428, + 1429, + 1430, + 1431, + 1432, + 1433, + 1438, + 1439, + 1440, + 1441, + 1442, + 1443, + 1444, + 1445, + 1502, + 1505, + 1506, + 1507, + 1508, + 1509, + 1510, + 1538, + 1541, + 1542, + 1543, + 1545, + 1619, + 1620, + 1621, + 1622, + 1631, + 1632, + 1633, + 1634, + 1635, + 1636, + 1637, + 1638, + 1639, + 1640, + 1641, + 1642, + 1645, + 1646, + 1647, + 1648, + 1649, + 1650, + 1651, + 1652, + 1653, + 1654, + 1655, + 1656, + 1658, + 1659, + 1661, + 1662, + 1664, + 1666, + 1667, + 1668, + 1669, + 1670, + 1671, + 1672, + 1673, + 1674, + 1675, + 1676, + 1677, + 1678, + 1679, + 1680, + 1681, + 1682, + 1683, + 1684, + 1696, + 1697, + 1698, + 1703, + 1704, + 1705, + 1706, + 1707, + 1708, + 1709, + 1710, + 1711, + 1712, + 1713, + 1714, + 1715, + 1716, + 1717, + 1718, + 1719, + 1720, + 1725, + 1731, + 1732, + 1733, + 1734, + 1735, + 1737, + 1739, + 1740, + 1745, + 1746, + 1747, + 1748, + 1749, + 1751, + 1761, + 1830, + 1831, + 1844, + 1845, + 1846, + 1850, + 1851, + 1854, + 1855, + 1858, + 1860, + 1865, + 1866, + 1867, + 1869, + 1870, + 1871, + 1874, + 1875, + 1876, + 1877, + 1878, + 1882, + 1883, + 1888, + 1889, + 1892, + 1900, + 1901, + 1902, + 1903, + 1904, + 1909, + 2819, + 2820, + 2821, + 2822, + 2895, + 2896, + 2897, + 2898, + 2899, + 2900, + 2901, + 2902, + 2903, + 2945, + 2946, + 2974, + 2975, + 2976, + 2977, + 2978, + 2979, + 2980, + 2981, + 2982, + 2983, + 2984, + 2985, + 2986, + 2987, + 2988, + 2989, + 2990, + 2991, + 2992, + 2993, + 2994, + 2995, + 2996, + 3002, + 3013 + ], + "leftLeg": [ + 995, + 998, + 999, + 1002, + 1004, + 1005, + 1008, + 1010, + 1012, + 1015, + 1016, + 1018, + 1019, + 1043, + 1044, + 1047, + 1048, + 1049, + 1050, + 1051, + 1052, + 1053, + 1054, + 1055, + 1056, + 1057, + 1058, + 1059, + 1060, + 1061, + 1062, + 1063, + 1064, + 1065, + 1066, + 1067, + 1068, + 1069, + 1070, + 1071, + 1072, + 1073, + 1074, + 1075, + 1076, + 1077, + 1078, + 1079, + 1080, + 1081, + 1082, + 1083, + 1084, + 1085, + 1086, + 1087, + 1088, + 1089, + 1090, + 1091, + 1092, + 1093, + 1094, + 1095, + 1096, + 1097, + 1098, + 1099, + 1100, + 1101, + 1102, + 1103, + 1104, + 1105, + 1106, + 1107, + 1108, + 1109, + 1110, + 1111, + 1112, + 1113, + 1114, + 1115, + 1116, + 1117, + 1118, + 1119, + 1120, + 1121, + 1122, + 1123, + 1124, + 1125, + 1126, + 1127, + 1128, + 1129, + 1130, + 1131, + 1132, + 1133, + 1134, + 1135, + 1136, + 1148, + 1149, + 1150, + 1151, + 1152, + 1153, + 1154, + 1155, + 1156, + 1157, + 1158, + 1175, + 1176, + 1177, + 1178, + 1179, + 1180, + 1181, + 1182, + 1183, + 1369, + 1370, + 1371, + 1372, + 1373, + 1374, + 1375, + 1464, + 1465, + 1466, + 1467, + 1468, + 1469, + 1470, + 1471, + 1472, + 1473, + 1474, + 1522, + 1523, + 1524, + 1525, + 1526, + 1527, + 1528, + 1529, + 1530, + 1531, + 1532, + 3174, + 3175, + 3176, + 3177, + 3178, + 3179, + 3180, + 3181, + 3182, + 3183, + 3184, + 3185, + 3186, + 3187, + 3188, + 3189, + 3190, + 3191, + 3192, + 3193, + 3194, + 3195, + 3196, + 3197, + 3198, + 3199, + 3200, + 3201, + 3202, + 3203, + 3204, + 3205, + 3206, + 3207, + 3208, + 3209, + 3210, + 3319, + 3320, + 3321, + 3322, + 3323, + 3324, + 3325, + 3326, + 3327, + 3328, + 3329, + 3330, + 3331, + 3332, + 3333, + 3334, + 3335, + 3432, + 3433, + 3434, + 3435, + 3436, + 3469, + 3472, + 3473, + 3474 + ], + "leftToeBase": [ + 3211, + 3212, + 3213, + 3214, + 3215, + 3216, + 3217, + 3218, + 3219, + 3220, + 3221, + 3222, + 3223, + 3224, + 3225, + 3226, + 3227, + 3228, + 3229, + 3230, + 3231, + 3232, + 3233, + 3234, + 3235, + 3236, + 3237, + 3238, + 3239, + 3240, + 3241, + 3242, + 3243, + 3244, + 3245, + 3246, + 3247, + 3248, + 3249, + 3250, + 3251, + 3252, + 3253, + 3254, + 3255, + 3256, + 3257, + 3258, + 3259, + 3260, + 3261, + 3262, + 3263, + 3264, + 3265, + 3266, + 3267, + 3268, + 3269, + 3270, + 3271, + 3272, + 3273, + 3274, + 3275, + 3276, + 3277, + 3278, + 3279, + 3280, + 3281, + 3282, + 3283, + 3284, + 3285, + 3286, + 3287, + 3288, + 3289, + 3290, + 3291, + 3292, + 3293, + 3294, + 3295, + 3296, + 3297, + 3298, + 3299, + 3300, + 3301, + 3302, + 3303, + 3304, + 3305, + 3306, + 3307, + 3308, + 3309, + 3310, + 3311, + 3312, + 3313, + 3314, + 3315, + 3316, + 3317, + 3318, + 3336, + 3337, + 3340, + 3342, + 3344, + 3346, + 3348, + 3350, + 3352, + 3354, + 3357, + 3358, + 3360, + 3362 + ], + "leftFoot": [ + 3327, + 3328, + 3329, + 3330, + 3331, + 3332, + 3333, + 3334, + 3335, + 3336, + 3337, + 3338, + 3339, + 3340, + 3341, + 3342, + 3343, + 3344, + 3345, + 3346, + 3347, + 3348, + 3349, + 3350, + 3351, + 3352, + 3353, + 3354, + 3355, + 3356, + 3357, + 3358, + 3359, + 3360, + 3361, + 3362, + 3363, + 3364, + 3365, + 3366, + 3367, + 3368, + 3369, + 3370, + 3371, + 3372, + 3373, + 3374, + 3375, + 3376, + 3377, + 3378, + 3379, + 3380, + 3381, + 3382, + 3383, + 3384, + 3385, + 3386, + 3387, + 3388, + 3389, + 3390, + 3391, + 3392, + 3393, + 3394, + 3395, + 3396, + 3397, + 3398, + 3399, + 3400, + 3401, + 3402, + 3403, + 3404, + 3405, + 3406, + 3407, + 3408, + 3409, + 3410, + 3411, + 3412, + 3413, + 3414, + 3415, + 3416, + 3417, + 3418, + 3419, + 3420, + 3421, + 3422, + 3423, + 3424, + 3425, + 3426, + 3427, + 3428, + 3429, + 3430, + 3431, + 3432, + 3433, + 3434, + 3435, + 3436, + 3437, + 3438, + 3439, + 3440, + 3441, + 3442, + 3443, + 3444, + 3445, + 3446, + 3447, + 3448, + 3449, + 3450, + 3451, + 3452, + 3453, + 3454, + 3455, + 3456, + 3457, + 3458, + 3459, + 3460, + 3461, + 3462, + 3463, + 3464, + 3465, + 3466, + 3467, + 3468, + 3469 + ], + "spine1": [ + 598, + 599, + 600, + 601, + 610, + 611, + 612, + 613, + 614, + 615, + 616, + 617, + 618, + 619, + 620, + 621, + 642, + 645, + 646, + 647, + 652, + 653, + 658, + 659, + 660, + 661, + 668, + 669, + 670, + 671, + 684, + 685, + 686, + 687, + 688, + 689, + 690, + 691, + 692, + 722, + 723, + 724, + 725, + 736, + 750, + 751, + 761, + 764, + 766, + 767, + 794, + 795, + 891, + 892, + 893, + 894, + 925, + 926, + 927, + 928, + 929, + 940, + 941, + 942, + 943, + 1190, + 1191, + 1192, + 1193, + 1194, + 1195, + 1196, + 1197, + 1200, + 1201, + 1202, + 1212, + 1236, + 1252, + 1253, + 1254, + 1255, + 1268, + 1269, + 1270, + 1329, + 1330, + 1348, + 1349, + 1351, + 1420, + 1421, + 1423, + 1424, + 1425, + 1426, + 1436, + 1437, + 1756, + 1757, + 1758, + 2839, + 2840, + 2841, + 2842, + 2843, + 2844, + 2845, + 2846, + 2847, + 2848, + 2849, + 2850, + 2851, + 2870, + 2871, + 2883, + 2906, + 2908, + 3014, + 3017, + 3025, + 3030, + 3033, + 3034, + 3037, + 3039, + 3040, + 3041, + 3042, + 3043, + 3044, + 3076, + 3077, + 3079, + 3480, + 3505, + 3511, + 4086, + 4087, + 4088, + 4089, + 4098, + 4099, + 4100, + 4101, + 4102, + 4103, + 4104, + 4105, + 4106, + 4107, + 4108, + 4109, + 4130, + 4131, + 4134, + 4135, + 4140, + 4141, + 4146, + 4147, + 4148, + 4149, + 4156, + 4157, + 4158, + 4159, + 4172, + 4173, + 4174, + 4175, + 4176, + 4177, + 4178, + 4179, + 4180, + 4210, + 4211, + 4212, + 4213, + 4225, + 4239, + 4240, + 4249, + 4250, + 4255, + 4256, + 4282, + 4283, + 4377, + 4378, + 4379, + 4380, + 4411, + 4412, + 4413, + 4414, + 4415, + 4426, + 4427, + 4428, + 4429, + 4676, + 4677, + 4678, + 4679, + 4680, + 4681, + 4682, + 4683, + 4686, + 4687, + 4688, + 4695, + 4719, + 4735, + 4736, + 4737, + 4740, + 4751, + 4752, + 4753, + 4824, + 4825, + 4828, + 4893, + 4894, + 4895, + 4897, + 4898, + 4899, + 4908, + 4909, + 5223, + 5224, + 5225, + 6300, + 6301, + 6302, + 6303, + 6304, + 6305, + 6306, + 6307, + 6308, + 6309, + 6310, + 6311, + 6312, + 6331, + 6332, + 6342, + 6366, + 6367, + 6475, + 6477, + 6478, + 6481, + 6482, + 6485, + 6487, + 6488, + 6489, + 6490, + 6491, + 6878 + ], + "spine2": [ + 570, + 571, + 572, + 573, + 584, + 585, + 586, + 587, + 588, + 589, + 590, + 591, + 592, + 593, + 594, + 595, + 596, + 597, + 602, + 603, + 604, + 605, + 606, + 607, + 608, + 609, + 622, + 623, + 624, + 625, + 638, + 639, + 640, + 641, + 643, + 644, + 648, + 649, + 650, + 651, + 666, + 667, + 672, + 673, + 674, + 675, + 680, + 681, + 682, + 683, + 693, + 694, + 695, + 696, + 697, + 698, + 699, + 700, + 701, + 702, + 703, + 704, + 713, + 714, + 715, + 716, + 717, + 726, + 727, + 728, + 729, + 730, + 731, + 732, + 733, + 735, + 737, + 738, + 739, + 740, + 741, + 742, + 743, + 744, + 745, + 746, + 747, + 748, + 749, + 752, + 753, + 754, + 755, + 756, + 757, + 758, + 759, + 760, + 762, + 763, + 803, + 804, + 805, + 806, + 811, + 812, + 813, + 814, + 817, + 818, + 819, + 820, + 821, + 824, + 825, + 826, + 827, + 828, + 895, + 896, + 930, + 931, + 1198, + 1199, + 1213, + 1214, + 1215, + 1216, + 1217, + 1218, + 1219, + 1220, + 1235, + 1237, + 1256, + 1257, + 1271, + 1272, + 1273, + 1279, + 1280, + 1283, + 1284, + 1285, + 1286, + 1287, + 1288, + 1289, + 1290, + 1291, + 1292, + 1293, + 1294, + 1295, + 1296, + 1297, + 1298, + 1299, + 1300, + 1301, + 1302, + 1303, + 1304, + 1305, + 1306, + 1307, + 1308, + 1309, + 1312, + 1313, + 1319, + 1320, + 1346, + 1347, + 1350, + 1352, + 1401, + 1417, + 1418, + 1419, + 1422, + 1427, + 1434, + 1435, + 1503, + 1504, + 1536, + 1537, + 1544, + 1545, + 1753, + 1754, + 1755, + 1759, + 1760, + 1761, + 1762, + 1763, + 1808, + 1809, + 1810, + 1811, + 1816, + 1817, + 1818, + 1819, + 1820, + 1834, + 1835, + 1836, + 1837, + 1838, + 1839, + 1868, + 1879, + 1880, + 2812, + 2813, + 2852, + 2853, + 2854, + 2855, + 2856, + 2857, + 2858, + 2859, + 2860, + 2861, + 2862, + 2863, + 2864, + 2865, + 2866, + 2867, + 2868, + 2869, + 2872, + 2875, + 2876, + 2877, + 2878, + 2881, + 2882, + 2884, + 2885, + 2886, + 2904, + 2905, + 2907, + 2931, + 2932, + 2933, + 2934, + 2935, + 2936, + 2937, + 2941, + 2950, + 2951, + 2952, + 2953, + 2954, + 2955, + 2956, + 2957, + 2958, + 2959, + 2960, + 2961, + 2962, + 2963, + 2964, + 2965, + 2966, + 2967, + 2968, + 2969, + 2970, + 2971, + 2972, + 2973, + 2997, + 2998, + 3006, + 3007, + 3012, + 3015, + 3026, + 3027, + 3028, + 3029, + 3031, + 3032, + 3035, + 3036, + 3038, + 3059, + 3060, + 3061, + 3062, + 3063, + 3064, + 3065, + 3066, + 3067, + 3073, + 3074, + 3075, + 3078, + 3168, + 3169, + 3171, + 3470, + 3471, + 3482, + 3483, + 3495, + 3496, + 3497, + 3498, + 3506, + 3508, + 4058, + 4059, + 4060, + 4061, + 4072, + 4073, + 4074, + 4075, + 4076, + 4077, + 4078, + 4079, + 4080, + 4081, + 4082, + 4083, + 4084, + 4085, + 4090, + 4091, + 4092, + 4093, + 4094, + 4095, + 4096, + 4097, + 4110, + 4111, + 4112, + 4113, + 4126, + 4127, + 4128, + 4129, + 4132, + 4133, + 4136, + 4137, + 4138, + 4139, + 4154, + 4155, + 4160, + 4161, + 4162, + 4163, + 4168, + 4169, + 4170, + 4171, + 4181, + 4182, + 4183, + 4184, + 4185, + 4186, + 4187, + 4188, + 4189, + 4190, + 4191, + 4192, + 4201, + 4202, + 4203, + 4204, + 4207, + 4214, + 4215, + 4216, + 4217, + 4218, + 4219, + 4220, + 4221, + 4223, + 4224, + 4226, + 4227, + 4228, + 4229, + 4230, + 4231, + 4232, + 4233, + 4234, + 4235, + 4236, + 4237, + 4238, + 4241, + 4242, + 4243, + 4244, + 4245, + 4246, + 4247, + 4248, + 4251, + 4252, + 4291, + 4292, + 4293, + 4294, + 4299, + 4300, + 4301, + 4302, + 4305, + 4306, + 4307, + 4308, + 4309, + 4312, + 4313, + 4314, + 4315, + 4381, + 4382, + 4416, + 4417, + 4684, + 4685, + 4696, + 4697, + 4698, + 4699, + 4700, + 4701, + 4702, + 4703, + 4718, + 4720, + 4738, + 4739, + 4754, + 4755, + 4756, + 4761, + 4762, + 4765, + 4766, + 4767, + 4768, + 4769, + 4770, + 4771, + 4772, + 4773, + 4774, + 4775, + 4776, + 4777, + 4778, + 4779, + 4780, + 4781, + 4782, + 4783, + 4784, + 4785, + 4786, + 4787, + 4788, + 4789, + 4792, + 4793, + 4799, + 4800, + 4822, + 4823, + 4826, + 4827, + 4874, + 4890, + 4891, + 4892, + 4896, + 4900, + 4907, + 4910, + 4975, + 4976, + 5007, + 5008, + 5013, + 5014, + 5222, + 5226, + 5227, + 5228, + 5229, + 5230, + 5269, + 5270, + 5271, + 5272, + 5277, + 5278, + 5279, + 5280, + 5281, + 5295, + 5296, + 5297, + 5298, + 5299, + 5300, + 5329, + 5340, + 5341, + 6273, + 6274, + 6313, + 6314, + 6315, + 6316, + 6317, + 6318, + 6319, + 6320, + 6321, + 6322, + 6323, + 6324, + 6325, + 6326, + 6327, + 6328, + 6329, + 6330, + 6333, + 6336, + 6337, + 6340, + 6341, + 6343, + 6344, + 6345, + 6363, + 6364, + 6365, + 6390, + 6391, + 6392, + 6393, + 6394, + 6395, + 6396, + 6398, + 6409, + 6410, + 6411, + 6412, + 6413, + 6414, + 6415, + 6416, + 6417, + 6418, + 6419, + 6420, + 6421, + 6422, + 6423, + 6424, + 6425, + 6426, + 6427, + 6428, + 6429, + 6430, + 6431, + 6432, + 6456, + 6457, + 6465, + 6466, + 6476, + 6479, + 6480, + 6483, + 6484, + 6486, + 6496, + 6497, + 6498, + 6499, + 6500, + 6501, + 6502, + 6503, + 6879 + ], + "leftShoulder": [ + 591, + 604, + 605, + 606, + 609, + 634, + 635, + 636, + 637, + 674, + 706, + 707, + 708, + 709, + 710, + 711, + 712, + 713, + 715, + 717, + 730, + 733, + 734, + 735, + 781, + 782, + 783, + 1238, + 1239, + 1240, + 1241, + 1242, + 1243, + 1244, + 1245, + 1290, + 1291, + 1294, + 1316, + 1317, + 1318, + 1401, + 1402, + 1403, + 1404, + 1509, + 1535, + 1545, + 1808, + 1810, + 1811, + 1812, + 1813, + 1814, + 1815, + 1818, + 1819, + 1821, + 1822, + 1823, + 1824, + 1825, + 1826, + 1827, + 1828, + 1829, + 1830, + 1831, + 1832, + 1833, + 1837, + 1840, + 1841, + 1842, + 1843, + 1844, + 1845, + 1846, + 1847, + 1848, + 1849, + 1850, + 1851, + 1852, + 1853, + 1854, + 1855, + 1856, + 1857, + 1858, + 1859, + 1861, + 1862, + 1863, + 1864, + 1872, + 1873, + 1880, + 1881, + 1884, + 1885, + 1886, + 1887, + 1890, + 1891, + 1893, + 1894, + 1895, + 1896, + 1897, + 1898, + 1899, + 2879, + 2880, + 2881, + 2886, + 2887, + 2888, + 2889, + 2890, + 2891, + 2892, + 2893, + 2894, + 2903, + 2938, + 2939, + 2940, + 2941, + 2942, + 2943, + 2944, + 2945, + 2946, + 2947, + 2948, + 2949, + 2965, + 2967, + 2969, + 2999, + 3000, + 3001, + 3002, + 3003, + 3004, + 3005, + 3008, + 3009, + 3010, + 3011 + ], + "rightShoulder": [ + 4077, + 4091, + 4092, + 4094, + 4095, + 4122, + 4123, + 4124, + 4125, + 4162, + 4194, + 4195, + 4196, + 4197, + 4198, + 4199, + 4200, + 4201, + 4203, + 4207, + 4218, + 4219, + 4222, + 4223, + 4269, + 4270, + 4271, + 4721, + 4722, + 4723, + 4724, + 4725, + 4726, + 4727, + 4728, + 4773, + 4774, + 4778, + 4796, + 4797, + 4798, + 4874, + 4875, + 4876, + 4877, + 4982, + 5006, + 5014, + 5269, + 5271, + 5272, + 5273, + 5274, + 5275, + 5276, + 5279, + 5281, + 5282, + 5283, + 5284, + 5285, + 5286, + 5287, + 5288, + 5289, + 5290, + 5291, + 5292, + 5293, + 5294, + 5298, + 5301, + 5302, + 5303, + 5304, + 5305, + 5306, + 5307, + 5308, + 5309, + 5310, + 5311, + 5312, + 5313, + 5314, + 5315, + 5316, + 5317, + 5318, + 5319, + 5320, + 5322, + 5323, + 5324, + 5325, + 5333, + 5334, + 5341, + 5342, + 5345, + 5346, + 5347, + 5348, + 5351, + 5352, + 5354, + 5355, + 5356, + 5357, + 5358, + 5359, + 5360, + 6338, + 6339, + 6340, + 6345, + 6346, + 6347, + 6348, + 6349, + 6350, + 6351, + 6352, + 6353, + 6362, + 6397, + 6398, + 6399, + 6400, + 6401, + 6402, + 6403, + 6404, + 6405, + 6406, + 6407, + 6408, + 6424, + 6425, + 6428, + 6458, + 6459, + 6460, + 6461, + 6462, + 6463, + 6464, + 6467, + 6468, + 6469, + 6470 + ], + "rightFoot": [ + 6727, + 6728, + 6729, + 6730, + 6731, + 6732, + 6733, + 6734, + 6735, + 6736, + 6737, + 6738, + 6739, + 6740, + 6741, + 6742, + 6743, + 6744, + 6745, + 6746, + 6747, + 6748, + 6749, + 6750, + 6751, + 6752, + 6753, + 6754, + 6755, + 6756, + 6757, + 6758, + 6759, + 6760, + 6761, + 6762, + 6763, + 6764, + 6765, + 6766, + 6767, + 6768, + 6769, + 6770, + 6771, + 6772, + 6773, + 6774, + 6775, + 6776, + 6777, + 6778, + 6779, + 6780, + 6781, + 6782, + 6783, + 6784, + 6785, + 6786, + 6787, + 6788, + 6789, + 6790, + 6791, + 6792, + 6793, + 6794, + 6795, + 6796, + 6797, + 6798, + 6799, + 6800, + 6801, + 6802, + 6803, + 6804, + 6805, + 6806, + 6807, + 6808, + 6809, + 6810, + 6811, + 6812, + 6813, + 6814, + 6815, + 6816, + 6817, + 6818, + 6819, + 6820, + 6821, + 6822, + 6823, + 6824, + 6825, + 6826, + 6827, + 6828, + 6829, + 6830, + 6831, + 6832, + 6833, + 6834, + 6835, + 6836, + 6837, + 6838, + 6839, + 6840, + 6841, + 6842, + 6843, + 6844, + 6845, + 6846, + 6847, + 6848, + 6849, + 6850, + 6851, + 6852, + 6853, + 6854, + 6855, + 6856, + 6857, + 6858, + 6859, + 6860, + 6861, + 6862, + 6863, + 6864, + 6865, + 6866, + 6867, + 6868, + 6869 + ], + "head": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 220, + 221, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 303, + 304, + 306, + 307, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 330, + 331, + 332, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 454, + 455, + 456, + 457, + 458, + 459, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 513, + 514, + 515, + 516, + 517, + 518, + 519, + 520, + 521, + 522, + 523, + 524, + 525, + 526, + 527, + 528, + 529, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + 538, + 539, + 540, + 541, + 542, + 543, + 544, + 545, + 546, + 547, + 548, + 549, + 550, + 551, + 552, + 553, + 554, + 555, + 556, + 557, + 558, + 559, + 560, + 561, + 562, + 563, + 564, + 565, + 566, + 567, + 568, + 569, + 574, + 575, + 576, + 577, + 578, + 579, + 580, + 581, + 582, + 583, + 1764, + 1765, + 1766, + 1770, + 1771, + 1772, + 1773, + 1774, + 1775, + 1776, + 1777, + 1778, + 1905, + 1906, + 1907, + 1908, + 2779, + 2780, + 2781, + 2782, + 2783, + 2784, + 2785, + 2786, + 2787, + 2788, + 2789, + 2790, + 2791, + 2792, + 2793, + 2794, + 2795, + 2796, + 2797, + 2798, + 2799, + 2800, + 2801, + 2802, + 2803, + 2804, + 2805, + 2806, + 2807, + 2808, + 2809, + 2810, + 2811, + 2814, + 2815, + 2816, + 2817, + 2818, + 3045, + 3046, + 3047, + 3048, + 3051, + 3052, + 3053, + 3054, + 3055, + 3056, + 3058, + 3069, + 3070, + 3071, + 3072, + 3161, + 3162, + 3163, + 3165, + 3166, + 3167, + 3485, + 3486, + 3487, + 3488, + 3489, + 3490, + 3491, + 3492, + 3493, + 3494, + 3499, + 3512, + 3513, + 3514, + 3515, + 3516, + 3517, + 3518, + 3519, + 3520, + 3521, + 3522, + 3523, + 3524, + 3525, + 3526, + 3527, + 3528, + 3529, + 3530, + 3531, + 3532, + 3533, + 3534, + 3535, + 3536, + 3537, + 3538, + 3539, + 3540, + 3541, + 3542, + 3543, + 3544, + 3545, + 3546, + 3547, + 3548, + 3549, + 3550, + 3551, + 3552, + 3553, + 3554, + 3555, + 3556, + 3557, + 3558, + 3559, + 3560, + 3561, + 3562, + 3563, + 3564, + 3565, + 3566, + 3567, + 3568, + 3569, + 3570, + 3571, + 3572, + 3573, + 3574, + 3575, + 3576, + 3577, + 3578, + 3579, + 3580, + 3581, + 3582, + 3583, + 3584, + 3585, + 3586, + 3587, + 3588, + 3589, + 3590, + 3591, + 3592, + 3593, + 3594, + 3595, + 3596, + 3597, + 3598, + 3599, + 3600, + 3601, + 3602, + 3603, + 3604, + 3605, + 3606, + 3607, + 3608, + 3609, + 3610, + 3611, + 3612, + 3613, + 3614, + 3615, + 3616, + 3617, + 3618, + 3619, + 3620, + 3621, + 3622, + 3623, + 3624, + 3625, + 3626, + 3627, + 3628, + 3629, + 3630, + 3631, + 3632, + 3633, + 3634, + 3635, + 3636, + 3637, + 3638, + 3639, + 3640, + 3641, + 3642, + 3643, + 3644, + 3645, + 3646, + 3647, + 3648, + 3649, + 3650, + 3651, + 3652, + 3653, + 3654, + 3655, + 3656, + 3657, + 3658, + 3659, + 3660, + 3661, + 3666, + 3667, + 3668, + 3669, + 3670, + 3671, + 3672, + 3673, + 3674, + 3675, + 3676, + 3677, + 3678, + 3679, + 3680, + 3681, + 3682, + 3683, + 3684, + 3685, + 3688, + 3689, + 3690, + 3691, + 3692, + 3693, + 3694, + 3695, + 3696, + 3697, + 3698, + 3699, + 3700, + 3701, + 3702, + 3703, + 3704, + 3705, + 3706, + 3707, + 3708, + 3709, + 3710, + 3711, + 3712, + 3713, + 3714, + 3715, + 3716, + 3717, + 3732, + 3733, + 3737, + 3738, + 3739, + 3740, + 3741, + 3742, + 3743, + 3744, + 3745, + 3746, + 3747, + 3748, + 3749, + 3750, + 3751, + 3752, + 3753, + 3754, + 3755, + 3756, + 3757, + 3758, + 3759, + 3760, + 3761, + 3762, + 3763, + 3764, + 3765, + 3766, + 3767, + 3770, + 3771, + 3772, + 3773, + 3774, + 3775, + 3776, + 3777, + 3778, + 3779, + 3780, + 3781, + 3782, + 3783, + 3784, + 3785, + 3786, + 3787, + 3788, + 3789, + 3790, + 3791, + 3792, + 3793, + 3794, + 3795, + 3798, + 3799, + 3800, + 3801, + 3802, + 3803, + 3804, + 3805, + 3806, + 3807, + 3815, + 3816, + 3819, + 3820, + 3821, + 3822, + 3823, + 3824, + 3825, + 3826, + 3827, + 3828, + 3829, + 3830, + 3831, + 3832, + 3833, + 3834, + 3835, + 3836, + 3837, + 3838, + 3841, + 3842, + 3843, + 3844, + 3845, + 3846, + 3847, + 3848, + 3849, + 3850, + 3851, + 3852, + 3853, + 3854, + 3855, + 3856, + 3857, + 3858, + 3859, + 3860, + 3861, + 3862, + 3863, + 3864, + 3865, + 3866, + 3867, + 3868, + 3869, + 3870, + 3871, + 3872, + 3873, + 3874, + 3875, + 3876, + 3877, + 3878, + 3879, + 3880, + 3881, + 3882, + 3883, + 3884, + 3885, + 3886, + 3887, + 3888, + 3889, + 3890, + 3891, + 3892, + 3893, + 3894, + 3895, + 3896, + 3897, + 3898, + 3899, + 3900, + 3901, + 3902, + 3903, + 3904, + 3905, + 3906, + 3907, + 3908, + 3909, + 3910, + 3911, + 3912, + 3913, + 3914, + 3915, + 3916, + 3917, + 3922, + 3923, + 3924, + 3925, + 3926, + 3927, + 3928, + 3929, + 3930, + 3931, + 3932, + 3933, + 3936, + 3937, + 3938, + 3939, + 3940, + 3941, + 3945, + 3946, + 3947, + 3948, + 3949, + 3950, + 3951, + 3952, + 3953, + 3954, + 3955, + 3956, + 3957, + 3958, + 3959, + 3960, + 3961, + 3962, + 3963, + 3964, + 3965, + 3966, + 3967, + 3968, + 3969, + 3970, + 3971, + 3972, + 3973, + 3974, + 3975, + 3976, + 3977, + 3978, + 3979, + 3980, + 3981, + 3982, + 3983, + 3984, + 3985, + 3986, + 3987, + 3988, + 3989, + 3990, + 3991, + 3992, + 3993, + 3994, + 3995, + 3996, + 3997, + 3998, + 3999, + 4000, + 4001, + 4002, + 4003, + 4004, + 4005, + 4006, + 4007, + 4008, + 4009, + 4010, + 4011, + 4012, + 4013, + 4014, + 4015, + 4016, + 4017, + 4018, + 4019, + 4020, + 4021, + 4022, + 4023, + 4024, + 4025, + 4026, + 4027, + 4028, + 4029, + 4030, + 4031, + 4032, + 4033, + 4034, + 4035, + 4036, + 4037, + 4038, + 4039, + 4040, + 4041, + 4042, + 4043, + 4044, + 4045, + 4046, + 4047, + 4048, + 4049, + 4050, + 4051, + 4052, + 4053, + 4054, + 4055, + 4056, + 4057, + 4062, + 4063, + 4064, + 4065, + 4066, + 4067, + 4068, + 4069, + 4070, + 4071, + 5231, + 5232, + 5233, + 5235, + 5236, + 5237, + 5238, + 5239, + 5240, + 5241, + 5242, + 5243, + 5366, + 5367, + 5368, + 5369, + 6240, + 6241, + 6242, + 6243, + 6244, + 6245, + 6246, + 6247, + 6248, + 6249, + 6250, + 6251, + 6252, + 6253, + 6254, + 6255, + 6256, + 6257, + 6258, + 6259, + 6260, + 6261, + 6262, + 6263, + 6264, + 6265, + 6266, + 6267, + 6268, + 6269, + 6270, + 6271, + 6272, + 6275, + 6276, + 6277, + 6278, + 6279, + 6492, + 6493, + 6494, + 6495, + 6880, + 6881, + 6882, + 6883, + 6884, + 6885, + 6886, + 6887, + 6888, + 6889 + ], + "rightArm": [ + 4114, + 4115, + 4116, + 4117, + 4122, + 4125, + 4168, + 4171, + 4204, + 4205, + 4206, + 4207, + 4257, + 4258, + 4259, + 4260, + 4261, + 4262, + 4263, + 4264, + 4265, + 4266, + 4267, + 4268, + 4272, + 4273, + 4274, + 4275, + 4276, + 4277, + 4278, + 4279, + 4280, + 4281, + 4714, + 4715, + 4716, + 4717, + 4741, + 4742, + 4743, + 4744, + 4756, + 4763, + 4764, + 4790, + 4791, + 4794, + 4795, + 4816, + 4817, + 4818, + 4819, + 4830, + 4831, + 4832, + 4833, + 4849, + 4850, + 4851, + 4852, + 4853, + 4854, + 4855, + 4856, + 4857, + 4858, + 4859, + 4860, + 4861, + 4862, + 4863, + 4864, + 4865, + 4866, + 4867, + 4868, + 4869, + 4870, + 4871, + 4872, + 4873, + 4876, + 4877, + 4878, + 4879, + 4880, + 4881, + 4882, + 4883, + 4884, + 4885, + 4886, + 4887, + 4888, + 4889, + 4901, + 4902, + 4903, + 4904, + 4905, + 4906, + 4911, + 4912, + 4913, + 4914, + 4915, + 4916, + 4917, + 4918, + 4974, + 4977, + 4978, + 4979, + 4980, + 4981, + 4982, + 5009, + 5010, + 5011, + 5012, + 5014, + 5088, + 5089, + 5090, + 5091, + 5100, + 5101, + 5102, + 5103, + 5104, + 5105, + 5106, + 5107, + 5108, + 5109, + 5110, + 5111, + 5114, + 5115, + 5116, + 5117, + 5118, + 5119, + 5120, + 5121, + 5122, + 5123, + 5124, + 5125, + 5128, + 5129, + 5130, + 5131, + 5134, + 5135, + 5136, + 5137, + 5138, + 5139, + 5140, + 5141, + 5142, + 5143, + 5144, + 5145, + 5146, + 5147, + 5148, + 5149, + 5150, + 5151, + 5152, + 5153, + 5165, + 5166, + 5167, + 5172, + 5173, + 5174, + 5175, + 5176, + 5177, + 5178, + 5179, + 5180, + 5181, + 5182, + 5183, + 5184, + 5185, + 5186, + 5187, + 5188, + 5189, + 5194, + 5200, + 5201, + 5202, + 5203, + 5204, + 5206, + 5208, + 5209, + 5214, + 5215, + 5216, + 5217, + 5218, + 5220, + 5229, + 5292, + 5293, + 5303, + 5306, + 5309, + 5311, + 5314, + 5315, + 5318, + 5319, + 5321, + 5326, + 5327, + 5328, + 5330, + 5331, + 5332, + 5335, + 5336, + 5337, + 5338, + 5339, + 5343, + 5344, + 5349, + 5350, + 5353, + 5361, + 5362, + 5363, + 5364, + 5365, + 5370, + 6280, + 6281, + 6282, + 6283, + 6354, + 6355, + 6356, + 6357, + 6358, + 6359, + 6360, + 6361, + 6362, + 6404, + 6405, + 6433, + 6434, + 6435, + 6436, + 6437, + 6438, + 6439, + 6440, + 6441, + 6442, + 6443, + 6444, + 6445, + 6446, + 6447, + 6448, + 6449, + 6450, + 6451, + 6452, + 6453, + 6454, + 6455, + 6461, + 6471 + ], + "leftHandIndex1": [ + 2027, + 2028, + 2029, + 2030, + 2037, + 2038, + 2039, + 2040, + 2057, + 2067, + 2068, + 2123, + 2124, + 2125, + 2126, + 2127, + 2128, + 2129, + 2130, + 2132, + 2145, + 2146, + 2152, + 2153, + 2154, + 2156, + 2157, + 2158, + 2159, + 2160, + 2161, + 2162, + 2163, + 2164, + 2165, + 2166, + 2167, + 2168, + 2169, + 2177, + 2178, + 2179, + 2181, + 2186, + 2187, + 2190, + 2191, + 2204, + 2205, + 2215, + 2216, + 2217, + 2218, + 2219, + 2220, + 2232, + 2233, + 2245, + 2246, + 2247, + 2258, + 2259, + 2261, + 2262, + 2263, + 2269, + 2270, + 2272, + 2273, + 2274, + 2276, + 2277, + 2280, + 2281, + 2282, + 2283, + 2291, + 2292, + 2293, + 2294, + 2295, + 2296, + 2297, + 2298, + 2299, + 2300, + 2301, + 2302, + 2303, + 2304, + 2305, + 2306, + 2307, + 2308, + 2309, + 2310, + 2311, + 2312, + 2313, + 2314, + 2315, + 2316, + 2317, + 2318, + 2319, + 2320, + 2321, + 2322, + 2323, + 2324, + 2325, + 2326, + 2327, + 2328, + 2329, + 2330, + 2331, + 2332, + 2333, + 2334, + 2335, + 2336, + 2337, + 2338, + 2339, + 2340, + 2341, + 2342, + 2343, + 2344, + 2345, + 2346, + 2347, + 2348, + 2349, + 2350, + 2351, + 2352, + 2353, + 2354, + 2355, + 2356, + 2357, + 2358, + 2359, + 2360, + 2361, + 2362, + 2363, + 2364, + 2365, + 2366, + 2367, + 2368, + 2369, + 2370, + 2371, + 2372, + 2373, + 2374, + 2375, + 2376, + 2377, + 2378, + 2379, + 2380, + 2381, + 2382, + 2383, + 2384, + 2385, + 2386, + 2387, + 2388, + 2389, + 2390, + 2391, + 2392, + 2393, + 2394, + 2395, + 2396, + 2397, + 2398, + 2399, + 2400, + 2401, + 2402, + 2403, + 2404, + 2405, + 2406, + 2407, + 2408, + 2409, + 2410, + 2411, + 2412, + 2413, + 2414, + 2415, + 2416, + 2417, + 2418, + 2419, + 2420, + 2421, + 2422, + 2423, + 2424, + 2425, + 2426, + 2427, + 2428, + 2429, + 2430, + 2431, + 2432, + 2433, + 2434, + 2435, + 2436, + 2437, + 2438, + 2439, + 2440, + 2441, + 2442, + 2443, + 2444, + 2445, + 2446, + 2447, + 2448, + 2449, + 2450, + 2451, + 2452, + 2453, + 2454, + 2455, + 2456, + 2457, + 2458, + 2459, + 2460, + 2461, + 2462, + 2463, + 2464, + 2465, + 2466, + 2467, + 2468, + 2469, + 2470, + 2471, + 2472, + 2473, + 2474, + 2475, + 2476, + 2477, + 2478, + 2479, + 2480, + 2481, + 2482, + 2483, + 2484, + 2485, + 2486, + 2487, + 2488, + 2489, + 2490, + 2491, + 2492, + 2493, + 2494, + 2495, + 2496, + 2497, + 2498, + 2499, + 2500, + 2501, + 2502, + 2503, + 2504, + 2505, + 2506, + 2507, + 2508, + 2509, + 2510, + 2511, + 2512, + 2513, + 2514, + 2515, + 2516, + 2517, + 2518, + 2519, + 2520, + 2521, + 2522, + 2523, + 2524, + 2525, + 2526, + 2527, + 2528, + 2529, + 2530, + 2531, + 2532, + 2533, + 2534, + 2535, + 2536, + 2537, + 2538, + 2539, + 2540, + 2541, + 2542, + 2543, + 2544, + 2545, + 2546, + 2547, + 2548, + 2549, + 2550, + 2551, + 2552, + 2553, + 2554, + 2555, + 2556, + 2557, + 2558, + 2559, + 2560, + 2561, + 2562, + 2563, + 2564, + 2565, + 2566, + 2567, + 2568, + 2569, + 2570, + 2571, + 2572, + 2573, + 2574, + 2575, + 2576, + 2577, + 2578, + 2579, + 2580, + 2581, + 2582, + 2583, + 2584, + 2585, + 2586, + 2587, + 2588, + 2589, + 2590, + 2591, + 2592, + 2593, + 2594, + 2596, + 2597, + 2599, + 2600, + 2601, + 2602, + 2603, + 2604, + 2606, + 2607, + 2609, + 2610, + 2611, + 2612, + 2613, + 2614, + 2615, + 2616, + 2617, + 2618, + 2619, + 2620, + 2621, + 2622, + 2623, + 2624, + 2625, + 2626, + 2627, + 2628, + 2629, + 2630, + 2631, + 2632, + 2633, + 2634, + 2635, + 2636, + 2637, + 2638, + 2639, + 2640, + 2641, + 2642, + 2643, + 2644, + 2645, + 2646, + 2647, + 2648, + 2649, + 2650, + 2651, + 2652, + 2653, + 2654, + 2655, + 2656, + 2657, + 2658, + 2659, + 2660, + 2661, + 2662, + 2663, + 2664, + 2665, + 2666, + 2667, + 2668, + 2669, + 2670, + 2671, + 2672, + 2673, + 2674, + 2675, + 2676, + 2677, + 2678, + 2679, + 2680, + 2681, + 2682, + 2683, + 2684, + 2685, + 2686, + 2687, + 2688, + 2689, + 2690, + 2691, + 2692, + 2693, + 2694, + 2695, + 2696 + ], + "rightLeg": [ + 4481, + 4482, + 4485, + 4486, + 4491, + 4492, + 4493, + 4495, + 4498, + 4500, + 4501, + 4505, + 4506, + 4529, + 4532, + 4533, + 4534, + 4535, + 4536, + 4537, + 4538, + 4539, + 4540, + 4541, + 4542, + 4543, + 4544, + 4545, + 4546, + 4547, + 4548, + 4549, + 4550, + 4551, + 4552, + 4553, + 4554, + 4555, + 4556, + 4557, + 4558, + 4559, + 4560, + 4561, + 4562, + 4563, + 4564, + 4565, + 4566, + 4567, + 4568, + 4569, + 4570, + 4571, + 4572, + 4573, + 4574, + 4575, + 4576, + 4577, + 4578, + 4579, + 4580, + 4581, + 4582, + 4583, + 4584, + 4585, + 4586, + 4587, + 4588, + 4589, + 4590, + 4591, + 4592, + 4593, + 4594, + 4595, + 4596, + 4597, + 4598, + 4599, + 4600, + 4601, + 4602, + 4603, + 4604, + 4605, + 4606, + 4607, + 4608, + 4609, + 4610, + 4611, + 4612, + 4613, + 4614, + 4615, + 4616, + 4617, + 4618, + 4619, + 4620, + 4621, + 4622, + 4634, + 4635, + 4636, + 4637, + 4638, + 4639, + 4640, + 4641, + 4642, + 4643, + 4644, + 4661, + 4662, + 4663, + 4664, + 4665, + 4666, + 4667, + 4668, + 4669, + 4842, + 4843, + 4844, + 4845, + 4846, + 4847, + 4848, + 4937, + 4938, + 4939, + 4940, + 4941, + 4942, + 4943, + 4944, + 4945, + 4946, + 4947, + 4993, + 4994, + 4995, + 4996, + 4997, + 4998, + 4999, + 5000, + 5001, + 5002, + 5003, + 6574, + 6575, + 6576, + 6577, + 6578, + 6579, + 6580, + 6581, + 6582, + 6583, + 6584, + 6585, + 6586, + 6587, + 6588, + 6589, + 6590, + 6591, + 6592, + 6593, + 6594, + 6595, + 6596, + 6597, + 6598, + 6599, + 6600, + 6601, + 6602, + 6603, + 6604, + 6605, + 6606, + 6607, + 6608, + 6609, + 6610, + 6719, + 6720, + 6721, + 6722, + 6723, + 6724, + 6725, + 6726, + 6727, + 6728, + 6729, + 6730, + 6731, + 6732, + 6733, + 6734, + 6735, + 6832, + 6833, + 6834, + 6835, + 6836, + 6869, + 6870, + 6871, + 6872 + ], + "rightHandIndex1": [ + 5488, + 5489, + 5490, + 5491, + 5498, + 5499, + 5500, + 5501, + 5518, + 5528, + 5529, + 5584, + 5585, + 5586, + 5587, + 5588, + 5589, + 5590, + 5591, + 5592, + 5606, + 5607, + 5613, + 5615, + 5616, + 5617, + 5618, + 5619, + 5620, + 5621, + 5622, + 5623, + 5624, + 5625, + 5626, + 5627, + 5628, + 5629, + 5630, + 5638, + 5639, + 5640, + 5642, + 5647, + 5648, + 5650, + 5651, + 5665, + 5666, + 5676, + 5677, + 5678, + 5679, + 5680, + 5681, + 5693, + 5694, + 5706, + 5707, + 5708, + 5719, + 5721, + 5722, + 5723, + 5724, + 5730, + 5731, + 5733, + 5734, + 5735, + 5737, + 5738, + 5741, + 5742, + 5743, + 5744, + 5752, + 5753, + 5754, + 5755, + 5756, + 5757, + 5758, + 5759, + 5760, + 5761, + 5762, + 5763, + 5764, + 5765, + 5766, + 5767, + 5768, + 5769, + 5770, + 5771, + 5772, + 5773, + 5774, + 5775, + 5776, + 5777, + 5778, + 5779, + 5780, + 5781, + 5782, + 5783, + 5784, + 5785, + 5786, + 5787, + 5788, + 5789, + 5790, + 5791, + 5792, + 5793, + 5794, + 5795, + 5796, + 5797, + 5798, + 5799, + 5800, + 5801, + 5802, + 5803, + 5804, + 5805, + 5806, + 5807, + 5808, + 5809, + 5810, + 5811, + 5812, + 5813, + 5814, + 5815, + 5816, + 5817, + 5818, + 5819, + 5820, + 5821, + 5822, + 5823, + 5824, + 5825, + 5826, + 5827, + 5828, + 5829, + 5830, + 5831, + 5832, + 5833, + 5834, + 5835, + 5836, + 5837, + 5838, + 5839, + 5840, + 5841, + 5842, + 5843, + 5844, + 5845, + 5846, + 5847, + 5848, + 5849, + 5850, + 5851, + 5852, + 5853, + 5854, + 5855, + 5856, + 5857, + 5858, + 5859, + 5860, + 5861, + 5862, + 5863, + 5864, + 5865, + 5866, + 5867, + 5868, + 5869, + 5870, + 5871, + 5872, + 5873, + 5874, + 5875, + 5876, + 5877, + 5878, + 5879, + 5880, + 5881, + 5882, + 5883, + 5884, + 5885, + 5886, + 5887, + 5888, + 5889, + 5890, + 5891, + 5892, + 5893, + 5894, + 5895, + 5896, + 5897, + 5898, + 5899, + 5900, + 5901, + 5902, + 5903, + 5904, + 5905, + 5906, + 5907, + 5908, + 5909, + 5910, + 5911, + 5912, + 5913, + 5914, + 5915, + 5916, + 5917, + 5918, + 5919, + 5920, + 5921, + 5922, + 5923, + 5924, + 5925, + 5926, + 5927, + 5928, + 5929, + 5930, + 5931, + 5932, + 5933, + 5934, + 5935, + 5936, + 5937, + 5938, + 5939, + 5940, + 5941, + 5942, + 5943, + 5944, + 5945, + 5946, + 5947, + 5948, + 5949, + 5950, + 5951, + 5952, + 5953, + 5954, + 5955, + 5956, + 5957, + 5958, + 5959, + 5960, + 5961, + 5962, + 5963, + 5964, + 5965, + 5966, + 5967, + 5968, + 5969, + 5970, + 5971, + 5972, + 5973, + 5974, + 5975, + 5976, + 5977, + 5978, + 5979, + 5980, + 5981, + 5982, + 5983, + 5984, + 5985, + 5986, + 5987, + 5988, + 5989, + 5990, + 5991, + 5992, + 5993, + 5994, + 5995, + 5996, + 5997, + 5998, + 5999, + 6000, + 6001, + 6002, + 6003, + 6004, + 6005, + 6006, + 6007, + 6008, + 6009, + 6010, + 6011, + 6012, + 6013, + 6014, + 6015, + 6016, + 6017, + 6018, + 6019, + 6020, + 6021, + 6022, + 6023, + 6024, + 6025, + 6026, + 6027, + 6028, + 6029, + 6030, + 6031, + 6032, + 6033, + 6034, + 6035, + 6036, + 6037, + 6038, + 6039, + 6040, + 6041, + 6042, + 6043, + 6044, + 6045, + 6046, + 6047, + 6048, + 6049, + 6050, + 6051, + 6052, + 6053, + 6054, + 6055, + 6058, + 6059, + 6060, + 6061, + 6062, + 6063, + 6064, + 6065, + 6068, + 6069, + 6070, + 6071, + 6072, + 6073, + 6074, + 6075, + 6076, + 6077, + 6078, + 6079, + 6080, + 6081, + 6082, + 6083, + 6084, + 6085, + 6086, + 6087, + 6088, + 6089, + 6090, + 6091, + 6092, + 6093, + 6094, + 6095, + 6096, + 6097, + 6098, + 6099, + 6100, + 6101, + 6102, + 6103, + 6104, + 6105, + 6106, + 6107, + 6108, + 6109, + 6110, + 6111, + 6112, + 6113, + 6114, + 6115, + 6116, + 6117, + 6118, + 6119, + 6120, + 6121, + 6122, + 6123, + 6124, + 6125, + 6126, + 6127, + 6128, + 6129, + 6130, + 6131, + 6132, + 6133, + 6134, + 6135, + 6136, + 6137, + 6138, + 6139, + 6140, + 6141, + 6142, + 6143, + 6144, + 6145, + 6146, + 6147, + 6148, + 6149, + 6150, + 6151, + 6152, + 6153, + 6154, + 6155, + 6156, + 6157 + ], + "leftForeArm": [ + 1546, + 1547, + 1548, + 1549, + 1550, + 1551, + 1552, + 1553, + 1554, + 1555, + 1556, + 1557, + 1558, + 1559, + 1560, + 1561, + 1562, + 1563, + 1564, + 1565, + 1566, + 1567, + 1568, + 1569, + 1570, + 1571, + 1572, + 1573, + 1574, + 1575, + 1576, + 1577, + 1578, + 1579, + 1580, + 1581, + 1582, + 1583, + 1584, + 1585, + 1586, + 1587, + 1588, + 1589, + 1590, + 1591, + 1592, + 1593, + 1594, + 1595, + 1596, + 1597, + 1598, + 1599, + 1600, + 1601, + 1602, + 1603, + 1604, + 1605, + 1606, + 1607, + 1608, + 1609, + 1610, + 1611, + 1612, + 1613, + 1614, + 1615, + 1616, + 1617, + 1618, + 1620, + 1621, + 1623, + 1624, + 1625, + 1626, + 1627, + 1628, + 1629, + 1630, + 1643, + 1644, + 1646, + 1647, + 1650, + 1651, + 1654, + 1655, + 1657, + 1658, + 1659, + 1660, + 1661, + 1662, + 1663, + 1664, + 1665, + 1666, + 1685, + 1686, + 1687, + 1688, + 1689, + 1690, + 1691, + 1692, + 1693, + 1694, + 1695, + 1699, + 1700, + 1701, + 1702, + 1721, + 1722, + 1723, + 1724, + 1725, + 1726, + 1727, + 1728, + 1729, + 1730, + 1732, + 1736, + 1738, + 1741, + 1742, + 1743, + 1744, + 1750, + 1752, + 1900, + 1909, + 1910, + 1911, + 1912, + 1913, + 1914, + 1915, + 1916, + 1917, + 1918, + 1919, + 1920, + 1921, + 1922, + 1923, + 1924, + 1925, + 1926, + 1927, + 1928, + 1929, + 1930, + 1931, + 1932, + 1933, + 1934, + 1935, + 1936, + 1937, + 1938, + 1939, + 1940, + 1941, + 1942, + 1943, + 1944, + 1945, + 1946, + 1947, + 1948, + 1949, + 1950, + 1951, + 1952, + 1953, + 1954, + 1955, + 1956, + 1957, + 1958, + 1959, + 1960, + 1961, + 1962, + 1963, + 1964, + 1965, + 1966, + 1967, + 1968, + 1969, + 1970, + 1971, + 1972, + 1973, + 1974, + 1975, + 1976, + 1977, + 1978, + 1979, + 1980, + 2019, + 2059, + 2060, + 2073, + 2089, + 2098, + 2099, + 2100, + 2101, + 2102, + 2103, + 2104, + 2105, + 2106, + 2107, + 2108, + 2109, + 2110, + 2111, + 2112, + 2147, + 2148, + 2206, + 2207, + 2208, + 2209, + 2228, + 2230, + 2234, + 2235, + 2241, + 2242, + 2243, + 2244, + 2279, + 2286, + 2873, + 2874 + ], + "rightForeArm": [ + 5015, + 5016, + 5017, + 5018, + 5019, + 5020, + 5021, + 5022, + 5023, + 5024, + 5025, + 5026, + 5027, + 5028, + 5029, + 5030, + 5031, + 5032, + 5033, + 5034, + 5035, + 5036, + 5037, + 5038, + 5039, + 5040, + 5041, + 5042, + 5043, + 5044, + 5045, + 5046, + 5047, + 5048, + 5049, + 5050, + 5051, + 5052, + 5053, + 5054, + 5055, + 5056, + 5057, + 5058, + 5059, + 5060, + 5061, + 5062, + 5063, + 5064, + 5065, + 5066, + 5067, + 5068, + 5069, + 5070, + 5071, + 5072, + 5073, + 5074, + 5075, + 5076, + 5077, + 5078, + 5079, + 5080, + 5081, + 5082, + 5083, + 5084, + 5085, + 5086, + 5087, + 5090, + 5091, + 5092, + 5093, + 5094, + 5095, + 5096, + 5097, + 5098, + 5099, + 5112, + 5113, + 5116, + 5117, + 5120, + 5121, + 5124, + 5125, + 5126, + 5127, + 5128, + 5129, + 5130, + 5131, + 5132, + 5133, + 5134, + 5135, + 5154, + 5155, + 5156, + 5157, + 5158, + 5159, + 5160, + 5161, + 5162, + 5163, + 5164, + 5168, + 5169, + 5170, + 5171, + 5190, + 5191, + 5192, + 5193, + 5194, + 5195, + 5196, + 5197, + 5198, + 5199, + 5202, + 5205, + 5207, + 5210, + 5211, + 5212, + 5213, + 5219, + 5221, + 5361, + 5370, + 5371, + 5372, + 5373, + 5374, + 5375, + 5376, + 5377, + 5378, + 5379, + 5380, + 5381, + 5382, + 5383, + 5384, + 5385, + 5386, + 5387, + 5388, + 5389, + 5390, + 5391, + 5392, + 5393, + 5394, + 5395, + 5396, + 5397, + 5398, + 5399, + 5400, + 5401, + 5402, + 5403, + 5404, + 5405, + 5406, + 5407, + 5408, + 5409, + 5410, + 5411, + 5412, + 5413, + 5414, + 5415, + 5416, + 5417, + 5418, + 5419, + 5420, + 5421, + 5422, + 5423, + 5424, + 5425, + 5426, + 5427, + 5428, + 5429, + 5430, + 5431, + 5432, + 5433, + 5434, + 5435, + 5436, + 5437, + 5438, + 5439, + 5440, + 5441, + 5480, + 5520, + 5521, + 5534, + 5550, + 5559, + 5560, + 5561, + 5562, + 5563, + 5564, + 5565, + 5566, + 5567, + 5568, + 5569, + 5570, + 5571, + 5572, + 5573, + 5608, + 5609, + 5667, + 5668, + 5669, + 5670, + 5689, + 5691, + 5695, + 5696, + 5702, + 5703, + 5704, + 5705, + 5740, + 5747, + 6334, + 6335 + ], + "neck": [ + 148, + 150, + 151, + 152, + 153, + 172, + 174, + 175, + 201, + 202, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 222, + 223, + 224, + 225, + 256, + 257, + 284, + 285, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 333, + 334, + 423, + 424, + 425, + 426, + 440, + 441, + 451, + 452, + 453, + 460, + 461, + 571, + 572, + 824, + 825, + 826, + 827, + 828, + 829, + 1279, + 1280, + 1312, + 1313, + 1319, + 1320, + 1331, + 3049, + 3050, + 3057, + 3058, + 3059, + 3068, + 3164, + 3661, + 3662, + 3663, + 3664, + 3665, + 3685, + 3686, + 3687, + 3714, + 3715, + 3716, + 3717, + 3718, + 3719, + 3720, + 3721, + 3722, + 3723, + 3724, + 3725, + 3726, + 3727, + 3728, + 3729, + 3730, + 3731, + 3734, + 3735, + 3736, + 3737, + 3768, + 3769, + 3796, + 3797, + 3807, + 3808, + 3809, + 3810, + 3811, + 3812, + 3813, + 3814, + 3815, + 3816, + 3817, + 3818, + 3819, + 3839, + 3840, + 3918, + 3919, + 3920, + 3921, + 3934, + 3935, + 3942, + 3943, + 3944, + 3950, + 4060, + 4061, + 4312, + 4313, + 4314, + 4315, + 4761, + 4762, + 4792, + 4793, + 4799, + 4800, + 4807 + ], + "rightToeBase": [ + 6611, + 6612, + 6613, + 6614, + 6615, + 6616, + 6617, + 6618, + 6619, + 6620, + 6621, + 6622, + 6623, + 6624, + 6625, + 6626, + 6627, + 6628, + 6629, + 6630, + 6631, + 6632, + 6633, + 6634, + 6635, + 6636, + 6637, + 6638, + 6639, + 6640, + 6641, + 6642, + 6643, + 6644, + 6645, + 6646, + 6647, + 6648, + 6649, + 6650, + 6651, + 6652, + 6653, + 6654, + 6655, + 6656, + 6657, + 6658, + 6659, + 6660, + 6661, + 6662, + 6663, + 6664, + 6665, + 6666, + 6667, + 6668, + 6669, + 6670, + 6671, + 6672, + 6673, + 6674, + 6675, + 6676, + 6677, + 6678, + 6679, + 6680, + 6681, + 6682, + 6683, + 6684, + 6685, + 6686, + 6687, + 6688, + 6689, + 6690, + 6691, + 6692, + 6693, + 6694, + 6695, + 6696, + 6697, + 6698, + 6699, + 6700, + 6701, + 6702, + 6703, + 6704, + 6705, + 6706, + 6707, + 6708, + 6709, + 6710, + 6711, + 6712, + 6713, + 6714, + 6715, + 6716, + 6717, + 6718, + 6736, + 6739, + 6741, + 6743, + 6745, + 6747, + 6749, + 6750, + 6752, + 6754, + 6757, + 6758, + 6760, + 6762 + ], + "spine": [ + 616, + 617, + 630, + 631, + 632, + 633, + 654, + 655, + 656, + 657, + 662, + 663, + 664, + 665, + 720, + 721, + 765, + 766, + 767, + 768, + 796, + 797, + 798, + 799, + 889, + 890, + 916, + 917, + 918, + 919, + 921, + 922, + 923, + 924, + 925, + 926, + 1188, + 1189, + 1211, + 1212, + 1248, + 1249, + 1250, + 1251, + 1264, + 1265, + 1266, + 1267, + 1323, + 1324, + 1325, + 1326, + 1327, + 1328, + 1332, + 1333, + 1334, + 1335, + 1336, + 1344, + 1345, + 1481, + 1482, + 1483, + 1484, + 1485, + 1486, + 1487, + 1488, + 1489, + 1490, + 1491, + 1492, + 1493, + 1494, + 1495, + 1496, + 1767, + 2823, + 2824, + 2825, + 2826, + 2827, + 2828, + 2829, + 2830, + 2831, + 2832, + 2833, + 2834, + 2835, + 2836, + 2837, + 2838, + 2839, + 2840, + 2841, + 2842, + 2843, + 2844, + 2845, + 2847, + 2848, + 2851, + 3016, + 3017, + 3018, + 3019, + 3020, + 3023, + 3024, + 3124, + 3173, + 3476, + 3477, + 3478, + 3480, + 3500, + 3501, + 3502, + 3504, + 3509, + 3511, + 4103, + 4104, + 4118, + 4119, + 4120, + 4121, + 4142, + 4143, + 4144, + 4145, + 4150, + 4151, + 4152, + 4153, + 4208, + 4209, + 4253, + 4254, + 4255, + 4256, + 4284, + 4285, + 4286, + 4287, + 4375, + 4376, + 4402, + 4403, + 4405, + 4406, + 4407, + 4408, + 4409, + 4410, + 4411, + 4412, + 4674, + 4675, + 4694, + 4695, + 4731, + 4732, + 4733, + 4734, + 4747, + 4748, + 4749, + 4750, + 4803, + 4804, + 4805, + 4806, + 4808, + 4809, + 4810, + 4811, + 4812, + 4820, + 4821, + 4953, + 4954, + 4955, + 4956, + 4957, + 4958, + 4959, + 4960, + 4961, + 4962, + 4963, + 4964, + 4965, + 4966, + 4967, + 4968, + 5234, + 6284, + 6285, + 6286, + 6287, + 6288, + 6289, + 6290, + 6291, + 6292, + 6293, + 6294, + 6295, + 6296, + 6297, + 6298, + 6299, + 6300, + 6301, + 6302, + 6303, + 6304, + 6305, + 6306, + 6308, + 6309, + 6312, + 6472, + 6473, + 6474, + 6545, + 6874, + 6875, + 6876, + 6878 + ], + "leftUpLeg": [ + 833, + 834, + 838, + 839, + 847, + 848, + 849, + 850, + 851, + 852, + 853, + 854, + 870, + 871, + 872, + 873, + 874, + 875, + 876, + 877, + 878, + 879, + 880, + 881, + 897, + 898, + 899, + 900, + 901, + 902, + 903, + 904, + 905, + 906, + 907, + 908, + 909, + 910, + 911, + 912, + 913, + 914, + 915, + 933, + 934, + 935, + 936, + 944, + 945, + 946, + 947, + 948, + 949, + 950, + 951, + 952, + 953, + 954, + 955, + 956, + 957, + 958, + 959, + 960, + 961, + 962, + 963, + 964, + 965, + 966, + 967, + 968, + 969, + 970, + 971, + 972, + 973, + 974, + 975, + 976, + 977, + 978, + 979, + 980, + 981, + 982, + 983, + 984, + 985, + 986, + 987, + 988, + 989, + 990, + 991, + 992, + 993, + 994, + 995, + 996, + 997, + 998, + 999, + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, + 1011, + 1012, + 1013, + 1014, + 1015, + 1016, + 1017, + 1018, + 1019, + 1020, + 1021, + 1022, + 1023, + 1024, + 1025, + 1026, + 1027, + 1028, + 1029, + 1030, + 1031, + 1032, + 1033, + 1034, + 1035, + 1036, + 1037, + 1038, + 1039, + 1040, + 1041, + 1042, + 1043, + 1044, + 1045, + 1046, + 1137, + 1138, + 1139, + 1140, + 1141, + 1142, + 1143, + 1144, + 1145, + 1146, + 1147, + 1148, + 1159, + 1160, + 1161, + 1162, + 1163, + 1164, + 1165, + 1166, + 1167, + 1168, + 1169, + 1170, + 1171, + 1172, + 1173, + 1174, + 1184, + 1185, + 1186, + 1187, + 1221, + 1222, + 1223, + 1224, + 1225, + 1226, + 1227, + 1228, + 1229, + 1230, + 1262, + 1263, + 1274, + 1275, + 1276, + 1277, + 1321, + 1322, + 1354, + 1359, + 1360, + 1361, + 1362, + 1365, + 1366, + 1367, + 1368, + 1451, + 1452, + 1453, + 1455, + 1456, + 1457, + 1458, + 1459, + 1460, + 1461, + 1462, + 1463, + 1475, + 1477, + 1478, + 1479, + 1480, + 1498, + 1499, + 1500, + 1501, + 1511, + 1512, + 1513, + 1514, + 1516, + 1517, + 1518, + 1519, + 1520, + 1521, + 1522, + 1533, + 1534, + 3125, + 3126, + 3127, + 3128, + 3131, + 3132, + 3133, + 3134, + 3135, + 3475, + 3479 + ], + "leftHand": [ + 1981, + 1982, + 1983, + 1984, + 1985, + 1986, + 1987, + 1988, + 1989, + 1990, + 1991, + 1992, + 1993, + 1994, + 1995, + 1996, + 1997, + 1998, + 1999, + 2000, + 2001, + 2002, + 2003, + 2004, + 2005, + 2006, + 2007, + 2008, + 2009, + 2010, + 2011, + 2012, + 2013, + 2014, + 2015, + 2016, + 2017, + 2018, + 2019, + 2020, + 2021, + 2022, + 2023, + 2024, + 2025, + 2026, + 2031, + 2032, + 2033, + 2034, + 2035, + 2036, + 2041, + 2042, + 2043, + 2044, + 2045, + 2046, + 2047, + 2048, + 2049, + 2050, + 2051, + 2052, + 2053, + 2054, + 2055, + 2056, + 2057, + 2058, + 2059, + 2060, + 2061, + 2062, + 2063, + 2064, + 2065, + 2066, + 2069, + 2070, + 2071, + 2072, + 2073, + 2074, + 2075, + 2076, + 2077, + 2078, + 2079, + 2080, + 2081, + 2082, + 2083, + 2084, + 2085, + 2086, + 2087, + 2088, + 2089, + 2090, + 2091, + 2092, + 2093, + 2094, + 2095, + 2096, + 2097, + 2098, + 2099, + 2100, + 2101, + 2107, + 2111, + 2113, + 2114, + 2115, + 2116, + 2117, + 2118, + 2119, + 2120, + 2121, + 2122, + 2127, + 2130, + 2131, + 2132, + 2133, + 2134, + 2135, + 2136, + 2137, + 2138, + 2139, + 2140, + 2141, + 2142, + 2143, + 2144, + 2149, + 2150, + 2151, + 2152, + 2155, + 2160, + 2163, + 2164, + 2170, + 2171, + 2172, + 2173, + 2174, + 2175, + 2176, + 2177, + 2178, + 2179, + 2180, + 2182, + 2183, + 2184, + 2185, + 2188, + 2189, + 2191, + 2192, + 2193, + 2194, + 2195, + 2196, + 2197, + 2198, + 2199, + 2200, + 2201, + 2202, + 2203, + 2207, + 2209, + 2210, + 2211, + 2212, + 2213, + 2214, + 2221, + 2222, + 2223, + 2224, + 2225, + 2226, + 2227, + 2228, + 2229, + 2231, + 2234, + 2236, + 2237, + 2238, + 2239, + 2240, + 2246, + 2247, + 2248, + 2249, + 2250, + 2251, + 2252, + 2253, + 2254, + 2255, + 2256, + 2257, + 2258, + 2259, + 2260, + 2262, + 2263, + 2264, + 2265, + 2266, + 2267, + 2268, + 2269, + 2270, + 2271, + 2274, + 2275, + 2276, + 2277, + 2278, + 2279, + 2284, + 2285, + 2287, + 2288, + 2289, + 2290, + 2293, + 2595, + 2598, + 2605, + 2608, + 2697, + 2698, + 2699, + 2700, + 2701, + 2702, + 2703, + 2704, + 2705, + 2706, + 2707, + 2708, + 2709, + 2710, + 2711, + 2712, + 2713, + 2714, + 2715, + 2716, + 2717, + 2718, + 2719, + 2720, + 2721, + 2722, + 2723, + 2724, + 2725, + 2726, + 2727, + 2728, + 2729, + 2730, + 2731, + 2732, + 2733, + 2734, + 2735, + 2736, + 2737, + 2738, + 2739, + 2740, + 2741, + 2742, + 2743, + 2744, + 2745, + 2746, + 2747, + 2748, + 2749, + 2750, + 2751, + 2752, + 2753, + 2754, + 2755, + 2756, + 2757, + 2758, + 2759, + 2760, + 2761, + 2762, + 2763, + 2764, + 2765, + 2766, + 2767, + 2768, + 2769, + 2770, + 2771, + 2772, + 2773, + 2774, + 2775, + 2776, + 2777, + 2778 + ], + "hips": [ + 631, + 632, + 654, + 657, + 662, + 665, + 676, + 677, + 678, + 679, + 705, + 720, + 796, + 799, + 800, + 801, + 802, + 807, + 808, + 809, + 810, + 815, + 816, + 822, + 823, + 830, + 831, + 832, + 833, + 834, + 835, + 836, + 837, + 838, + 839, + 840, + 841, + 842, + 843, + 844, + 845, + 846, + 855, + 856, + 857, + 858, + 859, + 860, + 861, + 862, + 863, + 864, + 865, + 866, + 867, + 868, + 869, + 871, + 878, + 881, + 882, + 883, + 884, + 885, + 886, + 887, + 888, + 889, + 890, + 912, + 915, + 916, + 917, + 918, + 919, + 920, + 932, + 937, + 938, + 939, + 1163, + 1166, + 1203, + 1204, + 1205, + 1206, + 1207, + 1208, + 1209, + 1210, + 1246, + 1247, + 1262, + 1263, + 1276, + 1277, + 1278, + 1321, + 1336, + 1337, + 1338, + 1339, + 1353, + 1354, + 1361, + 1362, + 1363, + 1364, + 1446, + 1447, + 1448, + 1449, + 1450, + 1454, + 1476, + 1497, + 1511, + 1513, + 1514, + 1515, + 1533, + 1534, + 1539, + 1540, + 1768, + 1769, + 1779, + 1780, + 1781, + 1782, + 1783, + 1784, + 1785, + 1786, + 1787, + 1788, + 1789, + 1790, + 1791, + 1792, + 1793, + 1794, + 1795, + 1796, + 1797, + 1798, + 1799, + 1800, + 1801, + 1802, + 1803, + 1804, + 1805, + 1806, + 1807, + 2909, + 2910, + 2911, + 2912, + 2913, + 2914, + 2915, + 2916, + 2917, + 2918, + 2919, + 2920, + 2921, + 2922, + 2923, + 2924, + 2925, + 2926, + 2927, + 2928, + 2929, + 2930, + 3018, + 3019, + 3021, + 3022, + 3080, + 3081, + 3082, + 3083, + 3084, + 3085, + 3086, + 3087, + 3088, + 3089, + 3090, + 3091, + 3092, + 3093, + 3094, + 3095, + 3096, + 3097, + 3098, + 3099, + 3100, + 3101, + 3102, + 3103, + 3104, + 3105, + 3106, + 3107, + 3108, + 3109, + 3110, + 3111, + 3112, + 3113, + 3114, + 3115, + 3116, + 3117, + 3118, + 3119, + 3120, + 3121, + 3122, + 3123, + 3124, + 3128, + 3129, + 3130, + 3136, + 3137, + 3138, + 3139, + 3140, + 3141, + 3142, + 3143, + 3144, + 3145, + 3146, + 3147, + 3148, + 3149, + 3150, + 3151, + 3152, + 3153, + 3154, + 3155, + 3156, + 3157, + 3158, + 3159, + 3160, + 3170, + 3172, + 3481, + 3484, + 3500, + 3502, + 3503, + 3507, + 3510, + 4120, + 4121, + 4142, + 4143, + 4150, + 4151, + 4164, + 4165, + 4166, + 4167, + 4193, + 4208, + 4284, + 4285, + 4288, + 4289, + 4290, + 4295, + 4296, + 4297, + 4298, + 4303, + 4304, + 4310, + 4311, + 4316, + 4317, + 4318, + 4319, + 4320, + 4321, + 4322, + 4323, + 4324, + 4325, + 4326, + 4327, + 4328, + 4329, + 4330, + 4331, + 4332, + 4341, + 4342, + 4343, + 4344, + 4345, + 4346, + 4347, + 4348, + 4349, + 4350, + 4351, + 4352, + 4353, + 4354, + 4355, + 4356, + 4364, + 4365, + 4368, + 4369, + 4370, + 4371, + 4372, + 4373, + 4374, + 4375, + 4376, + 4398, + 4399, + 4402, + 4403, + 4404, + 4405, + 4406, + 4418, + 4423, + 4424, + 4425, + 4649, + 4650, + 4689, + 4690, + 4691, + 4692, + 4693, + 4729, + 4730, + 4745, + 4746, + 4759, + 4760, + 4801, + 4812, + 4813, + 4814, + 4815, + 4829, + 4836, + 4837, + 4919, + 4920, + 4921, + 4922, + 4923, + 4927, + 4969, + 4983, + 4984, + 4986, + 5004, + 5005, + 5244, + 5245, + 5246, + 5247, + 5248, + 5249, + 5250, + 5251, + 5252, + 5253, + 5254, + 5255, + 5256, + 5257, + 5258, + 5259, + 5260, + 5261, + 5262, + 5263, + 5264, + 5265, + 5266, + 5267, + 5268, + 6368, + 6369, + 6370, + 6371, + 6372, + 6373, + 6374, + 6375, + 6376, + 6377, + 6378, + 6379, + 6380, + 6381, + 6382, + 6383, + 6384, + 6385, + 6386, + 6387, + 6388, + 6389, + 6473, + 6474, + 6504, + 6505, + 6506, + 6507, + 6508, + 6509, + 6510, + 6511, + 6512, + 6513, + 6514, + 6515, + 6516, + 6517, + 6518, + 6519, + 6520, + 6521, + 6522, + 6523, + 6524, + 6525, + 6526, + 6527, + 6528, + 6529, + 6530, + 6531, + 6532, + 6533, + 6534, + 6535, + 6536, + 6537, + 6538, + 6539, + 6540, + 6541, + 6542, + 6543, + 6544, + 6545, + 6549, + 6550, + 6551, + 6557, + 6558, + 6559, + 6560, + 6561, + 6562, + 6563, + 6564, + 6565, + 6566, + 6567, + 6568, + 6569, + 6570, + 6571, + 6572, + 6573 + ] +} \ No newline at end of file diff --git a/hmr4d/utils/body_model/smplx2smpl_sparse.pt b/hmr4d/utils/body_model/smplx2smpl_sparse.pt new file mode 100644 index 0000000..df81861 Binary files /dev/null and b/hmr4d/utils/body_model/smplx2smpl_sparse.pt differ diff --git a/hmr4d/utils/body_model/smplx_lite.py b/hmr4d/utils/body_model/smplx_lite.py new file mode 100644 index 0000000..013a268 --- /dev/null +++ b/hmr4d/utils/body_model/smplx_lite.py @@ -0,0 +1,302 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from pytorch3d.transforms import axis_angle_to_matrix, rotation_6d_to_matrix +from smplx.utils import Struct, to_np, to_tensor +from einops import einsum, rearrange +from time import time + +from hmr4d import PROJ_ROOT + + +class SmplxLite(nn.Module): + def __init__( + self, + model_path=PROJ_ROOT / "inputs/checkpoints/body_models/smplx", + gender="neutral", + num_betas=10, + ): + super().__init__() + + # Load the model + model_path = Path(model_path) + if model_path.is_dir(): + smplx_path = Path(model_path) / f"SMPLX_{gender.upper()}.npz" + else: + smplx_path = model_path + assert smplx_path.exists() + model_data = np.load(smplx_path, allow_pickle=True) + + data_struct = Struct(**model_data) + self.faces = data_struct.f # (F, 3) + + self.register_smpl_buffers(data_struct, num_betas) + # self.register_smplh_buffers(data_struct, num_pca_comps, flat_hand_mean) + # self.register_smplx_buffers(data_struct) + self.register_fast_skeleton_computing_buffers() + + # default_pose (99,) for torch.cat([global_orient, body_pose, default_pose]) + other_default_pose = torch.cat( + [ + torch.zeros(9), + to_tensor(data_struct.hands_meanl).float(), + to_tensor(data_struct.hands_meanr).float(), + ] + ) + self.register_buffer("other_default_pose", other_default_pose, False) + + def register_smpl_buffers(self, data_struct, num_betas): + # shapedirs, (V, 3, N_betas), V=10475 for SMPLX + shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])).float() + self.register_buffer("shapedirs", shapedirs, False) + + # v_template, (V, 3) + v_template = to_tensor(to_np(data_struct.v_template)).float() + self.register_buffer("v_template", v_template, False) + + # J_regressor, (J, V), J=55 for SMPLX + J_regressor = to_tensor(to_np(data_struct.J_regressor)).float() + self.register_buffer("J_regressor", J_regressor, False) + + # posedirs, (54*9, V, 3), note that the first global_orient is not included + posedirs = to_tensor(to_np(data_struct.posedirs)).float() # (V, 3, 54*9) + posedirs = rearrange(posedirs, "v c n -> n v c") + self.register_buffer("posedirs", posedirs, False) + + # lbs_weights, (V, J), J=55 + lbs_weights = to_tensor(to_np(data_struct.weights)).float() + self.register_buffer("lbs_weights", lbs_weights, False) + + # parents, (J), long + parents = to_tensor(to_np(data_struct.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer("parents", parents, False) + + def register_smplh_buffers(self, data_struct, num_pca_comps, flat_hand_mean): + # hand_pca, (N_pca, 45) + left_hand_components = to_tensor(data_struct.hands_componentsl[:num_pca_comps]).float() + right_hand_components = to_tensor(data_struct.hands_componentsr[:num_pca_comps]).float() + self.register_buffer("left_hand_components", left_hand_components, False) + self.register_buffer("right_hand_components", right_hand_components, False) + + # hand_mean, (45,) + left_hand_mean = to_tensor(data_struct.hands_meanl).float() + right_hand_mean = to_tensor(data_struct.hands_meanr).float() + if not flat_hand_mean: + left_hand_mean = torch.zeros_like(left_hand_mean) + right_hand_mean = torch.zeros_like(right_hand_mean) + self.register_buffer("left_hand_mean", left_hand_mean, False) + self.register_buffer("right_hand_mean", right_hand_mean, False) + + def register_smplx_buffers(self, data_struct): + # expr_dirs, (V, 3, N_expr) + expr_dirs = to_tensor(to_np(data_struct.shapedirs[:, :, 300:310])).float() + self.register_buffer("expr_dirs", expr_dirs, False) + + def register_fast_skeleton_computing_buffers(self): + # For fast computing of skeleton under beta + J_template = self.J_regressor @ self.v_template # (J, 3) + J_shapedirs = torch.einsum("jv, vcd -> jcd", self.J_regressor, self.shapedirs) # (J, 3, 10) + self.register_buffer("J_template", J_template, False) + self.register_buffer("J_shapedirs", J_shapedirs, False) + + def get_skeleton(self, betas): + return self.J_template + einsum(betas, self.J_shapedirs, "... k, j c k -> ... j c") + + def forward( + self, + body_pose, + betas, + global_orient, + transl=None, + rotation_type="aa", + ): + """ + Args: + body_pose: (B, L, 63) + betas: (B, L, 10) + global_orient: (B, L, 3) + transl: (B, L, 3) + Returns: + vertices: (B, L, V, 3) + """ + # 1. Convert [global_orient, body_pose, other_default_pose] to rot_mats + other_default_pose = self.other_default_pose # (99,) + if rotation_type == "aa": + other_default_pose = other_default_pose.expand(*body_pose.shape[:-1], -1) + full_pose = torch.cat([global_orient, body_pose, other_default_pose], dim=-1) + rot_mats = axis_angle_to_matrix(full_pose.reshape(*full_pose.shape[:-1], 55, 3)) + del full_pose, other_default_pose + else: + assert rotation_type == "r6d" # useful when doing smplify + other_default_pose = axis_angle_to_matrix(other_default_pose.view(33, 3)) + part_full_pose = torch.cat([global_orient, body_pose], dim=-1) + rot_mats = rotation_6d_to_matrix(part_full_pose.view(*part_full_pose.shape[:-1], 22, 6)) + other_default_pose = other_default_pose.expand(*rot_mats.shape[:-3], -1, -1, -1) + rot_mats = torch.cat([rot_mats, other_default_pose], dim=-3) + del part_full_pose, other_default_pose + + # 2. Forward Kinematics + J = self.get_skeleton(betas) # (*, 55, 3) + A = batch_rigid_transform_v2(rot_mats, J, self.parents)[1] + + # 3. Canonical v_posed = v_template + shaped_offsets + pose_offsets + pose_feature = rot_mats[..., 1:, :, :] - rot_mats.new([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + pose_feature = pose_feature.view(*pose_feature.shape[:-3], -1) # (*, 55*3*3) + v_posed = ( + self.v_template + + einsum(betas, self.shapedirs, "... k, v c k -> ... v c") + + einsum(pose_feature, self.posedirs, "... k, k v c -> ... v c") + ) + del pose_feature, rot_mats + + # 4. Skinning + T = einsum(self.lbs_weights, A, "v j, ... j c d -> ... v c d") + verts = einsum(T[..., :3, :3], v_posed, "... v c d, ... v d -> ... v c") + T[..., :3, 3] + + # 5. Translation + if transl is not None: + verts = verts + transl[..., None, :] + return verts + + +class SmplxLiteCoco17(SmplxLite): + """Output COCO17 joints (Faster, but cannot output vertices)""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Compute mapping + smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt") + COCO17_regressor = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_coco17_J_regressor.pt") + smplx2coco17 = torch.matmul(COCO17_regressor, smplx2smpl.to_dense()) + + jids, smplx_vids = torch.where(smplx2coco17 != 0) + smplx2coco17_interestd = torch.zeros([len(smplx_vids), 17]) + for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)): + smplx2coco17_interestd[idx, jid] = smplx2coco17[jid, smplx_vid] + self.register_buffer("smplx2coco17_interestd", smplx2coco17_interestd, False) # (132, 17) + + # Update to vertices of interest + self.v_template = self.v_template[smplx_vids].clone() # (V', 3) + self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K) + self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3) + self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J) + + def forward(self, body_pose, betas, global_orient, transl): + """Returns: joints (*, 17, 3). (B, L) or (B,) are both supported.""" + # Use super class's forward to get verts + verts = super().forward(body_pose, betas, global_orient, transl) # (*, 132, 3) + joints = einsum(self.smplx2coco17_interestd, verts, "v j, ... v c -> ... j c") + return joints + + +class SmplxLiteV437Coco17(SmplxLite): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Compute mapping (COCO17) + smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt") + COCO17_regressor = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_coco17_J_regressor.pt") + smplx2coco17 = torch.matmul(COCO17_regressor, smplx2smpl.to_dense()) + + jids, smplx_vids = torch.where(smplx2coco17 != 0) + smplx2coco17_interestd = torch.zeros([len(smplx_vids), 17]) + for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)): + smplx2coco17_interestd[idx, jid] = smplx2coco17[jid, smplx_vid] + self.register_buffer("smplx2coco17_interestd", smplx2coco17_interestd, False) # (132, 17) + assert len(smplx_vids) == 132 + + # Verts437 + smplx_vids2 = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx_verts437.pt") + smplx_vids = torch.cat([smplx_vids, smplx_vids2]) + + # Update to vertices of interest + self.v_template = self.v_template[smplx_vids].clone() # (V', 3) + self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K) + self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3) + self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J) + + def forward(self, body_pose, betas, global_orient, transl): + """ + Returns: + verts_437: (*, 437, 3) + joints (*, 17, 3). (B, L) or (B,) are both supported. + """ + # Use super class's forward to get verts + verts = super().forward(body_pose, betas, global_orient, transl) # (*, 132+437, 3) + + verts_437 = verts[..., 132:, :].clone() + joints = einsum(self.smplx2coco17_interestd, verts[..., :132, :], "v j, ... v c -> ... j c") + return verts_437, joints + + +class SmplxLiteSmplN24(SmplxLite): + """Output SMPL(not smplx)-Neutral 24 joints (Faster, but cannot output vertices)""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Compute mapping + smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt") + smpl2joints = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_neutral_J_regressor.pt") + smplx2joints = torch.matmul(smpl2joints, smplx2smpl.to_dense()) + + jids, smplx_vids = torch.where(smplx2joints != 0) + smplx2joints_interested = torch.zeros([len(smplx_vids), smplx2joints.size(0)]) + for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)): + smplx2joints_interested[idx, jid] = smplx2joints[jid, smplx_vid] + self.register_buffer("smplx2joints_interested", smplx2joints_interested, False) # (V', J) + + # Update to vertices of interest + self.v_template = self.v_template[smplx_vids].clone() # (V', 3) + self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K) + self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3) + self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J) + + def forward(self, body_pose, betas, global_orient, transl): + """Returns: joints (*, J, 3). (B, L) or (B,) are both supported.""" + # Use super class's forward to get verts + verts = super().forward(body_pose, betas, global_orient, transl) # (*, V', 3) + joints = einsum(self.smplx2joints_interested, verts, "v j, ... v c -> ... j c") + return joints + + +def batch_rigid_transform_v2(rot_mats, joints, parents): + """ + Args: + rot_mats: (*, J, 3, 3) + joints: (*, J, 3) + """ + # check shape, since sometimes beta has shape=1 + rot_mats_shape_prefix = rot_mats.shape[:-3] + if rot_mats_shape_prefix != joints.shape[:-2]: + joints = joints.expand(*rot_mats_shape_prefix, -1, -1) + + rel_joints = joints.clone() + rel_joints[..., 1:, :] -= joints[..., parents[1:], :] + transforms_mat = torch.cat([rot_mats, rel_joints[..., :, None]], dim=-1) # (*, J, 3, 4) + transforms_mat = F.pad(transforms_mat, [0, 0, 0, 1], value=0.0) + transforms_mat[..., 3, 3] = 1.0 # (*, J, 4, 4) + + transform_chain = [transforms_mat[..., 0, :, :]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[..., i, :, :]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=-3) # (*, J, 4, 4) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[..., :3, 3].clone() + rel_transforms = transforms.clone() + rel_transforms[..., :3, 3] -= einsum(transforms[..., :3, :3], joints, "... j c d, ... j d -> ... j c") + return posed_joints, rel_transforms + + +def sync_time(): + torch.cuda.synchronize() + return time() diff --git a/hmr4d/utils/body_model/smplx_verts437.pt b/hmr4d/utils/body_model/smplx_verts437.pt new file mode 100644 index 0000000..dbf87a1 Binary files /dev/null and b/hmr4d/utils/body_model/smplx_verts437.pt differ diff --git a/hmr4d/utils/body_model/utils.py b/hmr4d/utils/body_model/utils.py new file mode 100644 index 0000000..17856e6 --- /dev/null +++ b/hmr4d/utils/body_model/utils.py @@ -0,0 +1,229 @@ +import os +import numpy as np +import torch + +SMPLH_JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_index1', + 'left_index2', + 'left_index3', + 'left_middle1', + 'left_middle2', + 'left_middle3', + 'left_pinky1', + 'left_pinky2', + 'left_pinky3', + 'left_ring1', + 'left_ring2', + 'left_ring3', + 'left_thumb1', + 'left_thumb2', + 'left_thumb3', + 'right_index1', + 'right_index2', + 'right_index3', + 'right_middle1', + 'right_middle2', + 'right_middle3', + 'right_pinky1', + 'right_pinky2', + 'right_pinky3', + 'right_ring1', + 'right_ring2', + 'right_ring3', + 'right_thumb1', + 'right_thumb2', + 'right_thumb3', + 'nose', + 'right_eye', + 'left_eye', + 'right_ear', + 'left_ear', + 'left_big_toe', + 'left_small_toe', + 'left_heel', + 'right_big_toe', + 'right_small_toe', + 'right_heel', + 'left_thumb', + 'left_index', + 'left_middle', + 'left_ring', + 'left_pinky', + 'right_thumb', + 'right_index', + 'right_middle', + 'right_ring', + 'right_pinky', +] + +SMPLH_LEFT_LEG = ['left_hip', 'left_knee', 'left_ankle', 'left_foot'] +SMPLH_RIGHT_LEG = ['right_hip', 'right_knee', 'right_ankle', 'right_foot'] +SMPLH_LEFT_ARM = ['left_collar', 'left_shoulder', 'left_elbow', 'left_wrist'] +SMPLH_RIGHT_ARM = ['right_collar', 'right_shoulder', 'right_elbow', 'right_wrist'] +SMPLH_HEAD = ['neck', 'head'] +SMPLH_SPINE = ['spine1', 'spine2', 'spine3'] + +# name to 21 index (without pelvis, hand, and extra) +_name_2_idx = {j: i for i, j in enumerate(SMPLH_JOINT_NAMES[1:22])} +SMPLH_PART_IDX = { + 'left_leg': [_name_2_idx[x] for x in SMPLH_LEFT_LEG], + 'right_leg': [_name_2_idx[x] for x in SMPLH_RIGHT_LEG], + 'left_arm': [_name_2_idx[x] for x in SMPLH_LEFT_ARM], + 'right_arm': [_name_2_idx[x] for x in SMPLH_RIGHT_ARM], + 'two_legs': [_name_2_idx[x] for x in SMPLH_LEFT_LEG + SMPLH_RIGHT_LEG], + 'left_arm_and_leg': [_name_2_idx[x] for x in SMPLH_LEFT_ARM + SMPLH_LEFT_LEG], + 'right_arm_and_leg': [_name_2_idx[x] for x in SMPLH_RIGHT_ARM + SMPLH_RIGHT_LEG], +} + +# name to full index +_name_2_idx_full = {j: i for i, j in enumerate(SMPLH_JOINT_NAMES)} +SMPLH_PART_IDX_FULL = { + 'lower_body': [_name_2_idx_full[x] for x in ['pelvis'] + SMPLH_LEFT_LEG + SMPLH_RIGHT_LEG] +} + +# ===== ⬇️ Fitting optimizer ⬇️ ===== # +SMPL_JOINTS = {'hips': 0, 'leftUpLeg': 1, 'rightUpLeg': 2, 'spine': 3, 'leftLeg': 4, 'rightLeg': 5, + 'spine1': 6, 'leftFoot': 7, 'rightFoot': 8, 'spine2': 9, 'leftToeBase': 10, 'rightToeBase': 11, + 'neck': 12, 'leftShoulder': 13, 'rightShoulder': 14, 'head': 15, 'leftArm': 16, 'rightArm': 17, + 'leftForeArm': 18, 'rightForeArm': 19, 'leftHand': 20, 'rightHand': 21} + +# chosen virtual mocap markers that are "keypoints" to work with +KEYPT_VERTS = [4404, 920, 3076, 3169, 823, 4310, 1010, 1085, 4495, 4569, 6615, 3217, 3313, 6713, + 6785, 3383, 6607, 3207, 1241, 1508, 4797, 4122, 1618, 1569, 5135, 5040, 5691, 5636, + 5404, 2230, 2173, 2108, 134, 3645, 6543, 3123, 3024, 4194, 1306, 182, 3694, 4294, 744] + + +# From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py +# Please see license for usage restrictions. +def smpl_to_openpose(model_type='smplx', use_hands=True, use_face=True, + use_face_contour=False, openpose_format='coco25'): + ''' Returns the indices of the permutation that maps SMPL to OpenPose + + Parameters + ---------- + model_type: str, optional + The type of SMPL-like model that is used. The default mapping + returned is for the SMPLX model + use_hands: bool, optional + Flag for adding to the returned permutation the mapping for the + hand keypoints. Defaults to True + use_face: bool, optional + Flag for adding to the returned permutation the mapping for the + face keypoints. Defaults to True + use_face_contour: bool, optional + Flag for appending the facial contour keypoints. Defaults to False + openpose_format: bool, optional + The output format of OpenPose. For now only COCO-25 and COCO-19 is + supported. Defaults to 'coco25' + + ''' + if openpose_format.lower() == 'coco25': + if model_type == 'smpl': + return np.array([24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, + 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + dtype=np.int32) + elif model_type == 'smplh': + body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, + 8, 1, 4, 7, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62], dtype=np.int32) + mapping = [body_mapping] + if use_hands: + lhand_mapping = np.array([20, 34, 35, 36, 63, 22, 23, 24, 64, + 25, 26, 27, 65, 31, 32, 33, 66, 28, + 29, 30, 67], dtype=np.int32) + rhand_mapping = np.array([21, 49, 50, 51, 68, 37, 38, 39, 69, + 40, 41, 42, 70, 46, 47, 48, 71, 43, + 44, 45, 72], dtype=np.int32) + mapping += [lhand_mapping, rhand_mapping] + return np.concatenate(mapping) + # SMPLX + elif model_type == 'smplx': + body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, + 8, 1, 4, 7, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65], dtype=np.int32) + mapping = [body_mapping] + if use_hands: + lhand_mapping = np.array([20, 37, 38, 39, 66, 25, 26, 27, + 67, 28, 29, 30, 68, 34, 35, 36, 69, + 31, 32, 33, 70], dtype=np.int32) + rhand_mapping = np.array([21, 52, 53, 54, 71, 40, 41, 42, 72, + 43, 44, 45, 73, 49, 50, 51, 74, 46, + 47, 48, 75], dtype=np.int32) + + mapping += [lhand_mapping, rhand_mapping] + if use_face: + # end_idx = 127 + 17 * use_face_contour + face_mapping = np.arange(76, 127 + 17 * use_face_contour, + dtype=np.int32) + mapping += [face_mapping] + + return np.concatenate(mapping) + else: + raise ValueError('Unknown model type: {}'.format(model_type)) + elif openpose_format == 'coco19': + if model_type == 'smpl': + return np.array([24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, + 1, 4, 7, 25, 26, 27, 28], + dtype=np.int32) + elif model_type == 'smplh': + body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, + 8, 1, 4, 7, 53, 54, 55, 56], + dtype=np.int32) + mapping = [body_mapping] + if use_hands: + lhand_mapping = np.array([20, 34, 35, 36, 57, 22, 23, 24, 58, + 25, 26, 27, 59, 31, 32, 33, 60, 28, + 29, 30, 61], dtype=np.int32) + rhand_mapping = np.array([21, 49, 50, 51, 62, 37, 38, 39, 63, + 40, 41, 42, 64, 46, 47, 48, 65, 43, + 44, 45, 66], dtype=np.int32) + mapping += [lhand_mapping, rhand_mapping] + return np.concatenate(mapping) + # SMPLX + elif model_type == 'smplx': + body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, + 8, 1, 4, 7, 56, 57, 58, 59], + dtype=np.int32) + mapping = [body_mapping] + if use_hands: + lhand_mapping = np.array([20, 37, 38, 39, 60, 25, 26, 27, + 61, 28, 29, 30, 62, 34, 35, 36, 63, + 31, 32, 33, 64], dtype=np.int32) + rhand_mapping = np.array([21, 52, 53, 54, 65, 40, 41, 42, 66, + 43, 44, 45, 67, 49, 50, 51, 68, 46, + 47, 48, 69], dtype=np.int32) + + mapping += [lhand_mapping, rhand_mapping] + if use_face: + face_mapping = np.arange(70, 70 + 51 + + 17 * use_face_contour, + dtype=np.int32) + mapping += [face_mapping] + + return np.concatenate(mapping) + else: + raise ValueError('Unknown model type: {}'.format(model_type)) + else: + raise ValueError('Unknown joint format: {}'.format(openpose_format)) diff --git a/hmr4d/utils/callbacks/lr_monitor.py b/hmr4d/utils/callbacks/lr_monitor.py new file mode 100644 index 0000000..fd26ec1 --- /dev/null +++ b/hmr4d/utils/callbacks/lr_monitor.py @@ -0,0 +1,5 @@ +from pytorch_lightning.callbacks import LearningRateMonitor +from hmr4d.configs import builds, MainStore + + +MainStore.store(name="pl", node=builds(LearningRateMonitor), group="callbacks/lr_monitor") diff --git a/hmr4d/utils/callbacks/prog_bar.py b/hmr4d/utils/callbacks/prog_bar.py new file mode 100644 index 0000000..97798fb --- /dev/null +++ b/hmr4d/utils/callbacks/prog_bar.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from numbers import Number +from datetime import datetime, timedelta +from typing import Any, Dict, Union +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar, Tqdm, convert_inf +from pytorch_lightning.callbacks.progress import ProgressBar +from pytorch_lightning.utilities import rank_zero_only +import pytorch_lightning as pl + +from hmr4d.utils.pylogger import Log +from time import time +from collections import deque +import sys +from hmr4d.configs import MainStore, builds + +# ========== Helper functions ========== # + + +def format_num(n): + f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-") + n = str(n) + return f if len(f) < len(n) else n + + +def convert_kwargs_to_str(**kwargs): + # Sort in alphabetical order to be more deterministic + postfix = OrderedDict([]) + for key in sorted(kwargs.keys()): + new_key = key.split("/")[-1] + postfix[new_key] = kwargs[key] + # Preprocess stats according to datatype + for key in postfix.keys(): + # Number: limit the length of the string + if isinstance(postfix[key], Number): + postfix[key] = format_num(postfix[key]) + # Else for any other type, try to get the string conversion + elif not isinstance(postfix[key], str): + postfix[key] = str(postfix[key]) + # Else if it's a string, don't need to preprocess anything + # Stitch together to get the final postfix + postfix = ", ".join(key + "=" + postfix[key].strip() for key in postfix.keys()) + return postfix + + +def convert_t_to_str(t): + """Convert time in second to string in format hour:minute:second. + If hour is 0, don't show it. Always show minute and second. + """ + t_str = timedelta(seconds=t) # e.g. 0:00:00.704186 + t_str = str(t_str).split(".")[0] # e.g. 0:00:00 + if t_str[:2] == "0:": + t_str = t_str[2:] + return t_str + + +class MyTQDMProgressBar(TQDMProgressBar, pl.Callback): + def init_train_tqdm(self): + bar = Tqdm( + desc="Training", # this will be overwritten anyway + bar_format="{desc}{percentage:3.0f}%[{bar:10}][{n_fmt}/{total_fmt}, {elapsed}→{remaining},{rate_fmt}]{postfix}", + position=(2 * self.process_position), + disable=self.is_disabled, + leave=False, + smoothing=0, + dynamic_ncols=False, + ) + return bar + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # this function also updates the main progress bar + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + # in this function, we only set the postfix of the main progress bar + n = batch_idx + 1 + if self._should_update(n, self.train_progress_bar.total): + # Set post-fix string + # 1. maximum GPU usage + max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0 + post_fix_str = f"maxGPU={max_mem:.1f}GB" + + # 2. training metrics + training_metrics = self.get_metrics(trainer, pl_module) + training_metrics.pop("v_num", None) + post_fix_str += ", " + convert_kwargs_to_str(**training_metrics) + + # extra message if applicable + if "message" in outputs: + post_fix_str += ", " + outputs["message"] + + self.train_progress_bar.set_postfix_str(post_fix_str) + + +class ProgressReporter(ProgressBar, pl.Callback): + def __init__( + self, + log_every_percent: float = 0.1, # report interval + exp_name=None, # if None, use pl_module.exp_name or "Unnamed Experiment" + data_name=None, # if None, use pl_module.exp_name or "Unknown Data" + **kwargs, + ): + super().__init__() + self.enable = True + # 1. Store experiment meta data. + self.log_every_percent = log_every_percent + self.exp_name = exp_name + self.data_name = data_name + self.batch_time_queue = deque(maxlen=5) + self.start_prompt = "🚀" + self.finish_prompt = "✅" + # 2. Utils for evaluation + self.n_finished = 0 + + def disable(self): + self.enable = False + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + # Connect to the trainer object. + super().setup(trainer, pl_module, stage) + self.stage = stage + self.time_exp_start = time() + self.epoch_exp_start = trainer.current_epoch + + if self.exp_name is None: + if hasattr(pl_module, "exp_name"): + self.exp_name = pl_module.exp_name + else: + self.exp_name = "Unnamed Experiment" + if self.data_name is None: + if hasattr(pl_module, "data_name"): + self.data_name = pl_module.data_name + else: + self.data_name = "Unknown Data" + + def print(self, *args: Any, **kwargs: Any) -> None: + print(*args) + + def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Dict[str, Union[str, float]]: + """Get metrics from trainer for progress bar.""" + items = super().get_metrics(trainer, pl_module) + items.pop("v_num", None) + return items + + def _should_update(self, n_finished: int, total: int) -> bool: + """ + Rule: Log every `log_every_percent` percent, or the last batch. + """ + log_interval = max(int(total * self.log_every_percent), 1) + able = n_finished % log_interval == 0 or n_finished == total + if log_interval > 10: + able = able or n_finished in [5, 10] # always log + able = able and self.enable + return able + + @rank_zero_only + def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + self.print("=" * 80) + Log.info( + f"{self.start_prompt}[FIT][Epoch {trainer.current_epoch}] Data: {self.data_name} Experiment: {self.exp_name}" + ) + self.time_train_epoch_start = time() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # don't forget this :) + total = self.total_train_batches + + # Speed + n_finished = batch_idx + 1 + percent = 100 * n_finished / total + time_current = time() + self.batch_time_queue.append(time_current) + time_elapsed = time_current - self.time_train_epoch_start # second + time_remaining = time_elapsed * (total - n_finished) / n_finished # second + if len(self.batch_time_queue) == 1: # cannot compute speed + speed = 1 / time_elapsed + else: + speed = (len(self.batch_time_queue) - 1) / (self.batch_time_queue[-1] - self.batch_time_queue[0]) + + # Skip if not update + if not self._should_update(n_finished, total): + return + + # ===== Set Prefix string ===== # + # General + desc = f"[Train]" + + # Speed: Get elapsed time and estimated remaining time + time_elapsed_str = convert_t_to_str(time_elapsed) + time_remaining_str = convert_t_to_str(time_remaining) + speed_str = f"{speed:.2f}it/s" if speed > 1 else f"{1/speed:.1f}s/it" + n_digit = len(str(total)) + desc_speed = ( + f"[{n_finished:{n_digit}d}/{total}={percent:3.0f}%, {time_elapsed_str} → {time_remaining_str}, {speed_str}]" + ) + + # ===== Set postfix string ===== # + # 1. maximum GPU usage + max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0 + post_fix_str = f"maxGPU={max_mem:.1f}GB" + + # 2. training step metrics + train_metrics = self.get_metrics(trainer, pl_module) + train_metrics = {k: v for k, v in train_metrics.items() if ("train" in k and "epoch" not in k)} + post_fix_str += ", " + convert_kwargs_to_str(**train_metrics) + + # extra message if applicable + if "message" in outputs: + post_fix_str += ", " + outputs["message"] + post_fix_str = f"[{post_fix_str}]" + + # ===== Output ===== # + bar_output = f"{desc}{desc_speed}{post_fix_str}" + self.print(bar_output) + + @rank_zero_only + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + super().on_train_epoch_end(trainer, pl_module) + + # Clear + self.batch_time_queue.clear() + + # Estimate Epoch time + n_finished = trainer.current_epoch + 1 - self.epoch_exp_start + n_to_finish = trainer.max_epochs - trainer.current_epoch - 1 + time_current = time() + time_elapsed = time_current - self.time_exp_start + time_remaining = time_elapsed * n_to_finish / n_finished + time_elapsed_str = convert_t_to_str(time_elapsed) + time_remaining_str = convert_t_to_str(time_remaining) + + # Metrics + # training epoch metrics + train_metrics = self.get_metrics(trainer, pl_module) + train_metrics = {k: v for k, v in train_metrics.items() if ("train" in k and "epoch" in k)} + train_metrics_str = convert_kwargs_to_str(**train_metrics) + + Log.info( + f"{self.finish_prompt}[FIT][Epoch {trainer.current_epoch}] finished! {time_elapsed_str}→{time_remaining_str} | {train_metrics_str}" + ) + + # ===== Validation/Test/Prediction ===== # + @rank_zero_only + def on_validation_epoch_start(self, trainer, pl_module): + self.time_val_epoch_start = time() + + @rank_zero_only + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + self.n_finished += 1 + n_finished = self.n_finished + total = self.total_val_batches + if not self._should_update(n_finished, total): + return + + # General + desc = f"[Val]" + + # Speed + percent = 100 * n_finished / total + time_current = time() + time_elapsed = time_current - self.time_val_epoch_start # second + time_remaining = time_elapsed * (total - n_finished) / n_finished # second + time_elapsed_str = convert_t_to_str(time_elapsed) + time_remaining_str = convert_t_to_str(time_remaining) + desc_speed = f"[{n_finished}/{total} ={percent:3.0f}%, {time_elapsed_str}→{time_remaining_str}]" + + # Output + bar_output = f"{desc} {desc_speed}" + self.print(bar_output) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + # Reset + self.n_finished = 0 + + +class EmojiProgressReporter(ProgressBar, pl.Callback): + def __init__( + self, + refresh_rate_batch: Union[int, None] = 1, # report interval of batch, set None to disable it + refresh_rate_epoch: int = 1, # report interval of epoch + **kwargs, + ): + super().__init__() + self.enable = True + # Store experiment meta data. + self.refresh_rate_batch = refresh_rate_batch + self.refresh_rate_epoch = refresh_rate_epoch + + # Style of the progress bar. + self.title_prompt = "📝" + self.prog_prompt = "🚀" + self.timer_prompt = "⌛️" + self.metric_prompt = "📌" + self.finish_prompt = "✅" + + def disable(self): + self.enable = False + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str): + # Connect to the trainer object. + super().setup(trainer, pl_module, stage) + self.stage = stage + self.time_start_batch = None + self.time_start_epoch = None + if hasattr(pl_module, "exp_name"): + self.exp_name = pl_module.exp_name + else: + self.exp_name = "Unnamed Experiment" + Log.warn("Experiment name not found, please set it to `pl_module.exp_name`!") + + def print(self, *args: Any, **kwargs: Any): + print(*args) + + def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Dict[str, Union[str, float]]: + """Get metrics from trainer for progress bar.""" + items = super().get_metrics(trainer, pl_module) + items.pop("v_num", None) + return dict(sorted(items.items())) + + def _should_log_batch(self, n: int) -> bool: + # Disable batch log. + if self.refresh_rate_batch is None: + return False + # Log at the first & last batch, and every `self.refresh_rate_batch` batches. + able = n % self.refresh_rate_batch == 0 or n == self.total_train_batches - 1 + able = able and self.enable + return able + + def _should_log_epoch(self, n: int) -> bool: + # Log at the first & last epoch, and every `self.refresh_rate_epoch` epochs. + able = n % self.refresh_rate_epoch == 0 or n == self.trainer.max_epochs - 1 + able = able and self.enable + return able + + def timestamp_delta_to_str(self, timestamp_delta: float): + """Convert delta timestamp to string.""" + time_rest = timedelta(seconds=timestamp_delta) + hours, remainder = divmod(time_rest.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + time_str = "" + + # Check if the time is valid. Note that, if `hours` is visible, then `minutes` must be visible. + if hours <= 0: + hours = None + if minutes <= 0: + minutes = None + if seconds <= 0: + seconds = None + + time_str += f"{hours}h " if hours is not None else "" + time_str += f"{minutes}m " if minutes is not None else "" + time_str += f"{seconds}s" if seconds is not None else "" + return time_str + + @rank_zero_only + def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int): + super().on_train_batch_start(trainer, pl_module, batch, batch_idx) + # Initialize some meta data. + if self.time_start_batch is None: + self.time_start_batch = datetime.now().timestamp() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # don't forget this :) + # Get some meta data. + epoch_idx = trainer.current_epoch + percent = 100 * (batch_idx + 1) / (self.total_train_batches + 1) + metrics = self.get_metrics(trainer, pl_module) + + # Current time. + time_cur_stamp = datetime.now().timestamp() + time_cur_str = datetime.fromtimestamp(time_cur_stamp).strftime("%m-%d %H:%M:%S") + # Rest time. + time_rest_stamp = (time_cur_stamp - self.time_start_batch) * (100 - percent) / percent + time_rest_str = self.timestamp_delta_to_str(time_rest_stamp) + + if not self._should_log_batch(batch_idx): + return + + # Print the logs. + self.print(f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}...") + self.print( + f"{self.prog_prompt} Ep {epoch_idx}: {int(percent):02d}% <= [{batch_idx}/{self.total_train_batches}]" + ) + self.print(f"{self.timer_prompt} Time: {time_cur_str} | Ep Rest: {time_rest_str}") + for k, v in metrics.items(): + self.print(f"{self.metric_prompt} {k}: {v}") + self.print("") # Add a blank line. + + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + super().on_train_epoch_start(trainer, pl_module) + # Initialize some meta data. + self.time_start_batch = None + if self.time_start_epoch is None: + self.time_start_epoch = datetime.now().timestamp() + + @rank_zero_only + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + super().on_train_epoch_end(trainer, pl_module) + # Get some meta data. + epoch_idx = trainer.current_epoch + percent = 100 * (epoch_idx + 1) / (self.trainer.max_epochs + 1) + metrics = self.get_metrics(trainer, pl_module) + + # Current time. + time_cur = datetime.now().timestamp() + time_str = datetime.fromtimestamp(time_cur).strftime("%m-%d %H: %M:%S") + # Rest time. + time_rest_stamp = (time_cur - self.time_start_epoch) * (100 - percent) / percent + time_rest_str = self.timestamp_delta_to_str(time_rest_stamp) + + if not self._should_log_batch(epoch_idx): + return + + # Print the logs. + self.print(f">> >> >> >>") + self.print(f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}") + self.print(f"{self.finish_prompt} Ep {epoch_idx} finished!") + self.print(f"{self.timer_prompt} Time: {time_str} | Rest: {time_rest_str}") + for k, v in metrics.items(): + self.print(f"{self.metric_prompt} {k}: {v}") + self.print(f"<< << << <<") + self.print("") # Add a blank line. + + +group_name = "callbacks/prog_bar" +prog_reporter_base = builds( + ProgressReporter, + log_every_percent=0.1, + exp_name="${exp_name}", + data_name="${data_name}", + populate_full_signature=True, +) +MainStore.store(name="prog_reporter_every0.1", node=prog_reporter_base, group=group_name) +MainStore.store(name="prog_reporter_every0.2", node=prog_reporter_base(log_every_percent=0.2), group=group_name) diff --git a/hmr4d/utils/callbacks/simple_ckpt_saver.py b/hmr4d/utils/callbacks/simple_ckpt_saver.py new file mode 100644 index 0000000..5565339 --- /dev/null +++ b/hmr4d/utils/callbacks/simple_ckpt_saver.py @@ -0,0 +1,93 @@ +from pathlib import Path +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks.checkpoint import Checkpoint +from pytorch_lightning.utilities import rank_zero_only + +from hmr4d.utils.pylogger import Log +from hmr4d.configs import MainStore, builds + + +class SimpleCkptSaver(Checkpoint): + """ + This callback runs at the end of each training epoch. + Check {every_n_epochs} and save at most {save_top_k} model if it is time. + """ + + def __init__( + self, + output_dir, + filename="e{epoch:03d}-s{step:06d}.ckpt", + save_top_k=1, + every_n_epochs=1, + save_last=None, + save_weights_only=True, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.filename = filename + self.save_top_k = save_top_k + self.every_n_epochs = every_n_epochs + self.save_last = save_last + self.save_weights_only = save_weights_only + + # Setup output dir + if rank_zero_only.rank == 0: + self.output_dir.mkdir(parents=True, exist_ok=True) + Log.info(f"[Simple Ckpt Saver]: Save to `{self.output_dir}'") + + @rank_zero_only + def on_train_epoch_end(self, trainer, pl_module): + """Save a checkpoint at the end of the training epoch.""" + if self.every_n_epochs >= 1 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: + if self.save_top_k == 0: + return + + # Current saved ckpts in the output_dir + model_paths = [] + for p in sorted(list(self.output_dir.glob("*.ckpt"))): + model_paths.append(p) + model_to_remove = model_paths[0] if len(model_paths) >= self.save_top_k else None + + # Save cureent checkpoint + filepath = self.output_dir / self.filename.format(epoch=trainer.current_epoch, step=trainer.global_step) + checkpoint = { + "epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "pytorch-lightning_version": pl.__version__, + "state_dict": pl_module.state_dict(), + } + pl_module.on_save_checkpoint(checkpoint) + + if not self.save_weights_only: + # optimizer + optimizer_states = [] + for i, optimizer in enumerate(trainer.optimizers): + # Rely on accelerator to dump optimizer state + optimizer_state = trainer.strategy.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + checkpoint["optimizer_states"] = optimizer_states + + # lr_scheduler + lr_schedulers = [] + for config in trainer.lr_scheduler_configs: + lr_schedulers.append(config.scheduler.state_dict()) + checkpoint["lr_schedulers"] = lr_schedulers + + # trainer.strategy.checkpoint_io.save_checkpoint(checkpoint, filepath) + torch.save(checkpoint, filepath) + + # Remove the earliest checkpoint + if model_to_remove: + trainer.strategy.remove_checkpoint(model_paths[0]) + + +group_name = "callbacks/simple_ckpt_saver" +base = builds(SimpleCkptSaver, output_dir="${output_dir}/checkpoints/", populate_full_signature=True) +MainStore.store(name="base", node=base, group=group_name) +MainStore.store(name="every1e", node=base, group=group_name) +MainStore.store(name="every2e", node=base(every_n_epochs=2), group=group_name) +MainStore.store(name="every5e", node=base(every_n_epochs=5), group=group_name) +MainStore.store(name="every5e_top100", node=base(every_n_epochs=5, save_top_k=100), group=group_name) +MainStore.store(name="every10e", node=base(every_n_epochs=10), group=group_name) +MainStore.store(name="every10e_top100", node=base(every_n_epochs=10, save_top_k=100), group=group_name) diff --git a/hmr4d/utils/callbacks/train_speed_timer.py b/hmr4d/utils/callbacks/train_speed_timer.py new file mode 100644 index 0000000..6c8c038 --- /dev/null +++ b/hmr4d/utils/callbacks/train_speed_timer.py @@ -0,0 +1,71 @@ +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from time import time +from collections import deque + +from hmr4d.configs import MainStore, builds + + +class TrainSpeedTimer(pl.Callback): + def __init__(self, N_avg=5): + """ + This callback times the training speed (averge over recent 5 iterations) + 1. Data waiting time: this should be small, otherwise the data loading should be improved + 2. Single batch time: this is the time for one batch of training (excluding data waiting) + """ + super().__init__() + self.last_batch_end = None + self.this_batch_start = None + + # time queues for averaging + self.data_waiting_time_queue = deque(maxlen=N_avg) + self.single_batch_time_queue = deque(maxlen=N_avg) + + @rank_zero_only + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + """Count the time of data waiting""" + if self.last_batch_end is not None: + # This should be small, otherwise the data loading should be improved + data_waiting = time() - self.last_batch_end + + # Average the time + self.data_waiting_time_queue.append(data_waiting) + average_time = sum(self.data_waiting_time_queue) / len(self.data_waiting_time_queue) + + # Log to prog-bar + pl_module.log( + "train_timer/data_waiting", average_time, on_step=True, on_epoch=False, prog_bar=True, logger=True + ) + + self.this_batch_start = time() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # Effective training time elapsed (excluding data waiting) + single_batch = time() - self.this_batch_start + + # Average the time + self.single_batch_time_queue.append(single_batch) + average_time = sum(self.single_batch_time_queue) / len(self.single_batch_time_queue) + + # Log iter time + pl_module.log( + "train_timer/single_batch", average_time, on_step=True, on_epoch=False, prog_bar=False, logger=True + ) + + # Set timer for counting data waiting + self.last_batch_end = time() + + @rank_zero_only + def on_train_epoch_end(self, trainer, pl_module): + # Reset the timer + self.last_batch_end = None + self.this_batch_start = None + # Clear the queue + self.data_waiting_time_queue.clear() + self.single_batch_time_queue.clear() + + +group_name = "callbacks/train_speed_timer" +base = builds(TrainSpeedTimer, populate_full_signature=True) +MainStore.store(name="base", node=base, group=group_name) diff --git a/hmr4d/utils/comm/gather.py b/hmr4d/utils/comm/gather.py new file mode 100644 index 0000000..741a6cf --- /dev/null +++ b/hmr4d/utils/comm/gather.py @@ -0,0 +1,257 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024**3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024**3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert world_size >= 1, "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/hmr4d/utils/eval/eval_utils.py b/hmr4d/utils/eval/eval_utils.py new file mode 100644 index 0000000..6f48898 --- /dev/null +++ b/hmr4d/utils/eval/eval_utils.py @@ -0,0 +1,457 @@ +import torch +import numpy as np + + +@torch.no_grad() +def compute_camcoord_metrics(batch, pelvis_idxs=[1, 2], fps=30, mask=None): + """ + Args: + batch (dict): { + "pred_j3d": (..., J, 3) tensor + "target_j3d": + "pred_verts": + "target_verts": + } + Returns: + cam_coord_metrics (dict): { + "pa_mpjpe": (..., ) numpy array + "mpjpe": + "pve": + "accel": + } + """ + # All data is in camera coordinates + pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3) + target_j3d = batch["target_j3d"].cpu() + pred_verts = batch["pred_verts"].cpu() + target_verts = batch["target_verts"].cpu() + + if mask is not None: + mask = mask.cpu() + pred_j3d = pred_j3d[mask].clone() + target_j3d = target_j3d[mask].clone() + pred_verts = pred_verts[mask].clone() + target_verts = target_verts[mask].clone() + assert "mask" not in batch + + # Align by pelvis + pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( + [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs + ) + + # Metrics + m2mm = 1000 + S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d) + pa_mpjpe = compute_jpe(S1_hat, target_j3d) * m2mm + mpjpe = compute_jpe(pred_j3d, target_j3d) * m2mm + pve = compute_jpe(pred_verts, target_verts) * m2mm + accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d, fps=fps) + + camcoord_metrics = { + "pa_mpjpe": pa_mpjpe, + "mpjpe": mpjpe, + "pve": pve, + "accel": accel, + } + return camcoord_metrics + + +@torch.no_grad() +def compute_global_metrics(batch, mask=None): + """Follow WHAM, the input has skipped invalid frames + Args: + batch (dict): { + "pred_j3d_glob": (F, J, 3) tensor + "target_j3d_glob": + "pred_verts_glob": + "target_verts_glob": + } + Returns: + global_metrics (dict): { + "wa2_mpjpe": (F, ) numpy array + "waa_mpjpe": + "rte": + "jitter": + "fs": + } + """ + # All data is in global coordinates + pred_j3d_glob = batch["pred_j3d_glob"].cpu() # (..., J, 3) + target_j3d_glob = batch["target_j3d_glob"].cpu() + pred_verts_glob = batch["pred_verts_glob"].cpu() + target_verts_glob = batch["target_verts_glob"].cpu() + if mask is not None: + mask = mask.cpu() + pred_j3d_glob = pred_j3d_glob[mask].clone() + target_j3d_glob = target_j3d_glob[mask].clone() + pred_verts_glob = pred_verts_glob[mask].clone() + target_verts_glob = target_verts_glob[mask].clone() + assert "mask" not in batch + + seq_length = pred_j3d_glob.shape[0] + + # Use chunk to compare + chunk_length = 100 + wa2_mpjpe, waa_mpjpe = [], [] + for start in range(0, seq_length, chunk_length): + end = min(seq_length, start + chunk_length) + + target_j3d = target_j3d_glob[start:end].clone().cpu() + pred_j3d = pred_j3d_glob[start:end].clone().cpu() + + w_j3d = first_align_joints(target_j3d, pred_j3d) + wa_j3d = global_align_joints(target_j3d, pred_j3d) + + if False: + from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines + + wis3d = make_wis3d(name="debug-metric_utils") + add_motion_as_lines(target_j3d, wis3d, name="target_j3d") + add_motion_as_lines(pred_j3d, wis3d, name="pred_j3d") + add_motion_as_lines(w_j3d, wis3d, name="pred_w2_j3d") + add_motion_as_lines(wa_j3d, wis3d, name="pred_wa_j3d") + + wa2_mpjpe.append(compute_jpe(target_j3d, w_j3d)) + waa_mpjpe.append(compute_jpe(target_j3d, wa_j3d)) + + # Metrics + m2mm = 1000 + wa2_mpjpe = np.concatenate(wa2_mpjpe) * m2mm + waa_mpjpe = np.concatenate(waa_mpjpe) * m2mm + + # Additional Metrics + rte = compute_rte(target_j3d_glob[:, 0].cpu(), pred_j3d_glob[:, 0].cpu()) * 1e2 + jitter = compute_jitter(pred_j3d_glob, fps=30) + foot_sliding = compute_foot_sliding(target_verts_glob, pred_verts_glob) * m2mm + + global_metrics = { + "wa2_mpjpe": wa2_mpjpe, + "waa_mpjpe": waa_mpjpe, + "rte": rte, + "jitter": jitter, + "fs": foot_sliding, + } + return global_metrics + + +@torch.no_grad() +def compute_camcoord_perjoint_metrics(batch, pelvis_idxs=[1, 2]): + """ + Args: + batch (dict): { + "pred_j3d": (..., J, 3) tensor + "target_j3d": + } + Returns: + cam_coord_metrics (dict): { + "pa_mpjpe": (..., ) numpy array + "mpjpe": + "pve": + "accel": + } + """ + # All data is in camera coordinates + pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3) + target_j3d = batch["target_j3d"].cpu() + pred_verts = batch["pred_verts"].cpu() + target_verts = batch["target_verts"].cpu() + + # Align by pelvis + pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( + [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs + ) + # Metrics + m2mm = 1000 + perjoint_mpjpe = compute_perjoint_jpe(pred_j3d, target_j3d) * m2mm + + camcoord_perjoint_metrics = { + "mpjpe": perjoint_mpjpe, + } + return camcoord_perjoint_metrics + + +# ===== Utilities ===== + + +def compute_jpe(S1, S2): + return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() + + +def compute_perjoint_jpe(S1, S2): + return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).numpy() + + +def batch_align_by_pelvis(data_list, pelvis_idxs=[1, 2]): + """ + Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. + Each data is in shape of (frames, num_points, 3) + Pelvis is notated as one / two joints indices. + Align all data to the corresponding pelvis location. + """ + + pred_j3d, target_j3d, pred_verts, target_verts = data_list + + pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() + target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() + + # Align to the pelvis + pred_j3d = pred_j3d - pred_pelvis + target_j3d = target_j3d - target_pelvis + pred_verts = pred_verts - pred_pelvis + target_verts = target_verts - target_pelvis + + return (pred_j3d, target_j3d, pred_verts, target_verts) + + +def batch_compute_similarity_transform_torch(S1, S2): + """ + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + transposed = True + assert S2.shape[1] == S1.shape[1] + + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) + mu2 = S2.mean(axis=-1, keepdims=True) + + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1**2, dim=1).sum(dim=1) + + # 3. The outer product of X1 and X2. + K = X1.bmm(X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) + Z = Z.repeat(U.shape[0], 1, 1) + Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) + + # Construct R. + R = V.bmm(Z.bmm(U.permute(0, 2, 1))) + + # 5. Recover scale. + scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 + + # 6. Recover translation. + t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) + + # 7. Error: + S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t + + if transposed: + S1_hat = S1_hat.permute(0, 2, 1) + + return S1_hat + + +def compute_error_accel(joints_gt, joints_pred, valid_mask=None, fps=None): + """ + Use [i-1, i, i+1] to compute acc at frame_i. The acceleration error: + 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} + Note that for each frame that is not visible, three entries(-1, 0, +1) in the + acceleration error will be zero'd out. + Args: + joints_gt : (F, J, 3) + joints_pred : (F, J, 3) + valid_mask : (F) + Returns: + error_accel (F-2) when valid_mask is None, else (F'), F' <= F-2 + """ + # (F, J, 3) -> (F-2) per-joint + accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] + accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] + normed = np.linalg.norm(accel_pred - accel_gt, axis=-1).mean(axis=-1) + if fps is not None: + normed = normed * fps**2 + + if valid_mask is None: + new_vis = np.ones(len(normed), dtype=bool) + else: + invis = np.logical_not(valid_mask) + invis1 = np.roll(invis, -1) + invis2 = np.roll(invis, -2) + new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] + new_vis = np.logical_not(new_invis) + if new_vis.sum() == 0: + print("Warning!!! no valid acceleration error to compute.") + + return normed[new_vis] + + +def compute_rte(target_trans, pred_trans): + # Compute the global alignment + _, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True) + pred_trans_hat = (torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :])[0] + + # Compute the entire displacement of ground truth trajectory + disps, disp = [], 0 + for p1, p2 in zip(target_trans, target_trans[1:]): + delta = (p2 - p1).norm(2, dim=-1) + disp += delta + disps.append(disp) + + # Compute absolute root-translation-error (RTE) + rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) + + # Normalize it to the displacement + return (rte / disp).numpy() + + +def compute_jitter(joints, fps=30): + """compute jitter of the motion + Args: + joints (N, J, 3). + fps (float). + Returns: + jitter (N-3). + """ + pred_jitter = torch.norm( + (joints[3:] - 3 * joints[2:-1] + 3 * joints[1:-2] - joints[:-3]) * (fps**3), + dim=2, + ).mean(dim=-1) + + return pred_jitter.cpu().numpy() / 10.0 + + +def compute_foot_sliding(target_verts, pred_verts, thr=1e-2): + """compute foot sliding error + The foot ground contact label is computed by the threshold of 1 cm/frame + Args: + target_verts (N, 6890, 3). + pred_verts (N, 6890, 3). + Returns: + error (N frames in contact). + """ + assert target_verts.shape == pred_verts.shape + assert target_verts.shape[-2] == 6890 + + # Foot vertices idxs + foot_idxs = [3216, 3387, 6617, 6787] + + # Compute contact label + foot_loc = target_verts[:, foot_idxs] + foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) + contact = foot_disp[:] < thr + + pred_feet_loc = pred_verts[:, foot_idxs] + pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) + + error = pred_disp[contact] + + return error.cpu().numpy() + + +def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345): + joints24 = torch.zeros(*joints22.shape[:-2], 24, 3).to(joints22.device) + joints24[..., :22, :] = joints22 + joints24[..., 22, :] = joints22[..., 20, :] + ratio2220 * (joints22[..., 20, :] - joints22[..., 18, :]) + joints24[..., 23, :] = joints22[..., 21, :] + ratio2321 * (joints22[..., 21, :] - joints22[..., 19, :]) + return joints24 + + +def align_pcl(Y, X, weight=None, fixed_scale=False): + """align similarity transform to align X with Y using umeyama method + X' = s * R * X + t is aligned with Y + :param Y (*, N, 3) first trajectory + :param X (*, N, 3) second trajectory + :param weight (*, N, 1) optional weight of valid correspondences + :returns s (*, 1), R (*, 3, 3), t (*, 3) + """ + *dims, N, _ = Y.shape + N = torch.ones(*dims, 1, 1) * N + + if weight is not None: + Y = Y * weight + X = X * weight + N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) + + # subtract mean + my = Y.sum(dim=-2) / N[..., 0] # (*, 3) + mx = X.sum(dim=-2) / N[..., 0] + y0 = Y - my[..., None, :] # (*, N, 3) + x0 = X - mx[..., None, :] + + if weight is not None: + y0 = y0 * weight + x0 = x0 * weight + + # correlation + C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) + U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) + + S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) + neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 + S[neg, 2, 2] = -1 + + R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) + + D = torch.diag_embed(D) # (*, 3, 3) + if fixed_scale: + s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) + else: + var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) + s = torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(dim=-1, keepdim=True) / var[..., 0] # (*, 1) + + t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) + + return s, R, t + + +def global_align_joints(gt_joints, pred_joints): + """ + :param gt_joints (T, J, 3) + :param pred_joints (T, J, 3) + """ + s_glob, R_glob, t_glob = align_pcl(gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)) + pred_glob = s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] + return pred_glob + + +def first_align_joints(gt_joints, pred_joints): + """ + align the first two frames + :param gt_joints (T, J, 3) + :param pred_joints (T, J, 3) + """ + # (1, 1), (1, 3, 3), (1, 3) + s_first, R_first, t_first = align_pcl(gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)) + pred_first = s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] + return pred_first + + +def rearrange_by_mask(x, mask): + """ + x (L, *) + mask (M,), M >= L + """ + M = mask.size(0) + L = x.size(0) + if M == L: + return x + assert M > L + assert mask.sum() == L + x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) + x_rearranged[mask] = x + return x_rearranged + + +def as_np_array(d): + if isinstance(d, torch.Tensor): + return d.cpu().numpy() + elif isinstance(d, np.ndarray): + return d + else: + return np.array(d) diff --git a/hmr4d/utils/geo/augment_noisy_pose.py b/hmr4d/utils/geo/augment_noisy_pose.py new file mode 100644 index 0000000..3c8b9e5 --- /dev/null +++ b/hmr4d/utils/geo/augment_noisy_pose.py @@ -0,0 +1,207 @@ +import torch +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_rotation_6d +import hmr4d.utils.matrix as matrix +from hmr4d import PROJ_ROOT + +COCO17_AUG = {k: v.flatten() for k, v in torch.load(PROJ_ROOT / "hmr4d/utils/body_model/coco_aug_dict.pth").items()} +COCO17_AUG_CUDA = {} +COCO17_TREE = [[5, 6], 0, 0, 1, 2, -1, -1, 5, 6, 7, 8, -1, -1, 11, 12, 13, 14, 15, 15, 15, 16, 16, 16] + + +def gaussian_augment(body_pose, std_angle=10.0, to_R=True): + """ + Args: + body_pose torch.Tensor: (..., J, 3) axis-angle if to_R is True, else rotmat (..., J, 3, 3) + std_angle: scalar or list, in degree + """ + + body_pose = body_pose.clone() + + if to_R: + body_pose_R = axis_angle_to_matrix(body_pose) # (B, L, J, 3, 3) + else: + body_pose_R = body_pose + shape = body_pose_R.shape[:-2] + device = body_pose.device + + # 1. Simulate noise + # angle: + std_angle = torch.tensor(std_angle).to(device).reshape(-1) # allow scalar or list + noise_angle = torch.randn(shape, device=device) * std_angle * torch.pi / 180 + + # axis: avoid zero vector + noise_axis = torch.rand((*shape, 3), device=device) + mask_ = torch.norm(noise_axis, dim=-1) < 1e-6 + noise_axis[mask_] = 1 + + noise_axis = noise_axis / torch.norm(noise_axis, dim=-1, keepdim=True) + noise_aa = noise_angle[..., None] * noise_axis # (B, L, J, 3) + noise_R = axis_angle_to_matrix(noise_aa) # (B, L, J, 3, 3) + + # 2. Add noise to body pose + new_body_pose_R = matrix.get_mat_BfromA(body_pose_R, noise_R) # (B, L, J, 3, 3) + # new_body_pose_R = torch.matmul(noise_R, body_pose_R) + new_body_pose_r6d = matrix_to_rotation_6d(new_body_pose_R) # (B, L, J, 6) + new_body_pose_aa = matrix_to_axis_angle(new_body_pose_R) # (B, L, J, 3) + + return new_body_pose_R, new_body_pose_r6d, new_body_pose_aa + + +# ========= Augment Joint 3D ======== # + + +def get_jitter(shape=(8, 120), s_jittering=5e-2): + """Guassian jitter modeling.""" + jittering_noise = ( + torch.normal( + mean=torch.zeros((*shape, 17, 3)), + std=COCO17_AUG["jittering"].reshape(1, 1, 17, 1).expand(*shape, -1, 3), + ) + * s_jittering + ) + return jittering_noise + + +def get_jitter_cuda(shape=(8, 120), s_jittering=5e-2): + if "jittering" not in COCO17_AUG_CUDA: + COCO17_AUG_CUDA["jittering"] = COCO17_AUG["jittering"].cuda().reshape(1, 1, 17, 1) + jittering = COCO17_AUG_CUDA["jittering"] + jittering_noise = torch.randn((*shape, 17, 3), device="cuda") * jittering * s_jittering + return jittering_noise + + +def get_lfhp(shape=(8, 120), s_peak=3e-1, s_peak_mask=5e-3): + """Low-frequency high-peak noise modeling.""" + + def get_peak_noise_mask(): + peak_noise_mask = torch.rand(*shape, 17) * COCO17_AUG["pmask"] + peak_noise_mask = peak_noise_mask < s_peak_mask + return peak_noise_mask + + peak_noise_mask = get_peak_noise_mask() # (B, L, 17) + peak_noise = peak_noise_mask.float().unsqueeze(-1).repeat(1, 1, 1, 3) + peak_noise = peak_noise * torch.randn(3) * COCO17_AUG["peak"].reshape(17, 1) * s_peak + return peak_noise + + +def get_lfhp_cuda(shape=(8, 120), s_peak=3e-1, s_peak_mask=5e-3): + if "peak" not in COCO17_AUG_CUDA: + COCO17_AUG_CUDA["pmask"] = COCO17_AUG["pmask"].cuda() + COCO17_AUG_CUDA["peak"] = COCO17_AUG["peak"].cuda().reshape(17, 1) + + pmask = COCO17_AUG_CUDA["pmask"] + peak = COCO17_AUG_CUDA["peak"] + peak_noise_mask = torch.rand(*shape, 17, device="cuda") * pmask < s_peak_mask + peak_noise = ( + peak_noise_mask.float().unsqueeze(-1).expand(-1, -1, -1, 3) * torch.randn(3, device="cuda") * peak * s_peak + ) + return peak_noise + + +def get_bias(shape=(8, 120), s_bias=1e-1): + """Bias noise modeling.""" + b, l = shape + bias_noise = torch.normal(mean=torch.zeros((b, 17, 3)), std=COCO17_AUG["bias"].reshape(1, 17, 1)) * s_bias + bias_noise = bias_noise[:, None].expand(-1, l, -1, -1) # (B, L, J, 3), the whole sequence is moved by the same bias + return bias_noise + + +def get_bias_cuda(shape=(8, 120), s_bias=1e-1): + if "bias" not in COCO17_AUG_CUDA: + COCO17_AUG_CUDA["bias"] = COCO17_AUG["bias"].cuda().reshape(1, 17, 1) + + bias = COCO17_AUG_CUDA["bias"] + bias_noise = torch.randn((shape[0], 17, 3), device="cuda") * bias * s_bias + bias_noise = bias_noise[:, None].expand(-1, shape[1], -1, -1) + return bias_noise + + +def get_wham_aug_kp3d(shape=(8, 120)): + # aug = get_bias(shape).cuda() + get_lfhp(shape).cuda() + get_jitter(shape).cuda() + aug = get_bias_cuda(shape) + get_lfhp_cuda(shape) + get_jitter_cuda(shape) + return aug + + +def get_visible_mask(shape=(8, 120), s_mask=0.03): + """Mask modeling.""" + # Per-frame and joint + mask = torch.rand(*shape, 17) < s_mask + visible = (~mask).clone() # (B, L, 17) + + visible = visible.reshape(-1, 17) # (BL, 17) + for child in range(17): + parent = COCO17_TREE[child] + if parent == -1: + continue + if isinstance(parent, list): + visible[:, child] *= visible[:, parent[0]] * visible[:, parent[1]] + else: + visible[:, child] *= visible[:, parent] + visible = visible.reshape(*shape, 17).clone() # (B, L, J) + return visible + + +def get_invisible_legs_mask(shape, s_mask=0.03): + """ + Both legs are invisible for a random duration. + """ + B, L = shape + starts = torch.randint(0, L - 90, (B,)) + ends = starts + torch.randint(30, 90, (B,)) + mask_range = torch.arange(L).unsqueeze(0).expand(B, -1) + mask_to_apply = (mask_range >= starts.unsqueeze(1)) & (mask_range < ends.unsqueeze(1)) + mask_to_apply = mask_to_apply.unsqueeze(2).expand(-1, -1, 17).clone() + mask_to_apply[:, :, :11] = False # only both legs are invisible + mask_to_apply = mask_to_apply & (torch.rand(B, 1, 1) < s_mask) + return mask_to_apply + + +def randomly_occlude_lower_half(i_x2d, s_mask=0.03): + """ + Randomly occlude the lower half of the image. + """ + raise NotImplementedError + B, L, N, _ = i_x2d.shape + i_x2d = i_x2d.clone() + + # a period of time when the lower half of the image is invisible + starts = torch.randint(0, L - 90, (B,)) + ends = starts + torch.randint(30, 90, (B,)) + mask_range = torch.arange(L).unsqueeze(0).expand(B, -1) + mask_to_apply = (mask_range >= starts.unsqueeze(1)) & (mask_range < ends.unsqueeze(1)) + mask_to_apply = mask_to_apply.unsqueeze(2).expand(-1, -1, N) # (B, L, N) + + # only the lower half of the image is invisible + i_x2d + i_x2d[..., 1] / 2 + + mask_to_apply = mask_to_apply & (torch.rand(B, 1, 1) < s_mask) + return mask_to_apply + + +def randomly_modify_hands_legs(j3d): + hands = [9, 10] + legs = [15, 16] + + B, L, J, _ = j3d.shape + p_switch_hand = 0.001 + p_switch_leg = 0.001 + p_wrong_hand0 = 0.001 + p_wrong_hand1 = 0.001 + p_wrong_leg0 = 0.001 + p_wrong_leg1 = 0.001 + + mask = torch.rand(B, L) < p_switch_hand + j3d[mask][:, hands] = j3d[mask][:, hands[::-1]] + mask = torch.rand(B, L) < p_switch_leg + j3d[mask][:, legs] = j3d[mask][:, legs[::-1]] + mask = torch.rand(B, L) < p_wrong_hand0 + j3d[mask][:, 9] = j3d[mask][:, 10] + mask = torch.rand(B, L) < p_wrong_hand1 + j3d[mask][:, 10] = j3d[mask][:, 9] + mask = torch.rand(B, L) < p_wrong_leg0 + j3d[mask][:, 15] = j3d[mask][:, 16] + mask = torch.rand(B, L) < p_wrong_leg1 + j3d[mask][:, 16] = j3d[mask][:, 15] + + return j3d diff --git a/hmr4d/utils/geo/flip_utils.py b/hmr4d/utils/geo/flip_utils.py new file mode 100644 index 0000000..81fc815 --- /dev/null +++ b/hmr4d/utils/geo/flip_utils.py @@ -0,0 +1,86 @@ +import torch +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle + + +def flip_heatmap_coco17(output_flipped): + assert output_flipped.ndim == 4, "output_flipped should be [B, J, H, W]" + shape_ori = output_flipped.shape + channels = 1 + output_flipped = output_flipped.reshape(shape_ori[0], -1, channels, shape_ori[2], shape_ori[3]) + output_flipped_back = output_flipped.clone() + + # Swap left-right parts + for left, right in [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]: + output_flipped_back[:, left, ...] = output_flipped[:, right, ...] + output_flipped_back[:, right, ...] = output_flipped[:, left, ...] + output_flipped_back = output_flipped_back.reshape(shape_ori) + # Flip horizontally + output_flipped_back = output_flipped_back.flip(3) + return output_flipped_back + + +def flip_bbx_xys(bbx_xys, w): + """ + bbx_xys: (F, 3) + """ + bbx_xys_flip = bbx_xys.clone() + bbx_xys_flip[:, 0] = w - bbx_xys_flip[:, 0] + return bbx_xys_flip + + +def flip_kp2d_coco17(kp2d, w): + """Flip keypoints.""" + kp2d = kp2d.clone() + flipped_parts = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + kp2d = kp2d[..., flipped_parts, :] + kp2d[..., 0] = w - kp2d[..., 0] + return kp2d + + +def flip_smplx_params(smplx_params): + """Flip pose. + The flipping is based on SMPLX parameters. + """ + rotation = torch.cat([smplx_params["global_orient"], smplx_params["body_pose"]], dim=1) + + BN = rotation.shape[0] + pose = rotation.reshape(BN, -1).transpose(0, 1) + + SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20] # , 23, 22] + SMPL_POSE_FLIP_PERM = [] + for i in SMPL_JOINTS_FLIP_PERM: + SMPL_POSE_FLIP_PERM.append(3 * i) + SMPL_POSE_FLIP_PERM.append(3 * i + 1) + SMPL_POSE_FLIP_PERM.append(3 * i + 2) + + pose = pose[SMPL_POSE_FLIP_PERM] + + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + pose = pose.transpose(0, 1).reshape(BN, -1, 3) + + smplx_params_flipped = smplx_params.copy() + smplx_params_flipped["global_orient"] = pose[:, :1] + smplx_params_flipped["body_pose"] = pose[:, 1:] + return smplx_params_flipped + + +def avg_smplx_aa(aa1, aa2): + def avg_rot(rot): + # input [B,...,3,3] --> output [...,3,3] + rot = rot.mean(dim=0) + U, _, V = torch.svd(rot) + rot = U @ V.transpose(-1, -2) + return rot + + B, J3 = aa1.shape + aa1 = aa1.reshape(B, -1, 3) + aa2 = aa2.reshape(B, -1, 3) + + R1 = axis_angle_to_matrix(aa1) + R2 = axis_angle_to_matrix(aa2) + R_avg = avg_rot(torch.stack([R1, R2])) + aa_avg = matrix_to_axis_angle(R_avg).reshape(B, -1) + + return aa_avg diff --git a/hmr4d/utils/geo/hmr_cam.py b/hmr4d/utils/geo/hmr_cam.py new file mode 100644 index 0000000..48a7365 --- /dev/null +++ b/hmr4d/utils/geo/hmr_cam.py @@ -0,0 +1,398 @@ +import torch +import numpy as np +from hmr4d.utils.geo_transform import project_p2d, convert_bbx_xys_to_lurb, cvt_to_bi01_p2d + + +def estimate_focal_length(img_w, img_h): + return (img_w**2 + img_h**2) ** 0.5 # Diagonal FOV = 2*arctan(0.5) * 180/pi = 53 + + +def estimate_K(img_w, img_h): + focal_length = estimate_focal_length(img_w, img_h) + K = torch.eye(3).float() + K[0, 0] = focal_length + K[1, 1] = focal_length + K[0, 2] = img_w / 2.0 + K[1, 2] = img_h / 2.0 + return K + + +def convert_K_to_K4(K): + K4 = torch.stack([K[0, 0], K[1, 1], K[0, 2], K[1, 2]]).float() + return K4 + + +def convert_f_to_K(focal_length, img_w, img_h): + K = torch.eye(3).float() + K[0, 0] = focal_length + K[1, 1] = focal_length + K[0, 2] = img_w / 2.0 + K[1, 2] = img_h / 2.0 + return K + + +def resize_K(K, f=0.5): + K = K.clone() * f + K[..., 2, 2] = 1.0 + return K + + +def create_camera_sensor(width=None, height=None, f_fullframe=None): + if width is None or height is None: + # The 4:3 aspect ratio is widely adopted by image sensors in mobile phones. + if np.random.rand() < 0.5: + width, height = 1200, 1600 + else: + width, height = 1600, 1200 + + # Sample FOV from common options: + # 1. wide-angle lenses are common in mobile phones, + # 2. telephoto lenses has less perspective effect, which should makes it easy to learn + if f_fullframe is None: + f_fullframe_options = [24, 26, 28, 30, 35, 40, 50, 60, 70] + f_fullframe = np.random.choice(f_fullframe_options) + + # We use diag to map focal-length: https://www.nikonians.org/reviews/fov-tables + diag_fullframe = (24**2 + 36**2) ** 0.5 + diag_img = (width**2 + height**2) ** 0.5 + focal_length = diag_img / diag_fullframe * f_fullframe + + K_fullimg = torch.eye(3) + K_fullimg[0, 0] = focal_length + K_fullimg[1, 1] = focal_length + K_fullimg[0, 2] = width / 2 + K_fullimg[1, 2] = height / 2 + + return width, height, K_fullimg + + +# ====== Compute cliffcam ===== # + + +def convert_xys_to_cliff_cam_wham(xys, res): + """ + Args: + xys: (N, 3) in pixel. Note s should not be touched by 200 + res: (2), e.g. [4112., 3008.] (w,h) + Returns: + cliff_cam: (N, 3), normalized representation + """ + + def normalize_keypoints_to_image(x, res): + """ + Args: + x: (N, 2), centers + res: (2), e.g. [4112., 3008.] + Returns: + x_normalized: (N, 2) + """ + res = res.to(x.device) + scale = res.max(-1)[0].reshape(-1) + mean = torch.stack([res[..., 0] / scale, res[..., 1] / scale], dim=-1).to(x.device) + x = 2 * x / scale.reshape(*[1 for i in range(len(x.shape[1:]))]) - mean.reshape( + *[1 for i in range(len(x.shape[1:-1]))], -1 + ) + return x + + centers = normalize_keypoints_to_image(xys[:, :2], res) # (N, 2) + scale = xys[:, 2:] / res.max() + location = torch.cat((centers, scale), dim=-1) + return location + + +def compute_bbox_info_bedlam(bbx_xys, K_fullimg): + """impl as in BEDLAM + Args: + bbx_xys: ((B), N, 3), in pixel space described by K_fullimg + K_fullimg: ((B), (N), 3, 3) + Returns: + bbox_info: ((B), N, 3) + """ + fl = K_fullimg[..., 0, 0].unsqueeze(-1) + icx = K_fullimg[..., 0, 2] + icy = K_fullimg[..., 1, 2] + + cx, cy, b = bbx_xys[..., 0], bbx_xys[..., 1], bbx_xys[..., 2] + bbox_info = torch.stack([cx - icx, cy - icy, b], dim=-1) + bbox_info = bbox_info / fl + return bbox_info + + +# ====== Convert Prediction to Cam-t ===== # + + +def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg): + s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2] + focal_length = K_fullimg[..., 0, 0] + + icx = K_fullimg[..., 0, 2] + icy = K_fullimg[..., 1, 2] + sb = s * bbx_xys[..., 2] + cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9) + cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9) + tz = 2 * focal_length / (sb + 1e-9) + + cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) + return cam_t + + +def get_a_pred_cam(transl, bbx_xys, K_fullimg): + """Inverse operation of compute_transl_full_cam""" + assert transl.ndim == bbx_xys.ndim # (*, L, 3) + assert K_fullimg.ndim == (bbx_xys.ndim + 1) # (*, L, 3, 3) + f = K_fullimg[..., 0, 0] + cx = K_fullimg[..., 0, 2] + cy = K_fullimg[..., 1, 2] + gt_s = 2 * f / (transl[..., 2] * bbx_xys[..., 2]) # (B, L) + gt_x = transl[..., 0] - transl[..., 2] / f * (bbx_xys[..., 0] - cx) + gt_y = transl[..., 1] - transl[..., 2] / f * (bbx_xys[..., 1] - cy) + gt_pred_cam = torch.stack([gt_s, gt_x, gt_y], dim=-1) + return gt_pred_cam + + +# ====== 3D to 2D ===== # + + +def project_to_bi01(points, bbx_xys, K_fullimg): + """ + points: (B, L, J, 3) + bbx_xys: (B, L, 3) + K_fullimg: (B, L, 3, 3) + """ + # p2d = project_p2d(points, K_fullimg) + p2d = perspective_projection(points, K_fullimg) + bbx_lurb = convert_bbx_xys_to_lurb(bbx_xys) + p2d_bi01 = cvt_to_bi01_p2d(p2d, bbx_lurb) + return p2d_bi01 + + +def perspective_projection(points, K): + # points: (B, L, J, 3) + # K: (B, L, 3, 3) + projected_points = points / points[..., -1].unsqueeze(-1) + projected_points = torch.einsum("...ij,...kj->...ki", K, projected_points.float()) + return projected_points[..., :-1] + + +# ====== 2D (bbx from j2d) ===== # + + +def normalize_kp2d(obs_kp2d, bbx_xys, clamp_scale_min=False): + """ + Args: + obs_kp2d: (B, L, J, 3) [x, y, c] + bbx_xys: (B, L, 3) + Returns: + obs: (B, L, J, 3) [x, y, c] + """ + obs_xy = obs_kp2d[..., :2] # (B, L, J, 2) + obs_conf = obs_kp2d[..., 2] # (B, L, J) + center = bbx_xys[..., :2] + scale = bbx_xys[..., [2]] + + # Mark keypoints outside the bounding box as invisible + xy_max = center + scale / 2 + xy_min = center - scale / 2 + invisible_mask = ( + (obs_xy[..., 0] < xy_min[..., None, 0]) + + (obs_xy[..., 0] > xy_max[..., None, 0]) + + (obs_xy[..., 1] < xy_min[..., None, 1]) + + (obs_xy[..., 1] > xy_max[..., None, 1]) + ) + obs_conf = obs_conf * ~invisible_mask + if clamp_scale_min: + scale = scale.clamp(min=1e-5) + normalized_obs_xy = 2 * (obs_xy - center.unsqueeze(-2)) / scale.unsqueeze(-2) + + return torch.cat([normalized_obs_xy, obs_conf[..., None]], dim=-1) + + +def get_bbx_xys(i_j2d, bbx_ratio=[192, 256], do_augment=False, base_enlarge=1.2): + """Args: (B, L, J, 3) [x,y,c] -> Returns: (B, L, 3)""" + # Center + min_x = i_j2d[..., 0].min(-1)[0] + max_x = i_j2d[..., 0].max(-1)[0] + min_y = i_j2d[..., 1].min(-1)[0] + max_y = i_j2d[..., 1].max(-1)[0] + center_x = (min_x + max_x) / 2 + center_y = (min_y + max_y) / 2 + + # Size + h = max_y - min_y # (B, L) + w = max_x - min_x # (B, L) + + if True: # fit w and h into aspect-ratio + aspect_ratio = bbx_ratio[0] / bbx_ratio[1] + mask1 = w > aspect_ratio * h + h[mask1] = w[mask1] / aspect_ratio + mask2 = w < aspect_ratio * h + w[mask2] = h[mask2] * aspect_ratio + + # apply a common factor to enlarge the bounding box + bbx_size = torch.max(h, w) * base_enlarge + + if do_augment: + B, L = bbx_size.shape[:2] + device = bbx_size.device + if True: + scaleFactor = torch.rand((B, L), device=device) * 0.3 + 1.05 # 1.05~1.35 + txFactor = torch.rand((B, L), device=device) * 1.6 - 0.8 # -0.8~0.8 + tyFactor = torch.rand((B, L), device=device) * 1.6 - 0.8 # -0.8~0.8 + else: + scaleFactor = torch.rand((B, 1), device=device) * 0.3 + 1.05 # 1.05~1.35 + txFactor = torch.rand((B, 1), device=device) * 1.6 - 0.8 # -0.8~0.8 + tyFactor = torch.rand((B, 1), device=device) * 1.6 - 0.8 # -0.8~0.8 + + raw_bbx_size = bbx_size / base_enlarge + bbx_size = raw_bbx_size * scaleFactor + center_x += raw_bbx_size / 2 * ((scaleFactor - 1) * txFactor) + center_y += raw_bbx_size / 2 * ((scaleFactor - 1) * tyFactor) + + return torch.stack([center_x, center_y, bbx_size], dim=-1) + + +def safely_render_x3d_K(x3d, K_fullimg, thr): + """ + Args: + x3d: (B, L, V, 3), should as least have a safe points (not examined here) + K_fullimg: (B, L, 3, 3) + Returns: + bbx_xys: (B, L, 3) + i_x2d: (B, L, V, 2) + """ + # For each frame, update unsafe z ( 0: + x3d[..., 2][x3d_unsafe_mask] = thr + if False: + from hmr4d.utils.wis3d_utils import make_wis3d + + wis3d = make_wis3d(name="debug-update-z") + bs, ls, vs = torch.where(x3d_unsafe_mask) + bs = torch.unique(bs) + for b in bs: + for f in range(x3d.size(1)): + wis3d.set_scene_id(f) + wis3d.add_point_cloud(x3d[b, f], name="unsafe") + pass + + # renfer + i_x2d = perspective_projection(x3d, K_fullimg) # (B, L, V, 2) + return i_x2d + + +def get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2): + """ + Args: + bbx_xyxy: (N, 4) [x1, y1, x2, y2] + Returns: + bbx_xys: (N, 3) [center_x, center_y, size] + """ + + i_p2d = torch.stack([bbx_xyxy[:, [0, 1]], bbx_xyxy[:, [2, 3]]], dim=1) # (L, 2, 2) + bbx_xys = get_bbx_xys(i_p2d[None], base_enlarge=base_enlarge)[0] + return bbx_xys + + +def bbx_xyxy_from_x(p2d): + """ + Args: + p2d: (*, V, 2) - Tensor containing 2D points. + + Returns: + bbx_xyxy: (*, 4) - Bounding box coordinates in the format (xmin, ymin, xmax, ymax). + """ + # Compute the minimum and maximum coordinates for the bounding box + xy_min = p2d.min(dim=-2).values # (*, 2) + xy_max = p2d.max(dim=-2).values # (*, 2) + + # Concatenate min and max coordinates to form the bounding box + bbx_xyxy = torch.cat([xy_min, xy_max], dim=-1) # (*, 4) + + return bbx_xyxy + + +def bbx_xyxy_from_masked_x(p2d, mask): + """ + Args: + p2d: (*, V, 2) - Tensor containing 2D points. + mask: (*, V) - Boolean tensor indicating valid points. + + Returns: + bbx_xyxy: (*, 4) - Bounding box coordinates in the format (xmin, ymin, xmax, ymax). + """ + # Ensure the shapes of p2d and mask are compatible + assert p2d.shape[:-1] == mask.shape, "The shape of p2d and mask are not compatible." + + # Flatten the input tensors for batch processing + p2d_flat = p2d.view(-1, p2d.shape[-2], p2d.shape[-1]) + mask_flat = mask.view(-1, mask.shape[-1]) + + # Set masked out values to a large positive and negative value respectively + p2d_min = torch.where(mask_flat.unsqueeze(-1), p2d_flat, torch.tensor(float("inf")).to(p2d_flat)) + p2d_max = torch.where(mask_flat.unsqueeze(-1), p2d_flat, torch.tensor(float("-inf")).to(p2d_flat)) + + # Compute the minimum and maximum coordinates for the bounding box + xy_min = p2d_min.min(dim=1).values # (BL, 2) + xy_max = p2d_max.max(dim=1).values # (BL, 2) + + # Concatenate min and max coordinates to form the bounding box + bbx_xyxy = torch.cat([xy_min, xy_max], dim=-1) # (BL, 4) + + # Reshape back to the original shape prefix + bbx_xyxy = bbx_xyxy.view(*p2d.shape[:-2], 4) + + return bbx_xyxy + + +def bbx_xyxy_ratio(xyxy1, xyxy2): + """Designed for fov/unbounded + Args: + xyxy1: (*, 4) + xyxy2: (*, 4) + Return: + ratio: (*), squared_area(xyxy1) / squared_area(xyxy2) + """ + area1 = (xyxy1[..., 2] - xyxy1[..., 0]) * (xyxy1[..., 3] - xyxy1[..., 1]) + area2 = (xyxy2[..., 2] - xyxy2[..., 0]) * (xyxy2[..., 3] - xyxy2[..., 1]) + # Check + area1[~torch.isfinite(area1)] = 0 # replace inf in area1 with 0 + assert (area2 > 0).all(), "area2 should be positive" + return area1 / area2 + + +def get_mesh_in_fov_category(mask): + """mask: (L, V) + The definition: + 1. FullyVisible: The mesh in every frame is entirely within the field of view (FOV). + 2. PartiallyVisible: In some frames, parts of the mesh are outside the FOV, while other parts are within the FOV. + 3. PartiallyOut: In some frames, the mesh is completely outside the FOV, while in others, it is visible. + 4. FullyOut: The mesh is completely outside the FOV in every frame. + """ + mask = mask.clone().cpu() + is_class1 = mask.all() # FullyVisible + is_class2 = mask.any(1).all() * ~is_class1 # PartiallyVisible + is_class4 = ~(mask.any()) # PartiallyOut + is_class3 = ~is_class1 * ~is_class2 * ~is_class4 # FullyOut + + mask_frame_any_verts = mask.any(1) + assert is_class1.int() + is_class2.int() + is_class3.int() + is_class4.int() == 1 + class_type = is_class1.int() + 2 * is_class2.int() + 3 * is_class3.int() + 4 * is_class4.int() + return class_type.item(), mask_frame_any_verts + + +def get_infov_mask(p2d, w_real, h_real): + """ + Args: + p2d: (B, L, V, 2) + w_real, h_real: (B, L) or int + Returns: + mask: (B, L, V) + """ + x, y = p2d[..., 0], p2d[..., 1] + if isinstance(w_real, int): + mask = (x >= 0) * (x < w_real) * (y >= 0) * (y < h_real) + else: + mask = (x >= 0) * (x < w_real[..., None]) * (y >= 0) * (y < h_real[..., None]) + return mask diff --git a/hmr4d/utils/geo/hmr_global.py b/hmr4d/utils/geo/hmr_global.py new file mode 100644 index 0000000..5483370 --- /dev/null +++ b/hmr4d/utils/geo/hmr_global.py @@ -0,0 +1,345 @@ +import torch +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_quaternion, quaternion_to_matrix +import hmr4d.utils.matrix as matrix +from hmr4d.utils.net_utils import gaussian_smooth + + +def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]): + """ + Args: + R_w2c: (*, 3, 3) + Returns: + R_c2gv: (*, 3, 3) + """ + if isinstance(axis_gravity_in_w, list): + axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord + axis_z_in_c = torch.tensor([0, 0, 1]).float() + + # get gv-coord axes in in c-coord + axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3) + axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1) + # normalize + axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True) + axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5) + axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv + axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1) + + R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3) + R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3) + return R_c2gv + + +tsf_axisangle = { + "ay->ay": [0, 0, 0], + "any->ay": [0, 0, torch.pi], + "az->ay": [-torch.pi / 2, 0, 0], + "ay->any": [0, 0, torch.pi], +} + + +def get_tgtcoord_rootparam(global_orient, transl, gravity_vec=None, tgt_gravity_vec=None, tsf="ay->ay"): + """Rotate around the origin center, to match the new gravity direction + Args: + global_orient: torch.tensor, (*, 3) + transl: torch.tensor, (*, 3) + gravity_vec: torch.tensor, (3,) + tgt_gravity_vec: torch.tensor, (3,) + Returns: + tgt_global_orient: torch.tensor, (*, 3) + tgt_transl: torch.tensor, (*, 3) + R_g2tg: (3, 3) + """ + # get rotation matrix + device = global_orient.device + if gravity_vec is None and tgt_gravity_vec is None: + aa = torch.tensor(tsf_axisangle[tsf]).to(device) + R_g2tg = axis_angle_to_matrix(aa) # (3, 3) + else: + raise NotImplementedError + # TODO: Impl this function + gravity_vec = torch.tensor(gravity_vec).float().to(device) + gravity_vec = gravity_vec / gravity_vec.norm() + tgt_gravity_vec = torch.tensor(tgt_gravity_vec).float().to(device) + tgt_gravity_vec = tgt_gravity_vec / tgt_gravity_vec.norm() + # pick one identity axis + axis_identity = torch.tensor([0, 0, 0]).float().to(device) + for i in (gravity_vec == 0) & (tgt_global_orient == 0): + if i: + axis_identity[i] = 1 + break + + # rotate + global_orient_R = axis_angle_to_matrix(global_orient) # (*, 3, 3) + tgt_global_orient = matrix_to_axis_angle(R_g2tg @ global_orient_R) # (*, 3, 3) + tgt_transl = torch.einsum("...ij,...j->...i", R_g2tg, transl) + + return tgt_global_orient, tgt_transl, R_g2tg + + +def get_c_rootparam(global_orient, transl, T_w2c, offset): + """ + Args: + global_orient: torch.tensor, (F, 3) + transl: torch.tensor, (F, 3) + T_w2c: torch.tensor, (*, 4, 4) + offset: torch.tensor, (3,) + Returns: + R_c: torch.tensor, (F, 3) + t_c: torch.tensor, (F, 3) + """ + assert global_orient.shape == transl.shape and len(global_orient.shape) == 2 + R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3) + t_w = transl # (F, 3) + + R_w2c = T_w2c[..., :3, :3] # (*, 3, 3) + t_w2c = T_w2c[..., :3, 3] # (*, 3) + if len(R_w2c.shape) == 2: + R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3) + t_w2c = t_w2c[None].expand(t_w.size(0), -1) + + R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3) + t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3) + return R_c, t_c + + +def get_T_w2c_from_wcparams(global_orient_w, transl_w, global_orient_c, transl_c, offset): + """ + Args: + global_orient_w: torch.tensor, (F, 3) + transl_w: torch.tensor, (F, 3) + global_orient_c: torch.tensor, (F, 3) + transl_c: torch.tensor, (F, 3) + offset: torch.tensor, (*, 3) + Returns: + T_w2c: torch.tensor, (F, 4, 4) + """ + assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2 + assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2 + + R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3) + t_w = transl_w # (F, 3) + R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3) + t_c = transl_c # (F, 3) + + R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3) + t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3) + T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4) + T_w2c[..., :3, :3] = R_w2c # (F, 3, 3) + T_w2c[..., :3, 3] = t_w2c # (F, 3) + return T_w2c + + +def get_local_transl_vel(transl, global_orient): + """ + transl velocity is in local coordinate (or, SMPL-coord) + Args: + transl: (*, L, 3) + global_orient: (*, L, 3) + Returns: + transl_vel: (*, L, 3) + """ + assert len(transl.shape) == len(global_orient.shape) + global_orient_R = axis_angle_to_matrix(global_orient) # (B, L, 3, 3) + transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3) + transl_vel = torch.cat([transl_vel, transl_vel[..., [-1], :]], dim=-2) # (B, L, 3) last-padding + + # v_local = R^T @ v_global + local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel) + return local_transl_vel + + +def rollout_local_transl_vel(local_transl_vel, global_orient, transl_0=None): + """ + transl velocity is in local coordinate (or, SMPL-coord) + Args: + local_transl_vel: (*, L, 3) + global_orient: (*, L, 3) + transl_0: (*, 1, 3), if not provided, the start point is 0 + Returns: + transl: (*, L, 3) + """ + global_orient_R = axis_angle_to_matrix(global_orient) + transl_vel = torch.einsum("...lij,...lj->...li", global_orient_R, local_transl_vel) + + # set start point + if transl_0 is None: + transl_0 = transl_vel[..., :1, :].clone().detach().zero_() + transl_ = torch.cat([transl_0, transl_vel[..., :-1, :]], dim=-2) + + # rollout from start point + transl = torch.cumsum(transl_, dim=-2) + return transl + + +def get_local_transl_vel_alignhead(transl, global_orient): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa) + return local_transl_vel_alignhead + + +def rollout_local_transl_vel_alignhead(local_transl_vel_alignhead, global_orient, transl_0=None): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0) + return transl + + +def get_local_transl_vel_alignhead_absy(transl, global_orient): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa) + abs_y = torch.cumsum(local_transl_vel_alignhead[..., [1]], dim=-2) # (*, L, 1) + local_transl_vel_alignhead_absy = torch.cat( + [local_transl_vel_alignhead[..., [0]], abs_y, local_transl_vel_alignhead[..., [2]]], dim=-1 + ) + + return local_transl_vel_alignhead_absy + + +def rollout_local_transl_vel_alignhead_absy(local_transl_vel_alignhead_absy, global_orient, transl_0=None): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + local_transl_vel_alignhead_y = ( + local_transl_vel_alignhead_absy[..., 1:, [1]] - local_transl_vel_alignhead_absy[..., :-1, [1]] + ) + local_transl_vel_alignhead_y = torch.cat( + [local_transl_vel_alignhead_absy[..., :1, [1]], local_transl_vel_alignhead_y], dim=-2 + ) + local_transl_vel_alignhead = torch.cat( + [ + local_transl_vel_alignhead_absy[..., [0]], + local_transl_vel_alignhead_y, + local_transl_vel_alignhead_absy[..., [2]], + ], + dim=-1, + ) + + transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0) + return transl + + +def get_local_transl_vel_alignhead_absgy(transl, global_orient): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa) + abs_y = transl[..., [1]] # (*, L, 1) + local_transl_vel_alignhead_absy = torch.cat( + [local_transl_vel_alignhead[..., [0]], abs_y, local_transl_vel_alignhead[..., [2]]], dim=-1 + ) + + return local_transl_vel_alignhead_absy + + +def rollout_local_transl_vel_alignhead_absgy(local_transl_vel_alignhead_absgy, global_orient, transl_0=None): + # assume global_orient is ay + global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3) + global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4) + + global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4) + head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4) + head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4) + head_rot = quaternion_to_matrix(head_quat) + head_aa = matrix_to_axis_angle(head_rot) + + local_transl_vel_alignhead_y = ( + local_transl_vel_alignhead_absgy[..., 1:, [1]] - local_transl_vel_alignhead_absgy[..., :-1, [1]] + ) + local_transl_vel_alignhead_y = torch.cat( + [local_transl_vel_alignhead_y, local_transl_vel_alignhead_y[..., -1:, :]], dim=-2 + ) + if transl_0 is not None: + transl_0 = transl_0.clone() + transl_0[..., 1] = local_transl_vel_alignhead_absgy[..., :1, 1] + else: + transl_0 = local_transl_vel_alignhead_absgy.clone()[..., :1, :] # (*, 1, 3) + transl_0[..., :1, 0] = 0.0 + transl_0[..., :1, 2] = 0.0 + + local_transl_vel_alignhead = torch.cat( + [ + local_transl_vel_alignhead_absgy[..., [0]], + local_transl_vel_alignhead_y, + local_transl_vel_alignhead_absgy[..., [2]], + ], + dim=-1, + ) + + transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0) + return transl + + +def rollout_vel(vel, transl_0=None): + """ + Args: + vel: (*, L, 3) + transl_0: (*, 1, 3), if not provided, the start point is 0 + Returns: + transl: (*, L, 3) + """ + # set start point + if transl_0 is None: + assert len(vel.shape) == len(transl_0.shape) + transl_0 = vel[..., :1, :].clone().detach().zero_() + transl_ = torch.cat([transl_0, vel[..., :-1, :]], dim=-2) + + # rollout from start point + transl = torch.cumsum(transl_, dim=-2) + return transl + + +def get_static_joint_mask(w_j3d, vel_thr=0.25, smooth=False, repeat_last=False): + """ + w_j3d: (*, L, J, 3) + vel_thr: HuMoR uses 0.15m/s + """ + joint_v_ = (w_j3d[..., 1:, :, :] - w_j3d[..., :-1, :, :]).pow(2).sum(-1).sqrt() / 0.033 # (*, L-1, J) + if smooth: + joint_v_ = gaussian_smooth(joint_v_, 3, -2) + + static_joint_mask = joint_v_ < vel_thr # 1 as stable, 0 as moving + + if repeat_last: # repeat the last frame, this makes the shape same as w_j3d + static_joint_mask = torch.cat([static_joint_mask, static_joint_mask[..., [-1], :]], dim=-2) + + return static_joint_mask diff --git a/hmr4d/utils/geo/quaternion.py b/hmr4d/utils/geo/quaternion.py new file mode 100644 index 0000000..2aaff90 --- /dev/null +++ b/hmr4d/utils/geo/quaternion.py @@ -0,0 +1,440 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +try: + _FLOAT_EPS = np.finfo(np.float).eps +except: + _FLOAT_EPS = np.finfo(float).eps + + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)" + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)" + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)" + return q / torch.clamp(torch.norm(q, dim=-1, keepdim=True), min=1e-8) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.reshape(-1, 4, 1), q.reshape(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == "xyz": + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == "yzx": + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == "zxy": + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == "xzy": + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == "yxz": + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == "zyx": + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180.0 + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == "x": + r = rx + elif coord == "y": + r = ry + elif coord == "z": + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ["xyz", "yzx", "zxy"]: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == "x": + r = rx + elif coord == "y": + r = ry + elif coord == "z": + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ["xyz", "yzx", "zxy"]: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + """q0 : tensor of quaternions + t: tensor of powers + """ + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., :1]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + mask = mask.float() + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0) + + if isinstance(t, torch.Tensor): + # Do not check here + q = torch.zeros(t.shape + q0.shape, device=q0.device) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape, device=q0.device) + theta = t * theta0 + + q[..., :1] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + """ + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + """ + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, q0) + + +def qbetween(v0, v1): + """ + find the quaternion used to rotate v0 to v1 + """ + assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)" + assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)" + + v = torch.cross(v0, v1, dim=-1) + + w = torch.sqrt((v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum( + dim=-1, keepdim=True + ) + y_vec = torch.zeros_like(v) + y_vec[..., 1] = 1.0 + # if v0 is (0, 0, -1), v1 is (0, 0, 1), v will be 0 and w will also be 0 -> this makes below situation comes v=1 w = 2 + mask = v.norm(dim=-1) == 0 + # if v0 is (0, 0, 1), v1 is (0, 0, 1), v will be 0 and w will be 2 -> do nothing + mask2 = w.sum(dim=-1).abs() <= 1e-4 + mask = torch.logical_and(mask, mask2) + v[mask] = y_vec[mask] + + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + """ + find the quaternion used to rotate v0 to v1 + """ + assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)" + assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)" + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/hmr4d/utils/geo/transforms.py b/hmr4d/utils/geo/transforms.py new file mode 100644 index 0000000..3f6ad21 --- /dev/null +++ b/hmr4d/utils/geo/transforms.py @@ -0,0 +1,25 @@ +import torch + + +def axis_rotate_to_matrix(angle, axis="x"): + """Get rotation matrix for rotating around one axis + Args: + angle: (N, 1) + Returns: + R: (N, 3, 3) + """ + if isinstance(angle, float): + angle = torch.tensor([angle], dtype=torch.float) + + c = torch.cos(angle) + s = torch.sin(angle) + z = torch.zeros_like(angle) + o = torch.ones_like(angle) + if axis == "x": + R = torch.stack([o, z, z, z, c, -s, z, s, c], dim=1).view(-1, 3, 3) + elif axis == "y": + R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) + else: + assert axis == "z" + R = torch.stack([c, -s, z, s, c, z, z, z, o], dim=1).view(-1, 3, 3) + return R diff --git a/hmr4d/utils/geo_transform.py b/hmr4d/utils/geo_transform.py new file mode 100644 index 0000000..028cc35 --- /dev/null +++ b/hmr4d/utils/geo_transform.py @@ -0,0 +1,673 @@ +import numpy as np +import cv2 +import torch +import torch.nn.functional as F +from pytorch3d.transforms import so3_exp_map, so3_log_map +from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_axis_angle, matrix_to_rotation_6d +import pytorch3d.ops.knn as knn +from hmr4d.utils.pylogger import Log +from pytorch3d.transforms import euler_angles_to_matrix +import hmr4d.utils.matrix as matrix +from einops import einsum, rearrange, repeat +from hmr4d.utils.geo.quaternion import qbetween + + +def homo_points(points): + """ + Args: + points: (..., C) + Returns: (..., C+1), with 1 padded + """ + return F.pad(points, [0, 1], value=1.0) + + +def apply_Ts_on_seq_points(points, Ts): + """ + perform translation matrix on related point + Args: + points: (..., N, 3) + Ts: (..., N, 4, 4) + Returns: (..., N, 3) + """ + points = torch.torch.einsum("...ki,...i->...k", Ts[..., :3, :3], points) + Ts[..., :3, 3] + return points + + +def apply_T_on_points(points, T): + """ + Args: + points: (..., N, 3) + T: (..., 4, 4) + Returns: (..., N, 3) + """ + points_T = torch.einsum("...ki,...ji->...jk", T[..., :3, :3], points) + T[..., None, :3, 3] + return points_T + + +def T_transforms_points(T, points, pattern): + """manual mode of apply_T_on_points + T: (..., 4, 4) + points: (..., 3) + pattern: "... c d, ... d -> ... c" + """ + return einsum(T, homo_points(points), pattern)[..., :3] + + +def project_p2d(points, K=None, is_pinhole=True): + """ + Args: + points: (..., (N), 3) + K: (..., 3, 3) + Returns: shape is similar to points but without z + """ + points = points.clone() + if is_pinhole: + z = points[..., [-1]] + z.masked_fill_(z.abs() < 1e-6, 1e-6) + points_proj = points / z + else: # orthogonal + points_proj = F.pad(points[..., :2], (0, 1), value=1) + + if K is not None: + # Handle N + if len(points_proj.shape) == len(K.shape): + p2d_h = torch.einsum("...ki,...ji->...jk", K, points_proj) + else: + p2d_h = torch.einsum("...ki,...i->...k", K, points_proj) + else: + p2d_h = points_proj[..., :2] + + return p2d_h[..., :2] + + +def gen_uv_from_HW(H, W, device="cpu"): + """Returns: (H, W, 2), as float. Note: uv not ij""" + grid_v, grid_u = torch.meshgrid(torch.arange(H), torch.arange(W)) + return ( + torch.stack( + [grid_u, grid_v], + dim=-1, + ) + .float() + .to(device) + ) # (H, W, 2) + + +def unproject_p2d(uv, z, K): + """we assume a pinhole camera for unprojection + uv: (B, N, 2) + z: (B, N, 1) + K: (B, 3, 3) + Returns: (B, N, 3) + """ + xy_atz1 = (uv - K[:, None, :2, 2]) / K[:, None, [0, 1], [0, 1]] # (B, N, 2) + xyz = torch.cat([xy_atz1 * z, z], dim=-1) + return xyz + + +def cvt_p2d_from_i_to_c(uv, K): + """ + Args: + uv: (..., 2) or (..., N, 2) + K: (..., 3, 3) + Returns: the same shape as input uv + """ + if len(uv.shape) == len(K.shape): + xy = (uv - K[..., None, :2, 2]) / K[..., None, [0, 1], [0, 1]] + else: # without N + xy = (uv - K[..., :2, 2]) / K[..., [0, 1], [0, 1]] + return xy + + +def cvt_to_bi01_p2d(p2d, bbx_lurb): + """ + p2d: (..., (N), 2) + bbx_lurb: (..., 4) + """ + if len(p2d.shape) == len(bbx_lurb.shape) + 1: + bbx_lurb = bbx_lurb[..., None, :] + + bbx_wh = bbx_lurb[..., 2:] - bbx_lurb[..., :2] + bi01_p2d = (p2d - bbx_lurb[..., :2]) / bbx_wh + return bi01_p2d + + +def cvt_from_bi01_p2d(bi01_p2d, bbx_lurb): + """Use bbx_lurb to resize bi01_p2d to p2d (image-coordinates) + Args: + p2d: (..., 2) or (..., N, 2) + bbx_lurb: (..., 4) + Returns: + p2d: shape is the same as input + """ + bbx_wh = bbx_lurb[..., 2:] - bbx_lurb[..., :2] # (..., 2) + if len(bi01_p2d.shape) == len(bbx_wh.shape) + 1: + p2d = (bi01_p2d * bbx_wh.unsqueeze(-2)) + bbx_lurb[..., None, :2] + else: + p2d = (bi01_p2d * bbx_wh) + bbx_lurb[..., :2] + return p2d + + +def cvt_p2d_from_bi01_to_c(bi01, bbxs_lurb, Ks): + """ + Args: + bi01: (..., (N), 2), value in range (0,1), the point in the bbx image + bbxs_lurb: (..., 4) + Ks: (..., 3, 3) + Returns: + c: (..., (N), 2) + """ + i = cvt_from_bi01_p2d(bi01, bbxs_lurb) + c = cvt_p2d_from_i_to_c(i, Ks) + return c + + +def cvt_p2d_from_pm1_to_i(p2d_pm1, bbx_xys): + """ + Args: + p2d_pm1: (..., (N), 2), value in range (-1,1), the point in the bbx image + bbx_xys: (..., 3) + Returns: + p2d: (..., (N), 2) + """ + return bbx_xys[..., :2] + p2d_pm1 * bbx_xys[..., [2]] / 2 + + +def uv2l_index(uv, W): + return uv[..., 0] + uv[..., 1] * W + + +def l2uv_index(l, W): + v = torch.div(l, W, rounding_mode="floor") + u = l % W + return torch.stack([u, v], dim=-1) + + +def transform_mat(R, t): + """ + Args: + R: Bx3x3 array of a batch of rotation matrices + t: Bx3x(1) array of a batch of translation vectors + Returns: + T: Bx4x4 Transformation matrix + """ + # No padding left or right, only add an extra row + if len(R.shape) > len(t.shape): + t = t[..., None] + return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=-1) + + +def axis_angle_to_matrix_exp_map(aa): + """use pytorch3d so3_exp_map + Args: + aa: (*, 3) + Returns: + R: (*, 3, 3) + """ + print("Use pytorch3d.transforms.axis_angle_to_matrix instead!!!") + ori_shape = aa.shape[:-1] + return so3_exp_map(aa.reshape(-1, 3)).reshape(*ori_shape, 3, 3) + + +def matrix_to_axis_angle_log_map(R): + """use pytorch3d so3_log_map + Args: + aa: (*, 3, 3) + Returns: + R: (*, 3) + """ + print("WARINING! I met singularity problem with this function, use matrix_to_axis_angle instead!") + ori_shape = R.shape[:-2] + return so3_log_map(R.reshape(-1, 3, 3)).reshape(*ori_shape, 3) + + +def matrix_to_axis_angle(R): + """use pytorch3d so3_log_map + Args: + aa: (*, 3, 3) + Returns: + R: (*, 3) + """ + return quaternion_to_axis_angle(matrix_to_quaternion(R)) + + +def ransac_PnP(K, pts_2d, pts_3d, err_thr=10): + """solve pnp""" + dist_coeffs = np.zeros(shape=[8, 1], dtype="float64") + + pts_2d = np.ascontiguousarray(pts_2d.astype(np.float64)) + pts_3d = np.ascontiguousarray(pts_3d.astype(np.float64)) + K = K.astype(np.float64) + + try: + _, rvec, tvec, inliers = cv2.solvePnPRansac( + pts_3d, pts_2d, K, dist_coeffs, reprojectionError=err_thr, iterationsCount=10000, flags=cv2.SOLVEPNP_EPNP + ) + + rotation = cv2.Rodrigues(rvec)[0] + + pose = np.concatenate([rotation, tvec], axis=-1) + pose_homo = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) + + inliers = [] if inliers is None else inliers + + return pose, pose_homo, inliers + except cv2.error: + print("CV ERROR") + return np.eye(4)[:3], np.eye(4), [] + + +def ransac_PnP_batch(K_raw, pts_2d, pts_3d, err_thr=10): + fit_R, fit_t = [], [] + for b in range(K_raw.shape[0]): + pose, _, inliers = ransac_PnP(K_raw[b], pts_2d[b], pts_3d[b], err_thr=err_thr) + fit_R.append(pose[:3, :3]) + fit_t.append(pose[:3, 3]) + fit_R = np.stack(fit_R, axis=0) + fit_t = np.stack(fit_t, axis=0) + return fit_R, fit_t + + +def triangulate_point(Ts_w2c, c_p2d, **kwargs): + from hmr4d.utils.geo.triangulation import triangulate_persp + + print("Deprecated, please import from hmr4d.utils.geo.triangulation") + return triangulate_persp(Ts_w2c, c_p2d, **kwargs) + + +def triangulate_point_ortho(Ts_w2c, c_p2d, **kwargs): + from hmr4d.utils.geo.triangulation import triangulate_ortho + + print("Deprecated, please import from hmr4d.utils.geo.triangulation") + return triangulate_ortho(Ts_w2c, c_p2d, **kwargs) + + +def get_nearby_points(points, query_verts, padding=0.0, p=1): + """ + points: (S, 3) + query_verts: (V, 3) + """ + if p == 1: + max_xyz = query_verts.max(0)[0] + padding + min_xyz = query_verts.min(0)[0] - padding + idx = (((points - min_xyz) > 0).all(dim=-1) * ((points - max_xyz) < 0).all(dim=-1)).nonzero().squeeze(-1) + nearby_points = points[idx] + elif p == 2: + squared_dist, _, _ = knn.knn_points(points[None], query_verts[None], K=1, return_nn=False) + mask = squared_dist[0, :, 0] < padding**2 # (S,) + nearby_points = points[mask] + + return nearby_points + + +def unproj_bbx_to_fst(bbx_lurb, K, near_z=0.5, far_z=12.5): + B = bbx_lurb.size(0) + uv = bbx_lurb[:, [[0, 1], [2, 1], [2, 3], [0, 3], [0, 1], [2, 1], [2, 3], [0, 3]]] + if isinstance(near_z, float): + z = uv.new([near_z] * 4 + [far_z] * 4).reshape(1, 8, 1).repeat(B, 1, 1) + else: + z = torch.cat([near_z[:, None, None].repeat(1, 4, 1), far_z[:, None, None].repeat(1, 4, 1)], dim=1) + c_frustum_points = unproject_p2d(uv, z, K) # (B, 8, 3) + return c_frustum_points + + +def convert_bbx_xys_to_lurb(bbx_xys): + """ + Args: bbx_xys (..., 3) -> bbx_lurb (..., 4) + """ + size = bbx_xys[..., 2:] + center = bbx_xys[..., :2] + lurb = torch.cat([center - size / 2, center + size / 2], dim=-1) + return lurb + + +def convert_lurb_to_bbx_xys(bbx_lurb): + """ + Args: bbx_lurb (..., 4) -> bbx_xys (..., 3) be aware that it is squared + """ + size = (bbx_lurb[..., 2:] - bbx_lurb[..., :2]).max(-1, keepdim=True)[0] + center = (bbx_lurb[..., :2] + bbx_lurb[..., 2:]) / 2 + return torch.cat([center, size], dim=-1) + + +# ================== AZ/AY Transformations ================== # + + +def compute_T_ayf2az(joints, inverse=False): + """ + Args: + joints: (B, J, 3), in the start-frame, az-coordinate + Returns: + if inverse == False: + T_af2az: (B, 4, 4) + else : + T_az2af: (B, 4, 4) + """ + + t_ayf2az = joints[:, 0, :].detach().clone() + t_ayf2az[:, 2] = 0 # do not modify z + + RL_xy_h = joints[:, 1, [0, 1]] - joints[:, 2, [0, 1]] # (B, 2), hip point to left side + RL_xy_s = joints[:, 16, [0, 1]] - joints[:, 17, [0, 1]] # (B, 2), shoulder point to left side + RL_xy = RL_xy_h + RL_xy_s + I_mask = RL_xy.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction + if I_mask.sum() > 0: + Log.warn("{} samples can't decide the face direction".format(I_mask.sum())) + x_dir = F.pad(F.normalize(RL_xy, 2, -1), (0, 1), value=0) # (B, 3) + y_dir = torch.zeros_like(x_dir) + y_dir[..., 2] = 1 + z_dir = torch.cross(x_dir, y_dir, dim=-1) + R_ayf2az = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3) + R_ayf2az[I_mask] = torch.eye(3).to(R_ayf2az) + + if inverse: + R_az2ayf = R_ayf2az.transpose(1, 2) # (B, 3, 3) + t_az2ayf = -einsum(R_ayf2az, t_ayf2az, "b i j , b i -> b j") # (B, 3) + return transform_mat(R_az2ayf, t_az2ayf) + else: + return transform_mat(R_ayf2az, t_ayf2az) + + +def compute_T_ayfz2ay(joints, inverse=False): + """ + Args: + joints: (B, J, 3), in the start-frame, ay-coordinate + Returns: + if inverse == False: + T_ayfz2ay: (B, 4, 4) + else : + T_ay2ayfz: (B, 4, 4) + """ + t_ayfz2ay = joints[:, 0, :].detach().clone() + t_ayfz2ay[:, 1] = 0 # do not modify y + + RL_xz_h = joints[:, 1, [0, 2]] - joints[:, 2, [0, 2]] # (B, 2), hip point to left side + RL_xz_s = joints[:, 16, [0, 2]] - joints[:, 17, [0, 2]] # (B, 2), shoulder point to left side + RL_xz = RL_xz_h + RL_xz_s + I_mask = RL_xz.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction + if I_mask.sum() > 0: + Log.warn("{} samples can't decide the face direction".format(I_mask.sum())) + + x_dir = torch.zeros_like(t_ayfz2ay) # (B, 3) + x_dir[:, [0, 2]] = F.normalize(RL_xz, 2, -1) + y_dir = torch.zeros_like(x_dir) + y_dir[..., 1] = 1 # (B, 3) + z_dir = torch.cross(x_dir, y_dir, dim=-1) + R_ayfz2ay = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3) + R_ayfz2ay[I_mask] = torch.eye(3).to(R_ayfz2ay) + + if inverse: + R_ay2ayfz = R_ayfz2ay.transpose(1, 2) + t_ay2ayfz = -einsum(R_ayfz2ay, t_ayfz2ay, "b i j , b i -> b j") + return transform_mat(R_ay2ayfz, t_ay2ayfz) + else: + return transform_mat(R_ayfz2ay, t_ayfz2ay) + + +def compute_T_ay2ayrot(joints): + """ + Args: + joints: (B, J, 3), in the start-frame, ay-coordinate + Returns: + T_ay2ayrot: (B, 4, 4) + """ + t_ayrot2ay = joints[:, 0, :].detach().clone() + t_ayrot2ay[:, 1] = 0 # do not modify y + + B = joints.shape[0] + euler_angle = torch.zeros((B, 3), device=joints.device) + yrot_angle = torch.rand((B,), device=joints.device) * 2 * torch.pi + euler_angle[:, 0] = yrot_angle + R_ay2ayrot = euler_angles_to_matrix(euler_angle, "YXZ") # (B, 3, 3) + + R_ayrot2ay = R_ay2ayrot.transpose(1, 2) + t_ay2ayrot = -einsum(R_ayrot2ay, t_ayrot2ay, "b i j , b i -> b j") + return transform_mat(R_ay2ayrot, t_ay2ayrot) + + +def compute_root_quaternion_ay(joints): + """ + Args: + joints: (B, J, 3), in the start-frame, ay-coordinate + Returns: + root_quat: (B, 4) from z-axis to fz + """ + joints_shape = joints.shape + joints = joints.reshape((-1,) + joints_shape[-2:]) + t_ayfz2ay = joints[:, 0, :].detach().clone() + t_ayfz2ay[:, 1] = 0 # do not modify y + + RL_xz_h = joints[:, 1, [0, 2]] - joints[:, 2, [0, 2]] # (B, 2), hip point to left side + RL_xz_s = joints[:, 16, [0, 2]] - joints[:, 17, [0, 2]] # (B, 2), shoulder point to left side + RL_xz = RL_xz_h + RL_xz_s + I_mask = RL_xz.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction + if I_mask.sum() > 0: + Log.warn("{} samples can't decide the face direction".format(I_mask.sum())) + + x_dir = torch.zeros_like(t_ayfz2ay) # (B, 3) + x_dir[:, [0, 2]] = F.normalize(RL_xz, 2, -1) + y_dir = torch.zeros_like(x_dir) + y_dir[..., 1] = 1 # (B, 3) + z_dir = torch.cross(x_dir, y_dir, dim=-1) + + z_dir[..., 2] += 1e-9 + pos_z_vec = torch.tensor([0, 0, 1]).to(joints.device).float() # (3,) + root_quat = qbetween(pos_z_vec[None], z_dir) # (B, 4) + root_quat = root_quat.reshape(joints_shape[:-2] + (4,)) + return root_quat + + +# ================== Transformations between two sets of features ================== # + + +def similarity_transform_batch(S1, S2): + """ + Computes a similarity transform (sR, t) that solves the orthogonal Procrutes problem. + Args: + S1, S2: (*, L, 3) + """ + assert S1.shape == S2.shape + S_shape = S1.shape + S1 = S1.reshape(-1, *S_shape[-2:]) + S2 = S2.reshape(-1, *S_shape[-2:]) + + S1 = S1.transpose(-2, -1) + S2 = S2.transpose(-2, -1) + + # --- The code is borrowed from WHAM --- + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) # axis is along N, S1(B, 3, N) + mu2 = S2.mean(axis=-1, keepdims=True) + + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1**2, dim=1).sum(dim=1) + + # 3. The outer product of X1 and X2. + K = X1.bmm(X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) + Z = Z.repeat(U.shape[0], 1, 1) + Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) + + # Construct R. + R = V.bmm(Z.bmm(U.permute(0, 2, 1))) + + # 5. Recover scale. + scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 + + # 6. Recover translation. + t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) + + # ------- + # reshape back + # sR = scale[:, None, None] * R + # sR = sR.reshape(*S_shape[:-2], 3, 3) + scale = scale.reshape(*S_shape[:-2], 1, 1) + R = R.reshape(*S_shape[:-2], 3, 3) + t = t.reshape(*S_shape[:-2], 3, 1) + + return (scale, R), t + + +def kabsch_algorithm_batch(X1, X2): + """ + Computes a rigid transform (R, t) + Args: + X1, X2: (*, L, 3) + """ + assert X1.shape == X2.shape + X_shape = X1.shape + X1 = X1.reshape(-1, *X_shape[-2:]) + X2 = X2.reshape(-1, *X_shape[-2:]) + + # 1. 计算质心 + centroid_X1 = torch.mean(X1, dim=-2, keepdim=True) + centroid_X2 = torch.mean(X2, dim=-2, keepdim=True) + + # 2. 去中心化 + X1_centered = X1 - centroid_X1 + X2_centered = X2 - centroid_X2 + + # 3. 计算协方差矩阵 + H = torch.matmul(X1_centered.transpose(-2, -1), X2_centered) + + # 4. 奇异值分解 + U, S, Vt = torch.linalg.svd(H) + + # 5. 计算旋转矩阵 + R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1)) + + # 修正反射矩阵 + d = (torch.det(R) < 0).unsqueeze(-1).unsqueeze(-1) + Vt = torch.where(d, -Vt, Vt) + R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1)) + + # 6. 计算平移向量 + t = centroid_X2.transpose(-2, -1) - torch.matmul(R, centroid_X1.transpose(-2, -1)) + + # ------- + # reshape back + R = R.reshape(*X_shape[:-2], 3, 3) + t = t.reshape(*X_shape[:-2], 3, 1) + + return R, t + + +# ===== WHAM cam_angvel ===== # + + +def compute_cam_angvel(R_w2c, padding_last=True): + """ + R_w2c : (F, 3, 3) + """ + # R @ R0 = R1, so R = R1 @ R0^T + cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6) + # cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS + assert padding_last + cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6) + return cam_angvel.float() + + +def ransac_gravity_vec(xyz, num_iterations=100, threshold=0.05, verbose=False): + # xyz: (L, 3) + N = xyz.shape[0] + max_inliers = [] + best_model = None + norms = xyz.norm(dim=-1) # (L,) + + for _ in range(num_iterations): + # 随机选择一个样本 + sample_index = np.random.randint(N) + sample = xyz[sample_index] # (3,) + + # 计算所有点与样本点的角度差 + dot_product = (xyz * sample).sum(dim=-1) # (L,) + angles = dot_product / norms * norms[sample_index] # (L,) + angles = torch.clamp(angles, -1, 1) # 防止数值误差导致的异常 + angles = torch.acos(angles) + + # 确定内点 + inliers = xyz[angles < threshold] + + if len(inliers) > len(max_inliers): + max_inliers = inliers + best_model = sample + if len(max_inliers) == N: + break + if verbose: + print(f"Inliers: {len(max_inliers)} / {N}") + result = max_inliers.mean(dim=0) + + return result, max_inliers + + +def sequence_best_cammat(w_j3d, c_j3d, cam_rot): + # get best camera estimation along the sequence, requires static camera + # w_j3d: (L, J, 3) + # c_j3d: (L, J, 3) + # cam_rot: (L, 3, 3) + + L, J, _ = w_j3d.shape + + root_in_w = w_j3d[:, 0] # (L, 3) + root_in_c = c_j3d[:, 0] # (L, 3) + cam_mat = matrix.get_TRS(cam_rot, root_in_w) # (L, 4, 4) + cam_pos = matrix.get_position_from(-root_in_c[:, None], cam_mat)[:, 0] # (L, 3) + cam_mat = matrix.set_position(cam_mat, cam_pos) # (L, 4, 4) + + w_j3d_expand = w_j3d[None].expand(L, -1, -1, -1) # (L, L, J, 3) + w_j3d_expand = w_j3d_expand.reshape(L, -1, 3) # (L, L*J, 3) + + # get reproject error + w_j3d_expand_in_c = matrix.get_relative_position_to(w_j3d_expand, cam_mat) # (L, L*J, 3) + w_j2d_expand_in_c = project_p2d(w_j3d_expand_in_c) # (L, L*J, 2) + w_j2d_expand_in_c = w_j2d_expand_in_c.reshape(L, L, J, 2) # (L, L, J, 2) + c_j2d = project_p2d(c_j3d) # (L, J, 2) + error = w_j2d_expand_in_c - c_j2d[None] # (L, L, J, 2) + error = error.norm(dim=-1).mean(dim=-1) # (L, L) + error = error.mean(dim=-1) # (L,) + ind = error.argmin() + return cam_mat[ind], ind + + +def get_sequence_cammat(w_j3d, c_j3d, cam_rot): + # w_j3d: (L, J, 3) + # c_j3d: (L, J, 3) + # cam_rot: (L, 3, 3) + + L, J, _ = w_j3d.shape + + root_in_w = w_j3d[:, 0] # (L, 3) + root_in_c = c_j3d[:, 0] # (L, 3) + cam_mat = matrix.get_TRS(cam_rot, root_in_w) # (L, 4, 4) + cam_pos = matrix.get_position_from(-root_in_c[:, None], cam_mat)[:, 0] # (L, 3) + cam_mat = matrix.set_position(cam_mat, cam_pos) # (L, 4, 4) + return cam_mat + + +def ransac_vec(vel, min_multiply=20, verbose=False): + # xyz: (L, 3) + # remove outlier velocity + N = vel.shape[0] + vel_1 = vel[None].expand(N, -1, -1) # (L, L, 3) + vel_2 = vel[:, None].expand(-1, N, -1) # (L, L, 3) + dist_mat = (vel_1 - vel_2).norm(dim=-1) # (L, L) + big_identity = torch.eye(N, device=vel.device) * 1e6 + dist_mat_ = dist_mat + big_identity + threshold = dist_mat_.min() * min_multiply + inner_mask = dist_mat < threshold # (L, L) + inner_num = inner_mask.sum(dim=-1) # (L, ) + ind = inner_num.argmax() + result = vel[inner_mask[ind]].mean(dim=0) # (3,) + if verbose: + print(inner_mask[ind].sum().item()) + + return result, inner_mask[ind] diff --git a/hmr4d/utils/ik/ccd_ik.py b/hmr4d/utils/ik/ccd_ik.py new file mode 100644 index 0000000..9bc606c --- /dev/null +++ b/hmr4d/utils/ik/ccd_ik.py @@ -0,0 +1,149 @@ +# Sebastian IK +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import einsum, rearrange, repeat + +from pytorch3d.transforms import ( + matrix_to_rotation_6d, + rotation_6d_to_matrix, + axis_angle_to_matrix, + matrix_to_axis_angle, + quaternion_to_matrix, + matrix_to_quaternion, +) +import hmr4d.utils.matrix as matrix +from hmr4d.utils.geo.quaternion import qbetween, qslerp, qinv, qmul, qrot + + +class CCD_IK: + def __init__( + self, + local_mat, + parent, + target_ind, + target_pos=None, + target_rot=None, + kinematic_chain=None, + max_iter=2, # sebas sets 25 but with converged flag, 2 is enough + threshold=0.001, + pos_weight=1.0, + rot_weight=0.0, # this makes optimization unstable, although sebas uses 1.0 + ): + if kinematic_chain is None: + kinematic_chain = range(local_mat.shape[-3]) + global_mat = matrix.forward_kinematics(local_mat, parent) + + # get kinematic chain only local mat and assign root mat (do not modify root during IK) + local_mat = local_mat.clone() + local_mat = local_mat[..., kinematic_chain, :, :] + local_mat[..., 0, :, :] = global_mat[..., kinematic_chain[0], :, :] + + parent = [i - 1 for i in range(len(kinematic_chain))] + self.local_mat = local_mat + self.global_mat = matrix.forward_kinematics(local_mat, parent) # (*, J, 4, 4) + self.parent = parent + + self.target_ind = target_ind + if target_pos is not None: + self.target_pos = target_pos # (*, O, 3) + else: + self.target_pos = None + if target_rot is not None: + self.target_q = matrix_to_quaternion(target_rot) # (*, O, 4) + else: + self.target_q = None + + self.threshold = threshold + self.J_N = self.local_mat.shape[-3] + self.target_N = len(target_ind) + self.max_iter = max_iter + self.pos_weight = pos_weight + self.rot_weight = rot_weight + + def is_converged(self): + end_pos = matrix.get_position(self.global_mat)[..., self.target_ind, :] # (*, OJ, 3) + converged_mask = (self.target_pos - end_pos).norm(dim=-1) < self.threshold + self.converged_mask = converged_mask + if self.converged_mask.sum() > 0: + return False + return True + + def solve(self): + for _ in range(self.max_iter): + # if self.is_converged(): + # return self.local_mat + # do not optimize root, so start from 1 + self.optimize(1) + return self.local_mat + + def optimize(self, i): + # i: joint_i + if i == self.J_N - 1: + return + pos = matrix.get_position(self.global_mat)[..., i, :] # (*, 3) + rot = matrix.get_rotation(self.global_mat)[..., i, :, :] # (*, 3, 3) + quat = matrix_to_quaternion(rot) # (*, 4) + x_vec = torch.zeros((quat.shape[:-1] + (3,)), device=quat.device) + x_vec[..., 0] = 1.0 + x_vec_sum = torch.zeros_like(x_vec) + y_vec = torch.zeros((quat.shape[:-1] + (3,)), device=quat.device) + y_vec[..., 1] = 1.0 + y_vec_sum = torch.zeros_like(y_vec) + + count = 0 + + for target_i, j in enumerate(self.target_ind): + if i >= j: + # do not optimise same joint or child joint of targets + continue + end_pos = matrix.get_position(self.global_mat)[..., j, :] # (*, 3) + end_rot = matrix.get_rotation(self.global_mat)[..., j, :, :] # (*, 3, 3) + end_quat = matrix_to_quaternion(end_rot) # (*, 4) + + if self.target_pos is not None: + target_pos = self.target_pos[..., target_i, :] # (*, 3) + # Solve objective position + solved_pos_target_quat = qslerp( + quat, + qmul(qbetween(end_pos - pos, target_pos - pos), quat), + self.get_weight(i), + ) + + x_vec_sum += qrot(solved_pos_target_quat, x_vec) + y_vec_sum += qrot(solved_pos_target_quat, y_vec) + if self.pos_weight > 0: + count += 1 + + if self.target_q is not None: + if target_i < self.target_N - 1: + # multiple rot target makes more unstable, only keep the last one + continue + # optimize rotation target is not stable + target_q = self.target_q[..., target_i, :] # (*, 4) + # Solve objective rotation + solved_q_target_quat = qslerp( + quat, + qmul(qmul(target_q, qinv(end_quat)), quat), + self.get_weight(i), + ) + x_vec_sum += qrot(solved_q_target_quat, x_vec) * self.rot_weight + y_vec_sum += qrot(solved_q_target_quat, y_vec) * self.rot_weight + if self.rot_weight > 0: + count += 1 + + if count > 0: + x_vec_avg = matrix.normalize(x_vec_sum / count) + y_vec_avg = matrix.normalize(y_vec_sum / count) + z_vec_avg = torch.cross(x_vec_avg, y_vec_avg, dim=-1) + solved_rot = torch.stack([x_vec_avg, y_vec_avg, z_vec_avg], dim=-1) # column + + parent_rot = matrix.get_rotation(self.global_mat)[..., self.parent[i], :, :] + solved_local_rot = matrix.get_mat_BtoA(parent_rot, solved_rot) + self.local_mat[..., i, :-1, :-1] = solved_local_rot + self.global_mat = matrix.forward_kinematics(self.local_mat, self.parent) + self.optimize(i + 1) + + def get_weight(self, i): + weight = (i + 1) / self.J_N + return weight diff --git a/hmr4d/utils/kpts/kp2d_utils.py b/hmr4d/utils/kpts/kp2d_utils.py new file mode 100644 index 0000000..b08be88 --- /dev/null +++ b/hmr4d/utils/kpts/kp2d_utils.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +import cv2 +import numpy as np + +# expose _taylor to outside +__all__ = ["keypoints_from_heatmaps"] + + +def _taylor(heatmap, coord): + """Distribution aware coordinate decoding method. + + Note: + - heatmap height: H + - heatmap width: W + + Args: + heatmap (np.ndarray[H, W]): Heatmap of a particular joint type. + coord (np.ndarray[2,]): Coordinates of the predicted keypoints. + + Returns: + np.ndarray[2,]: Updated coordinates. + """ + H, W = heatmap.shape[:2] + px, py = int(coord[0]), int(coord[1]) + if 1 < px < W - 2 and 1 < py < H - 2: + dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) + dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) + dxx = 0.25 * (heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2]) + dxy = 0.25 * ( + heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1] + ) + dyy = 0.25 * (heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + heatmap[py - 2 * 1][px]) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + +def _get_max_preds(heatmaps): + """Get keypoint predictions from score maps. + + Note: + batch_size: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 4, "batch_images should be 4-ndim" + + N, K, _, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + preds[:, :, 0] = preds[:, :, 0] % W + preds[:, :, 1] = preds[:, :, 1] // W + + preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) + return preds, maxvals + + +def post_dark_udp(coords, batch_heatmaps, kernel=3): + """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The + Devil is in the Details: Delving into Unbiased Data Processing for Human + Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + + Note: + - batch size: B + - num keypoints: K + - num persons: N + - height of heatmaps: H + - width of heatmaps: W + + B=1 for bottom_up paradigm where all persons share the same heatmap. + B=N for top_down paradigm where each person has its own heatmaps. + + Args: + coords (np.ndarray[N, K, 2]): Initial coordinates of human pose. + batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps + kernel (int): Gaussian kernel size (K) for modulation. + + Returns: + np.ndarray([N, K, 2]): Refined coordinates. + """ + if not isinstance(batch_heatmaps, np.ndarray): + batch_heatmaps = batch_heatmaps.cpu().numpy() + B, K, H, W = batch_heatmaps.shape + N = coords.shape[0] + assert B == 1 or B == N + for heatmaps in batch_heatmaps: + for heatmap in heatmaps: + cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap) + np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps) + np.log(batch_heatmaps, batch_heatmaps) + + batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten() + + index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K) + index = index.astype(int).reshape(-1, 1) + i_ = batch_heatmaps_pad[index] + ix1 = batch_heatmaps_pad[index + 1] + iy1 = batch_heatmaps_pad[index + W + 2] + ix1y1 = batch_heatmaps_pad[index + W + 3] + ix1_y1_ = batch_heatmaps_pad[index - W - 3] + ix1_ = batch_heatmaps_pad[index - 1] + iy1_ = batch_heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(N, K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(N, K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze() + return coords + + +def _gaussian_blur(heatmaps, kernel=11): + """Modulate heatmap distribution with Gaussian. + sigma = 0.3*((kernel_size-1)*0.5-1)+0.8 + sigma~=3 if k=17 + sigma=2 if k=11; + sigma~=1.5 if k=7; + sigma~=1 if k=3; + + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([N, K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + height = heatmaps.shape[2] + width = heatmaps.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmaps[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[i, j] = dr[border:-border, border:-border].copy() + heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j]) + return heatmaps + + +def keypoints_from_heatmaps( + heatmaps, + center, + scale, + unbiased=False, + post_process="default", + kernel=11, + valid_radius_factor=0.0546875, + use_udp=False, + target_type="GaussianHeatmap", +): + """Get final keypoint predictions from heatmaps and transform them back to + the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + post_process (str/None): Choice of methods to post-process + heatmaps. Currently supported: None, 'default', 'unbiased', + 'megvii'. + unbiased (bool): Option to use unbiased decoding. Mutually + exclusive with megvii. + Note: this arg is deprecated and unbiased=True can be replaced + by post_process='unbiased' + Paper ref: Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + valid_radius_factor (float): The radius factor of the positive area + in classification heatmap for UDP. + use_udp (bool): Use unbiased data processing. + target_type (str): 'GaussianHeatmap' or 'CombinedTarget'. + GaussianHeatmap: Classification target with gaussian distribution. + CombinedTarget: The combination of classification target + (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into + Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + # Avoid being affected + heatmaps = heatmaps.copy() + + # detect conflicts + if unbiased: + assert post_process not in [False, None, "megvii"] + if post_process in ["megvii", "unbiased"]: + assert kernel > 0 + if use_udp: + assert not post_process == "megvii" + + # normalize configs + if post_process is False: + warnings.warn("post_process=False is deprecated, " "please use post_process=None instead", DeprecationWarning) + post_process = None + elif post_process is True: + if unbiased is True: + warnings.warn( + "post_process=True, unbiased=True is deprecated," " please use post_process='unbiased' instead", + DeprecationWarning, + ) + post_process = "unbiased" + else: + warnings.warn( + "post_process=True, unbiased=False is deprecated, " "please use post_process='default' instead", + DeprecationWarning, + ) + post_process = "default" + elif post_process == "default": + if unbiased is True: + warnings.warn( + "unbiased=True is deprecated, please use " "post_process='unbiased' instead", DeprecationWarning + ) + post_process = "unbiased" + + # start processing + if post_process == "megvii": + heatmaps = _gaussian_blur(heatmaps, kernel=kernel) + + N, K, H, W = heatmaps.shape + if use_udp: + if target_type.lower() == "GaussianHeatMap".lower(): + preds, maxvals = _get_max_preds(heatmaps) + preds = post_dark_udp(preds, heatmaps, kernel=kernel) + elif target_type.lower() == "CombinedTarget".lower(): + for person_heatmaps in heatmaps: + for i, heatmap in enumerate(person_heatmaps): + kt = 2 * kernel + 1 if i % 3 == 0 else kernel + cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap) + # valid radius is in direct proportion to the height of heatmap. + valid_radius = valid_radius_factor * H + offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius + offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius + heatmaps = heatmaps[:, ::3, :] + preds, maxvals = _get_max_preds(heatmaps) + index = preds[..., 0] + preds[..., 1] * W + index += W * H * np.arange(0, N * K / 3) + index = index.astype(int).reshape(N, K // 3, 1) + preds += np.concatenate((offset_x[index], offset_y[index]), axis=2) + else: + raise ValueError("target_type should be either " "'GaussianHeatmap' or 'CombinedTarget'") + else: + preds, maxvals = _get_max_preds(heatmaps) + if post_process == "unbiased": # alleviate biased coordinate + # apply Gaussian distribution modulation. + heatmaps = np.log(np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10)) + for n in range(N): + for k in range(K): + preds[n][k] = _taylor(heatmaps[n][k], preds[n][k]) + elif post_process is not None: + # add +/-0.25 shift to the predicted locations for higher acc. + for n in range(N): + for k in range(K): + heatmap = heatmaps[n][k] + px = int(preds[n][k][0]) + py = int(preds[n][k][1]) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array( + [heatmap[py][px + 1] - heatmap[py][px - 1], heatmap[py + 1][px] - heatmap[py - 1][px]] + ) + preds[n][k] += np.sign(diff) * 0.25 + if post_process == "megvii": + preds[n][k] += 0.5 + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds(preds[i], center[i], scale[i], [W, H], use_udp=use_udp) + + if post_process == "megvii": + maxvals = maxvals / 255.0 + 0.5 + + return preds, maxvals + + +def transform_preds(coords, center, scale, output_size, use_udp=False): + """Get final keypoint predictions from heatmaps and apply scaling and + translation to map them back to the image. + + Note: + num_keypoints: K + + Args: + coords (np.ndarray[K, ndims]): + + * If ndims=2, corrds are predicted keypoint location. + * If ndims=4, corrds are composed of (x, y, scores, tags) + * If ndims=5, corrds are composed of (x, y, scores, tags, + flipped_tags) + + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + use_udp (bool): Use unbiased data processing + + Returns: + np.ndarray: Predicted coordinates in the images. + """ + assert coords.shape[1] in (2, 4, 5) + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + # Recover the scale which is normalized by a factor of 200. + scale = scale * 200.0 + + if use_udp: + scale_x = scale[0] / (output_size[0] - 1.0) + scale_y = scale[1] / (output_size[1] - 1.0) + else: + scale_x = scale[0] / output_size[0] + scale_y = scale[1] / output_size[1] + + target_coords = np.ones_like(coords) + target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 + target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 + + return target_coords diff --git a/hmr4d/utils/matrix.py b/hmr4d/utils/matrix.py new file mode 100644 index 0000000..ce9b1e8 --- /dev/null +++ b/hmr4d/utils/matrix.py @@ -0,0 +1,1677 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from typing import List, Optional + +import numpy as np + +import math + + +def identity_mat(x=None, device="cpu", is_numpy=False): + if x is not None: + if isinstance(x, torch.Tensor): + mat = torch.eye(4, device=device) + mat = mat.repeat(x.shape[:-2] + (1, 1)) + elif isinstance(x, np.ndarray): + mat = np.eye(4, dtype=np.float32) + if x is not None: + for _ in range(len(x.shape) - 2): + mat = mat[None] + mat = np.tile(mat, x.shape[:-2] + (1, 1)) + else: + raise ValueError + else: + # (4, 4) + if is_numpy: + mat = np.eye(4, dtype=np.float32) + else: + mat = torch.eye(4, device=device) + + return mat + + +def vec2mat(vec): + """_summary_ + + Args: + vec (tensor): [12], pos, forward, up and right + + Returns: + mat_world(tensor): [4, 4] + """ + # Assume bs = 1 + v = np.tile(np.array([[0, 0, 0, 1]]), (1, 1)) + if isinstance(vec, torch.Tensor): + v = torch.tensor( + v, + device=vec.device, + dtype=vec.dtype, + ) + pos = vec[:3] + forward = vec[3:6] + up = vec[6:9] + right = vec[9:12] + + if isinstance(vec, torch.Tensor): + mat_world = torch.stack([right, up, forward, pos], dim=-1) + mat_world = torch.cat([mat_world, v], dim=-2) + elif isinstance(vec, np.ndarray): + mat_world = np.stack([right, up, forward, pos], axis=-1) + mat_world = np.concatenate([mat_world, v], axis=-2) + else: + raise ValueError + mat_world = normalized_matrix(mat_world) + return mat_world + + +def mat2vec(mat): + """_summary_ + + Args: + mat(tensor): [4, 4] + + Returns: + vec (tensor): [12], pos, forward, up and right + """ + # Assume bs = 1 + pos = mat[:-1, 3] + forward = normalized(mat[:-1, 2]) + up = normalized(mat[:-1, 1]) + right = normalized(mat[:-1, 0]) + if isinstance(mat, torch.Tensor): + vec = torch.cat((pos, forward, up, right)) + elif isinstance(mat, np.ndarray): + vec = np.concatenate((pos, forward, up, right)) + else: + raise ValueError + + return vec + + +def vec2mat_batch(vec): + """_summary_ + + Args: + vec (tensor): [B, 12], pos, forward, up and right + + Returns: + mat_world(tensor): [B, 4, 4] + """ + # Assume bs = 1 + + v = np.tile(np.array([[0, 0, 0, 1]], dtype=np.float32), (vec.shape[0], 1, 1)) + if isinstance(vec, torch.Tensor): + v = torch.tensor( + v, + device=vec.device, + dtype=vec.dtype, + ) + pos = vec[..., :3] + forward = vec[..., 3:6] + up = vec[..., 6:9] + right = vec[..., 9:12] + if isinstance(vec, torch.Tensor): + mat_world = torch.stack([right, up, forward, pos], dim=-1) + mat_world = torch.cat([mat_world, v], dim=-2) + elif isinstance(vec, np.ndarray): + mat_world = np.stack([right, up, forward, pos], axis=-1) + mat_world = np.concatenate([mat_world, v], axis=-2) + else: + raise ValueError + + mat_world = normalized_matrix(mat_world) + return mat_world + + +def rotmat2tan_norm(mat): + """_summary_ + + Args: + mat(tensor): [B, 3, 3] + + Returns: + vec (tensor): [B, 6], tan norm + """ + if isinstance(mat, np.ndarray): + tan = np.zeros_like(mat[..., 2]) + norm = np.zeros_like(mat[..., 0]) + elif isinstance(mat, torch.Tensor): + tan = torch.zeros_like(mat[..., 2]) + norm = torch.zeros_like(mat[..., 0]) + else: + raise ValueError + tan[...] = mat[..., 2, ::-1] + tan[..., -1] *= -1 + norm[...] = mat[..., 0, ::-1] + norm[..., -1] *= -1 + if isinstance(mat, np.ndarray): + tan_norm = np.concatenate((tan, norm), axis=-1) + elif isinstance(mat, torch.Tensor): + tan_norm = torch.cat((tan, norm), dim=-1) + else: + raise ValueError + return tan_norm + + +def mat2tan_norm(mat): + """_summary_ + + Args: + mat(tensor): [B, 4, 4] + + Returns: + vec (tensor): [B, 6], tan norm + """ + rot_mat = mat[..., :-1, :-1] + return rotmat2tan_norm(rot_mat) + + +def rotmat2tan_norm(mat): + """_summary_ + + Args: + mat(tensor): [B, 3, 3] + + Returns: + vec (tensor): [B, 6], tan norm + """ + if isinstance(mat, np.ndarray): + tan = np.zeros_like(mat[..., 2]) + norm = np.zeros_like(mat[..., 0]) + tan[...] = mat[..., 2, ::-1] + norm[...] = mat[..., 0, ::-1] + elif isinstance(mat, torch.Tensor): + tan = torch.zeros_like(mat[..., 2]) + norm = torch.zeros_like(mat[..., 0]) + tan[...] = torch.flip(mat[..., 2], dims=[-1]) + norm[...] = torch.flip(mat[..., 0], dims=[-1]) + else: + raise ValueError + tan[..., -1] *= -1 + norm[..., -1] *= -1 + if isinstance(mat, np.ndarray): + tan_norm = np.concatenate((tan, norm), axis=-1) + elif isinstance(mat, torch.Tensor): + tan_norm = torch.cat((tan, norm), dim=-1) + else: + raise ValueError + return tan_norm + + +def tan_norm2rotmat(tan_norm): + """_summary_ + + Args: + mat(tensor): [B, 6] + + Returns: + vec (tensor): [B, 3] + """ + tan = copy.deepcopy(tan_norm[..., :3]) + norm = copy.deepcopy(tan_norm[..., 3:]) + tan[..., -1] *= -1 + norm[..., -1] *= -1 + if isinstance(tan_norm, np.ndarray): + rotmat = np.zeros(tan_norm.shape[:-1] + (3, 3)) + tan = tan[..., ::-1] + norm = norm[..., ::-1] + other = np.cross(tan, norm) + elif isinstance(tan_norm, torch.Tensor): + rotmat = torch.zeros(tan_norm.shape[:-1] + (3, 3), device=tan_norm.device) + tan = torch.flip(tan, dims=[-1]) + norm = torch.flip(norm, dims=[-1]) + other = torch.cross(tan, norm) + else: + raise ValueError + rotmat[..., 2, :] = tan + rotmat[..., 0, :] = norm + rotmat[..., 1, :] = other + return rotmat + + +def rotmat332vec_batch(mat): + """_summary_ + + Args: + mat(tensor): [B, 3, 3] + + Returns: + vec (tensor): [B, 6], forward, up, right + """ + # Assume bs = 1 + mat = normalized_matrix(mat) + forward = mat[..., :, 2] + up = mat[..., :, 1] + right = mat[..., :, 0] + if isinstance(mat, torch.Tensor): + vec = torch.cat((forward, up, right), dim=-1) + elif isinstance(mat, np.ndarray): + vec = np.concatenate((forward, up, right), axis=-1) + else: + raise ValueError + return vec + + +def rotmat2vec_batch(mat): + """_summary_ + + Args: + mat(tensor): [B, 4, 4] + + Returns: + vec (tensor): [B, 9], forward, up, right + """ + # Assume bs = 1 + mat = normalized_matrix(mat) + forward = mat[..., :-1, 2] + up = mat[..., :-1, 1] + right = mat[..., :-1, 0] + if isinstance(mat, torch.Tensor): + vec = torch.cat((forward, up, right), dim=-1) + elif isinstance(mat, np.ndarray): + vec = np.concatenate((forward, up, right), axis=-1) + else: + raise ValueError + return vec + + +def mat2vec_batch(mat): + """_summary_ + + Args: + mat(tensor): [B, 4, 4] + + Returns: + vec (tensor): [B, 12], pos, forward, up and right + """ + # Assume bs = 1 + mat = normalized_matrix(mat) + pos = mat[..., :-1, 3] + forward = mat[..., :-1, 2] + up = mat[..., :-1, 1] + right = mat[..., :-1, 0] + if isinstance(mat, torch.Tensor): + vec = torch.cat((pos, forward, up, right), dim=-1) + elif isinstance(mat, np.ndarray): + vec = np.concatenate((pos, forward, up, right), axis=-1) + else: + raise ValueError + return vec + + +def mat2pose_batch(mat, returnvel=True): + """_summary_ + + Args: + mat(tensor): [B, 4, 4] + + Returns: + vec (tensor): [B, 12], pos, forward, up, zeros + """ + # Assume bs = 1 + mat = normalized_matrix(mat) + pos = mat[..., :-1, 3] + forward = mat[..., :-1, 2] + up = mat[..., :-1, 1] + if isinstance(mat, torch.Tensor): + if returnvel: + vel = torch.zeros_like(up) + vec = torch.cat((pos, forward, up, vel), dim=-1) + else: + vec = torch.cat((pos, forward, up), dim=-1) + elif isinstance(mat, np.ndarray): + if returnvel: + vel = np.zeros_like(up) + vec = np.concatenate((pos, forward, up, vel), axis=-1) + else: + vec = np.concatenate((pos, forward, up), axis=-1) + else: + raise ValueError + return vec + + +def get_mat_BinA(matCtoA, matCtoB): + """ + given matrix of the same object in two coordinate A and B, + return matrix B in the coordinate of A + + Args: + matCtoA (tensor): [4, 4] world matrix + matCtoB (tensor): [4, 4] world matrix + """ + if isinstance(matCtoA, torch.Tensor): + matCtoB_inv = torch.inverse(matCtoB) + elif isinstance(matCtoA, np.ndarray): + matCtoB_inv = np.linalg.inv(matCtoB) + else: + raise ValueError + matCtoB_inv = normalized_matrix(matCtoB_inv) + if isinstance(matCtoA, torch.Tensor): + mat_BtoA = torch.matmul(matCtoA, matCtoB_inv) + elif isinstance(matCtoA, np.ndarray): + mat_BtoA = np.matmul(matCtoA, matCtoB_inv) + mat_BtoA = normalized_matrix(mat_BtoA) + return mat_BtoA + + +def get_mat_BtoA(matA, matB): + """ + return matrix B in the coordinate of A + + Args: + matA (tensor): [4, 4] world matrix + matB (tensor): [4, 4] world matrix + """ + if isinstance(matA, torch.Tensor): + matA_inv = torch.inverse(matA) + elif isinstance(matA, np.ndarray): + matA_inv = np.linalg.inv(matA) + else: + raise ValueError + matA_inv = normalized_matrix(matA_inv) + if isinstance(matA, torch.Tensor): + mat_BtoA = torch.matmul(matA_inv, matB) + elif isinstance(matA, np.ndarray): + mat_BtoA = np.matmul(matA_inv, matB) + mat_BtoA = normalized_matrix(mat_BtoA) + return mat_BtoA + + +def get_mat_BfromA(matA, matBtoA): + """ + return world matrix B given matrix A and mat B realtive to A + + Args: + matA (_type_): [4, 4] world matrix + matBtoA (_type_): [4, 4] matrix B relative to A + """ + if isinstance(matA, torch.Tensor): + matB = torch.matmul(matA, matBtoA) + if isinstance(matA, np.ndarray): + matB = np.matmul(matA, matBtoA) + matB = normalized_matrix(matB) + return matB + + +def get_relative_position_to(pos, mat): + """_summary_ + + Args: + pos (_type_): [N, M, 3] or [N, 3] + mat (_type_): [N, 4, 4] or [4, 4] + + Returns: + _type_: _description_ + """ + if isinstance(mat, torch.Tensor): + mat_inv = torch.inverse(mat) + elif isinstance(mat, np.ndarray): + mat_inv = np.linalg.inv(mat) + else: + raise ValueError + mat_inv = normalized_matrix(mat_inv) + if isinstance(mat, torch.Tensor): + rot_pos = torch.matmul(mat_inv[..., :-1, :-1], pos.transpose(-1, -2)).transpose(-1, -2) + elif isinstance(mat, np.ndarray): + rot_pos = np.matmul(mat_inv[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2) + world_pos = rot_pos + mat_inv[..., None, :-1, 3] + return world_pos + + +def get_rotation(mat): + """_summary_ + + Args: + mat (_type_): [..., 4, 4] + + Returns: + _type_: _description_ + """ + return mat[..., :-1, :-1] + + +def set_rotation(mat, rotmat): + """_summary_ + + Args: + mat (_type_): [..., 4, 4] + + Returns: + _type_: _description_ + """ + mat[..., :-1, :-1] = rotmat + return mat + + +def set_position(mat, pos): + """_summary_ + + Args: + mat (_type_): [..., 4, 4] + + Returns: + _type_: _description_ + """ + mat[..., :-1, 3] = pos + return mat + + +def get_position(mat): + """_summary_ + + Args: + mat (_type_): [..., 4, 4] + + Returns: + _type_: _description_ + """ + return mat[..., :-1, 3] + + +def get_position_from(pos, mat): + """_summary_ + + Args: + pos (_type_): [N, M, 3] or [N, 3] + mat (_type_): [N, 4, 4] or [4, 4] + + Returns: + _type_: _description_ + """ + if isinstance(mat, torch.Tensor): + rot_pos = torch.matmul(mat[..., :-1, :-1], pos.transpose(-1, -2)).transpose(-1, -2) + elif isinstance(mat, np.ndarray): + rot_pos = np.matmul(mat[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2) + else: + raise ValueError + + world_pos = rot_pos + mat[..., None, :-1, 3] + return world_pos + + +def get_position_from_rotmat(pos, mat): + """_summary_ + + Args: + pos (_type_): [N, M, 3] or [N, 3] + mat (_type_): [N, 4, 4] or [4, 4] + + Returns: + _type_: _description_ + """ + if isinstance(mat, torch.Tensor): + rot_pos = torch.matmul(mat, pos.transpose(-1, -2)).transpose(-1, -2) + elif isinstance(mat, np.ndarray): + rot_pos = np.matmul(mat, pos.swapaxes(-1, -2)).swapaxes(-1, -2) + else: + raise ValueError + return rot_pos + + +def get_relative_direction_to(dir, mat): + """_summary_ + + Args: + dir (_type_): [N, M, 3] or [N, 3] + mat (_type_): [N, 4, 4] or [4, 4] + + Returns: + _type_: _description_ + """ + if isinstance(mat, torch.Tensor): + mat_inv = torch.inverse(mat) + elif isinstance(mat, np.ndarray): + mat_inv = np.linalg.inv(mat) + else: + raise ValueError + mat_inv = normalized_matrix(mat_inv) + rot_mat_inv = mat_inv[..., :3, :3] + if isinstance(mat, torch.Tensor): + rel_dir = torch.matmul(rot_mat_inv, dir.transpose(-1, -2)) + return rel_dir.transpose(-1, -2) + elif isinstance(mat, np.ndarray): + rel_dir = np.matmul(rot_mat_inv, dir.swapaxes(-1, -2)) + return rel_dir.swapaxes(-1, -2) + else: + raise ValueError + return + + +def get_direction_from(dir, mat): + """_summary_ + + Args: + dir (_type_): [N, M, 3] or [N, 3] + mat (_type_): [N, 4, 4] or [4, 4] + + Returns: + tensor: [N, M, 3] or [N, 3] + """ + rot_mat = mat[..., :3, :3] + if isinstance(mat, torch.Tensor): + world_dir = torch.matmul(rot_mat, dir.transpose(-1, -2)) + return world_dir.transpose(-1, -2) + elif isinstance(mat, np.ndarray): + world_dir = np.matmul(rot_mat, dir.swapaxes(-1, -2)) + return world_dir.swapaxes(-1, -2) + else: + raise ValueError + return + + +def get_coord_vis(pos, rot_mat, scale=1.0): + forward = rot_mat[..., :, 2] + up = rot_mat[..., :, 1] + right = rot_mat[..., :, 0] + return pos + right * scale, pos + up * scale, pos + forward * scale + + +def project_vec(vec): + """_summary_ + + Args: + vec (tensor): [*, 12], pos, forward, up and right + + Returns: + proj_vec (tensor): [*, 4], posx, posz, forwardx, forwardz + """ + posx = vec[..., 0:1] + posz = vec[..., 2:3] + forwardx = vec[..., 3:4] + forwardz = vec[..., 5:6] + if isinstance(vec, torch.Tensor): + proj_vec = torch.cat((posx, posz, forwardx, forwardz), dim=-1) + elif isinstance(vec, np.ndarray): + proj_vec = np.concatenate((posx, posz, forwardx, forwardz), axis=-1) + else: + raise ValueError + + return proj_vec + + +def xz2xyz(vec): + x = vec[..., 0:1] + z = vec[..., 1:2] + if isinstance(vec, torch.Tensor): + y = torch.zeros(vec.shape[:-1] + (1,), device=vec.device) + xyz_vec = torch.cat((x, y, z), dim=-1) + elif isinstance(vec, np.ndarray): + y = np.zeros(vec.shape[:-1] + (1,)) + xyz_vec = np.concatenate((x, y, z), axis=-1) + else: + raise ValueError + + return xyz_vec + + +def normalized(vec): + if isinstance(vec, torch.Tensor): + norm_vec = vec / (vec.norm(2, dim=-1, keepdim=True) + 1e-9) + elif isinstance(vec, np.ndarray): + norm_vec = vec / (np.linalg.norm(vec, ord=2, axis=-1, keepdims=True) + 1e-9) + else: + raise ValueError + + return norm_vec + + +def normalized_matrix(mat): + if mat.shape[-1] == 4: + rot_mat = mat[..., :-1, :-1] + else: + rot_mat = mat + if isinstance(mat, torch.Tensor): + rot_mat_norm = rot_mat / (rot_mat.norm(2, dim=-2, keepdim=True) + 1e-9) + norm_mat = torch.zeros_like(mat) + elif isinstance(mat, np.ndarray): + rot_mat_norm = rot_mat / (np.linalg.norm(rot_mat, ord=2, axis=-2, keepdims=True) + 1e-9) + norm_mat = np.zeros_like(mat) + else: + raise ValueError + if mat.shape[-1] == 4: + norm_mat[..., :-1, :-1] = rot_mat_norm + norm_mat[..., :-1, -1] = mat[..., :-1, -1] + norm_mat[..., -1, -1] = 1.0 + else: + norm_mat = rot_mat_norm + return norm_mat + + +def get_rot_mat_from_forward(forward): + """_summary_ + + Args: + forward (tensor): [N, M, 3] + + Returns: + mat (tensor): [N, M, 3, 3] + """ + if isinstance(forward, torch.Tensor): + mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1)) + right = torch.zeros_like(forward) + elif isinstance(forward, np.ndarray): + mat = np.eye(3, dtype=np.float32) + for _ in range(len(forward.shape) - 1): + mat = mat[None] + mat = np.tile(mat, forward.shape[:-1] + (1, 1)) + right = np.zeros_like(forward) + else: + raise ValueError + + right[..., 0] = forward[..., 2] + right[..., 1] = 0.0 + right[..., 2] = -forward[..., 0] + # right = torch.cross(mat[..., 1], forward) # cannot backward + + mat[..., 2] = normalized(forward) + right = normalized(right) + mat[..., 0] = right + return mat + + +def get_rot_mat_from_forward_up(forward, up): + """_summary_ + + Args: + forward (tensor): [N, M, 3] + up (tensor): [N, M, 3] + + Returns: + mat (tensor): [N, M, 3, 3] + """ + if isinstance(forward, torch.Tensor): + mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1)) + right = torch.cross(up, forward) + elif isinstance(forward, np.ndarray): + mat = np.eye(3, dtype=np.float32) + for _ in range(len(forward.shape) - 1): + mat = mat[None] + mat = np.tile(mat, forward.shape[:-1] + (1, 1)) + right = np.cross(up, forward) + else: + raise ValueError + + right = normalized(right) + mat[..., 2] = normalized(forward) + mat[..., 1] = normalized(up) + mat[..., 0] = right + return mat + + +def get_rot_mat_from_pose_vec(vec): + """_summary_ + + Args: + vec (tensor): [N, M, 6] + + Returns: + mat (tensor): [N, M, 3, 3] + """ + forward = vec[..., :3] + up = vec[..., 3:6] + return get_rot_mat_from_forward_up(forward, up) + + +def get_TRS(rot_mat, pos): + """_summary_ + + Args: + rot_mat (tensor): [N, 3, 3] + pos (tensor): [N, 3] + + Returns: + mat (tensor): [N, 4, 4] + """ + if isinstance(rot_mat, torch.Tensor): + mat = torch.eye(4, device=pos.device).repeat(pos.shape[:-1] + (1, 1)) + elif isinstance(rot_mat, np.ndarray): + mat = np.eye(4, dtype=np.float32) + for _ in range(len(pos.shape) - 1): + mat = mat[None] + mat = np.tile(mat, pos.shape[:-1] + (1, 1)) + else: + raise ValueError + mat[..., :3, :3] = rot_mat + mat[..., :3, 3] = pos + mat = normalized_matrix(mat) + return mat + + +def xzvec2mat(vec): + """_summary_ + + Args: + vec (tensor): [N, 4] + + Returns: + mat (tensor): [N, 4, 4] + """ + vec_shape = vec.shape[:-1] + if isinstance(vec, torch.Tensor): + pos = torch.zeros(vec_shape + (3,)) + forward = torch.zeros(vec_shape + (3,)) + elif isinstance(vec, np.ndarray): + pos = np.zeros(vec_shape + (3,)) + forward = np.zeros(vec_shape + (3,)) + else: + raise ValueError + + pos[..., 0] = vec[..., 0] + pos[..., 2] = vec[..., 1] + forward[..., 0] = vec[..., 2] + forward[..., 2] = vec[..., 3] + rot_mat = get_rot_mat_from_forward(forward) + mat = get_TRS(rot_mat, pos) + return mat + + +def distance(vec1, vec2): + return ((vec1 - vec2) ** 2).sum() ** 0.5 + + +def get_relative_pose_from_vec(pose, root, N): + root_p_mat = xzvec2mat(root) + pose = pose.reshape(-1, N, 12) + pose[..., :3] = get_position_from(pose[..., :3], root_p_mat) + pose[..., 3:6] = get_direction_from(pose[..., 3:6], root_p_mat) + pose[..., 6:9] = get_direction_from(pose[..., 6:9], root_p_mat) + pose[..., 9:] = get_direction_from(pose[..., 9:], root_p_mat) + pos = pose[..., 0, :3] + rot = pose[..., 3:9].reshape(-1, N * 6) + pose = np.concatenate((pos, rot), axis=-1) + return pose + + +def get_forward_from_pos(pos): + """_summary_ + + Args: + pos (N, J, 3): joints positions of each frame + + Returns: + _type_: _description_ + """ + + pos_y_vec = torch.tensor([0, 1, 0], dtype=torch.float32).to(pos.device) + face_joint_indx = [2, 1, 17, 16] + r_hip, l_hip, r_sdr, l_sdr = face_joint_indx # use hip and shoulder to get the cross vector + cross_hip = pos[..., 0, r_hip, :] - pos[..., 0, l_hip, :] + cross_sdr = pos[..., 0, r_sdr, :] - pos[..., 0, l_sdr, :] + cross_vec = cross_hip + cross_sdr # (3, ) + forward_vec = torch.cross(pos_y_vec, cross_vec, dim=-1) + forward_vec = normalized(forward_vec) + return forward_vec + + +def project_point_along_ray(p, ray, keepnorm=False): + """_summary_ + + Args: + p (*, 3): point positions + ray (*, 3): ray direction + keepnorm: False -> project point on the ray, + True -> project point on the ray and keep the point length + + Returns: + _type_: _description_ + """ + ray = normalized(ray) + if keepnorm: + new_p = ray * p.norm(dim=-1, keepdim=True) + else: + dot_product = torch.sum(p * ray, dim=-1, keepdim=True) + new_p = dot_product * ray + return new_p + + +def solve_point_along_ray_with_constraint(c, ray, p, constraint="x"): + """_summary_ + + Args: + c (*,): constraint value + ray (*, 3): ray direction + p (*, 3): start point of the ray + + Returns: + _type_: _description_ + """ + ray = normalized(ray) + if constraint == "x": + ind = 0 + elif constraint == "y": + ind = 1 + elif constraint == "z": + ind = 2 + else: + raise ValueError + t = (c - p[..., ind]) / ray[..., ind] + out_p = ray * t[..., None] + p + + return out_p + + +def calc_cosine(vec1, vec2, return_angle=False): + """_summary_ + + Args: + vec1 (*, 3): vector + vec2 (*, 3): vector + return_angle: True -> return angle, False -> return cosine + + Returns: + _type_: _description_ + """ + vec1 = normalized(vec1) + vec2 = normalized(vec2) + cosine = torch.sum(vec1 * vec2, dim=-1) + if return_angle: + return torch.acos(cosine) + return cosine + + +############################################ +# +# quaternion assumes xyzw +# +############################################ + + +def quat_xyzw2wxyz(quat): + new_quat = torch.cat([quat[..., 3:4], quat[..., :3]], dim=-1) + return new_quat + + +def quat_wxyz2xyzw(quat): + new_quat = torch.cat([quat[..., 1:4], quat[..., :1]], dim=-1) + return new_quat + + +def quat_mul(a, b): + """ + quaternion multiplication + """ + x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] + x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + + return torch.stack([x, y, z, w], dim=-1) + + +def quat_pos(x): + """ + make all the real part of the quaternion positive + """ + q = x + z = (q[..., 3:] < 0).float() + q = (1 - 2 * z) * q + return q + + +def quat_abs(x): + """ + quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1) + """ + x = x.norm(p=2, dim=-1) + return x + + +def quat_unit(x): + """ + normalized quaternion with norm of 1 + """ + norm = quat_abs(x).unsqueeze(-1) + return x / (norm.clamp(min=1e-4)) + + +def quat_conjugate(x): + """ + quaternion with its imaginary part negated + """ + return torch.cat([-x[..., :3], x[..., 3:]], dim=-1) + + +def quat_real(x): + """ + real component of the quaternion + """ + return x[..., 3] + + +def quat_imaginary(x): + """ + imaginary components of the quaternion + """ + return x[..., :3] + + +def quat_norm_check(x): + """ + verify that a quaternion has norm 1 + """ + assert bool((abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()), "the quaternion is has non-1 norm: {}".format( + abs(x.norm(p=2, dim=-1) - 1) + ) + assert bool((x[..., 3] >= 0).all()), "the quaternion has negative real part" + + +def quat_normalize(q): + """ + Construct 3D rotation from quaternion (the quaternion needs not to be normalized). + """ + q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion + return q + + +def quat_from_xyz(xyz): + """ + Construct 3D rotation from the imaginary component + """ + w = (1.0 - xyz.norm()).unsqueeze(-1) + assert bool((w >= 0).all()), "xyz has its norm greater than 1" + return torch.cat([xyz, w], dim=-1) + + +def quat_identity(shape: List[int]): + """ + Construct 3D identity rotation given shape + """ + w = torch.ones(shape + (1,)) + xyz = torch.zeros(shape + (3,)) + q = torch.cat([xyz, w], dim=-1) + return quat_normalize(q) + + +def tgm_quat_from_angle_axis(angle, axis, degree: bool = False): + """Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise + along the axis. + + The rotation can be interpreted as a_R_b where frame "b" is the new frame that + gets rotated counter-clockwise along the axis from frame "a" + + :param angle: angle of rotation + :type angle: Tensor + :param axis: axis of rotation + :type axis: Tensor + :param degree: put True here if the angle is given by degree + :type degree: bool, optional, default=False + """ + if degree: + angle = angle / 180.0 * math.pi + theta = (angle / 2).unsqueeze(-1) + axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4)) + xyz = axis * theta.sin() + w = theta.cos() + return quat_normalize(torch.cat([w, xyz], dim=-1)) + + +def quat_from_rotation_matrix(m): + """ + Construct a 3D rotation from a valid 3x3 rotation matrices. + Reference can be found here: + http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html + + :param m: 3x3 orthogonal rotation matrices. + :type m: Tensor + + :rtype: Tensor + """ + m = m.unsqueeze(0) + diag0 = m[..., 0, 0] + diag1 = m[..., 1, 1] + diag2 = m[..., 2, 2] + + # Math stuff. + w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 + x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 + y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 + z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 + + # Only modify quaternions where w > x, y, z. + c0 = (w >= x) & (w >= y) & (w >= z) + x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign() + y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign() + z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign() + + # Only modify quaternions where x > w, y, z + c1 = (x >= w) & (x >= y) & (x >= z) + w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign() + y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign() + z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign() + + # Only modify quaternions where y > w, x, z. + c2 = (y >= w) & (y >= x) & (y >= z) + w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign() + x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign() + z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign() + + # Only modify quaternions where z > w, x, y. + c3 = (z >= w) & (z >= x) & (z >= y) + w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign() + x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign() + y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign() + + return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0) + + +def quat_mul_norm(x, y): + """ + Combine two set of 3D rotations together using \**\* operator. The shape needs to be + broadcastable + """ + return quat_normalize(quat_mul(x, y)) + + +def quat_rotate(rot, vec): + """ + Rotate a 3D vector with the 3D rotation + """ + other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1) + return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot))) + + +def quat_inverse(x): + """ + The inverse of the rotation + """ + return quat_conjugate(x) + + +def quat_identity_like(x): + """ + Construct identity 3D rotation with the same shape + """ + return quat_identity(x.shape[:-1]) + + +def quat_angle_axis(x): + """ + The (angle, axis) representation of the rotation. The axis is normalized to unit length. + The angle is guaranteed to be between [0, pi]. + """ + s = 2 * (x[..., 3] ** 2) - 1 + angle = s.clamp(-1, 1).arccos() # just to be safe + axis = x[..., :3] + axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4) + return angle, axis + + +def quat_yaw_rotation(x, z_up: bool = True): + """ + Yaw rotation (rotation along z-axis) + """ + q = x + if z_up: + q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1) + else: + q = torch.cat( + [ + torch.zeros_like(q[..., 0:1]), + q[..., 1:2], + torch.zeros_like(q[..., 2:3]), + q[..., 3:4], + ], + dim=-1, + ) + return quat_normalize(q) + + +def transform_from_rotation_translation(r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None): + """ + Construct a transform from a quaternion and 3D translation. Only one of them can be None. + """ + assert r is not None or t is not None, "rotation and translation can't be all None" + if r is None: + assert t is not None + r = quat_identity(list(t.shape)) + if t is None: + t = torch.zeros(list(r.shape) + [3]) + return torch.cat([r, t], dim=-1) + + +def transform_identity(shape: List[int]): + """ + Identity transformation with given shape + """ + r = quat_identity(shape) + t = torch.zeros(shape + [3]) + return transform_from_rotation_translation(r, t) + + +def transform_rotation(x): + """Get rotation from transform""" + return x[..., :4] + + +def transform_translation(x): + """Get translation from transform""" + return x[..., 4:] + + +def transform_inverse(x): + """ + Inverse transformation + """ + inv_so3 = quat_inverse(transform_rotation(x)) + return transform_from_rotation_translation(r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x))) + + +def transform_identity_like(x): + """ + identity transformation with the same shape + """ + return transform_identity(x.shape) + + +def transform_mul(x, y): + """ + Combine two transformation together + """ + z = transform_from_rotation_translation( + r=quat_mul_norm(transform_rotation(x), transform_rotation(y)), + t=quat_rotate(transform_rotation(x), transform_translation(y)) + transform_translation(x), + ) + return z + + +def transform_apply(rot, vec): + """ + Transform a 3D vector + """ + assert isinstance(vec, torch.Tensor) + return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot) + + +def rot_matrix_det(x): + """ + Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the + shape of the matrix + """ + a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2] + d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2] + g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2] + t1 = a * (e * i - f * h) + t2 = b * (d * i - f * g) + t3 = c * (d * h - e * g) + return t1 - t2 + t3 + + +def rot_matrix_integrity_check(x): + """ + Verify that a rotation matrix has a determinant of one and is orthogonal + """ + det = rot_matrix_det(x) + assert bool((abs(det - 1) < 1e-3).all()), "the matrix has non-one determinant" + rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2) + rtr_gt = rtr.zeros_like() + rtr_gt[..., 0, 0] = 1 + rtr_gt[..., 1, 1] = 1 + rtr_gt[..., 2, 2] = 1 + assert bool(((rtr - rtr_gt) < 1e-3).all()), "the matrix is not orthogonal" + + +def rot_matrix_from_quaternion(q): + """ + Construct rotation matrix from quaternion + """ + # Shortcuts for individual elements (using wikipedia's convention) + qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3] + + # Set individual elements + R00 = 1.0 - 2.0 * (qj**2 + qk**2) + R01 = 2 * (qi * qj - qk * qr) + R02 = 2 * (qi * qk + qj * qr) + R10 = 2 * (qi * qj + qk * qr) + R11 = 1.0 - 2.0 * (qi**2 + qk**2) + R12 = 2 * (qj * qk - qi * qr) + R20 = 2 * (qi * qk - qj * qr) + R21 = 2 * (qj * qk + qi * qr) + R22 = 1.0 - 2.0 * (qi**2 + qj**2) + + R0 = torch.stack([R00, R01, R02], dim=-1) + R1 = torch.stack([R10, R11, R12], dim=-1) + R2 = torch.stack([R20, R21, R22], dim=-1) + + R = torch.stack([R0, R1, R2], dim=-2) + + return R + + +def euclidean_to_rotation_matrix(x): + """ + Get the rotation matrix on the top-left corner of a Euclidean transformation matrix + """ + return x[..., :3, :3] + + +def euclidean_integrity_check(x): + euclidean_to_rotation_matrix(x) # check 3d-rotation matrix + assert bool((x[..., 3, :3] == 0).all()), "the last row is illegal" + assert bool((x[..., 3, 3] == 1).all()), "the last row is illegal" + + +def euclidean_translation(x): + """ + Get the translation vector located at the last column of the matrix + """ + return x[..., :3, 3] + + +def euclidean_inverse(x): + """ + Compute the matrix that represents the inverse rotation + """ + s = x.zeros_like() + irot = quat_inverse(quat_from_rotation_matrix(x)) + s[..., :3, :3] = irot + s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x)) + return s + + +def euclidean_to_transform(transformation_matrix): + """ + Construct a transform from a Euclidean transformation matrix + """ + return transform_from_rotation_translation( + r=quat_from_rotation_matrix(m=euclidean_to_rotation_matrix(transformation_matrix)), + t=euclidean_translation(transformation_matrix), + ) + + +def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False): + return torch.tensor(x, dtype=dtype, device=device, requires_grad=requires_grad) + + +def quat_mul(a, b): + assert a.shape == b.shape + shape = a.shape + a = a.reshape(-1, 4) + b = b.reshape(-1, 4) + + x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] + x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] + ww = (z1 + x1) * (x2 + y2) + yy = (w1 - y1) * (w2 + z2) + zz = (w1 + y1) * (w2 - z2) + xx = ww + yy + zz + qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) + w = qq - ww + (z1 - y1) * (y2 - z2) + x = qq - xx + (x1 + w1) * (x2 + w2) + y = qq - yy + (w1 - x1) * (y2 + z2) + z = qq - zz + (z1 + y1) * (w2 - x2) + + quat = torch.stack([x, y, z, w], dim=-1).view(shape) + + return quat + + +def normalize(x, eps: float = 1e-9): + return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) + + +def quat_apply(a, b): + shape = b.shape + a = a.reshape(-1, 4) + b = b.reshape(-1, 3) + xyz = a[:, :3] + t = xyz.cross(b, dim=-1) * 2 + return (b + a[:, 3:] * t + xyz.cross(t, dim=-1)).view(shape) + + +def quat_rotate(q, v): + shape = q.shape + q_w = q[:, -1] + q_vec = q[:, :3] + a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) + b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 + c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 + return a + b + c + + +def quat_rotate_inverse(q, v): + shape = q.shape + q_w = q[:, -1] + q_vec = q[:, :3] + a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) + b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 + c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 + return a - b + c + + +def quat_conjugate(a): + shape = a.shape + a = a.reshape(-1, 4) + return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) + + +def quat_unit(a): + return normalize(a) + + +def quat_from_angle_axis(angle, axis): + theta = (angle / 2).unsqueeze(-1) + xyz = normalize(axis) * torch.sin(theta.clone()) + w = torch.cos(theta.clone()) + return quat_unit(torch.cat([xyz, w], dim=-1)) + + +def normalize_angle(x): + return torch.atan2(torch.sin(x.clone()), torch.cos(x.clone())) + + +def tf_inverse(q, t): + q_inv = quat_conjugate(q) + return q_inv, -quat_apply(q_inv, t) + + +def tf_apply(q, t, v): + return quat_apply(q, v) + t + + +def tf_vector(q, v): + return quat_apply(q, v) + + +def tf_combine(q1, t1, q2, t2): + return quat_mul(q1, q2), quat_apply(q1, t2) + t1 + + +def get_basis_vector(q, v): + return quat_rotate(q, v) + + +def get_axis_params(value, axis_idx, x_value=0.0, dtype=float, n_dims=3): + """construct arguments to `Vec` according to axis index.""" + zs = np.zeros((n_dims,)) + assert axis_idx < n_dims, "the axis dim should be within the vector dimensions" + zs[axis_idx] = 1.0 + params = np.where(zs == 1.0, value, zs) + params[0] = x_value + return list(params.astype(dtype)) + + +def copysign(a, b): + # type: (float, Tensor) -> Tensor + a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) + return torch.abs(a) * torch.sign(b) + + +def get_euler_xyz(q): + qx, qy, qz, qw = 0, 1, 2, 3 + # roll (x-axis rotation) + sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) + cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] + roll = torch.atan2(sinr_cosp, cosr_cosp) + + # pitch (y-axis rotation) + sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) + pitch = torch.where(torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)) + + # yaw (z-axis rotation) + siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) + cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] + yaw = torch.atan2(siny_cosp, cosy_cosp) + + return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) + + +def quat_from_euler_xyz(roll, pitch, yaw): + cy = torch.cos(yaw * 0.5) + sy = torch.sin(yaw * 0.5) + cr = torch.cos(roll * 0.5) + sr = torch.sin(roll * 0.5) + cp = torch.cos(pitch * 0.5) + sp = torch.sin(pitch * 0.5) + + qw = cy * cr * cp + sy * sr * sp + qx = cy * sr * cp - sy * cr * sp + qy = cy * cr * sp + sy * sr * cp + qz = sy * cr * cp - cy * sr * sp + + return torch.stack([qx, qy, qz, qw], dim=-1) + + +def torch_rand_float(lower, upper, shape, device): + # type: (float, float, Tuple[int, int], str) -> Tensor + return (upper - lower) * torch.rand(*shape, device=device) + lower + + +def torch_random_dir_2(shape, device): + # type: (Tuple[int, int], str) -> Tensor + angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1) + return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) + + +def tensor_clamp(t, min_t, max_t): + return torch.max(torch.min(t, max_t), min_t) + + +def scale(x, lower, upper): + return 0.5 * (x + 1.0) * (upper - lower) + lower + + +def unscale(x, lower, upper): + return (2.0 * x - upper - lower) / (upper - lower) + + +def unscale_np(x, lower, upper): + return (2.0 * x - upper - lower) / (upper - lower) + + +def quat_to_angle_axis(q): + # type: (Tensor) -> Tuple[Tensor, Tensor] + # computes axis-angle representation from quaternion q + # q must be normalized + min_theta = 1e-5 + qx, qy, qz, qw = 0, 1, 2, 3 + + sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) + angle = 2 * torch.acos(q[..., qw]) + angle = normalize_angle(angle) + sin_theta_expand = sin_theta.unsqueeze(-1) + axis = q[..., qx:qw] / sin_theta_expand + + mask = torch.abs(sin_theta) > min_theta + default_axis = torch.zeros_like(axis) + default_axis[..., -1] = 1 + + angle = torch.where(mask, angle, torch.zeros_like(angle)) + mask_expand = mask.unsqueeze(-1) + axis = torch.where(mask_expand, axis, default_axis) + return angle, axis + + +def angle_axis_to_exp_map(angle, axis): + # type: (Tensor, Tensor) -> Tensor + # compute exponential map from axis-angle + angle_expand = angle.unsqueeze(-1) + exp_map = angle_expand * axis + return exp_map + + +def quat_to_exp_map(q): + # type: (Tensor) -> Tensor + # compute exponential map from quaternion + # q must be normalized + angle, axis = quat_to_angle_axis(q) + exp_map = angle_axis_to_exp_map(angle, axis) + return exp_map + + +def quat_to_tan_norm(q): + # type: (Tensor) -> Tensor + # represents a rotation using the tangent and normal vectors + ref_tan = torch.zeros_like(q[..., 0:3]) + ref_tan[..., 0] = 1 + tan = quat_rotate(q, ref_tan) + + ref_norm = torch.zeros_like(q[..., 0:3]) + ref_norm[..., -1] = 1 + norm = quat_rotate(q, ref_norm) + + norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1) + return norm_tan + + +def euler_xyz_to_exp_map(roll, pitch, yaw): + # type: (Tensor, Tensor, Tensor) -> Tensor + q = quat_from_euler_xyz(roll, pitch, yaw) + exp_map = quat_to_exp_map(q) + return exp_map + + +def exp_map_to_angle_axis(exp_map): + min_theta = 1e-5 + + angle = torch.norm(exp_map.clone(), dim=-1) + 1e-6 + angle_exp = torch.unsqueeze(angle, dim=-1) + axis = exp_map.clone() / angle_exp.clone() + angle = normalize_angle(angle) + + default_axis = torch.zeros_like(exp_map) + default_axis[..., -1] = 1 + + mask = torch.abs(angle) > min_theta + angle = torch.where(mask, angle, torch.zeros_like(angle)) + mask_expand = mask.unsqueeze(-1) + axis = torch.where(mask_expand, axis, default_axis) + + return angle, axis + + +def exp_map_to_quat(exp_map): + angle, axis = exp_map_to_angle_axis(exp_map) + q = quat_from_angle_axis(angle, axis) + return q + + +def slerp(q0, q1, t): + # type: (Tensor, Tensor, Tensor) -> Tensor + cos_half_theta = torch.sum(q0 * q1, dim=-1) + + neg_mask = cos_half_theta < 0 + q1 = q1.clone() + q1[neg_mask] = -q1[neg_mask] + cos_half_theta = torch.abs(cos_half_theta) + cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) + + half_theta = torch.acos(cos_half_theta) + sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) + + ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta + ratioB = torch.sin(t * half_theta) / sin_half_theta + + new_q = ratioA * q0 + ratioB * q1 + + new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) + new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) + + return new_q + + +def calc_heading_vec(q, head_ind=0): + # type: (Tensor, int) -> Tensor + # calculate heading direction from quaternion + # the heading is the direction vector + # q must be normalized + ref_dir = torch.zeros_like(q[..., 0:3]) + ref_dir[..., head_ind] = 1 + rot_dir = quat_rotate(q, ref_dir) + + return rot_dir + + +def calc_heading(q, head_ind=0, gravity_axis="z"): + # type: (Tensor, int, str) -> Tensor + # calculate heading direction from quaternion + # the heading is the direction on the xy plane + # q must be normalized + ref_dir = torch.zeros_like(q[..., 0:3]) + ref_dir[..., head_ind] = 1 + # ref_dir[..., 0] = 1 + shape = ref_dir.shape[:-1] + q = q.reshape((-1, 4)) + ref_dir = ref_dir.reshape(-1, 3) + rot_dir = quat_rotate(q, ref_dir) + rot_dir = rot_dir.reshape(shape + (3,)) + if gravity_axis == "z": + heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) + elif gravity_axis == "y": + heading = torch.atan2(rot_dir[..., 0], rot_dir[..., 2]) + elif gravity_axis == "x": + heading = torch.atan2(rot_dir[..., 2], rot_dir[..., 1]) + return heading + + +def calc_heading_quat(q, head_ind=0, gravity_axis="z"): + # type: (Tensor, int, str) -> Tensor + # calculate heading rotation from quaternion + # the heading is the direction on the xy plane + # q must be normalized + heading = calc_heading(q, head_ind, gravity_axis=gravity_axis) + axis = torch.zeros_like(q[..., 0:3]) + if gravity_axis == "z": + g_axis = 2 + elif gravity_axis == "y": + g_axis = 1 + elif gravity_axis == "x": + g_axis = 0 + axis[..., g_axis] = 1 + + heading_q = quat_from_angle_axis(heading, axis) + return heading_q + + +def calc_heading_quat_inv(q, head_ind=0): + # type: (Tensor, int) -> Tensor + # calculate heading rotation from quaternion + # the heading is the direction on the xy plane + # q must be normalized + heading = calc_heading(q, head_ind) + axis = torch.zeros_like(q[..., 0:3]) + axis[..., 2] = 1 + + heading_q = quat_from_angle_axis(-heading, axis) + return heading_q + + +def forward_kinematics(mat, parent): + """_summary_ + + Args: + mat ([..., N, 3, 3]): _description_ + parent (): _description_ + """ + if isinstance(mat, torch.Tensor): + rotations = torch.eye(mat.shape[-1], device=mat.device) + rotations = rotations.repeat(mat.shape[:-2] + (1, 1)) + else: + rotations = np.eye(mat.shape[-1], dtype=np.float32) + rotations = np.tile(rotations, mat.shape[:-2] + (1, 1)) + for i in range(mat.shape[-3]): + if parent[i] != -1: + if isinstance(mat, torch.Tensor): + # this way make gradient flow + new_mat = get_mat_BfromA(rotations[..., parent[i], :, :], mat[..., i, :, :]) + rotations = torch.cat( + ( + rotations[..., :i, :, :], + new_mat[..., None, :, :], + rotations[..., i + 1 :, :, :], + ), + dim=-3, + ) + else: + rotations[..., i, :, :] = get_mat_BfromA(rotations[..., parent[i], :, :], mat[..., i, :, :]) + else: + if isinstance(mat, torch.Tensor): + # this way make gradient flow + rotations = torch.cat((mat[..., : i + 1, :, :], rotations[..., i + 1 :, :, :]), dim=-3) + else: + rotations[..., i, :, :] = mat[..., i, :, :] + return rotations diff --git a/hmr4d/utils/net_utils.py b/hmr4d/utils/net_utils.py new file mode 100644 index 0000000..698d61b --- /dev/null +++ b/hmr4d/utils/net_utils.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from hmr4d.utils.pylogger import Log +from pytorch_lightning.utilities.memory import recursive_detach +from einops import repeat, rearrange +from scipy.ndimage._filters import _gaussian_kernel1d + + +def load_pretrained_model(model, ckpt_path): + """ + Load ckpt to model with strategy + """ + assert Path(ckpt_path).exists() + # use model's own load_pretrained_model method + if hasattr(model, "load_pretrained_model"): + model.load_pretrained_model(ckpt_path) + else: + Log.info(f"Loading ckpt: {ckpt_path}") + ckpt = torch.load(ckpt_path, "cpu") + model.load_state_dict(ckpt, strict=True) + + +def find_last_ckpt_path(dirpath): + """ + Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning. + """ + assert dirpath is not None + dirpath = Path(dirpath) + assert dirpath.exists() + # Priority 1: last.ckpt + auto_last_ckpt_path = dirpath / "last.ckpt" + if auto_last_ckpt_path.exists(): + return auto_last_ckpt_path + + # Priority 2 + model_paths = [] + for p in sorted(list(dirpath.glob("*.ckpt"))): + if "last" in p.name: + continue + model_paths.append(p) + if len(model_paths) > 0: + return model_paths[-1] + else: + Log.info("No checkpoint found, set model_path to None") + return None + + +def get_resume_ckpt_path(resume_mode, ckpt_dir=None): + if Path(resume_mode).exists(): # This is a path + return resume_mode + assert resume_mode == "last" + return find_last_ckpt_path(ckpt_dir) + + +def select_state_dict_by_prefix(state_dict, prefix, new_prefix=""): + """ + For each weight that start with {old_prefix}, remove the {old_prefic} and form a new state_dict. + Args: + state_dict: dict + prefix: str + new_prefix: str, if exists, the new key will be {new_prefix} + {old_key[len(prefix):]} + Returns: + state_dict_new: dict + """ + state_dict_new = {} + for k in list(state_dict.keys()): + if k.startswith(prefix): + new_key = new_prefix + k[len(prefix) :] + state_dict_new[new_key] = state_dict[k] + return state_dict_new + + +def detach_to_cpu(in_dict): + return recursive_detach(in_dict, to_cpu=True) + + +def to_cuda(data): + """Move data in the batch to cuda(), carefully handle data that is not tensor""" + if isinstance(data, torch.Tensor): + return data.cuda() + elif isinstance(data, dict): + return {k: to_cuda(v) for k, v in data.items()} + elif isinstance(data, list): + return [to_cuda(v) for v in data] + else: + return data + + +def get_valid_mask(max_len, valid_len, device="cpu"): + mask = torch.zeros(max_len, dtype=torch.bool).to(device) + mask[:valid_len] = True + return mask + + +def length_to_mask(lengths, max_len): + """ + Returns: (B, max_len) + """ + mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def repeat_to_max_len(x, max_len, dim=0): + """Repeat last frame to max_len along dim""" + assert isinstance(x, torch.Tensor) + if x.shape[dim] == max_len: + return x + elif x.shape[dim] < max_len: + x = x.clone() + x = x.transpose(0, dim) + x = torch.cat([x, repeat(x[-1:], "b ... -> (b r) ...", r=max_len - x.shape[0])]) + x = x.transpose(0, dim) + return x + else: + raise ValueError(f"Unexpected length v.s. max_len: {x.shape[0]} v.s. {max_len}") + + +def repeat_to_max_len_dict(x_dict, max_len, dim=0): + for k, v in x_dict.items(): + x_dict[k] = repeat_to_max_len(v, max_len, dim=dim) + return x_dict + + +class Transpose(nn.Module): + def __init__(self, dim1, dim2): + super(Transpose, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x): + return x.transpose(self.dim1, self.dim2) + + +class GaussianSmooth(nn.Module): + def __init__(self, sigma=3, dim=-1): + super(GaussianSmooth, self).__init__() + kernel_smooth = _gaussian_kernel1d(sigma=sigma, order=0, radius=int(4 * sigma + 0.5)) + kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None] # (1, 1, K) + self.register_buffer("kernel_smooth", kernel_smooth, persistent=False) + self.dim = dim + + def forward(self, x): + """x (..., f, ...) f at dim""" + rad = self.kernel_smooth.size(-1) // 2 + + x = x.transpose(self.dim, -1) + x_shape = x.shape[:-1] + x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f) + x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] + x = F.conv1d(x, self.kernel_smooth) + x = x.squeeze(1).reshape(*x_shape, -1) # (..., f) + x = x.transpose(-1, self.dim) + return x + + +def gaussian_smooth(x, sigma=3, dim=-1): + kernel_smooth = _gaussian_kernel1d(sigma=sigma, order=0, radius=int(4 * sigma + 0.5)) + kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None].to(x) # (1, 1, K) + rad = kernel_smooth.size(-1) // 2 + + x = x.transpose(dim, -1) + x_shape = x.shape[:-1] + x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f) + x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] + x = F.conv1d(x, kernel_smooth) + x = x.squeeze(1).reshape(*x_shape, -1) # (..., f) + x = x.transpose(-1, dim) + return x + + +def moving_average_smooth(x, window_size=5, dim=-1): + kernel_smooth = torch.ones(window_size).float() / window_size + kernel_smooth = kernel_smooth[None, None].to(x) # (1, 1, window_size) + rad = kernel_smooth.size(-1) // 2 + + x = x.transpose(dim, -1) + x_shape = x.shape[:-1] + x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f) + x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] + x = F.conv1d(x, kernel_smooth) + x = x.squeeze(1).reshape(*x_shape, -1) # (..., f) + x = x.transpose(-1, dim) + return x diff --git a/hmr4d/utils/preproc/__init__.py b/hmr4d/utils/preproc/__init__.py new file mode 100644 index 0000000..322b467 --- /dev/null +++ b/hmr4d/utils/preproc/__init__.py @@ -0,0 +1,7 @@ +try: + from hmr4d.utils.preproc.tracker import Tracker + from hmr4d.utils.preproc.vitfeat_extractor import Extractor + from hmr4d.utils.preproc.vitpose import VitPoseExtractor + from hmr4d.utils.preproc.slam import SLAMModel +except: + pass diff --git a/hmr4d/utils/preproc/slam.py b/hmr4d/utils/preproc/slam.py new file mode 100644 index 0000000..ccd75d7 --- /dev/null +++ b/hmr4d/utils/preproc/slam.py @@ -0,0 +1,104 @@ +import cv2 +import time +import torch +from multiprocessing import Process, Queue + +try: + from dpvo.utils import Timer + from dpvo.dpvo import DPVO + from dpvo.config import cfg +except: + pass + + +from hmr4d import PROJ_ROOT +from hmr4d.utils.geo.hmr_cam import estimate_focal_length + + +class SLAMModel(object): + def __init__(self, video_path, width, height, intrinsics=None, stride=1, skip=0, buffer=2048, resize=0.5): + """ + Args: + intrinsics: [fx, fy, cx, cy] + """ + if intrinsics is None: + print("Estimating focal length") + focal_length = estimate_focal_length(width, height) + intrinsics = torch.tensor([focal_length, focal_length, width / 2.0, height / 2.0]) + else: + intrinsics = intrinsics.clone() + + self.dpvo_cfg = str(PROJ_ROOT / "third-party/DPVO/config/default.yaml") + self.dpvo_ckpt = "inputs/checkpoints/dpvo/dpvo.pth" + + self.buffer = buffer + self.times = [] + self.slam = None + self.queue = Queue(maxsize=8) + self.reader = Process(target=video_stream, args=(self.queue, video_path, intrinsics, stride, skip, resize)) + self.reader.start() + + def track(self): + (t, image, intrinsics) = self.queue.get() + + if t < 0: + return False + + image = torch.from_numpy(image).permute(2, 0, 1).cuda() + intrinsics = intrinsics.cuda() # [fx, fy, cx, cy] + + if self.slam is None: + cfg.merge_from_file(self.dpvo_cfg) + cfg.BUFFER_SIZE = self.buffer + self.slam = DPVO(cfg, self.dpvo_ckpt, ht=image.shape[1], wd=image.shape[2], viz=False) + + with Timer("SLAM", enabled=False): + t = time.time() + self.slam(t, image, intrinsics) + self.times.append(time.time() - t) + + return True + + def process(self): + for _ in range(12): + self.slam.update() + + self.reader.join() + return self.slam.terminate()[0] + + +def video_stream(queue, imagedir, intrinsics, stride, skip=0, resize=0.5): + """video generator""" + assert len(intrinsics) == 4, "intrinsics should be [fx, fy, cx, cy]" + + cap = cv2.VideoCapture(imagedir) + t = 0 + for _ in range(skip): + ret, image = cap.read() + + while True: + # Capture frame-by-frame + for _ in range(stride): + ret, image = cap.read() + # if frame is read correctly ret is True + if not ret: + break + + if not ret: + break + + image = cv2.resize(image, None, fx=resize, fy=resize, interpolation=cv2.INTER_AREA) + h, w, _ = image.shape + image = image[: h - h % 16, : w - w % 16] + + intrinsics_ = intrinsics.clone() * resize + queue.put((t, image, intrinsics_)) + + t += 1 + + queue.put((-1, image, intrinsics)) # -1 will terminate the process + cap.release() + + # wait for the queue to be empty, otherwise the process will end immediately + while not queue.empty(): + time.sleep(1) diff --git a/hmr4d/utils/preproc/tracker.py b/hmr4d/utils/preproc/tracker.py new file mode 100644 index 0000000..b093ed4 --- /dev/null +++ b/hmr4d/utils/preproc/tracker.py @@ -0,0 +1,95 @@ +from ultralytics import YOLO +from hmr4d import PROJ_ROOT + +import torch +import numpy as np +from tqdm import tqdm +from collections import defaultdict + +from hmr4d.utils.seq_utils import ( + get_frame_id_list_from_mask, + linear_interpolate_frame_ids, + frame_id_to_mask, + rearrange_by_mask, +) +from hmr4d.utils.video_io_utils import get_video_lwh +from hmr4d.utils.net_utils import moving_average_smooth + + +class Tracker: + def __init__(self) -> None: + # https://docs.ultralytics.com/modes/predict/ + self.yolo = YOLO(PROJ_ROOT / "inputs/checkpoints/yolo/yolov8x.pt") + + def track(self, video_path): + track_history = [] + cfg = { + "device": "cuda", + "conf": 0.5, # default 0.25, wham 0.5 + "classes": 0, # human + "verbose": False, + "stream": True, + } + results = self.yolo.track(video_path, **cfg) + # frame-by-frame tracking + track_history = [] + for result in tqdm(results, total=get_video_lwh(video_path)[0], desc="YoloV8 Tracking"): + if result.boxes.id is not None: + track_ids = result.boxes.id.int().cpu().tolist() # (N) + bbx_xyxy = result.boxes.xyxy.cpu().numpy() # (N, 4) + result_frame = [{"id": track_ids[i], "bbx_xyxy": bbx_xyxy[i]} for i in range(len(track_ids))] + else: + result_frame = [] + track_history.append(result_frame) + + return track_history + + @staticmethod + def sort_track_length(track_history, video_path): + """This handles the track history from YOLO tracker.""" + id_to_frame_ids = defaultdict(list) + id_to_bbx_xyxys = defaultdict(list) + # parse to {det_id : [frame_id]} + for frame_id, frame in enumerate(track_history): + for det in frame: + id_to_frame_ids[det["id"]].append(frame_id) + id_to_bbx_xyxys[det["id"]].append(det["bbx_xyxy"]) + for k, v in id_to_bbx_xyxys.items(): + id_to_bbx_xyxys[k] = np.array(v) + + # Sort by length of each track (max to min) + id_length = {k: len(v) for k, v in id_to_frame_ids.items()} + id2length = dict(sorted(id_length.items(), key=lambda item: item[1], reverse=True)) + + # Sort by area sum (max to min) + id_area_sum = {} + l, w, h = get_video_lwh(video_path) + for k, v in id_to_bbx_xyxys.items(): + bbx_wh = v[:, 2:] - v[:, :2] + id_area_sum[k] = (bbx_wh[:, 0] * bbx_wh[:, 1] / w / h).sum() + id2area_sum = dict(sorted(id_area_sum.items(), key=lambda item: item[1], reverse=True)) + id_sorted = list(id2area_sum.keys()) + + return id_to_frame_ids, id_to_bbx_xyxys, id_sorted + + def get_one_track(self, video_path): + # track + track_history = self.track(video_path) + + # parse track_history & use top1 track + id_to_frame_ids, id_to_bbx_xyxys, id_sorted = self.sort_track_length(track_history, video_path) + track_id = id_sorted[0] + frame_ids = torch.tensor(id_to_frame_ids[track_id]) # (N,) + bbx_xyxys = torch.tensor(id_to_bbx_xyxys[track_id]) # (N, 4) + + # interpolate missing frames + mask = frame_id_to_mask(frame_ids, get_video_lwh(video_path)[0]) + bbx_xyxy_one_track = rearrange_by_mask(bbx_xyxys, mask) # (F, 4), missing filled with 0 + missing_frame_id_list = get_frame_id_list_from_mask(~mask) # list of list + bbx_xyxy_one_track = linear_interpolate_frame_ids(bbx_xyxy_one_track, missing_frame_id_list) + assert (bbx_xyxy_one_track.sum(1) != 0).all() + + bbx_xyxy_one_track = moving_average_smooth(bbx_xyxy_one_track, window_size=5, dim=0) + bbx_xyxy_one_track = moving_average_smooth(bbx_xyxy_one_track, window_size=5, dim=0) + + return bbx_xyxy_one_track diff --git a/hmr4d/utils/preproc/vitfeat_extractor.py b/hmr4d/utils/preproc/vitfeat_extractor.py new file mode 100644 index 0000000..6ae37f5 --- /dev/null +++ b/hmr4d/utils/preproc/vitfeat_extractor.py @@ -0,0 +1,89 @@ +import torch +from hmr4d.network.hmr2 import load_hmr2, HMR2 + + +from hmr4d.utils.video_io_utils import read_video_np +import cv2 +import numpy as np + +from hmr4d.network.hmr2.utils.preproc import crop_and_resize, IMAGE_MEAN, IMAGE_STD +from tqdm import tqdm + + +def get_batch(input_path, bbx_xys, img_ds=0.5, img_dst_size=256, path_type="video"): + if path_type == "video": + imgs = read_video_np(input_path, scale=img_ds) + elif path_type == "image": + imgs = cv2.imread(str(input_path))[..., ::-1] + imgs = cv2.resize(imgs, (0, 0), fx=img_ds, fy=img_ds) + imgs = imgs[None] + elif path_type == "np": + assert isinstance(input_path, np.ndarray) + assert img_ds == 1.0 # this is safe + imgs = input_path + + gt_center = bbx_xys[:, :2] + gt_bbx_size = bbx_xys[:, 2] + + # Blur image to avoid aliasing artifacts + if True: + gt_bbx_size_ds = gt_bbx_size * img_ds + ds_factors = ((gt_bbx_size_ds * 1.0) / img_dst_size / 2.0).numpy() + imgs = np.stack( + [ + # gaussian(v, sigma=(d - 1) / 2, channel_axis=2, preserve_range=True) if d > 1.1 else v + cv2.GaussianBlur(v, (5, 5), (d - 1) / 2) if d > 1.1 else v + for v, d in zip(imgs, ds_factors) + ] + ) + + # Output + imgs_list = [] + bbx_xys_ds_list = [] + for i in range(len(imgs)): + img, bbx_xys_ds = crop_and_resize( + imgs[i], + gt_center[i] * img_ds, + gt_bbx_size[i] * img_ds, + img_dst_size, + enlarge_ratio=1.0, + ) + imgs_list.append(img) + bbx_xys_ds_list.append(bbx_xys_ds) + imgs = torch.from_numpy(np.stack(imgs_list)) # (F, 256, 256, 3), RGB + bbx_xys = torch.from_numpy(np.stack(bbx_xys_ds_list)) / img_ds # (F, 3) + + imgs = ((imgs / 255.0 - IMAGE_MEAN) / IMAGE_STD).permute(0, 3, 1, 2) # (F, 3, 256, 256 + return imgs, bbx_xys + + +class Extractor: + def __init__(self, tqdm_leave=True): + self.extractor: HMR2 = load_hmr2().cuda().eval() + self.tqdm_leave = tqdm_leave + + def extract_video_features(self, video_path, bbx_xys, img_ds=0.5): + """ + img_ds makes the image smaller, which is useful for faster processing + """ + # Get the batch + if isinstance(video_path, str): + imgs, bbx_xys = get_batch(video_path, bbx_xys, img_ds=img_ds) + else: + assert isinstance(video_path, torch.Tensor) + imgs = video_path + + # Inference + F, _, H, W = imgs.shape # (F, 3, H, W) + imgs = imgs.cuda() + batch_size = 16 # 5GB GPU memory, occupies all CUDA cores of 3090 + features = [] + for j in tqdm(range(0, F, batch_size), desc="HMR2 Feature", leave=self.tqdm_leave): + imgs_batch = imgs[j : j + batch_size] + + with torch.no_grad(): + feature = self.extractor({"img": imgs_batch}) + features.append(feature.detach().cpu()) + + features = torch.cat(features, dim=0).clone() # (F, 1024) + return features diff --git a/hmr4d/utils/preproc/vitpose.py b/hmr4d/utils/preproc/vitpose.py new file mode 100644 index 0000000..215adc8 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose.py @@ -0,0 +1,144 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .vitpose_pytorch import build_model +from .vitfeat_extractor import get_batch +from tqdm import tqdm + +from hmr4d.utils.kpts.kp2d_utils import keypoints_from_heatmaps +from hmr4d.utils.geo_transform import cvt_p2d_from_pm1_to_i +from hmr4d.utils.geo.flip_utils import flip_heatmap_coco17 + + +class VitPoseExtractor: + def __init__(self, tqdm_leave=True): + ckpt_path = "inputs/checkpoints/vitpose/vitpose-h-multi-coco.pth" + self.pose = build_model("ViTPose_huge_coco_256x192", ckpt_path) + self.pose.cuda().eval() + + self.flip_test = True + self.tqdm_leave = tqdm_leave + + @torch.no_grad() + def extract(self, video_path, bbx_xys, img_ds=0.5): + # Get the batch + if isinstance(video_path, str): + imgs, bbx_xys = get_batch(video_path, bbx_xys, img_ds=img_ds) + else: + assert isinstance(video_path, torch.Tensor) + imgs = video_path + + # Inference + L, _, H, W = imgs.shape # (L, 3, H, W) + batch_size = 16 + vitpose = [] + for j in tqdm(range(0, L, batch_size), desc="ViTPose", leave=self.tqdm_leave): + # Heat map + imgs_batch = imgs[j : j + batch_size, :, :, 32:224].cuda() + if self.flip_test: + heatmap, heatmap_flipped = self.pose(torch.cat([imgs_batch, imgs_batch.flip(3)], dim=0)).chunk(2) + heatmap_flipped = flip_heatmap_coco17(heatmap_flipped) + heatmap = (heatmap + heatmap_flipped) * 0.5 + del heatmap_flipped + else: + heatmap = self.pose(imgs_batch.clone()) # (B, J, 64, 48) + + if False: + # Get joint + bbx_xys_batch = bbx_xys[j : j + batch_size].cuda() + method = "hard" + if method == "hard": + kp2d_pm1, conf = get_heatmap_preds(heatmap) + elif method == "soft": + kp2d_pm1, conf = get_heatmap_preds(heatmap, soft=True) + + # Convert 64, 48 to 64, 64 + kp2d_pm1[:, :, 0] *= 24 / 32 + kp2d = cvt_p2d_from_pm1_to_i(kp2d_pm1, bbx_xys_batch[:, None]) + kp2d = torch.cat([kp2d, conf], dim=-1) + + else: # postprocess from mmpose + bbx_xys_batch = bbx_xys[j : j + batch_size] + heatmap = heatmap.clone().cpu().numpy() + center = bbx_xys_batch[:, :2].numpy() + scale = (torch.cat((bbx_xys_batch[:, [2]] * 24 / 32, bbx_xys_batch[:, [2]]), dim=1) / 200).numpy() + preds, maxvals = keypoints_from_heatmaps(heatmaps=heatmap, center=center, scale=scale, use_udp=True) + kp2d = np.concatenate((preds, maxvals), axis=-1) + kp2d = torch.from_numpy(kp2d) + + vitpose.append(kp2d.detach().cpu().clone()) + + vitpose = torch.cat(vitpose, dim=0).clone() # (F, 17, 3) + return vitpose + + +def get_heatmap_preds(heatmap, normalize_keypoints=True, thr=0.0, soft=False): + """ + heatmap: (B, J, H, W) + """ + assert heatmap.ndim == 4, "batch_images should be 4-ndim" + + B, J, H, W = heatmap.shape + heatmaps_reshaped = heatmap.reshape((B, J, -1)) + + maxvals, idx = torch.max(heatmaps_reshaped, 2) + maxvals = maxvals.reshape((B, J, 1)) + idx = idx.reshape((B, J, 1)) + preds = idx.repeat(1, 1, 2).float() + preds[:, :, 0] = (preds[:, :, 0]) % W + preds[:, :, 1] = torch.floor((preds[:, :, 1]) / W) + + pred_mask = torch.gt(maxvals, thr).repeat(1, 1, 2) + pred_mask = pred_mask.float() + preds *= pred_mask + + # soft peak + if soft: + patch_size = 5 + patch_half = patch_size // 2 + patches = torch.zeros((B, J, patch_size, patch_size)).to(heatmap) + default_patch = torch.zeros(patch_size, patch_size).to(heatmap) + default_patch[patch_half, patch_half] = 1 + for b in range(B): + for j in range(17): + x, y = preds[b, j].int() + if x >= patch_half and x <= W - patch_half and y >= patch_half and y <= H - patch_half: + patches[b, j] = heatmap[ + b, j, y - patch_half : y + patch_half + 1, x - patch_half : x + patch_half + 1 + ] + else: + patches[b, j] = default_patch + + dx, dy = soft_patch_dx_dy(patches) + preds[:, :, 0] += dx + preds[:, :, 1] += dy + + if normalize_keypoints: # to [-1, 1] + preds[:, :, 0] = preds[:, :, 0] / (W - 1) * 2 - 1 + preds[:, :, 1] = preds[:, :, 1] / (H - 1) * 2 - 1 + + return preds, maxvals + + +def soft_patch_dx_dy(p): + """p (B,J,P,P)""" + p_batch_shape = p.shape[:-2] + patch_size = p.size(-1) + temperature = 1.0 + score = F.softmax(p.view(-1, patch_size**2) * temperature, dim=-1) + + # get a offset_grid (BN, P, P, 2) for dx, dy + offset_grid = torch.meshgrid(torch.arange(patch_size), torch.arange(patch_size))[::-1] + offset_grid = torch.stack(offset_grid, dim=-1).float() - (patch_size - 1) / 2 + offset_grid = offset_grid.view(1, 1, patch_size, patch_size, 2).to(p.device) + + score = score.view(*p_batch_shape, patch_size, patch_size) + dx = torch.sum(score * offset_grid[..., 0], dim=(-2, -1)) + dy = torch.sum(score * offset_grid[..., 1], dim=(-2, -1)) + + if False: + b, j = 0, 0 + print(torch.stack([dx[b, j], dy[b, j]])) + print(p[b, j]) + + return dx, dy diff --git a/hmr4d/utils/preproc/vitpose_pytorch/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/__init__.py new file mode 100644 index 0000000..df55ce4 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/__init__.py @@ -0,0 +1 @@ +from .src.vitpose_infer.model_builder import build_model diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/__init__.py new file mode 100644 index 0000000..44586e8 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .alexnet import AlexNet +# from .cpm import CPM +# from .hourglass import HourglassNet +# from .hourglass_ae import HourglassAENet +# from .hrformer import HRFormer +# from .hrnet import HRNet +# from .litehrnet import LiteHRNet +# from .mobilenet_v2 import MobileNetV2 +# from .mobilenet_v3 import MobileNetV3 +# from .mspn import MSPN +# from .regnet import RegNet +# from .resnest import ResNeSt +# from .resnet import ResNet, ResNetV1d +# from .resnext import ResNeXt +# from .rsn import RSN +# from .scnet import SCNet +# from .seresnet import SEResNet +# from .seresnext import SEResNeXt +# from .shufflenet_v1 import ShuffleNetV1 +# from .shufflenet_v2 import ShuffleNetV2 +# from .tcn import TCN +# from .v2v_net import V2VNet +# from .vgg import VGG +# from .vipnas_mbv3 import ViPNAS_MobileNetV3 +# from .vipnas_resnet import ViPNAS_ResNet +from .vit import ViT + +# __all__ = [ +# 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2', +# 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet', +# 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', +# 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3', +# 'LiteHRNet', 'V2VNet', 'HRFormer', 'ViT' +# ] diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/alexnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/alexnet.py new file mode 100644 index 0000000..a8efd74 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/alexnet.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +@BACKBONES.register_module() +class AlexNet(BaseBackbone): + """`AlexNet `__ backbone. + + The input for AlexNet is a 224x224 RGB image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super().__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return x diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/cpm.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/cpm.py new file mode 100644 index 0000000..458245d --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/cpm.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.utils import get_root_logger +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import load_checkpoint + + +class CpmBlock(nn.Module): + """CpmBlock for Convolutional Pose Machine. + + Args: + in_channels (int): Input channels of this block. + channels (list): Output channels of each conv module. + kernels (list): Kernel sizes of each conv module. + """ + + def __init__(self, + in_channels, + channels=(128, 128, 128), + kernels=(11, 11, 11), + norm_cfg=None): + super().__init__() + + assert len(channels) == len(kernels) + layers = [] + for i in range(len(channels)): + if i == 0: + input_channels = in_channels + else: + input_channels = channels[i - 1] + layers.append( + ConvModule( + input_channels, + channels[i], + kernels[i], + padding=(kernels[i] - 1) // 2, + norm_cfg=norm_cfg)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + """Model forward function.""" + out = self.model(x) + return out + + +@BACKBONES.register_module() +class CPM(BaseBackbone): + """CPM backbone. + + Convolutional Pose Machines. + More details can be found in the `paper + `__ . + + Args: + in_channels (int): The input channels of the CPM. + out_channels (int): The output channels of the CPM. + feat_channels (int): Feature channel of each CPM stage. + middle_channels (int): Feature channel of conv after the middle stage. + num_stages (int): Number of stages. + norm_cfg (dict): Dictionary to construct and config norm layer. + + Example: + >>> from mmpose.models import CPM + >>> import torch + >>> self = CPM(3, 17) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 368, 368) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + """ + + def __init__(self, + in_channels, + out_channels, + feat_channels=128, + middle_channels=32, + num_stages=6, + norm_cfg=dict(type='BN', requires_grad=True)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + assert in_channels == 3 + + self.num_stages = num_stages + assert self.num_stages >= 1 + + self.stem = nn.Sequential( + ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg), + ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg), + ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg), + ConvModule(512, out_channels, 1, padding=0, act_cfg=None)) + + self.middle = nn.Sequential( + ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + self.cpm_stages = nn.ModuleList([ + CpmBlock( + middle_channels + out_channels, + channels=[feat_channels, feat_channels, feat_channels], + kernels=[11, 11, 11], + norm_cfg=norm_cfg) for _ in range(num_stages - 1) + ]) + + self.middle_conv = nn.ModuleList([ + nn.Sequential( + ConvModule( + 128, middle_channels, 5, padding=2, norm_cfg=norm_cfg)) + for _ in range(num_stages - 1) + ]) + + self.out_convs = nn.ModuleList([ + nn.Sequential( + ConvModule( + feat_channels, + feat_channels, + 1, + padding=0, + norm_cfg=norm_cfg), + ConvModule(feat_channels, out_channels, 1, act_cfg=None)) + for _ in range(num_stages - 1) + ]) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Model forward function.""" + stage1_out = self.stem(x) + middle_out = self.middle(x) + out_feats = [] + + out_feats.append(stage1_out) + + for ind in range(self.num_stages - 1): + single_stage = self.cpm_stages[ind] + out_conv = self.out_convs[ind] + + inp_feat = torch.cat( + [out_feats[-1], self.middle_conv[ind](middle_out)], 1) + cpm_feat = single_stage(inp_feat) + out_feat = out_conv(cpm_feat) + out_feats.append(out_feat) + + return out_feats diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass.py new file mode 100644 index 0000000..bf75fad --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.utils import get_root_logger +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .resnet import BasicBlock, ResLayer +from .utils import load_checkpoint + + +class HourglassModule(nn.Module): + """Hourglass Module for HourglassNet backbone. + + Generate module recursively and use BasicBlock as the base unit. + + Args: + depth (int): Depth of current HourglassModule. + stage_channels (list[int]): Feature channels of sub-modules in current + and follow-up HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in current and + follow-up HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + """ + + def __init__(self, + depth, + stage_channels, + stage_blocks, + norm_cfg=dict(type='BN', requires_grad=True)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + self.depth = depth + + cur_block = stage_blocks[0] + next_block = stage_blocks[1] + + cur_channel = stage_channels[0] + next_channel = stage_channels[1] + + self.up1 = ResLayer( + BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg) + + self.low1 = ResLayer( + BasicBlock, + cur_block, + cur_channel, + next_channel, + stride=2, + norm_cfg=norm_cfg) + + if self.depth > 1: + self.low2 = HourglassModule(depth - 1, stage_channels[1:], + stage_blocks[1:]) + else: + self.low2 = ResLayer( + BasicBlock, + next_block, + next_channel, + next_channel, + norm_cfg=norm_cfg) + + self.low3 = ResLayer( + BasicBlock, + cur_block, + next_channel, + cur_channel, + norm_cfg=norm_cfg, + downsample_first=False) + + self.up2 = nn.Upsample(scale_factor=2) + + def forward(self, x): + """Model forward function.""" + up1 = self.up1(x) + low1 = self.low1(x) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 + + +@BACKBONES.register_module() +class HourglassNet(BaseBackbone): + """HourglassNet backbone. + + Stacked Hourglass Networks for Human Pose Estimation. + More details can be found in the `paper + `__ . + + Args: + downsample_times (int): Downsample times in a HourglassModule. + num_stacks (int): Number of HourglassModule modules stacked, + 1 for Hourglass-52, 2 for Hourglass-104. + stage_channels (list[int]): Feature channel of each sub-module in a + HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in a + HourglassModule. + feat_channel (int): Feature channel of conv after a HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + + Example: + >>> from mmpose.models import HourglassNet + >>> import torch + >>> self = HourglassNet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 256, 128, 128) + (1, 256, 128, 128) + """ + + def __init__(self, + downsample_times=5, + num_stacks=2, + stage_channels=(256, 256, 384, 384, 384, 512), + stage_blocks=(2, 2, 2, 2, 2, 4), + feat_channel=256, + norm_cfg=dict(type='BN', requires_grad=True)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + self.num_stacks = num_stacks + assert self.num_stacks >= 1 + assert len(stage_channels) == len(stage_blocks) + assert len(stage_channels) > downsample_times + + cur_channel = stage_channels[0] + + self.stem = nn.Sequential( + ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg), + ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg)) + + self.hourglass_modules = nn.ModuleList([ + HourglassModule(downsample_times, stage_channels, stage_blocks) + for _ in range(num_stacks) + ]) + + self.inters = ResLayer( + BasicBlock, + num_stacks - 1, + cur_channel, + cur_channel, + norm_cfg=norm_cfg) + + self.conv1x1s = nn.ModuleList([ + ConvModule( + cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.out_convs = nn.ModuleList([ + ConvModule( + cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) + for _ in range(num_stacks) + ]) + + self.remap_convs = nn.ModuleList([ + ConvModule( + feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.relu = nn.ReLU(inplace=True) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Model forward function.""" + inter_feat = self.stem(x) + out_feats = [] + + for ind in range(self.num_stacks): + single_hourglass = self.hourglass_modules[ind] + out_conv = self.out_convs[ind] + + hourglass_feat = single_hourglass(inter_feat) + out_feat = out_conv(hourglass_feat) + out_feats.append(out_feat) + + if ind < self.num_stacks - 1: + inter_feat = self.conv1x1s[ind]( + inter_feat) + self.remap_convs[ind]( + out_feat) + inter_feat = self.inters[ind](self.relu(inter_feat)) + + return out_feats diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass_ae.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass_ae.py new file mode 100644 index 0000000..5a700e5 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass_ae.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, MaxPool2d, constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.utils import get_root_logger +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import load_checkpoint + + +class HourglassAEModule(nn.Module): + """Modified Hourglass Module for HourglassNet_AE backbone. + + Generate module recursively and use BasicBlock as the base unit. + + Args: + depth (int): Depth of current HourglassModule. + stage_channels (list[int]): Feature channels of sub-modules in current + and follow-up HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + """ + + def __init__(self, + depth, + stage_channels, + norm_cfg=dict(type='BN', requires_grad=True)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + self.depth = depth + + cur_channel = stage_channels[0] + next_channel = stage_channels[1] + + self.up1 = ConvModule( + cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.pool1 = MaxPool2d(2, 2) + + self.low1 = ConvModule( + cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) + + if self.depth > 1: + self.low2 = HourglassAEModule(depth - 1, stage_channels[1:]) + else: + self.low2 = ConvModule( + next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.low3 = ConvModule( + next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.up2 = nn.UpsamplingNearest2d(scale_factor=2) + + def forward(self, x): + """Model forward function.""" + up1 = self.up1(x) + pool1 = self.pool1(x) + low1 = self.low1(pool1) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 + + +@BACKBONES.register_module() +class HourglassAENet(BaseBackbone): + """Hourglass-AE Network proposed by Newell et al. + + Associative Embedding: End-to-End Learning for Joint + Detection and Grouping. + + More details can be found in the `paper + `__ . + + Args: + downsample_times (int): Downsample times in a HourglassModule. + num_stacks (int): Number of HourglassModule modules stacked, + 1 for Hourglass-52, 2 for Hourglass-104. + stage_channels (list[int]): Feature channel of each sub-module in a + HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in a + HourglassModule. + feat_channels (int): Feature channel of conv after a HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + + Example: + >>> from mmpose.models import HourglassAENet + >>> import torch + >>> self = HourglassAENet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 512, 512) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 34, 128, 128) + """ + + def __init__(self, + downsample_times=4, + num_stacks=1, + out_channels=34, + stage_channels=(256, 384, 512, 640, 768), + feat_channels=256, + norm_cfg=dict(type='BN', requires_grad=True)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + self.num_stacks = num_stacks + assert self.num_stacks >= 1 + assert len(stage_channels) > downsample_times + + cur_channels = stage_channels[0] + + self.stem = nn.Sequential( + ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg), + ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg), + MaxPool2d(2, 2), + ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg), + ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg), + ) + + self.hourglass_modules = nn.ModuleList([ + nn.Sequential( + HourglassAEModule( + downsample_times, stage_channels, norm_cfg=norm_cfg), + ConvModule( + feat_channels, + feat_channels, + 3, + padding=1, + norm_cfg=norm_cfg), + ConvModule( + feat_channels, + feat_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) for _ in range(num_stacks) + ]) + + self.out_convs = nn.ModuleList([ + ConvModule( + cur_channels, + out_channels, + 1, + padding=0, + norm_cfg=None, + act_cfg=None) for _ in range(num_stacks) + ]) + + self.remap_out_convs = nn.ModuleList([ + ConvModule( + out_channels, + feat_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None) for _ in range(num_stacks - 1) + ]) + + self.remap_feature_convs = nn.ModuleList([ + ConvModule( + feat_channels, + feat_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None) for _ in range(num_stacks - 1) + ]) + + self.relu = nn.ReLU(inplace=True) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Model forward function.""" + inter_feat = self.stem(x) + out_feats = [] + + for ind in range(self.num_stacks): + single_hourglass = self.hourglass_modules[ind] + out_conv = self.out_convs[ind] + + hourglass_feat = single_hourglass(inter_feat) + out_feat = out_conv(hourglass_feat) + out_feats.append(out_feat) + + if ind < self.num_stacks - 1: + inter_feat = inter_feat + self.remap_out_convs[ind]( + out_feat) + self.remap_feature_convs[ind]( + hourglass_feat) + + return out_feats diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hrformer.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hrformer.py new file mode 100644 index 0000000..b843300 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hrformer.py @@ -0,0 +1,746 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math + +import torch +import torch.nn as nn +# from timm.models.layers import to_2tuple, trunc_normal_ +from mmcv.cnn import (build_activation_layer, build_conv_layer, + build_norm_layer, trunc_normal_init) +from mmcv.cnn.bricks.transformer import build_dropout +from mmcv.runner import BaseModule +from torch.nn.functional import pad + +from ..builder import BACKBONES +from .hrnet import Bottleneck, HRModule, HRNet + + +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def build_drop_path(drop_path_rate): + """Build drop path layer.""" + return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate)) + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + with_rpe (bool, optional): If True, use relative position bias. + Default: True. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + with_rpe=True, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + self.with_rpe = with_rpe + if self.with_rpe: + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_init(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (B*num_windows, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.with_rpe: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class LocalWindowSelfAttention(BaseModule): + r""" Local-window Self Attention (LSA) module with relative position bias. + + This module is the short-range self-attention module in the + Interlaced Sparse Self-Attention `_. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int] | int): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + with_rpe (bool, optional): If True, use relative position bias. + Default: True. + with_pad_mask (bool, optional): If True, mask out the padded tokens in + the attention process. Default: False. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + with_rpe=True, + with_pad_mask=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(window_size, int): + window_size = (window_size, window_size) + self.window_size = window_size + self.with_pad_mask = with_pad_mask + self.attn = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + with_rpe=with_rpe, + init_cfg=init_cfg) + + def forward(self, x, H, W, **kwargs): + """Forward function.""" + B, N, C = x.shape + x = x.view(B, H, W, C) + Wh, Ww = self.window_size + + # center-pad the feature on H and W axes + pad_h = math.ceil(H / Wh) * Wh - H + pad_w = math.ceil(W / Ww) * Ww - W + x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + + # permute + x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C) + x = x.permute(0, 1, 3, 2, 4, 5) + x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C) + + # attention + if self.with_pad_mask and pad_h > 0 and pad_w > 0: + pad_mask = x.new_zeros(1, H, W, 1) + pad_mask = pad( + pad_mask, [ + 0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ], + value=-float('inf')) + pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh, + math.ceil(W / Ww), Ww, 1) + pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5) + pad_mask = pad_mask.reshape(-1, Wh * Ww) + pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1]) + out = self.attn(x, pad_mask, **kwargs) + else: + out = self.attn(x, **kwargs) + + # reverse permutation + out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C) + out = out.permute(0, 1, 3, 2, 4, 5) + out = out.reshape(B, H + pad_h, W + pad_w, C) + + # de-pad + out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2] + return out.reshape(B, N, C) + + +class CrossFFN(BaseModule): + r"""FFN with Depthwise Conv of HRFormer. + + Args: + in_features (int): The feature dimension. + hidden_features (int, optional): The hidden dimension of FFNs. + Defaults: The same as in_features. + act_cfg (dict, optional): Config of activation layer. + Default: dict(type='GELU'). + dw_act_cfg (dict, optional): Config of activation layer appended + right after DW Conv. Default: dict(type='GELU'). + norm_cfg (dict, optional): Config of norm layer. + Default: dict(type='SyncBN'). + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + dw_act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act1 = build_activation_layer(act_cfg) + self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] + self.dw3x3 = nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=3, + stride=1, + groups=hidden_features, + padding=1) + self.act2 = build_activation_layer(dw_act_cfg) + self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1] + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + self.act3 = build_activation_layer(act_cfg) + self.norm3 = build_norm_layer(norm_cfg, out_features)[1] + + # put the modules togather + self.layers = [ + self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2, + self.fc2, self.norm3, self.act3 + ] + + def forward(self, x, H, W): + """Forward function.""" + x = nlc_to_nchw(x, (H, W)) + for layer in self.layers: + x = layer(x) + x = nchw_to_nlc(x) + return x + + +class HRFormerBlock(BaseModule): + """High-Resolution Block for HRFormer. + + Args: + in_features (int): The input dimension. + out_features (int): The output dimension. + num_heads (int): The number of head within each LSA. + window_size (int, optional): The window size for the LSA. + Default: 7 + mlp_ratio (int, optional): The expansion ration of FFN. + Default: 4 + act_cfg (dict, optional): Config of activation layer. + Default: dict(type='GELU'). + norm_cfg (dict, optional): Config of norm layer. + Default: dict(type='SyncBN'). + transformer_norm_cfg (dict, optional): Config of transformer norm + layer. Default: dict(type='LN', eps=1e-6). + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + expansion = 1 + + def __init__(self, + in_features, + out_features, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop_path=0.0, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN'), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + init_cfg=None, + **kwargs): + super(HRFormerBlock, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1] + self.attn = LocalWindowSelfAttention( + in_features, + num_heads=num_heads, + window_size=window_size, + init_cfg=None, + **kwargs) + + self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1] + self.ffn = CrossFFN( + in_features=in_features, + hidden_features=int(in_features * mlp_ratio), + out_features=out_features, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dw_act_cfg=act_cfg, + init_cfg=None) + + self.drop_path = build_drop_path( + drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Forward function.""" + B, C, H, W = x.size() + # Attention + x = x.view(B, C, -1).permute(0, 2, 1) + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + # FFN + x = x + self.drop_path(self.ffn(self.norm2(x), H, W)) + x = x.permute(0, 2, 1).view(B, C, H, W) + return x + + def extra_repr(self): + """(Optional) Set the extra information about this module.""" + return 'num_heads={}, window_size={}, mlp_ratio={}'.format( + self.num_heads, self.window_size, self.mlp_ratio) + + +class HRFomerModule(HRModule): + """High-Resolution Module for HRFormer. + + Args: + num_branches (int): The number of branches in the HRFormerModule. + block (nn.Module): The building block of HRFormer. + The block should be the HRFormerBlock. + num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + num_inchannels (tuple): The number of input channels in each branch. + The length must be equal to num_branches. + num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + num_heads (tuple): The number of heads within the LSAs. + num_window_sizes (tuple): The window size for the LSAs. + num_mlp_ratios (tuple): The expansion ratio for the FFNs. + drop_path (int, optional): The drop path rate of HRFomer. + Default: 0.0 + multiscale_output (bool, optional): Whether to output multi-level + features produced by multiple branches. If False, only the first + level feature will be output. Default: True. + conv_cfg (dict, optional): Config of the conv layers. + Default: None. + norm_cfg (dict, optional): Config of the norm layers appended + right after conv. Default: dict(type='SyncBN', requires_grad=True) + transformer_norm_cfg (dict, optional): Config of the norm layers. + Default: dict(type='LN', eps=1e-6) + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False + upsample_cfg(dict, optional): The config of upsample layers in fuse + layers. Default: dict(mode='bilinear', align_corners=False) + """ + + def __init__(self, + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + multiscale_output=True, + drop_paths=0.0, + with_rpe=True, + with_pad_mask=False, + conv_cfg=None, + norm_cfg=dict(type='SyncBN', requires_grad=True), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + with_cp=False, + upsample_cfg=dict(mode='bilinear', align_corners=False)): + + self.transformer_norm_cfg = transformer_norm_cfg + self.drop_paths = drop_paths + self.num_heads = num_heads + self.num_window_sizes = num_window_sizes + self.num_mlp_ratios = num_mlp_ratios + self.with_rpe = with_rpe + self.with_pad_mask = with_pad_mask + + super().__init__(num_branches, block, num_blocks, num_inchannels, + num_channels, multiscale_output, with_cp, conv_cfg, + norm_cfg, upsample_cfg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + # HRFormerBlock does not support down sample layer yet. + assert stride == 1 and self.in_channels[branch_index] == num_channels[ + branch_index] + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + num_heads=self.num_heads[branch_index], + window_size=self.num_window_sizes[branch_index], + mlp_ratio=self.num_mlp_ratios[branch_index], + drop_path=self.drop_paths[0], + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + init_cfg=None, + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask)) + + self.in_channels[ + branch_index] = self.in_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + num_heads=self.num_heads[branch_index], + window_size=self.num_window_sizes[branch_index], + mlp_ratio=self.num_mlp_ratios[branch_index], + drop_path=self.drop_paths[i], + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + init_cfg=None, + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask)) + return nn.Sequential(*layers) + + def _make_fuse_layers(self): + """Build fuse layers.""" + if self.num_branches == 1: + return None + num_branches = self.num_branches + num_inchannels = self.in_channels + fuse_layers = [] + for i in range(num_branches if self.multiscale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_inchannels[i], + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_inchannels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), + mode=self.upsample_cfg['mode'], + align_corners=self. + upsample_cfg['align_corners']))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + with_out_act = False + else: + num_outchannels_conv3x3 = num_inchannels[j] + with_out_act = True + sub_modules = [ + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_inchannels[j], + kernel_size=3, + stride=2, + padding=1, + groups=num_inchannels[j], + bias=False, + ), + build_norm_layer(self.norm_cfg, + num_inchannels[j])[1], + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=1, + stride=1, + bias=False, + ), + build_norm_layer(self.norm_cfg, + num_outchannels_conv3x3)[1] + ] + if with_out_act: + sub_modules.append(nn.ReLU(False)) + conv3x3s.append(nn.Sequential(*sub_modules)) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + """Return the number of input channels.""" + return self.in_channels + + +@BACKBONES.register_module() +class HRFormer(HRNet): + """HRFormer backbone. + + This backbone is the implementation of `HRFormer: High-Resolution + Transformer for Dense Prediction `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Config of norm layer. + Use `SyncBN` by default. + transformer_norm_cfg (dict): Config of transformer norm layer. + Use `LN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + Example: + >>> from mmpose.models import HRFormer + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(2, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='HRFORMER', + >>> window_sizes=(7, 7), + >>> num_heads=(1, 2), + >>> mlp_ratios=(4, 4), + >>> num_blocks=(2, 2), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='HRFORMER', + >>> window_sizes=(7, 7, 7), + >>> num_heads=(1, 2, 4), + >>> mlp_ratios=(4, 4, 4), + >>> num_blocks=(2, 2, 2), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=2, + >>> num_branches=4, + >>> block='HRFORMER', + >>> window_sizes=(7, 7, 7, 7), + >>> num_heads=(1, 2, 4, 8), + >>> mlp_ratios=(4, 4, 4, 4), + >>> num_blocks=(2, 2, 2, 2), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRFormer(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + frozen_stages=-1): + + # stochastic depth + depths = [ + extra[stage]['num_blocks'][0] * extra[stage]['num_modules'] + for stage in ['stage2', 'stage3', 'stage4'] + ] + depth_s2, depth_s3, _ = depths + drop_path_rate = extra['drop_path_rate'] + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + extra['stage2']['drop_path_rates'] = dpr[0:depth_s2] + extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3] + extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:] + + # HRFormer use bilinear upsample as default + upsample_cfg = extra.get('upsample', { + 'mode': 'bilinear', + 'align_corners': False + }) + extra['upsample'] = upsample_cfg + self.transformer_norm_cfg = transformer_norm_cfg + self.with_rpe = extra.get('with_rpe', True) + self.with_pad_mask = extra.get('with_pad_mask', False) + + super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval, + with_cp, zero_init_residual, frozen_stages) + + def _make_stage(self, + layer_config, + num_inchannels, + multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + num_heads = layer_config['num_heads'] + num_window_sizes = layer_config['window_sizes'] + num_mlp_ratios = layer_config['mlp_ratios'] + drop_path_rates = layer_config['drop_path_rates'] + + modules = [] + for i in range(num_modules): + # multiscale_output is only used at the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + modules.append( + HRFomerModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + reset_multiscale_output, + drop_paths=drop_path_rates[num_blocks[0] * + i:num_blocks[0] * (i + 1)], + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + with_cp=self.with_cp, + upsample_cfg=self.upsample_cfg)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/litehrnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/litehrnet.py new file mode 100644 index 0000000..9543688 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/litehrnet.py @@ -0,0 +1,984 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/HRNet/Lite-HRNet +# Original licence: Apache License 2.0. +# ------------------------------------------------------------------------------ + +import mmcv +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_conv_layer, build_norm_layer, constant_init, + normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.utils import get_root_logger +from ..builder import BACKBONES +from .utils import channel_shuffle, load_checkpoint + + +class SpatialWeighting(nn.Module): + """Spatial weighting module. + + Args: + channels (int): The channels of the module. + ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + act_cfg (dict): Config dict for activation layer. + Default: (dict(type='ReLU'), dict(type='Sigmoid')). + The last ConvModule uses Sigmoid by default. + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + norm_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class CrossResolutionWeighting(nn.Module): + """Cross-resolution channel weighting module. + + Args: + channels (int): The channels of the module. + ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + act_cfg (dict): Config dict for activation layer. + Default: (dict(type='ReLU'), dict(type='Sigmoid')). + The last ConvModule uses Sigmoid by default. + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + norm_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.channels = channels + total_channel = sum(channels) + self.conv1 = ConvModule( + in_channels=total_channel, + out_channels=int(total_channel / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(total_channel / ratio), + out_channels=total_channel, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + mini_size = x[-1].size()[-2:] + out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] + out = torch.cat(out, dim=1) + out = self.conv1(out) + out = self.conv2(out) + out = torch.split(out, self.channels, dim=1) + out = [ + s * F.interpolate(a, size=s.size()[-2:], mode='nearest') + for s, a in zip(x, out) + ] + return out + + +class ConditionalChannelWeighting(nn.Module): + """Conditional channel weighting block. + + Args: + in_channels (int): The input channels of the block. + stride (int): Stride of the 3x3 convolution layer. + reduce_ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + stride, + reduce_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.stride = stride + assert stride in [1, 2] + + branch_channels = [channel // 2 for channel in in_channels] + + self.cross_resolution_weighting = CrossResolutionWeighting( + branch_channels, + ratio=reduce_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.depthwise_convs = nn.ModuleList([ + ConvModule( + channel, + channel, + kernel_size=3, + stride=self.stride, + padding=1, + groups=channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) for channel in branch_channels + ]) + + self.spatial_weighting = nn.ModuleList([ + SpatialWeighting(channels=channel, ratio=4) + for channel in branch_channels + ]) + + def forward(self, x): + + def _inner_forward(x): + x = [s.chunk(2, dim=1) for s in x] + x1 = [s[0] for s in x] + x2 = [s[1] for s in x] + + x2 = self.cross_resolution_weighting(x2) + x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] + x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] + + out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] + out = [channel_shuffle(s, 2) for s in out] + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class Stem(nn.Module): + """Stem network block. + + Args: + in_channels (int): The input channels of the block. + stem_channels (int): Output channels of the stem layer. + out_channels (int): The output channels of the block. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + stem_channels, + out_channels, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU')) + + mid_channels = int(round(stem_channels * expand_ratio)) + branch_channels = stem_channels // 2 + if stem_channels == self.out_channels: + inc_channels = self.out_channels - branch_channels + else: + inc_channels = self.out_channels - stem_channels + + self.branch1 = nn.Sequential( + ConvModule( + branch_channels, + branch_channels, + kernel_size=3, + stride=2, + padding=1, + groups=branch_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_channels, + inc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')), + ) + + self.expand_conv = ConvModule( + branch_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.depthwise_conv = ConvModule( + mid_channels, + mid_channels, + kernel_size=3, + stride=2, + padding=1, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + self.linear_conv = ConvModule( + mid_channels, + branch_channels + if stem_channels == self.out_channels else stem_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + + def forward(self, x): + + def _inner_forward(x): + x = self.conv1(x) + x1, x2 = x.chunk(2, dim=1) + + x2 = self.expand_conv(x2) + x2 = self.depthwise_conv(x2) + x2 = self.linear_conv(x2) + + out = torch.cat((self.branch1(x1), x2), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class IterativeHead(nn.Module): + """Extra iterative head for feature learning. + + Args: + in_channels (int): The input channels of the block. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, in_channels, norm_cfg=dict(type='BN')): + super().__init__() + projects = [] + num_branchs = len(in_channels) + self.in_channels = in_channels[::-1] + + for i in range(num_branchs): + if i != num_branchs - 1: + projects.append( + DepthwiseSeparableConvModule( + in_channels=self.in_channels[i], + out_channels=self.in_channels[i + 1], + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + dw_act_cfg=None, + pw_act_cfg=dict(type='ReLU'))) + else: + projects.append( + DepthwiseSeparableConvModule( + in_channels=self.in_channels[i], + out_channels=self.in_channels[i], + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + dw_act_cfg=None, + pw_act_cfg=dict(type='ReLU'))) + self.projects = nn.ModuleList(projects) + + def forward(self, x): + x = x[::-1] + + y = [] + last_x = None + for i, s in enumerate(x): + if last_x is not None: + last_x = F.interpolate( + last_x, + size=s.size()[-2:], + mode='bilinear', + align_corners=True) + s = s + last_x + s = self.projects[i](s) + y.append(s) + last_x = s + + return y[::-1] + + +class ShuffleUnit(nn.Module): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super().__init__() + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class LiteHRModule(nn.Module): + """High-Resolution Module for LiteHRNet. + + It contains conditional channel weighting blocks and + shuffle blocks. + + + Args: + num_branches (int): Number of branches in the module. + num_blocks (int): Number of blocks in the module. + in_channels (list(int)): Number of input image channels. + reduce_ratio (int): Channel reduction ratio. + module_type (str): 'LITE' or 'NAIVE' + multiscale_output (bool): Whether to output multi-scale features. + with_fuse (bool): Whether to use fuse layers. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__( + self, + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=False, + with_fuse=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False, + ): + super().__init__() + self._check_branches(num_branches, in_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.module_type = module_type + self.multiscale_output = multiscale_output + self.with_fuse = with_fuse + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + + if self.module_type.upper() == 'LITE': + self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) + elif self.module_type.upper() == 'NAIVE': + self.layers = self._make_naive_branches(num_branches, num_blocks) + else: + raise ValueError("module_type should be either 'LITE' or 'NAIVE'.") + if self.with_fuse: + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU() + + def _check_branches(self, num_branches, in_channels): + """Check input to avoid ValueError.""" + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): + """Make channel weighting blocks.""" + layers = [] + for i in range(num_blocks): + layers.append( + ConditionalChannelWeighting( + self.in_channels, + stride=stride, + reduce_ratio=reduce_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp)) + + return nn.Sequential(*layers) + + def _make_one_branch(self, branch_index, num_blocks, stride=1): + """Make one branch.""" + layers = [] + layers.append( + ShuffleUnit( + self.in_channels[branch_index], + self.in_channels[branch_index], + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU'), + with_cp=self.with_cp)) + for i in range(1, num_blocks): + layers.append( + ShuffleUnit( + self.in_channels[branch_index], + self.in_channels[branch_index], + stride=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU'), + with_cp=self.with_cp)) + + return nn.Sequential(*layers) + + def _make_naive_branches(self, num_branches, num_blocks): + """Make branches.""" + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, num_blocks)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + """Make fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=in_channels[j], + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=in_channels[j], + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.layers[0](x[0])] + + if self.module_type.upper() == 'LITE': + out = self.layers(x) + elif self.module_type.upper() == 'NAIVE': + for i in range(self.num_branches): + x[i] = self.layers[i](x[i]) + out = x + + if self.with_fuse: + out_fuse = [] + for i in range(len(self.fuse_layers)): + # `y = 0` will lead to decreased accuracy (0.5~1 mAP) + y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) + for j in range(self.num_branches): + if i == j: + y += out[j] + else: + y += self.fuse_layers[i][j](out[j]) + out_fuse.append(self.relu(y)) + out = out_fuse + if not self.multiscale_output: + out = [out[0]] + return out + + +@BACKBONES.register_module() +class LiteHRNet(nn.Module): + """Lite-HRNet backbone. + + `Lite-HRNet: A Lightweight High-Resolution Network + `_. + + Code adapted from 'https://github.com/HRNet/Lite-HRNet'. + + Args: + extra (dict): detailed configuration for each stage of HRNet. + in_channels (int): Number of input image channels. Default: 3. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + + Example: + >>> from mmpose.models import LiteHRNet + >>> import torch + >>> extra=dict( + >>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), + >>> num_stages=3, + >>> stages_spec=dict( + >>> num_modules=(2, 4, 2), + >>> num_branches=(2, 3, 4), + >>> num_blocks=(2, 2, 2), + >>> module_type=('LITE', 'LITE', 'LITE'), + >>> with_fuse=(True, True, True), + >>> reduce_ratios=(8, 8, 8), + >>> num_channels=( + >>> (40, 80), + >>> (40, 80, 160), + >>> (40, 80, 160, 320), + >>> )), + >>> with_head=False) + >>> self = LiteHRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 40, 8, 8) + """ + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False): + super().__init__() + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.stem = Stem( + in_channels, + stem_channels=self.extra['stem']['stem_channels'], + out_channels=self.extra['stem']['out_channels'], + expand_ratio=self.extra['stem']['expand_ratio'], + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + + self.num_stages = self.extra['num_stages'] + self.stages_spec = self.extra['stages_spec'] + + num_channels_last = [ + self.stem.out_channels, + ] + for i in range(self.num_stages): + num_channels = self.stages_spec['num_channels'][i] + num_channels = [num_channels[i] for i in range(len(num_channels))] + setattr( + self, f'transition{i}', + self._make_transition_layer(num_channels_last, num_channels)) + + stage, num_channels_last = self._make_stage( + self.stages_spec, i, num_channels, multiscale_output=True) + setattr(self, f'stage{i}', stage) + + self.with_head = self.extra['with_head'] + if self.with_head: + self.head_layer = IterativeHead( + in_channels=num_channels_last, + norm_cfg=self.norm_cfg, + ) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_pre_layer[i], + kernel_size=3, + stride=1, + padding=1, + groups=num_channels_pre_layer[i], + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_pre_layer[i])[1], + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU())) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + bias=False), + build_norm_layer(self.norm_cfg, in_channels)[1], + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU())) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_stage(self, + stages_spec, + stage_index, + in_channels, + multiscale_output=True): + num_modules = stages_spec['num_modules'][stage_index] + num_branches = stages_spec['num_branches'][stage_index] + num_blocks = stages_spec['num_blocks'][stage_index] + reduce_ratio = stages_spec['reduce_ratios'][stage_index] + with_fuse = stages_spec['with_fuse'][stage_index] + module_type = stages_spec['module_type'][stage_index] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + modules.append( + LiteHRModule( + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=reset_multiscale_output, + with_fuse=with_fuse, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp)) + in_channels = modules[-1].in_channels + + return nn.Sequential(*modules), in_channels + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.stem(x) + + y_list = [x] + for i in range(self.num_stages): + x_list = [] + transition = getattr(self, f'transition{i}') + for j in range(self.stages_spec['num_branches'][i]): + if transition[j]: + if j >= len(y_list): + x_list.append(transition[j](y_list[-1])) + else: + x_list.append(transition[j](y_list[j])) + else: + x_list.append(y_list[j]) + y_list = getattr(self, f'stage{i}')(x_list) + + x = y_list + if self.with_head: + x = self.head_layer(x) + + return [x[0]] + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v2.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v2.py new file mode 100644 index 0000000..5dc0cd1 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v2.py @@ -0,0 +1,275 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import load_checkpoint, make_divisible + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v3.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v3.py new file mode 100644 index 0000000..d640abe --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v3.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging + +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import InvertedResidual, load_checkpoint + + +@BACKBONES.register_module() +class MobileNetV3(BaseBackbone): + """MobileNetV3 backbone. + + Args: + arch (str): Architecture of mobilnetv3, from {small, big}. + Default: small. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (None or Sequence[int]): Output from which stages. + Default: (-1, ), which means output tensors from final stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'big': [[3, 16, 16, False, 'ReLU', 1], + [3, 64, 24, False, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(-1, ), + frozen_stages=-1, + norm_eval=False, + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + assert arch in self.arch_settings + for index in out_indices: + if index not in range(-len(self.arch_settings[arch]), + len(self.arch_settings[arch])): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch])): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = 16 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='HSwish')) + + self.layers = self._make_layer() + self.feat_dim = self.arch_settings[arch][-1][2] + + def _make_layer(self): + layers = [] + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), dict(type='HSigmoid'))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + self.in_channels = out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + return layers + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices or \ + i - len(self.layers) in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mspn.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mspn.py new file mode 100644 index 0000000..71cee34 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mspn.py @@ -0,0 +1,513 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy as cp +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init, + normal_init) +from mmcv.runner.checkpoint import load_state_dict + +from mmpose.utils import get_root_logger +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .resnet import Bottleneck as _Bottleneck +from .utils.utils import get_state_dict + + +class Bottleneck(_Bottleneck): + expansion = 4 + """Bottleneck block for MSPN. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + stride (int): stride of the block. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels * 4, **kwargs) + + +class DownsampleModule(nn.Module): + """Downsample module for MSPN. + + Args: + block (nn.Module): Downsample block. + num_blocks (list): Number of blocks in each downsample unit. + num_units (int): Numbers of downsample units. Default: 4 + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the input feature to + downsample module. Default: 64 + """ + + def __init__(self, + block, + num_blocks, + num_units=4, + has_skip=False, + norm_cfg=dict(type='BN'), + in_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.has_skip = has_skip + self.in_channels = in_channels + assert len(num_blocks) == num_units + self.num_blocks = num_blocks + self.num_units = num_units + self.norm_cfg = norm_cfg + self.layer1 = self._make_layer(block, in_channels, num_blocks[0]) + for i in range(1, num_units): + module_name = f'layer{i + 1}' + self.add_module( + module_name, + self._make_layer( + block, in_channels * pow(2, i), num_blocks[i], stride=2)) + + def _make_layer(self, block, out_channels, blocks, stride=1): + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = ConvModule( + self.in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + units = list() + units.append( + block( + self.in_channels, + out_channels, + stride=stride, + downsample=downsample, + norm_cfg=self.norm_cfg)) + self.in_channels = out_channels * block.expansion + for _ in range(1, blocks): + units.append(block(self.in_channels, out_channels)) + + return nn.Sequential(*units) + + def forward(self, x, skip1, skip2): + out = list() + for i in range(self.num_units): + module_name = f'layer{i + 1}' + module_i = getattr(self, module_name) + x = module_i(x) + if self.has_skip: + x = x + skip1[i] + skip2[i] + out.append(x) + out.reverse() + + return tuple(out) + + +class UpsampleUnit(nn.Module): + """Upsample unit for upsample module. + + Args: + ind (int): Indicates whether to interpolate (>0) and whether to + generate feature map for the next hourglass-like module. + num_units (int): Number of units that form a upsample module. Along + with ind and gen_cross_conv, nm_units is used to decide whether + to generate feature map for the next hourglass-like module. + in_channels (int): Channel number of the skip-in feature maps from + the corresponding downsample unit. + unit_channels (int): Channel number in this unit. Default:256. + gen_skip: (bool): Whether or not to generate skips for the posterior + downsample module. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + """ + + def __init__(self, + ind, + num_units, + in_channels, + unit_channels=256, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.num_units = num_units + self.norm_cfg = norm_cfg + self.in_skip = ConvModule( + in_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + self.relu = nn.ReLU(inplace=True) + + self.ind = ind + if self.ind > 0: + self.up_conv = ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + self.gen_skip = gen_skip + if self.gen_skip: + self.out_skip1 = ConvModule( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.out_skip2 = ConvModule( + unit_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.gen_cross_conv = gen_cross_conv + if self.ind == num_units - 1 and self.gen_cross_conv: + self.cross_conv = ConvModule( + unit_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + def forward(self, x, up_x): + out = self.in_skip(x) + + if self.ind > 0: + up_x = F.interpolate( + up_x, + size=(x.size(2), x.size(3)), + mode='bilinear', + align_corners=True) + up_x = self.up_conv(up_x) + out = out + up_x + out = self.relu(out) + + skip1 = None + skip2 = None + if self.gen_skip: + skip1 = self.out_skip1(x) + skip2 = self.out_skip2(out) + + cross_conv = None + if self.ind == self.num_units - 1 and self.gen_cross_conv: + cross_conv = self.cross_conv(out) + + return out, skip1, skip2, cross_conv + + +class UpsampleModule(nn.Module): + """Upsample module for MSPN. + + Args: + unit_channels (int): Channel number in the upsample units. + Default:256. + num_units (int): Numbers of upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + """ + + def __init__(self, + unit_channels=256, + num_units=4, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.in_channels = list() + for i in range(num_units): + self.in_channels.append(Bottleneck.expansion * out_channels * + pow(2, i)) + self.in_channels.reverse() + self.num_units = num_units + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.norm_cfg = norm_cfg + for i in range(num_units): + module_name = f'up{i + 1}' + self.add_module( + module_name, + UpsampleUnit( + i, + self.num_units, + self.in_channels[i], + unit_channels, + self.gen_skip, + self.gen_cross_conv, + norm_cfg=self.norm_cfg, + out_channels=64)) + + def forward(self, x): + out = list() + skip1 = list() + skip2 = list() + cross_conv = None + for i in range(self.num_units): + module_i = getattr(self, f'up{i + 1}') + if i == 0: + outi, skip1_i, skip2_i, _ = module_i(x[i], None) + elif i == self.num_units - 1: + outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) + else: + outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) + out.append(outi) + skip1.append(skip1_i) + skip2.append(skip2_i) + skip1.reverse() + skip2.reverse() + + return out, skip1, skip2, cross_conv + + +class SingleStageNetwork(nn.Module): + """Single_stage Network. + + Args: + unit_channels (int): Channel number in the upsample units. Default:256. + num_units (int): Numbers of downsample/upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_blocks (list): Number of blocks in each downsample unit. + Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the feature from ResNetTop. + Default: 64. + """ + + def __init__(self, + has_skip=False, + gen_skip=False, + gen_cross_conv=False, + unit_channels=256, + num_units=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + in_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__() + assert len(num_blocks) == num_units + self.has_skip = has_skip + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.num_units = num_units + self.unit_channels = unit_channels + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units, + has_skip, norm_cfg, in_channels) + self.upsample = UpsampleModule(unit_channels, num_units, gen_skip, + gen_cross_conv, norm_cfg, in_channels) + + def forward(self, x, skip1, skip2): + mid = self.downsample(x, skip1, skip2) + out, skip1, skip2, cross_conv = self.upsample(mid) + + return out, skip1, skip2, cross_conv + + +class ResNetTop(nn.Module): + """ResNet top for MSPN. + + Args: + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + channels (int): Number of channels of the feature output by ResNetTop. + """ + + def __init__(self, norm_cfg=dict(type='BN'), channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.top = nn.Sequential( + ConvModule( + 3, + channels, + kernel_size=7, + stride=2, + padding=3, + norm_cfg=norm_cfg, + inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) + + def forward(self, img): + return self.top(img) + + +@BACKBONES.register_module() +class MSPN(BaseBackbone): + """MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks + for Human Pose Estimation" (CVPR 2020). + + Args: + unit_channels (int): Number of Channels in an upsample unit. + Default: 256 + num_stages (int): Number of stages in a multi-stage MSPN. Default: 4 + num_units (int): Number of downsample/upsample units in a single-stage + network. Default: 4 + Note: Make sure num_units == len(self.num_blocks) + num_blocks (list): Number of bottlenecks in each + downsample unit. Default: [2, 2, 2, 2] + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + res_top_channels (int): Number of channels of feature from ResNetTop. + Default: 64. + + Example: + >>> from mmpose.models import MSPN + >>> import torch + >>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2]) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... for feature in level_output: + ... print(tuple(feature.shape)) + ... + (1, 256, 64, 64) + (1, 256, 128, 128) + (1, 256, 64, 64) + (1, 256, 128, 128) + """ + + def __init__(self, + unit_channels=256, + num_stages=4, + num_units=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + res_top_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__() + self.unit_channels = unit_channels + self.num_stages = num_stages + self.num_units = num_units + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + assert self.num_stages > 0 + assert self.num_units > 1 + assert self.num_units == len(self.num_blocks) + self.top = ResNetTop(norm_cfg=norm_cfg) + self.multi_stage_mspn = nn.ModuleList([]) + for i in range(self.num_stages): + if i == 0: + has_skip = False + else: + has_skip = True + if i != self.num_stages - 1: + gen_skip = True + gen_cross_conv = True + else: + gen_skip = False + gen_cross_conv = False + self.multi_stage_mspn.append( + SingleStageNetwork(has_skip, gen_skip, gen_cross_conv, + unit_channels, num_units, num_blocks, + norm_cfg, res_top_channels)) + + def forward(self, x): + """Model forward function.""" + out_feats = [] + skip1 = None + skip2 = None + x = self.top(x) + for i in range(self.num_stages): + out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2) + out_feats.append(out) + + return out_feats + + def init_weights(self, pretrained=None): + """Initialize model weights.""" + if isinstance(pretrained, str): + logger = get_root_logger() + state_dict_tmp = get_state_dict(pretrained) + state_dict = OrderedDict() + state_dict['top'] = OrderedDict() + state_dict['bottlenecks'] = OrderedDict() + for k, v in state_dict_tmp.items(): + if k.startswith('layer'): + if 'downsample.0' in k: + state_dict['bottlenecks'][k.replace( + 'downsample.0', 'downsample.conv')] = v + elif 'downsample.1' in k: + state_dict['bottlenecks'][k.replace( + 'downsample.1', 'downsample.bn')] = v + else: + state_dict['bottlenecks'][k] = v + elif k.startswith('conv1'): + state_dict['top'][k.replace('conv1', 'top.0.conv')] = v + elif k.startswith('bn1'): + state_dict['top'][k.replace('bn1', 'top.0.bn')] = v + + load_state_dict( + self.top, state_dict['top'], strict=False, logger=logger) + for i in range(self.num_stages): + load_state_dict( + self.multi_stage_mspn[i].downsample, + state_dict['bottlenecks'], + strict=False, + logger=logger) + else: + for m in self.multi_stage_mspn.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + + for m in self.top.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/regnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/regnet.py new file mode 100644 index 0000000..693417c --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/regnet.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import ResNet +from .resnext import Bottleneck + + +@BACKBONES.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `__ . + + Args: + arch (dict): The parameter of RegNets. + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: "pytorch". + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Default: -1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpose.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0), + out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super(ResNet, self).__init__() + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert 1 <= num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + if self.deep_stem: + raise NotImplementedError( + 'deep_stem has not been implemented for RegNet') + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + _in_channels = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + res_layer = self.make_res_layer( + block=Bottleneck, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=self.stage_widths[i], + expansion=1, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + base_channels=self.stage_widths[i], + groups=stage_groups, + width_per_group=group_width) + _in_channels = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + @staticmethod + def generate_regnet(initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int, optional): The divisor of channels. Defaults to 8. + + Returns: + list, int: return a list of widths of each stage and the number of + stages + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divior. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + return tuple(outs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnest.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnest.py new file mode 100644 index 0000000..0a2d408 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnest.py @@ -0,0 +1,338 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super().__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + return getattr(self, self.norm0_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + width_per_group=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = SplitAttentionConv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Please refer to the `paper `__ + for details. + + Args: + depth (int): Network depth, from {50, 101, 152, 200}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)), + 269: (Bottleneck, (3, 30, 48, 8)) + } + + def __init__(self, + depth, + groups=1, + width_per_group=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.width_per_group = width_per_group + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super().__init__(depth=depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnext.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnext.py new file mode 100644 index 0000000..c10dc33 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnext.py @@ -0,0 +1,162 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpose.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/rsn.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/rsn.py new file mode 100644 index 0000000..29038af --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/rsn.py @@ -0,0 +1,616 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy as cp + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init, + normal_init) + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class RSB(nn.Module): + """Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate + Local Representations for Multi-Person Pose Estimation" (ECCV 2020). + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + num_steps (int): Numbers of steps in RSB + stride (int): stride of the block. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + expand_times (int): Times by which the in_channels are expanded. + Default:26. + res_top_channels (int): Number of channels of feature output by + ResNet_top. Default:64. + """ + + expansion = 1 + + def __init__(self, + in_channels, + out_channels, + num_steps=4, + stride=1, + downsample=None, + with_cp=False, + norm_cfg=dict(type='BN'), + expand_times=26, + res_top_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + assert num_steps > 1 + self.in_channels = in_channels + self.branch_channels = self.in_channels * expand_times + self.branch_channels //= res_top_channels + self.out_channels = out_channels + self.stride = stride + self.downsample = downsample + self.with_cp = with_cp + self.norm_cfg = norm_cfg + self.num_steps = num_steps + self.conv_bn_relu1 = ConvModule( + self.in_channels, + self.num_steps * self.branch_channels, + kernel_size=1, + stride=self.stride, + padding=0, + norm_cfg=self.norm_cfg, + inplace=False) + for i in range(self.num_steps): + for j in range(i + 1): + module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' + self.add_module( + module_name, + ConvModule( + self.branch_channels, + self.branch_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg, + inplace=False)) + self.conv_bn3 = ConvModule( + self.num_steps * self.branch_channels, + self.out_channels * self.expansion, + kernel_size=1, + stride=1, + padding=0, + act_cfg=None, + norm_cfg=self.norm_cfg, + inplace=False) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + """Forward function.""" + + identity = x + x = self.conv_bn_relu1(x) + spx = torch.split(x, self.branch_channels, 1) + outputs = list() + outs = list() + for i in range(self.num_steps): + outputs_i = list() + outputs.append(outputs_i) + for j in range(i + 1): + if j == 0: + inputs = spx[i] + else: + inputs = outputs[i][j - 1] + if i > j: + inputs = inputs + outputs[i - 1][j] + module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' + module_i_j = getattr(self, module_name) + outputs[i].append(module_i_j(inputs)) + + outs.append(outputs[i][i]) + out = torch.cat(tuple(outs), 1) + out = self.conv_bn3(out) + + if self.downsample is not None: + identity = self.downsample(identity) + out = out + identity + + out = self.relu(out) + + return out + + +class Downsample_module(nn.Module): + """Downsample module for RSN. + + Args: + block (nn.Module): Downsample block. + num_blocks (list): Number of blocks in each downsample unit. + num_units (int): Numbers of downsample units. Default: 4 + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_steps (int): Number of steps in a block. Default:4 + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the input feature to + downsample module. Default: 64 + expand_times (int): Times by which the in_channels are expanded. + Default:26. + """ + + def __init__(self, + block, + num_blocks, + num_steps=4, + num_units=4, + has_skip=False, + norm_cfg=dict(type='BN'), + in_channels=64, + expand_times=26): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.has_skip = has_skip + self.in_channels = in_channels + assert len(num_blocks) == num_units + self.num_blocks = num_blocks + self.num_units = num_units + self.num_steps = num_steps + self.norm_cfg = norm_cfg + self.layer1 = self._make_layer( + block, + in_channels, + num_blocks[0], + expand_times=expand_times, + res_top_channels=in_channels) + for i in range(1, num_units): + module_name = f'layer{i + 1}' + self.add_module( + module_name, + self._make_layer( + block, + in_channels * pow(2, i), + num_blocks[i], + stride=2, + expand_times=expand_times, + res_top_channels=in_channels)) + + def _make_layer(self, + block, + out_channels, + blocks, + stride=1, + expand_times=26, + res_top_channels=64): + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = ConvModule( + self.in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + units = list() + units.append( + block( + self.in_channels, + out_channels, + num_steps=self.num_steps, + stride=stride, + downsample=downsample, + norm_cfg=self.norm_cfg, + expand_times=expand_times, + res_top_channels=res_top_channels)) + self.in_channels = out_channels * block.expansion + for _ in range(1, blocks): + units.append( + block( + self.in_channels, + out_channels, + num_steps=self.num_steps, + expand_times=expand_times, + res_top_channels=res_top_channels)) + + return nn.Sequential(*units) + + def forward(self, x, skip1, skip2): + out = list() + for i in range(self.num_units): + module_name = f'layer{i + 1}' + module_i = getattr(self, module_name) + x = module_i(x) + if self.has_skip: + x = x + skip1[i] + skip2[i] + out.append(x) + out.reverse() + + return tuple(out) + + +class Upsample_unit(nn.Module): + """Upsample unit for upsample module. + + Args: + ind (int): Indicates whether to interpolate (>0) and whether to + generate feature map for the next hourglass-like module. + num_units (int): Number of units that form a upsample module. Along + with ind and gen_cross_conv, nm_units is used to decide whether + to generate feature map for the next hourglass-like module. + in_channels (int): Channel number of the skip-in feature maps from + the corresponding downsample unit. + unit_channels (int): Channel number in this unit. Default:256. + gen_skip: (bool): Whether or not to generate skips for the posterior + downsample module. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (in): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + """ + + def __init__(self, + ind, + num_units, + in_channels, + unit_channels=256, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.num_units = num_units + self.norm_cfg = norm_cfg + self.in_skip = ConvModule( + in_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + self.relu = nn.ReLU(inplace=True) + + self.ind = ind + if self.ind > 0: + self.up_conv = ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + self.gen_skip = gen_skip + if self.gen_skip: + self.out_skip1 = ConvModule( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.out_skip2 = ConvModule( + unit_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.gen_cross_conv = gen_cross_conv + if self.ind == num_units - 1 and self.gen_cross_conv: + self.cross_conv = ConvModule( + unit_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + def forward(self, x, up_x): + out = self.in_skip(x) + + if self.ind > 0: + up_x = F.interpolate( + up_x, + size=(x.size(2), x.size(3)), + mode='bilinear', + align_corners=True) + up_x = self.up_conv(up_x) + out = out + up_x + out = self.relu(out) + + skip1 = None + skip2 = None + if self.gen_skip: + skip1 = self.out_skip1(x) + skip2 = self.out_skip2(out) + + cross_conv = None + if self.ind == self.num_units - 1 and self.gen_cross_conv: + cross_conv = self.cross_conv(out) + + return out, skip1, skip2, cross_conv + + +class Upsample_module(nn.Module): + """Upsample module for RSN. + + Args: + unit_channels (int): Channel number in the upsample units. + Default:256. + num_units (int): Numbers of upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + """ + + def __init__(self, + unit_channels=256, + num_units=4, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.in_channels = list() + for i in range(num_units): + self.in_channels.append(RSB.expansion * out_channels * pow(2, i)) + self.in_channels.reverse() + self.num_units = num_units + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.norm_cfg = norm_cfg + for i in range(num_units): + module_name = f'up{i + 1}' + self.add_module( + module_name, + Upsample_unit( + i, + self.num_units, + self.in_channels[i], + unit_channels, + self.gen_skip, + self.gen_cross_conv, + norm_cfg=self.norm_cfg, + out_channels=64)) + + def forward(self, x): + out = list() + skip1 = list() + skip2 = list() + cross_conv = None + for i in range(self.num_units): + module_i = getattr(self, f'up{i + 1}') + if i == 0: + outi, skip1_i, skip2_i, _ = module_i(x[i], None) + elif i == self.num_units - 1: + outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) + else: + outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) + out.append(outi) + skip1.append(skip1_i) + skip2.append(skip2_i) + skip1.reverse() + skip2.reverse() + + return out, skip1, skip2, cross_conv + + +class Single_stage_RSN(nn.Module): + """Single_stage Residual Steps Network. + + Args: + unit_channels (int): Channel number in the upsample units. Default:256. + num_units (int): Numbers of downsample/upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_steps (int): Number of steps in RSB. Default: 4 + num_blocks (list): Number of blocks in each downsample unit. + Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the feature from ResNet_Top. + Default: 64. + expand_times (int): Times by which the in_channels are expanded in RSB. + Default:26. + """ + + def __init__(self, + has_skip=False, + gen_skip=False, + gen_cross_conv=False, + unit_channels=256, + num_units=4, + num_steps=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + in_channels=64, + expand_times=26): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__() + assert len(num_blocks) == num_units + self.has_skip = has_skip + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.num_units = num_units + self.num_steps = num_steps + self.unit_channels = unit_channels + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + self.downsample = Downsample_module(RSB, num_blocks, num_steps, + num_units, has_skip, norm_cfg, + in_channels, expand_times) + self.upsample = Upsample_module(unit_channels, num_units, gen_skip, + gen_cross_conv, norm_cfg, in_channels) + + def forward(self, x, skip1, skip2): + mid = self.downsample(x, skip1, skip2) + out, skip1, skip2, cross_conv = self.upsample(mid) + + return out, skip1, skip2, cross_conv + + +class ResNet_top(nn.Module): + """ResNet top for RSN. + + Args: + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + channels (int): Number of channels of the feature output by ResNet_top. + """ + + def __init__(self, norm_cfg=dict(type='BN'), channels=64): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.top = nn.Sequential( + ConvModule( + 3, + channels, + kernel_size=7, + stride=2, + padding=3, + norm_cfg=norm_cfg, + inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) + + def forward(self, img): + return self.top(img) + + +@BACKBONES.register_module() +class RSN(BaseBackbone): + """Residual Steps Network backbone. Paper ref: Cai et al. "Learning + Delicate Local Representations for Multi-Person Pose Estimation" (ECCV + 2020). + + Args: + unit_channels (int): Number of Channels in an upsample unit. + Default: 256 + num_stages (int): Number of stages in a multi-stage RSN. Default: 4 + num_units (int): NUmber of downsample/upsample units in a single-stage + RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks) + num_blocks (list): Number of RSBs (Residual Steps Block) in each + downsample unit. Default: [2, 2, 2, 2] + num_steps (int): Number of steps in a RSB. Default:4 + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + res_top_channels (int): Number of channels of feature from ResNet_top. + Default: 64. + expand_times (int): Times by which the in_channels are expanded in RSB. + Default:26. + Example: + >>> from mmpose.models import RSN + >>> import torch + >>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2]) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... for feature in level_output: + ... print(tuple(feature.shape)) + ... + (1, 256, 64, 64) + (1, 256, 128, 128) + (1, 256, 64, 64) + (1, 256, 128, 128) + """ + + def __init__(self, + unit_channels=256, + num_stages=4, + num_units=4, + num_blocks=[2, 2, 2, 2], + num_steps=4, + norm_cfg=dict(type='BN'), + res_top_channels=64, + expand_times=26): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__() + self.unit_channels = unit_channels + self.num_stages = num_stages + self.num_units = num_units + self.num_blocks = num_blocks + self.num_steps = num_steps + self.norm_cfg = norm_cfg + + assert self.num_stages > 0 + assert self.num_steps > 1 + assert self.num_units > 1 + assert self.num_units == len(self.num_blocks) + self.top = ResNet_top(norm_cfg=norm_cfg) + self.multi_stage_rsn = nn.ModuleList([]) + for i in range(self.num_stages): + if i == 0: + has_skip = False + else: + has_skip = True + if i != self.num_stages - 1: + gen_skip = True + gen_cross_conv = True + else: + gen_skip = False + gen_cross_conv = False + self.multi_stage_rsn.append( + Single_stage_RSN(has_skip, gen_skip, gen_cross_conv, + unit_channels, num_units, num_steps, + num_blocks, norm_cfg, res_top_channels, + expand_times)) + + def forward(self, x): + """Model forward function.""" + out_feats = [] + skip1 = None + skip2 = None + x = self.top(x) + for i in range(self.num_stages): + out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2) + out_feats.append(out) + + return out_feats + + def init_weights(self, pretrained=None): + """Initialize model weights.""" + for m in self.multi_stage_rsn.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + + for m in self.top.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/scnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/scnet.py new file mode 100644 index 0000000..3786c57 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/scnet.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import Bottleneck, ResNet + + +class SCConv(nn.Module): + """SCConv (Self-calibrated Convolution) + + Args: + in_channels (int): The input channels of the SCConv. + out_channels (int): The output channel of the SCConv. + stride (int): stride of SCConv. + pooling_r (int): size of pooling for scconv. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + stride, + pooling_r, + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.1)): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + + assert in_channels == out_channels + + self.k2 = nn.Sequential( + nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(norm_cfg, in_channels)[1], + ) + self.k3 = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(norm_cfg, in_channels)[1], + ) + self.k4 = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + nn.ReLU(inplace=True), + ) + + def forward(self, x): + """Forward function.""" + identity = x + + out = torch.sigmoid( + torch.add(identity, F.interpolate(self.k2(x), + identity.size()[2:]))) + out = torch.mul(self.k3(x), out) + out = self.k4(out) + + return out + + +class SCBottleneck(Bottleneck): + """SC(Self-calibrated) Bottleneck. + + Args: + in_channels (int): The input channels of the SCBottleneck block. + out_channels (int): The output channel of the SCBottleneck block. + """ + + pooling_r = 4 + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.mid_channels = out_channels // self.expansion // 2 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm1_name, norm1) + + self.k1 = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.stride, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, self.mid_channels)[1], + nn.ReLU(inplace=True)) + + self.conv2 = build_conv_layer( + self.conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride, + self.pooling_r, self.conv_cfg, self.norm_cfg) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels * 2, + out_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out_a = self.conv1(x) + out_a = self.norm1(out_a) + out_a = self.relu(out_a) + + out_a = self.k1(out_a) + + out_b = self.conv2(x) + out_b = self.norm2(out_b) + out_b = self.relu(out_b) + + out_b = self.scconv(out_b) + + out = self.conv3(torch.cat([out_a, out_b], dim=1)) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class SCNet(ResNet): + """SCNet backbone. + + Improving Convolutional Networks with Self-Calibrated Convolutions, + Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng, + IEEE CVPR, 2020. + http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf + + Args: + depth (int): Depth of scnet, from {50, 101}. + in_channels (int): Number of input image channels. Normally 3. + base_channels (int): Number of base channels of hidden layer. + num_stages (int): SCNet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmpose.models import SCNet + >>> import torch + >>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SCBottleneck, [3, 4, 6, 3]), + 101: (SCBottleneck, [3, 4, 23, 3]) + } + + def __init__(self, depth, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SCNet') + super().__init__(depth, **kwargs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnet.py new file mode 100644 index 0000000..ac2d53b --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnet.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp + +from ..builder import BACKBONES +from .resnet import Bottleneck, ResLayer, ResNet +from .utils.se_layer import SELayer + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + in_channels (int): The input channels of the SEBottleneck block. + out_channels (int): The output channel of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.se_layer = SELayer(out_channels, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpose.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SEResNet') + self.se_ratio = se_ratio + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnext.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnext.py new file mode 100644 index 0000000..c5c4e4c --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnext.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import ResLayer +from .seresnet import SEBottleneck as _SEBottleneck +from .seresnet import SEResNet + + +class SEBottleneck(_SEBottleneck): + """SEBottleneck block for SEResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + base_channels (int): Middle channels of the first stage. Default: 64. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + se_ratio=16, + **kwargs): + super().__init__(in_channels, out_channels, se_ratio, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # We follow the same rational of ResNext to compute mid_channels. + # For SEResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for SEResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class SEResNeXt(SEResNet): + """SEResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpose.models import SEResNeXt + >>> import torch + >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v1.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v1.py new file mode 100644 index 0000000..9f98cbd --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v1.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, constant_init, + normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import channel_shuffle, load_checkpoint, make_divisible + + +class ShuffleUnit(nn.Module): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + in_channels (int): The input channels of the ShuffleUnit. + out_channels (int): The output channels of the ShuffleUnit. + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3 + first_block (bool, optional): Whether it is the first ShuffleUnit of a + sequential ShuffleUnits. Default: True, which means not using the + grouped 1x1 convolution. + combine (str, optional): The ways to combine the input and output + branches. Default: 'add'. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + groups=3, + first_block=True, + combine='add', + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.out_channels // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + assert in_channels == out_channels, ( + 'in_channels must be equal to out_channels when combine ' + 'is add') + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.out_channels -= self.in_channels + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + 'Only "add" and "concat" are supported') + + self.first_1x1_groups = 1 if first_block else self.groups + self.g_conv_1x1_compress = ConvModule( + in_channels=self.in_channels, + out_channels=self.bottleneck_channels, + kernel_size=1, + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, + stride=self.depthwise_stride, + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.g_conv_1x1_expand = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.out_channels, + kernel_size=1, + groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.act = build_activation_layer(act_cfg) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.g_conv_1x1_compress(x) + out = self.depthwise_conv3x3_bn(out) + + if self.groups > 1: + out = channel_shuffle(out, self.groups) + + out = self.g_conv_1x1_expand(out) + + if self.combine == 'concat': + residual = self.avgpool(residual) + out = self.act(out) + out = self._combine_func(residual, out) + else: + out = self._combine_func(residual, out) + out = self.act(out) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. + + Args: + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3. + widen_factor (float, optional): Width multiplier - adjusts the number + of channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, ) + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(2, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.stage_blocks = [4, 8, 4] + self.groups = groups + + for index in out_indices: + if index not in range(0, 3): + raise ValueError('the item in out_indices must in ' + f'range(0, 3). But received {index}') + + if frozen_stages not in range(-1, 3): + raise ValueError('frozen_stages must be in range(-1, 3). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if groups == 1: + channels = (144, 288, 576) + elif groups == 2: + channels = (200, 400, 800) + elif groups == 3: + channels = (240, 480, 960) + elif groups == 4: + channels = (272, 544, 1088) + elif groups == 8: + channels = (384, 768, 1536) + else: + raise ValueError(f'{groups} groups is not supported for 1x1 ' + 'Grouped Convolutions') + + channels = [make_divisible(ch * widen_factor, 8) for ch in channels] + + self.in_channels = int(24 * widen_factor) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + first_block = (i == 0) + layer = self.make_layer(channels[i], num_blocks, first_block) + self.layers.append(layer) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + else: + raise TypeError('pretrained must be a str or None. But received ' + f'{type(pretrained)}') + + def make_layer(self, out_channels, num_blocks, first_block=False): + """Stack ShuffleUnit blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): Number of blocks. + first_block (bool, optional): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. Default: False, which means using + the grouped 1x1 convolution. + """ + layers = [] + for i in range(num_blocks): + first_block = first_block if i == 0 else False + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.in_channels, + out_channels, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v2.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v2.py new file mode 100644 index 0000000..e935333 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v2.py @@ -0,0 +1,302 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import channel_shuffle, load_checkpoint + + +class InvertedResidual(nn.Module): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. + + Args: + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + widen_factor=1.0, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.stage_blocks = [4, 8, 4] + for index in out_indices: + if index not in range(0, 4): + raise ValueError('the item in out_indices must in ' + f'range(0, 4). But received {index}') + + if frozen_stages not in range(-1, 4): + raise ValueError('frozen_stages must be in range(-1, 4). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.in_channels = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels, num_blocks): + """Stack blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + else: + raise TypeError('pretrained must be a str or None. But received ' + f'{type(pretrained)}') + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/tcn.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/tcn.py new file mode 100644 index 0000000..deca229 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/tcn.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmpose.core import WeightNormClipHook +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class BasicTemporalBlock(nn.Module): + """Basic block for VideoPose3D. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + mid_channels (int): The output channels of conv1. Default: 1024. + kernel_size (int): Size of the convolving kernel. Default: 3. + dilation (int): Spacing between kernel elements. Default: 3. + dropout (float): Dropout rate. Default: 0.25. + causal (bool): Use causal convolutions instead of symmetric + convolutions (for real-time applications). Default: False. + residual (bool): Use residual connection. Default: True. + use_stride_conv (bool): Use optimized TCN that designed + specifically for single-frame batching, i.e. where batches have + input length = receptive field, and output length = 1. This + implementation replaces dilated convolutions with strided + convolutions to avoid generating unused intermediate results. + Default: False. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: dict(type='Conv1d'). + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN1d'). + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels=1024, + kernel_size=3, + dilation=3, + dropout=0.25, + causal=False, + residual=True, + use_stride_conv=False, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d')): + # Protect mutable default arguments + conv_cfg = copy.deepcopy(conv_cfg) + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = mid_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.dropout = dropout + self.causal = causal + self.residual = residual + self.use_stride_conv = use_stride_conv + + self.pad = (kernel_size - 1) * dilation // 2 + if use_stride_conv: + self.stride = kernel_size + self.causal_shift = kernel_size // 2 if causal else 0 + self.dilation = 1 + else: + self.stride = 1 + self.causal_shift = kernel_size // 2 * dilation if causal else 0 + + self.conv1 = nn.Sequential( + ConvModule( + in_channels, + mid_channels, + kernel_size=kernel_size, + stride=self.stride, + dilation=self.dilation, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + self.conv2 = nn.Sequential( + ConvModule( + mid_channels, + out_channels, + kernel_size=1, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + if residual and in_channels != out_channels: + self.short_cut = build_conv_layer(conv_cfg, in_channels, + out_channels, 1) + else: + self.short_cut = None + + self.dropout = nn.Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + """Forward function.""" + if self.use_stride_conv: + assert self.causal_shift + self.kernel_size // 2 < x.shape[2] + else: + assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ + self.pad + self.causal_shift <= x.shape[2] + + out = self.conv1(x) + if self.dropout is not None: + out = self.dropout(out) + + out = self.conv2(out) + if self.dropout is not None: + out = self.dropout(out) + + if self.residual: + if self.use_stride_conv: + res = x[:, :, self.causal_shift + + self.kernel_size // 2::self.kernel_size] + else: + res = x[:, :, + (self.pad + self.causal_shift):(x.shape[2] - self.pad + + self.causal_shift)] + + if self.short_cut is not None: + res = self.short_cut(res) + out = out + res + + return out + + +@BACKBONES.register_module() +class TCN(BaseBackbone): + """TCN backbone. + + Temporal Convolutional Networks. + More details can be found in the + `paper `__ . + + Args: + in_channels (int): Number of input channels, which equals to + num_keypoints * num_features. + stem_channels (int): Number of feature channels. Default: 1024. + num_blocks (int): NUmber of basic temporal convolutional blocks. + Default: 2. + kernel_sizes (Sequence[int]): Sizes of the convolving kernel of + each basic block. Default: ``(3, 3, 3)``. + dropout (float): Dropout rate. Default: 0.25. + causal (bool): Use causal convolutions instead of symmetric + convolutions (for real-time applications). + Default: False. + residual (bool): Use residual connection. Default: True. + use_stride_conv (bool): Use TCN backbone optimized for + single-frame batching, i.e. where batches have input length = + receptive field, and output length = 1. This implementation + replaces dilated convolutions with strided convolutions to avoid + generating unused intermediate results. The weights are + interchangeable with the reference implementation. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: dict(type='Conv1d'). + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN1d'). + max_norm (float|None): if not None, the weight of convolution layers + will be clipped to have a maximum norm of max_norm. + + Example: + >>> from mmpose.models import TCN + >>> import torch + >>> self = TCN(in_channels=34) + >>> self.eval() + >>> inputs = torch.rand(1, 34, 243) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 1024, 235) + (1, 1024, 217) + """ + + def __init__(self, + in_channels, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + causal=False, + residual=True, + use_stride_conv=False, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + max_norm=None): + # Protect mutable default arguments + conv_cfg = copy.deepcopy(conv_cfg) + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + self.in_channels = in_channels + self.stem_channels = stem_channels + self.num_blocks = num_blocks + self.kernel_sizes = kernel_sizes + self.dropout = dropout + self.causal = causal + self.residual = residual + self.use_stride_conv = use_stride_conv + self.max_norm = max_norm + + assert num_blocks == len(kernel_sizes) - 1 + for ks in kernel_sizes: + assert ks % 2 == 1, 'Only odd filter widths are supported.' + + self.expand_conv = ConvModule( + in_channels, + stem_channels, + kernel_size=kernel_sizes[0], + stride=kernel_sizes[0] if use_stride_conv else 1, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + dilation = kernel_sizes[0] + self.tcn_blocks = nn.ModuleList() + for i in range(1, num_blocks + 1): + self.tcn_blocks.append( + BasicTemporalBlock( + in_channels=stem_channels, + out_channels=stem_channels, + mid_channels=stem_channels, + kernel_size=kernel_sizes[i], + dilation=dilation, + dropout=dropout, + causal=causal, + residual=residual, + use_stride_conv=use_stride_conv, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + dilation *= kernel_sizes[i] + + if self.max_norm is not None: + # Apply weight norm clip to conv layers + weight_clip = WeightNormClipHook(self.max_norm) + for module in self.modules(): + if isinstance(module, nn.modules.conv._ConvNd): + weight_clip.register(module) + + self.dropout = nn.Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + """Forward function.""" + x = self.expand_conv(x) + + if self.dropout is not None: + x = self.dropout(x) + + outs = [] + for i in range(self.num_blocks): + x = self.tcn_blocks[i](x) + outs.append(x) + + return tuple(outs) + + def init_weights(self, pretrained=None): + """Initialize the weights.""" + super().init_weights(pretrained) + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.modules.conv._ConvNd): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, _BatchNorm): + constant_init(m, 1) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/test_torch.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/test_torch.py new file mode 100644 index 0000000..c6833af --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/test_torch.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + + def __init__(self): + super(Net, self).__init__() + # 1 input image channel, 6 output channels, 5x5 square convolution + # kernel + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + # an affine operation: y = Wx + b + self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + # Max pooling over a (2, 2) window + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + # If the size is a square, you can specify with a single number + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +net = Net() +# print(net) + +net.train() + +input = torch.randn(1, 1, 32, 32) +# out = net(input) +# print(out) +output = net(input) +target = torch.randn(10) # a dummy target, for example +target = target.view(1, -1) # make it the same shape as output +criterion = nn.MSELoss() + +# loss = criterion(output.cuda(), target.cuda()) + +import torch.optim as optim + +# create your optimizer +optimizer = optim.SGD(net.parameters(), lr=0.01) + +# in your training loop: +optimizer.zero_grad() # zero the gradient buffers +output = net(input) +loss = criterion(output, target) + +loss.backward() + +optimizer.step() + +# print(loss) \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/__init__.py new file mode 100644 index 0000000..52a30ca --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .channel_shuffle import channel_shuffle +from .inverted_residual import InvertedResidual +from .make_divisible import make_divisible +from .se_layer import SELayer +from .utils import load_checkpoint + +__all__ = [ + 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', + 'load_checkpoint' +] diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/channel_shuffle.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/channel_shuffle.py new file mode 100644 index 0000000..27006a8 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/channel_shuffle.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/inverted_residual.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/inverted_residual.py new file mode 100644 index 0000000..dff762c --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/inverted_residual.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + groups (None or int): The group number of the depthwise convolution. + Default: None, which means group number = mid_channels. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. + Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + groups=None, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if groups is None: + groups = mid_channels + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if self.with_se: + self.se = SELayer(**se_cfg) + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/make_divisible.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/make_divisible.py new file mode 100644 index 0000000..b7666be --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/se_layer.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/se_layer.py new file mode 100644 index 0000000..07f7080 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/se_layer.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch.nn as nn +from mmcv.cnn import ConvModule + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/utils.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/utils.py new file mode 100644 index 0000000..a9ac948 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +from mmcv.runner.checkpoint import _load_checkpoint, load_state_dict + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict_tmp = checkpoint['state_dict'] + else: + state_dict_tmp = checkpoint + + state_dict = OrderedDict() + # strip prefix of state_dict + for k, v in state_dict_tmp.items(): + if k.startswith('module.backbone.'): + state_dict[k[16:]] = v + elif k.startswith('module.'): + state_dict[k[7:]] = v + elif k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def get_state_dict(filename, map_location='cpu'): + """Get state_dict from a file or URI. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + map_location (str): Same as :func:`torch.load`. + + Returns: + OrderedDict: The state_dict. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict_tmp = checkpoint['state_dict'] + else: + state_dict_tmp = checkpoint + + state_dict = OrderedDict() + # strip prefix of state_dict + for k, v in state_dict_tmp.items(): + if k.startswith('module.backbone.'): + state_dict[k[16:]] = v + elif k.startswith('module.'): + state_dict[k[7:]] = v + elif k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + return state_dict diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vgg.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vgg.py new file mode 100644 index 0000000..f7d4670 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vgg.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@BACKBONES.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. When it is None, the default behavior depends on + whether num_classes is specified. If num_classes <= 0, the default + value is (4, ), outputting the last feature map before classifier. + If num_classes > 0, the default value is (5, ), outputting the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True): + super().__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + super().init_weights(pretrained) + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, _BatchNorm): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_mbv3.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_mbv3.py new file mode 100644 index 0000000..ed990e3 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_mbv3.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging + +import torch.nn as nn +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone +from .utils import InvertedResidual, load_checkpoint + + +@BACKBONES.register_module() +class ViPNAS_MobileNetV3(BaseBackbone): + """ViPNAS_MobileNetV3 backbone. + + "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search" + More details can be found in the `paper + `__ . + + Args: + wid (list(int)): Searched width config for each stage. + expan (list(int)): Searched expansion ratio config for each stage. + dep (list(int)): Searched depth config for each stage. + ks (list(int)): Searched kernel size config for each stage. + group (list(int)): Searched group number config for each stage. + att (list(bool)): Searched attention config for each stage. + stride (list(int)): Stride config for each stage. + act (list(dict)): Activation config for each stage. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + + def __init__(self, + wid=[16, 16, 24, 40, 80, 112, 160], + expan=[None, 1, 5, 4, 5, 5, 6], + dep=[None, 1, 4, 4, 4, 4, 4], + ks=[3, 3, 7, 7, 5, 7, 5], + group=[None, 8, 120, 20, 100, 280, 240], + att=[None, True, True, False, True, True, True], + stride=[2, 1, 2, 2, 2, 1, 2], + act=[ + 'HSwish', 'ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish', + 'HSwish' + ], + conv_cfg=None, + norm_cfg=dict(type='BN'), + frozen_stages=-1, + norm_eval=False, + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + self.wid = wid + self.expan = expan + self.dep = dep + self.ks = ks + self.group = group + self.att = att + self.stride = stride + self.act = act + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.wid[0], + kernel_size=self.ks[0], + stride=self.stride[0], + padding=self.ks[0] // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type=self.act[0])) + + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + layer_index = 0 + for i, dep in enumerate(self.dep[1:]): + mid_channels = self.wid[i + 1] * self.expan[i + 1] + + if self.att[i + 1]: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), dict(type='HSigmoid'))) + else: + se_cfg = None + + if self.expan[i + 1] == 1: + with_expand_conv = False + else: + with_expand_conv = True + + for j in range(dep): + if j == 0: + stride = self.stride[i + 1] + in_channels = self.wid[i] + else: + stride = 1 + in_channels = self.wid[i + 1] + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=self.wid[i + 1], + mid_channels=mid_channels, + kernel_size=self.ks[i + 1], + groups=self.group[i + 1], + stride=stride, + se_cfg=se_cfg, + with_expand_conv=with_expand_conv, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=self.act[i + 1]), + with_cp=self.with_cp) + layer_index += 1 + layer_name = f'layer{layer_index}' + self.add_module(layer_name, layer) + layers.append(layer_name) + return layers + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_resnet.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_resnet.py new file mode 100644 index 0000000..81b028e --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_resnet.py @@ -0,0 +1,589 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmcv.cnn.bricks import ContextBlock +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class ViPNAS_Bottleneck(nn.Module): + """Bottleneck block for ViPNAS_ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + kernel_size (int): kernel size of conv2 searched in ViPANS. + groups (int): group number of conv2 searched in ViPNAS. + attention (bool): whether to use attention module in the end of + the block. + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + kernel_size=3, + groups=1, + attention=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=kernel_size, + stride=self.conv2_stride, + padding=kernel_size // 2, + groups=groups, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + if attention: + self.attention = ContextBlock(out_channels, + max(1.0 / 16, 16.0 / out_channels)) + else: + self.attention = None + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: the normalization layer named "norm3" """ + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.attention is not None: + out = self.attention(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 4 for ``ViPNAS_Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, ViPNAS_Bottleneck): + expansion = 1 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ViPNAS_ResLayer(nn.Sequential): + """ViPNAS_ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ViPNAS ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + kernel_size (int): Kernel Size of the corresponding convolution layer + searched in the block. + groups (int): Group number of the corresponding convolution layer + searched in the block. + attention (bool): Whether to use attention module in the end of the + block. + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + kernel_size=3, + groups=1, + attention=False, + **kwargs): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + self.block = block + self.expansion = get_expansion(block, expansion) + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + in_channels = out_channels + for _ in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + else: # downsample_first=False is for HourglassModule + for i in range(0, num_blocks - 1): + layers.append( + block( + in_channels=in_channels, + out_channels=in_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + + super().__init__(*layers) + + +@BACKBONES.register_module() +class ViPNAS_ResNet(BaseBackbone): + """ViPNAS_ResNet backbone. + + "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search" + More details can be found in the `paper + `__ . + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + wid (list(int)): Searched width config for each stage. + expan (list(int)): Searched expansion ratio config for each stage. + dep (list(int)): Searched depth config for each stage. + ks (list(int)): Searched kernel size config for each stage. + group (list(int)): Searched group number config for each stage. + att (list(bool)): Searched attention config for each stage. + """ + + arch_settings = { + 50: ViPNAS_Bottleneck, + } + + def __init__(self, + depth, + in_channels=3, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + wid=[48, 80, 160, 304, 608], + expan=[None, 1, 1, 1, 1], + dep=[None, 4, 6, 7, 3], + ks=[7, 3, 5, 5, 5], + group=[None, 16, 16, 16, 16], + att=[None, True, False, True, True]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = dep[0] + self.num_stages = num_stages + assert 1 <= num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block = self.arch_settings[depth] + self.stage_blocks = dep[1:1 + num_stages] + + self._make_stem_layer(in_channels, wid[0], ks[0]) + + self.res_layers = [] + _in_channels = wid[0] + for i, num_blocks in enumerate(self.stage_blocks): + expansion = get_expansion(self.block, expan[i + 1]) + _out_channels = wid[i + 1] * expansion + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=ks[i + 1], + groups=group[i + 1], + attention=att[i + 1]) + _in_channels = _out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + """Make a ViPNAS ResLayer.""" + return ViPNAS_ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels, kernel_size): + """Make stem layer.""" + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize model weights.""" + super().init_weights(pretrained) + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vit.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vit.py new file mode 100644 index 0000000..465dfad --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vit.py @@ -0,0 +1,308 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +# from ..builder import BACKBONES +# from .base_backbone import BaseBackbone + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +# @BACKBONES.register_module() +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + super().init_weights(pretrained, patch_padding=self.patch_padding) + + if pretrained is None: + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_coco_256x192.py new file mode 100644 index 0000000..de927db --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_coco_256x192.py @@ -0,0 +1,168 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=768, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict()) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +# data = dict( +# samples_per_gpu=64, +# workers_per_gpu=4, +# val_dataloader=dict(samples_per_gpu=32), +# test_dataloader=dict(samples_per_gpu=32), +# train=dict( +# type='TopDownCocoDataset', +# ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', +# img_prefix=f'{data_root}/train2017/', +# data_cfg=data_cfg, +# pipeline=train_pipeline, +# dataset_info={{_base_.dataset_info}}), +# val=dict( +# type='TopDownCocoDataset', +# ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', +# img_prefix=f'{data_root}/val2017/', +# data_cfg=data_cfg, +# pipeline=val_pipeline, +# dataset_info={{_base_.dataset_info}}), +# test=dict( +# type='TopDownCocoDataset', +# ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', +# img_prefix=f'{data_root}/val2017/', +# data_cfg=data_cfg, +# pipeline=test_pipeline, +# dataset_info={{_base_.dataset_info}}), +# ) + +def make_cfg(model=model,data_cfg=data_cfg): + cfg={} + cfg['model'] = model + cfg['data_cfg'] = data_cfg \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_simple_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_simple_coco_256x192.py new file mode 100644 index 0000000..d410a15 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_simple_coco_256x192.py @@ -0,0 +1,171 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=768, + num_deconv_layers=0, + num_deconv_filters=[], + num_deconv_kernels=[], + upsample=4, + extra=dict(final_conv_kernel=3, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_coco_256x192.py new file mode 100644 index 0000000..298b2b5 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_coco_256x192.py @@ -0,0 +1,170 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=32, + layer_decay_rate=0.85, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=1280, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_simple_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_simple_coco_256x192.py new file mode 100644 index 0000000..f9a86f0 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_simple_coco_256x192.py @@ -0,0 +1,171 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=32, + layer_decay_rate=0.85, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=1280, + num_deconv_layers=0, + num_deconv_filters=[], + num_deconv_kernels=[], + upsample=4, + extra=dict(final_conv_kernel=3, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_coco_256x192.py new file mode 100644 index 0000000..0753a3c --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_coco_256x192.py @@ -0,0 +1,170 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=16, + layer_decay_rate=0.8, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.5, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=1024, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_simple_coco_256x192.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_simple_coco_256x192.py new file mode 100644 index 0000000..63c7949 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_simple_coco_256x192.py @@ -0,0 +1,171 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/coco.py' +] +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=24, + layer_decay_rate=0.8, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.5, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=1024, + num_deconv_layers=0, + num_deconv_filters=[], + num_deconv_kernels=[], + upsample=4, + extra=dict(final_conv_kernel=3, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/__init__.py new file mode 100644 index 0000000..bccac75 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from .ae_higher_resolution_head import AEHigherResolutionHead +# from .ae_multi_stage_head import AEMultiStageHead +# from .ae_simple_head import AESimpleHead +# from .deconv_head import DeconvHead +# from .deeppose_regression_head import DeepposeRegressionHead +# from .hmr_head import HMRMeshHead +# from .interhand_3d_head import Interhand3DHead +# from .temporal_regression_head import TemporalRegressionHead +from .topdown_heatmap_base_head import TopdownHeatmapBaseHead +# from .topdown_heatmap_multi_stage_head import (TopdownHeatmapMSMUHead, +# TopdownHeatmapMultiStageHead) +from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead +# from .vipnas_heatmap_simple_head import ViPNASHeatmapSimpleHead +# from .voxelpose_head import CuboidCenterHead, CuboidPoseHead + +# __all__ = [ +# 'TopdownHeatmapSimpleHead', 'TopdownHeatmapMultiStageHead', +# 'TopdownHeatmapMSMUHead', 'TopdownHeatmapBaseHead', +# 'AEHigherResolutionHead', 'AESimpleHead', 'AEMultiStageHead', +# 'DeepposeRegressionHead', 'TemporalRegressionHead', 'Interhand3DHead', +# 'HMRMeshHead', 'DeconvHead', 'ViPNASHeatmapSimpleHead', 'CuboidCenterHead', +# 'CuboidPoseHead' +# ] diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deconv_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deconv_head.py new file mode 100644 index 0000000..90846d2 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deconv_head.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, + constant_init, normal_init) + +from mmpose.models.builder import HEADS, build_loss +from mmpose.models.utils.ops import resize + + +@HEADS.register_module() +class DeconvHead(nn.Module): + """Simple deconv head. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + in_index (int|Sequence[int]): Input feature index. Default: 0 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + Default: None. + + - 'resize_concat': Multiple feature maps will be resized to the + same size as the first one and then concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_keypoint (dict): Config for loss. Default: None. + """ + + def __init__(self, + in_channels=3, + out_channels=17, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + in_index=0, + input_transform=None, + align_corners=False, + loss_keypoint=None): + super().__init__() + + self.in_channels = in_channels + self.loss = build_loss(loss_keypoint) + + self._init_inputs(in_channels, in_index, input_transform) + self.in_index = in_index + self.align_corners = align_corners + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append( + build_norm_layer(dict(type='BN'), conv_channels)[1]) + layers.append(nn.ReLU(inplace=True)) + + layers.append( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding)) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform is not None, in_channels and in_index must be + list or tuple, with the same length. + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + + - 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor] | Tensor): multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(inputs, list): + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def get_loss(self, outputs, targets, masks): + """Calculate bottom-up masked mse loss. + + Note: + - batch_size: N + - num_channels: C + - heatmaps height: H + - heatmaps weight: W + + Args: + outputs (List(torch.Tensor[N,C,H,W])): Multi-scale outputs. + targets (List(torch.Tensor[N,C,H,W])): Multi-scale targets. + masks (List(torch.Tensor[N,H,W])): Masks of multi-scale targets. + """ + + losses = dict() + + for idx in range(len(targets)): + if 'loss' not in losses: + losses['loss'] = self.loss(outputs[idx], targets[idx], + masks[idx]) + else: + losses['loss'] += self.loss(outputs[idx], targets[idx], + masks[idx]) + + return losses + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + final_outputs = [] + x = self.deconv_layers(x) + y = self.final_layer(x) + final_outputs.append(y) + return final_outputs + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deeppose_regression_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deeppose_regression_head.py new file mode 100644 index 0000000..f326e26 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deeppose_regression_head.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmpose.core.evaluation import (keypoint_pck_accuracy, + keypoints_from_regression) +from mmpose.core.post_processing import fliplr_regression +from mmpose.models.builder import HEADS, build_loss + + +@HEADS.register_module() +class DeepposeRegressionHead(nn.Module): + """Deeppose regression head with fully connected layers. + + "DeepPose: Human Pose Estimation via Deep Neural Networks". + + Args: + in_channels (int): Number of input channels + num_joints (int): Number of joints + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + in_channels, + num_joints, + loss_keypoint=None, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.in_channels = in_channels + self.num_joints = num_joints + + self.loss = build_loss(loss_keypoint) + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + + self.fc = nn.Linear(self.in_channels, self.num_joints * 2) + + def forward(self, x): + """Forward function.""" + output = self.fc(x) + N, C = output.shape + return output.reshape([N, C // 2, 2]) + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 2]): Output keypoints. + target (torch.Tensor[N, K, 2]): Target keypoints. + target_weight (torch.Tensor[N, K, 2]): + Weights across different joint types. + """ + + losses = dict() + assert not isinstance(self.loss, nn.Sequential) + assert target.dim() == 3 and target_weight.dim() == 3 + losses['reg_loss'] = self.loss(output, target, target_weight) + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 2]): Output keypoints. + target (torch.Tensor[N, K, 2]): Target keypoints. + target_weight (torch.Tensor[N, K, 2]): + Weights across different joint types. + """ + + accuracy = dict() + + N = output.shape[0] + + _, avg_acc, cnt = keypoint_pck_accuracy( + output.detach().cpu().numpy(), + target.detach().cpu().numpy(), + target_weight[:, :, 0].detach().cpu().numpy() > 0, + thr=0.05, + normalize=np.ones((N, 2), dtype=np.float32)) + accuracy['acc_pose'] = avg_acc + + return accuracy + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_regression (np.ndarray): Output regression. + + Args: + x (torch.Tensor[N, K, 2]): Input features. + flip_pairs (None | list[tuple()): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_regression = fliplr_regression( + output.detach().cpu().numpy(), flip_pairs) + else: + output_regression = output.detach().cpu().numpy() + return output_regression + + def decode(self, img_metas, output, **kwargs): + """Decode the keypoints from output regression. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + output (np.ndarray[N, K, 2]): predicted regression vector. + kwargs: dict contains 'img_size'. + img_size (tuple(img_width, img_height)): input image size. + """ + batch_size = len(img_metas) + + if 'bbox_id' in img_metas[0]: + bbox_ids = [] + else: + bbox_ids = None + + c = np.zeros((batch_size, 2), dtype=np.float32) + s = np.zeros((batch_size, 2), dtype=np.float32) + image_paths = [] + score = np.ones(batch_size) + for i in range(batch_size): + c[i, :] = img_metas[i]['center'] + s[i, :] = img_metas[i]['scale'] + image_paths.append(img_metas[i]['image_file']) + + if 'bbox_score' in img_metas[i]: + score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + if bbox_ids is not None: + bbox_ids.append(img_metas[i]['bbox_id']) + + preds, maxvals = keypoints_from_regression(output, c, s, + kwargs['img_size']) + + all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) + all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + all_preds[:, :, 0:2] = preds[:, :, 0:2] + all_preds[:, :, 2:3] = maxvals + all_boxes[:, 0:2] = c[:, 0:2] + all_boxes[:, 2:4] = s[:, 0:2] + all_boxes[:, 4] = np.prod(s * 200.0, axis=1) + all_boxes[:, 5] = score + + result = {} + + result['preds'] = all_preds + result['boxes'] = all_boxes + result['image_paths'] = image_paths + result['bbox_ids'] = bbox_ids + + return result + + def init_weights(self): + normal_init(self.fc, mean=0, std=0.01, bias=0) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/hmr_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/hmr_head.py new file mode 100644 index 0000000..015a307 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/hmr_head.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import xavier_init + +from ..builder import HEADS +from ..utils.geometry import rot6d_to_rotmat + + +@HEADS.register_module() +class HMRMeshHead(nn.Module): + """SMPL parameters regressor head of simple baseline. "End-to-end Recovery + of Human Shape and Pose", CVPR'2018. + + Args: + in_channels (int): Number of input channels + smpl_mean_params (str): The file name of the mean SMPL parameters + n_iter (int): The iterations of estimating delta parameters + """ + + def __init__(self, in_channels, smpl_mean_params=None, n_iter=3): + super().__init__() + + self.in_channels = in_channels + self.n_iter = n_iter + + npose = 24 * 6 + nbeta = 10 + ncam = 3 + hidden_dim = 1024 + + self.fc1 = nn.Linear(in_channels + npose + nbeta + ncam, hidden_dim) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(hidden_dim, npose) + self.decshape = nn.Linear(hidden_dim, nbeta) + self.deccam = nn.Linear(hidden_dim, ncam) + + # Load mean SMPL parameters + if smpl_mean_params is None: + init_pose = torch.zeros([1, npose]) + init_shape = torch.zeros([1, nbeta]) + init_cam = torch.FloatTensor([[1, 0, 0]]) + else: + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy( + mean_params['pose'][:]).unsqueeze(0).float() + init_shape = torch.from_numpy( + mean_params['shape'][:]).unsqueeze(0).float() + init_cam = torch.from_numpy( + mean_params['cam']).unsqueeze(0).float() + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + + def forward(self, x): + """Forward function. + + x is the image feature map and is expected to be in shape (batch size x + channel number x height x width) + """ + batch_size = x.shape[0] + # extract the global feature vector by average along + # spatial dimension. + x = x.mean(dim=-1).mean(dim=-1) + + init_pose = self.init_pose.expand(batch_size, -1) + init_shape = self.init_shape.expand(batch_size, -1) + init_cam = self.init_cam.expand(batch_size, -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + for _ in range(self.n_iter): + xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1) + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + pred_cam = self.deccam(xc) + pred_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) + out = (pred_rotmat, pred_shape, pred_cam) + return out + + def init_weights(self): + """Initialize model weights.""" + xavier_init(self.decpose, gain=0.01) + xavier_init(self.decshape, gain=0.01) + xavier_init(self.deccam, gain=0.01) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/interhand_3d_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/interhand_3d_head.py new file mode 100644 index 0000000..aebe4a5 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/interhand_3d_head.py @@ -0,0 +1,521 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, + constant_init, normal_init) + +from mmpose.core.evaluation.top_down_eval import ( + keypoints_from_heatmaps3d, multilabel_classification_accuracy) +from mmpose.core.post_processing import flip_back +from mmpose.models.builder import build_loss +from mmpose.models.necks import GlobalAveragePooling +from ..builder import HEADS + + +class Heatmap3DHead(nn.Module): + """Heatmap3DHead is a sub-module of Interhand3DHead, and outputs 3D + heatmaps. Heatmap3DHead is composed of (>=0) number of deconv layers and a + simple conv2d layer. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + depth_size (int): Number of depth discretization size + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + num_deconv_kernels (list|tuple): Kernel sizes. + extra (dict): Configs for extra conv layers. Default: None + """ + + def __init__(self, + in_channels, + out_channels, + depth_size=64, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None): + + super().__init__() + + assert out_channels % depth_size == 0 + self.depth_size = depth_size + self.in_channels = in_channels + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append( + build_norm_layer(dict(type='BN'), conv_channels)[1]) + layers.append(nn.ReLU(inplace=True)) + + layers.append( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding)) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def forward(self, x): + """Forward function.""" + x = self.deconv_layers(x) + x = self.final_layer(x) + N, C, H, W = x.shape + # reshape the 2D heatmap to 3D heatmap + x = x.reshape(N, C // self.depth_size, self.depth_size, H, W) + return x + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + + +class Heatmap1DHead(nn.Module): + """Heatmap1DHead is a sub-module of Interhand3DHead, and outputs 1D + heatmaps. + + Args: + in_channels (int): Number of input channels + heatmap_size (int): Heatmap size + hidden_dims (list|tuple): Number of feature dimension of FC layers. + """ + + def __init__(self, in_channels=2048, heatmap_size=64, hidden_dims=(512, )): + super().__init__() + + self.in_channels = in_channels + self.heatmap_size = heatmap_size + + feature_dims = [in_channels, *hidden_dims, heatmap_size] + self.fc = self._make_linear_layers(feature_dims, relu_final=False) + + def soft_argmax_1d(self, heatmap1d): + heatmap1d = F.softmax(heatmap1d, 1) + accu = heatmap1d * torch.arange( + self.heatmap_size, dtype=heatmap1d.dtype, + device=heatmap1d.device)[None, :] + coord = accu.sum(dim=1) + return coord + + def _make_linear_layers(self, feat_dims, relu_final=False): + """Make linear layers.""" + layers = [] + for i in range(len(feat_dims) - 1): + layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1])) + if i < len(feat_dims) - 2 or \ + (i == len(feat_dims) - 2 and relu_final): + layers.append(nn.ReLU(inplace=True)) + return nn.Sequential(*layers) + + def forward(self, x): + """Forward function.""" + heatmap1d = self.fc(x) + value = self.soft_argmax_1d(heatmap1d).view(-1, 1) + return value + + def init_weights(self): + """Initialize model weights.""" + for m in self.fc.modules(): + if isinstance(m, nn.Linear): + normal_init(m, mean=0, std=0.01, bias=0) + + +class MultilabelClassificationHead(nn.Module): + """MultilabelClassificationHead is a sub-module of Interhand3DHead, and + outputs hand type classification. + + Args: + in_channels (int): Number of input channels + num_labels (int): Number of labels + hidden_dims (list|tuple): Number of hidden dimension of FC layers. + """ + + def __init__(self, in_channels=2048, num_labels=2, hidden_dims=(512, )): + super().__init__() + + self.in_channels = in_channels + self.num_labesl = num_labels + + feature_dims = [in_channels, *hidden_dims, num_labels] + self.fc = self._make_linear_layers(feature_dims, relu_final=False) + + def _make_linear_layers(self, feat_dims, relu_final=False): + """Make linear layers.""" + layers = [] + for i in range(len(feat_dims) - 1): + layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1])) + if i < len(feat_dims) - 2 or \ + (i == len(feat_dims) - 2 and relu_final): + layers.append(nn.ReLU(inplace=True)) + return nn.Sequential(*layers) + + def forward(self, x): + """Forward function.""" + labels = torch.sigmoid(self.fc(x)) + return labels + + def init_weights(self): + for m in self.fc.modules(): + if isinstance(m, nn.Linear): + normal_init(m, mean=0, std=0.01, bias=0) + + +@HEADS.register_module() +class Interhand3DHead(nn.Module): + """Interhand 3D head of paper ref: Gyeongsik Moon. "InterHand2.6M: A + Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single + RGB Image". + + Args: + keypoint_head_cfg (dict): Configs of Heatmap3DHead for hand + keypoint estimation. + root_head_cfg (dict): Configs of Heatmap1DHead for relative + hand root depth estimation. + hand_type_head_cfg (dict): Configs of MultilabelClassificationHead + for hand type classification. + loss_keypoint (dict): Config for keypoint loss. Default: None. + loss_root_depth (dict): Config for relative root depth loss. + Default: None. + loss_hand_type (dict): Config for hand type classification + loss. Default: None. + """ + + def __init__(self, + keypoint_head_cfg, + root_head_cfg, + hand_type_head_cfg, + loss_keypoint=None, + loss_root_depth=None, + loss_hand_type=None, + train_cfg=None, + test_cfg=None): + super().__init__() + + # build sub-module heads + self.right_hand_head = Heatmap3DHead(**keypoint_head_cfg) + self.left_hand_head = Heatmap3DHead(**keypoint_head_cfg) + self.root_head = Heatmap1DHead(**root_head_cfg) + self.hand_type_head = MultilabelClassificationHead( + **hand_type_head_cfg) + self.neck = GlobalAveragePooling() + + # build losses + self.keypoint_loss = build_loss(loss_keypoint) + self.root_depth_loss = build_loss(loss_root_depth) + self.hand_type_loss = build_loss(loss_hand_type) + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + def init_weights(self): + self.left_hand_head.init_weights() + self.right_hand_head.init_weights() + self.root_head.init_weights() + self.hand_type_head.init_weights() + + def get_loss(self, output, target, target_weight): + """Calculate loss for hand keypoint heatmaps, relative root depth and + hand type. + + Args: + output (list[Tensor]): a list of outputs from multiple heads. + target (list[Tensor]): a list of targets for multiple heads. + target_weight (list[Tensor]): a list of targets weight for + multiple heads. + """ + losses = dict() + + # hand keypoint loss + assert not isinstance(self.keypoint_loss, nn.Sequential) + out, tar, tar_weight = output[0], target[0], target_weight[0] + assert tar.dim() == 5 and tar_weight.dim() == 3 + losses['hand_loss'] = self.keypoint_loss(out, tar, tar_weight) + + # relative root depth loss + assert not isinstance(self.root_depth_loss, nn.Sequential) + out, tar, tar_weight = output[1], target[1], target_weight[1] + assert tar.dim() == 2 and tar_weight.dim() == 2 + losses['rel_root_loss'] = self.root_depth_loss(out, tar, tar_weight) + + # hand type loss + assert not isinstance(self.hand_type_loss, nn.Sequential) + out, tar, tar_weight = output[2], target[2], target_weight[2] + assert tar.dim() == 2 and tar_weight.dim() in [1, 2] + losses['hand_type_loss'] = self.hand_type_loss(out, tar, tar_weight) + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for hand type. + + Args: + output (list[Tensor]): a list of outputs from multiple heads. + target (list[Tensor]): a list of targets for multiple heads. + target_weight (list[Tensor]): a list of targets weight for + multiple heads. + """ + accuracy = dict() + avg_acc = multilabel_classification_accuracy( + output[2].detach().cpu().numpy(), + target[2].detach().cpu().numpy(), + target_weight[2].detach().cpu().numpy(), + ) + accuracy['acc_classification'] = float(avg_acc) + return accuracy + + def forward(self, x): + """Forward function.""" + outputs = [] + outputs.append( + torch.cat([self.right_hand_head(x), + self.left_hand_head(x)], dim=1)) + x = self.neck(x) + outputs.append(self.root_head(x)) + outputs.append(self.hand_type_head(x)) + return outputs + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output (list[np.ndarray]): list of output hand keypoint + heatmaps, relative root depth and hand type. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple()): + Pairs of keypoints which are mirrored. + """ + + output = self.forward(x) + + if flip_pairs is not None: + # flip 3D heatmap + heatmap_3d = output[0] + N, K, D, H, W = heatmap_3d.shape + # reshape 3D heatmap to 2D heatmap + heatmap_3d = heatmap_3d.reshape(N, K * D, H, W) + # 2D heatmap flip + heatmap_3d_flipped_back = flip_back( + heatmap_3d.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # reshape back to 3D heatmap + heatmap_3d_flipped_back = heatmap_3d_flipped_back.reshape( + N, K, D, H, W) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + heatmap_3d_flipped_back[..., + 1:] = heatmap_3d_flipped_back[..., :-1] + output[0] = heatmap_3d_flipped_back + + # flip relative hand root depth + output[1] = -output[1].detach().cpu().numpy() + + # flip hand type + hand_type = output[2].detach().cpu().numpy() + hand_type_flipped_back = hand_type.copy() + hand_type_flipped_back[:, 0] = hand_type[:, 1] + hand_type_flipped_back[:, 1] = hand_type[:, 0] + output[2] = hand_type_flipped_back + else: + output = [out.detach().cpu().numpy() for out in output] + + return output + + def decode(self, img_metas, output, **kwargs): + """Decode hand keypoint, relative root depth and hand type. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + - "heatmap3d_depth_bound": depth bound of hand keypoint + 3D heatmap + - "root_depth_bound": depth bound of relative root depth + 1D heatmap + output (list[np.ndarray]): model predicted 3D heatmaps, relative + root depth and hand type. + """ + + batch_size = len(img_metas) + result = {} + + heatmap3d_depth_bound = np.ones(batch_size, dtype=np.float32) + root_depth_bound = np.ones(batch_size, dtype=np.float32) + center = np.zeros((batch_size, 2), dtype=np.float32) + scale = np.zeros((batch_size, 2), dtype=np.float32) + image_paths = [] + score = np.ones(batch_size, dtype=np.float32) + if 'bbox_id' in img_metas[0]: + bbox_ids = [] + else: + bbox_ids = None + + for i in range(batch_size): + heatmap3d_depth_bound[i] = img_metas[i]['heatmap3d_depth_bound'] + root_depth_bound[i] = img_metas[i]['root_depth_bound'] + center[i, :] = img_metas[i]['center'] + scale[i, :] = img_metas[i]['scale'] + image_paths.append(img_metas[i]['image_file']) + + if 'bbox_score' in img_metas[i]: + score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + if bbox_ids is not None: + bbox_ids.append(img_metas[i]['bbox_id']) + + all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + all_boxes[:, 0:2] = center[:, 0:2] + all_boxes[:, 2:4] = scale[:, 0:2] + # scale is defined as: bbox_size / 200.0, so we + # need multiply 200.0 to get bbox size + all_boxes[:, 4] = np.prod(scale * 200.0, axis=1) + all_boxes[:, 5] = score + result['boxes'] = all_boxes + result['image_paths'] = image_paths + result['bbox_ids'] = bbox_ids + + # decode 3D heatmaps of hand keypoints + heatmap3d = output[0] + preds, maxvals = keypoints_from_heatmaps3d(heatmap3d, center, scale) + keypoints_3d = np.zeros((batch_size, preds.shape[1], 4), + dtype=np.float32) + keypoints_3d[:, :, 0:3] = preds[:, :, 0:3] + keypoints_3d[:, :, 3:4] = maxvals + # transform keypoint depth to camera space + keypoints_3d[:, :, 2] = \ + (keypoints_3d[:, :, 2] / self.right_hand_head.depth_size - 0.5) \ + * heatmap3d_depth_bound[:, np.newaxis] + + result['preds'] = keypoints_3d + + # decode relative hand root depth + # transform relative root depth to camera space + result['rel_root_depth'] = (output[1] / self.root_head.heatmap_size - + 0.5) * root_depth_bound + + # decode hand type + result['hand_type'] = output[2] > 0.5 + return result diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/temporal_regression_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/temporal_regression_head.py new file mode 100644 index 0000000..97a07f9 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/temporal_regression_head.py @@ -0,0 +1,319 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, constant_init, kaiming_init +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmpose.core import (WeightNormClipHook, compute_similarity_transform, + fliplr_regression) +from mmpose.models.builder import HEADS, build_loss + + +@HEADS.register_module() +class TemporalRegressionHead(nn.Module): + """Regression head of VideoPose3D. + + "3D human pose estimation in video with temporal convolutions and + semi-supervised training", CVPR'2019. + + Args: + in_channels (int): Number of input channels + num_joints (int): Number of joints + loss_keypoint (dict): Config for keypoint loss. Default: None. + max_norm (float|None): if not None, the weight of convolution layers + will be clipped to have a maximum norm of max_norm. + is_trajectory (bool): If the model only predicts root joint + position, then this arg should be set to True. In this case, + traj_loss will be calculated. Otherwise, it should be set to + False. Default: False. + """ + + def __init__(self, + in_channels, + num_joints, + max_norm=None, + loss_keypoint=None, + is_trajectory=False, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.in_channels = in_channels + self.num_joints = num_joints + self.max_norm = max_norm + self.loss = build_loss(loss_keypoint) + self.is_trajectory = is_trajectory + if self.is_trajectory: + assert self.num_joints == 1 + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + + self.conv = build_conv_layer( + dict(type='Conv1d'), in_channels, num_joints * 3, 1) + + if self.max_norm is not None: + # Apply weight norm clip to conv layers + weight_clip = WeightNormClipHook(self.max_norm) + for module in self.modules(): + if isinstance(module, nn.modules.conv._ConvNd): + weight_clip.register(module) + + @staticmethod + def _transform_inputs(x): + """Transform inputs for decoder. + + Args: + inputs (tuple or list of Tensor | Tensor): multi-level features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(x, (list, tuple)): + return x + + assert len(x) > 0 + + # return the top-level feature of the 1D feature pyramid + return x[-1] + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + + assert x.ndim == 3 and x.shape[2] == 1, f'Invalid shape {x.shape}' + output = self.conv(x) + N = output.shape[0] + return output.reshape(N, self.num_joints, 3) + + def get_loss(self, output, target, target_weight): + """Calculate keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 3]): Output keypoints. + target (torch.Tensor[N, K, 3]): Target keypoints. + target_weight (torch.Tensor[N, K, 3]): + Weights across different joint types. + If self.is_trajectory is True and target_weight is None, + target_weight will be set inversely proportional to joint + depth. + """ + losses = dict() + assert not isinstance(self.loss, nn.Sequential) + + # trajectory model + if self.is_trajectory: + if target.dim() == 2: + target.unsqueeze_(1) + + if target_weight is None: + target_weight = (1 / target[:, :, 2:]).expand(target.shape) + assert target.dim() == 3 and target_weight.dim() == 3 + + losses['traj_loss'] = self.loss(output, target, target_weight) + + # pose model + else: + if target_weight is None: + target_weight = target.new_ones(target.shape) + assert target.dim() == 3 and target_weight.dim() == 3 + losses['reg_loss'] = self.loss(output, target, target_weight) + + return losses + + def get_accuracy(self, output, target, target_weight, metas): + """Calculate accuracy for keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 3]): Output keypoints. + target (torch.Tensor[N, K, 3]): Target keypoints. + target_weight (torch.Tensor[N, K, 3]): + Weights across different joint types. + metas (list(dict)): Information about data augmentation including: + + - target_image_path (str): Optional, path to the image file + - target_mean (float): Optional, normalization parameter of + the target pose. + - target_std (float): Optional, normalization parameter of the + target pose. + - root_position (np.ndarray[3,1]): Optional, global + position of the root joint. + - root_index (torch.ndarray[1,]): Optional, original index of + the root joint before root-centering. + """ + + accuracy = dict() + + N = output.shape[0] + output_ = output.detach().cpu().numpy() + target_ = target.detach().cpu().numpy() + # Denormalize the predicted pose + if 'target_mean' in metas[0] and 'target_std' in metas[0]: + target_mean = np.stack([m['target_mean'] for m in metas]) + target_std = np.stack([m['target_std'] for m in metas]) + output_ = self._denormalize_joints(output_, target_mean, + target_std) + target_ = self._denormalize_joints(target_, target_mean, + target_std) + + # Restore global position + if self.test_cfg.get('restore_global_position', False): + root_pos = np.stack([m['root_position'] for m in metas]) + root_idx = metas[0].get('root_position_index', None) + output_ = self._restore_global_position(output_, root_pos, + root_idx) + target_ = self._restore_global_position(target_, root_pos, + root_idx) + # Get target weight + if target_weight is None: + target_weight_ = np.ones_like(target_) + else: + target_weight_ = target_weight.detach().cpu().numpy() + if self.test_cfg.get('restore_global_position', False): + root_idx = metas[0].get('root_position_index', None) + root_weight = metas[0].get('root_joint_weight', 1.0) + target_weight_ = self._restore_root_target_weight( + target_weight_, root_weight, root_idx) + + mpjpe = np.mean( + np.linalg.norm((output_ - target_) * target_weight_, axis=-1)) + + transformed_output = np.zeros_like(output_) + for i in range(N): + transformed_output[i, :, :] = compute_similarity_transform( + output_[i, :, :], target_[i, :, :]) + p_mpjpe = np.mean( + np.linalg.norm( + (transformed_output - target_) * target_weight_, axis=-1)) + + accuracy['mpjpe'] = output.new_tensor(mpjpe) + accuracy['p_mpjpe'] = output.new_tensor(p_mpjpe) + + return accuracy + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_regression (np.ndarray): Output regression. + + Args: + x (torch.Tensor[N, K, 2]): Input features. + flip_pairs (None | list[tuple()): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_regression = fliplr_regression( + output.detach().cpu().numpy(), + flip_pairs, + center_mode='static', + center_x=0) + else: + output_regression = output.detach().cpu().numpy() + return output_regression + + def decode(self, metas, output): + """Decode the keypoints from output regression. + + Args: + metas (list(dict)): Information about data augmentation. + By default this includes: + + - "target_image_path": path to the image file + output (np.ndarray[N, K, 3]): predicted regression vector. + metas (list(dict)): Information about data augmentation including: + + - target_image_path (str): Optional, path to the image file + - target_mean (float): Optional, normalization parameter of + the target pose. + - target_std (float): Optional, normalization parameter of the + target pose. + - root_position (np.ndarray[3,1]): Optional, global + position of the root joint. + - root_index (torch.ndarray[1,]): Optional, original index of + the root joint before root-centering. + """ + + # Denormalize the predicted pose + if 'target_mean' in metas[0] and 'target_std' in metas[0]: + target_mean = np.stack([m['target_mean'] for m in metas]) + target_std = np.stack([m['target_std'] for m in metas]) + output = self._denormalize_joints(output, target_mean, target_std) + + # Restore global position + if self.test_cfg.get('restore_global_position', False): + root_pos = np.stack([m['root_position'] for m in metas]) + root_idx = metas[0].get('root_position_index', None) + output = self._restore_global_position(output, root_pos, root_idx) + + target_image_paths = [m.get('target_image_path', None) for m in metas] + result = {'preds': output, 'target_image_paths': target_image_paths} + + return result + + @staticmethod + def _denormalize_joints(x, mean, std): + """Denormalize joint coordinates with given statistics mean and std. + + Args: + x (np.ndarray[N, K, 3]): Normalized joint coordinates. + mean (np.ndarray[K, 3]): Mean value. + std (np.ndarray[K, 3]): Std value. + """ + assert x.ndim == 3 + assert x.shape == mean.shape == std.shape + + return x * std + mean + + @staticmethod + def _restore_global_position(x, root_pos, root_idx=None): + """Restore global position of the root-centered joints. + + Args: + x (np.ndarray[N, K, 3]): root-centered joint coordinates + root_pos (np.ndarray[N,1,3]): The global position of the + root joint. + root_idx (int|None): If not none, the root joint will be inserted + back to the pose at the given index. + """ + x = x + root_pos + if root_idx is not None: + x = np.insert(x, root_idx, root_pos.squeeze(1), axis=1) + return x + + @staticmethod + def _restore_root_target_weight(target_weight, root_weight, root_idx=None): + """Restore the target weight of the root joint after the restoration of + the global position. + + Args: + target_weight (np.ndarray[N, K, 1]): Target weight of relativized + joints. + root_weight (float): The target weight value of the root joint. + root_idx (int|None): If not none, the root joint weight will be + inserted back to the target weight at the given index. + """ + if root_idx is not None: + root_weight = np.full( + target_weight.shape[0], root_weight, dtype=target_weight.dtype) + target_weight = np.insert( + target_weight, root_idx, root_weight[:, None], axis=1) + return target_weight + + def init_weights(self): + """Initialize the weights.""" + for m in self.modules(): + if isinstance(m, nn.modules.conv._ConvNd): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, _BatchNorm): + constant_init(m, 1) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_base_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_base_head.py new file mode 100644 index 0000000..08d483c --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_base_head.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch.nn as nn + +# from mmpose.core.evaluation.top_down_eval import keypoints_from_heatmaps + + +class TopdownHeatmapBaseHead(nn.Module): + """Base class for top-down heatmap heads. + + All top-down heatmap heads should subclass it. + All subclass should overwrite: + + Methods:`get_loss`, supporting to calculate loss. + Methods:`get_accuracy`, supporting to calculate accuracy. + Methods:`forward`, supporting to forward model. + Methods:`inference_model`, supporting to inference model. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def get_loss(self, **kwargs): + """Gets the loss.""" + + @abstractmethod + def get_accuracy(self, **kwargs): + """Gets the accuracy.""" + + @abstractmethod + def forward(self, **kwargs): + """Forward function.""" + + @abstractmethod + def inference_model(self, **kwargs): + """Inference function.""" + + def decode(self, img_metas, output, **kwargs): + """Decode keypoints from heatmaps. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + output (np.ndarray[N, K, H, W]): model predicted heatmaps. + """ + # batch_size = len(img_metas) + + # if 'bbox_id' in img_metas[0]: + # bbox_ids = [] + # else: + # bbox_ids = None + + # c = np.zeros((batch_size, 2), dtype=np.float32) + # s = np.zeros((batch_size, 2), dtype=np.float32) + # image_paths = [] + # score = np.ones(batch_size) + # for i in range(batch_size): + # c[i, :] = img_metas[i]['center'] + # s[i, :] = img_metas[i]['scale'] + # image_paths.append(img_metas[i]['image_file']) + + # if 'bbox_score' in img_metas[i]: + # score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + # if bbox_ids is not None: + # bbox_ids.append(img_metas[i]['bbox_id']) + + # preds, maxvals = keypoints_from_heatmaps( + # output, + # c, + # s, + # unbiased=self.test_cfg.get('unbiased_decoding', False), + # post_process=self.test_cfg.get('post_process', 'default'), + # kernel=self.test_cfg.get('modulate_kernel', 11), + # valid_radius_factor=self.test_cfg.get('valid_radius_factor', + # 0.0546875), + # use_udp=self.test_cfg.get('use_udp', False), + # target_type=self.test_cfg.get('target_type', 'GaussianHeatmap')) + + # all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) + # all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + # all_preds[:, :, 0:2] = preds[:, :, 0:2] + # all_preds[:, :, 2:3] = maxvals + # all_boxes[:, 0:2] = c[:, 0:2] + # all_boxes[:, 2:4] = s[:, 0:2] + # all_boxes[:, 4] = np.prod(s * 200.0, axis=1) + # all_boxes[:, 5] = score + + # result = {} + + # result['preds'] = all_preds + # result['boxes'] = all_boxes + # result['image_paths'] = image_paths + # result['bbox_ids'] = bbox_ids + + return None + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_multi_stage_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_multi_stage_head.py new file mode 100644 index 0000000..c439f5b --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_multi_stage_head.py @@ -0,0 +1,572 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy as cp + +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Linear, + build_activation_layer, build_conv_layer, + build_norm_layer, build_upsample_layer, constant_init, + kaiming_init, normal_init) + +from mmpose.core.evaluation import pose_pck_accuracy +from mmpose.core.post_processing import flip_back +from mmpose.models.builder import build_loss +from ..builder import HEADS +from .topdown_heatmap_base_head import TopdownHeatmapBaseHead + + +@HEADS.register_module() +class TopdownHeatmapMultiStageHead(TopdownHeatmapBaseHead): + """Top-down heatmap multi-stage head. + + TopdownHeatmapMultiStageHead is consisted of multiple branches, + each of which has num_deconv_layers(>=0) number of deconv layers + and a simple conv2d layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_stages (int): Number of stages. + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + in_channels=512, + out_channels=17, + num_stages=1, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + loss_keypoint=None, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.in_channels = in_channels + self.num_stages = num_stages + self.loss = build_loss(loss_keypoint) + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + # build multi-stage deconv layers + self.multi_deconv_layers = nn.ModuleList([]) + for _ in range(self.num_stages): + if num_deconv_layers > 0: + deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + self.multi_deconv_layers.append(deconv_layers) + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + # build multi-stage final layers + self.multi_final_layers = nn.ModuleList([]) + for i in range(self.num_stages): + if identity_final_layer: + final_layer = nn.Identity() + else: + final_layer = build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=num_deconv_filters[-1] + if num_deconv_layers > 0 else in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding) + self.multi_final_layers.append(final_layer) + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - num_outputs: O + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): + Output heatmaps. + target (torch.Tensor[N,K,H,W]): + Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + losses = dict() + + assert isinstance(output, list) + assert target.dim() == 4 and target_weight.dim() == 3 + + if isinstance(self.loss, nn.Sequential): + assert len(self.loss) == len(output) + for i in range(len(output)): + target_i = target + target_weight_i = target_weight + if isinstance(self.loss, nn.Sequential): + loss_func = self.loss[i] + else: + loss_func = self.loss + loss_i = loss_func(output[i], target_i, target_weight_i) + if 'heatmap_loss' not in losses: + losses['heatmap_loss'] = loss_i + else: + losses['heatmap_loss'] += loss_i + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + accuracy = dict() + + if self.target_type == 'GaussianHeatmap': + _, avg_acc, _ = pose_pck_accuracy( + output[-1].detach().cpu().numpy(), + target.detach().cpu().numpy(), + target_weight.detach().cpu().numpy().squeeze(-1) > 0) + accuracy['acc_pose'] = float(avg_acc) + + return accuracy + + def forward(self, x): + """Forward function. + + Returns: + out (list[Tensor]): a list of heatmaps from multiple stages. + """ + out = [] + assert isinstance(x, list) + for i in range(self.num_stages): + y = self.multi_deconv_layers[i](x[i]) + y = self.multi_final_layers[i](y) + out.append(y) + return out + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (List[torch.Tensor[NxKxHxW]]): Input features. + flip_pairs (None | list[tuple()): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + assert isinstance(output, list) + output = output[-1] + + if flip_pairs is not None: + # perform flip + output_heatmap = flip_back( + output.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output.detach().cpu().numpy() + + return output_heatmap + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.multi_deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.multi_final_layers.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + + +class PredictHeatmap(nn.Module): + """Predict the heat map for an input feature. + + Args: + unit_channels (int): Number of input channels. + out_channels (int): Number of output channels. + out_shape (tuple): Shape of the output heatmap. + use_prm (bool): Whether to use pose refine machine. Default: False. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + unit_channels, + out_channels, + out_shape, + use_prm=False, + norm_cfg=dict(type='BN')): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.unit_channels = unit_channels + self.out_channels = out_channels + self.out_shape = out_shape + self.use_prm = use_prm + if use_prm: + self.prm = PRM(out_channels, norm_cfg=norm_cfg) + self.conv_layers = nn.Sequential( + ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + inplace=False), + ConvModule( + unit_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None, + inplace=False)) + + def forward(self, feature): + feature = self.conv_layers(feature) + output = nn.functional.interpolate( + feature, size=self.out_shape, mode='bilinear', align_corners=True) + if self.use_prm: + output = self.prm(output) + return output + + +class PRM(nn.Module): + """Pose Refine Machine. + + Please refer to "Learning Delicate Local Representations + for Multi-Person Pose Estimation" (ECCV 2020). + + Args: + out_channels (int): Channel number of the output. Equals to + the number of key points. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, out_channels, norm_cfg=dict(type='BN')): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + self.out_channels = out_channels + self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) + self.middle_path = nn.Sequential( + Linear(self.out_channels, self.out_channels), + build_norm_layer(dict(type='BN1d'), out_channels)[1], + build_activation_layer(dict(type='ReLU')), + Linear(self.out_channels, self.out_channels), + build_norm_layer(dict(type='BN1d'), out_channels)[1], + build_activation_layer(dict(type='ReLU')), + build_activation_layer(dict(type='Sigmoid'))) + + self.bottom_path = nn.Sequential( + ConvModule( + self.out_channels, + self.out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + inplace=False), + DepthwiseSeparableConvModule( + self.out_channels, + 1, + kernel_size=9, + stride=1, + padding=4, + norm_cfg=norm_cfg, + inplace=False), build_activation_layer(dict(type='Sigmoid'))) + self.conv_bn_relu_prm_1 = ConvModule( + self.out_channels, + self.out_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + inplace=False) + + def forward(self, x): + out = self.conv_bn_relu_prm_1(x) + out_1 = out + + out_2 = self.global_pooling(out_1) + out_2 = out_2.view(out_2.size(0), -1) + out_2 = self.middle_path(out_2) + out_2 = out_2.unsqueeze(2) + out_2 = out_2.unsqueeze(3) + + out_3 = self.bottom_path(out_1) + out = out_1 * (1 + out_2 * out_3) + + return out + + +@HEADS.register_module() +class TopdownHeatmapMSMUHead(TopdownHeatmapBaseHead): + """Heads for multi-stage multi-unit heads used in Multi-Stage Pose + estimation Network (MSPN), and Residual Steps Networks (RSN). + + Args: + unit_channels (int): Number of input channels. + out_channels (int): Number of output channels. + out_shape (tuple): Shape of the output heatmap. + num_stages (int): Number of stages. + num_units (int): Number of units in each stage. + use_prm (bool): Whether to use pose refine machine (PRM). + Default: False. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + out_shape, + unit_channels=256, + out_channels=17, + num_stages=4, + num_units=4, + use_prm=False, + norm_cfg=dict(type='BN'), + loss_keypoint=None, + train_cfg=None, + test_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__() + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + self.out_shape = out_shape + self.unit_channels = unit_channels + self.out_channels = out_channels + self.num_stages = num_stages + self.num_units = num_units + + self.loss = build_loss(loss_keypoint) + + self.predict_layers = nn.ModuleList([]) + for i in range(self.num_stages): + for j in range(self.num_units): + self.predict_layers.append( + PredictHeatmap( + unit_channels, + out_channels, + out_shape, + use_prm, + norm_cfg=norm_cfg)) + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - num_outputs: O + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,O,K,H,W]): Output heatmaps. + target (torch.Tensor[N,O,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,O,K,1]): + Weights across different joint types. + """ + + losses = dict() + + assert isinstance(output, list) + assert target.dim() == 5 and target_weight.dim() == 4 + assert target.size(1) == len(output) + + if isinstance(self.loss, nn.Sequential): + assert len(self.loss) == len(output) + for i in range(len(output)): + target_i = target[:, i, :, :, :] + target_weight_i = target_weight[:, i, :, :] + + if isinstance(self.loss, nn.Sequential): + loss_func = self.loss[i] + else: + loss_func = self.loss + + loss_i = loss_func(output[i], target_i, target_weight_i) + if 'heatmap_loss' not in losses: + losses['heatmap_loss'] = loss_i + else: + losses['heatmap_loss'] += loss_i + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + accuracy = dict() + + if self.target_type == 'GaussianHeatmap': + assert isinstance(output, list) + assert target.dim() == 5 and target_weight.dim() == 4 + _, avg_acc, _ = pose_pck_accuracy( + output[-1].detach().cpu().numpy(), + target[:, -1, ...].detach().cpu().numpy(), + target_weight[:, -1, + ...].detach().cpu().numpy().squeeze(-1) > 0) + accuracy['acc_pose'] = float(avg_acc) + + return accuracy + + def forward(self, x): + """Forward function. + + Returns: + out (list[Tensor]): a list of heatmaps from multiple stages + and units. + """ + out = [] + assert isinstance(x, list) + assert len(x) == self.num_stages + assert isinstance(x[0], list) + assert len(x[0]) == self.num_units + assert x[0][0].shape[1] == self.unit_channels + for i in range(self.num_stages): + for j in range(self.num_units): + y = self.predict_layers[i * self.num_units + j](x[i][j]) + out.append(y) + + return out + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (list[torch.Tensor[N,K,H,W]]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + assert isinstance(output, list) + output = output[-1] + if flip_pairs is not None: + output_heatmap = flip_back( + output.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output.detach().cpu().numpy() + return output_heatmap + + def init_weights(self): + """Initialize model weights.""" + for m in self.predict_layers.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_simple_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_simple_head.py new file mode 100644 index 0000000..9725ab4 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_simple_head.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +# from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, +# constant_init, normal_init) + +# from mmpose.core.evaluation import pose_pck_accuracy +# from mmpose.core.post_processing import flip_back +# from mmpose.models.builder import build_loss +# from mmpose.models.utils.ops import resize +# from ..builder import HEADS +import torch.nn.functional as F +from .topdown_heatmap_base_head import TopdownHeatmapBaseHead + +def build_conv_layer(cfg, *args, **kwargs) -> nn.Module: + """LICENSE""" + + if cfg is None: + cfg_ = dict(type='Conv2d') + else: + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type !='Conv2d': + raise KeyError(f'Unrecognized layer type {layer_type}') + else: + conv_layer = nn.Conv2d + + layer = conv_layer(*args, **kwargs, **cfg_) + + return layer + +def build_upsample_layer(cfg, *args, **kwargs) -> nn.Module: + + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type !='deconv': + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = nn.ConvTranspose2d + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer + +# @HEADS.register_module() +class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead): + """Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple + Baselines for Human Pose Estimation and Tracking``. + + TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers + and a simple conv2d layer. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + in_index (int|Sequence[int]): Input feature index. Default: 0 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + Default: None. + + - 'resize_concat': Multiple feature maps will be resized to the + same size as the first one and then concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + in_index=0, + input_transform=None, + align_corners=False, + loss_keypoint=None, + train_cfg=None, + test_cfg=None, + upsample=0,): + super().__init__() + + self.in_channels = in_channels + self.loss = None + self.upsample = upsample + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + self._init_inputs(in_channels, in_index, input_transform) + self.in_index = in_index + self.align_corners = align_corners + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append( + nn.BatchNorm2d(conv_channels) +) + layers.append(nn.ReLU(inplace=True)) + + layers.append( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding)) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + losses = dict() + + assert not isinstance(self.loss, nn.Sequential) + assert target.dim() == 4 and target_weight.dim() == 3 + losses['heatmap_loss'] = self.loss(output, target, target_weight) + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + accuracy = dict() + + if self.target_type == 'GaussianHeatmap': + _, avg_acc, _ = pose_pck_accuracy( + output.detach().cpu().numpy(), + target.detach().cpu().numpy(), + target_weight.detach().cpu().numpy().squeeze(-1) > 0) + accuracy['acc_pose'] = float(avg_acc) + + return accuracy + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + x = self.deconv_layers(x) + x = self.final_layer(x) + return x + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_heatmap = flip_back( + output.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output.detach().cpu().numpy() + return output_heatmap + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform is not None, in_channels and in_index must be + list or tuple, with the same length. + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + + - 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor] | Tensor): multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(inputs, list): + if not isinstance(inputs, list): + if self.upsample > 0: + inputs = resize( + input=F.relu(inputs), + scale_factor=self.upsample, + mode='bilinear', + align_corners=self.align_corners + ) + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/vipnas_heatmap_simple_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/vipnas_heatmap_simple_head.py new file mode 100644 index 0000000..5844fd5 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/vipnas_heatmap_simple_head.py @@ -0,0 +1,349 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +# from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, +# constant_init, normal_init) + +# from mmpose.core.evaluation import pose_pck_accuracy +# from mmpose.core.post_processing import flip_back +# from mmpose.models.builder import build_loss +# from mmpose.models.utils.ops import resize +# from ..builder import HEADS +# from .topdown_heatmap_base_head import TopdownHeatmapBaseHead + + +# @HEADS.register_module() +class ViPNASHeatmapSimpleHead(TopdownHeatmapBaseHead): + """ViPNAS heatmap simple head. + + ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search. + More details can be found in the `paper + `__ . + + TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers + and a simple conv2d layer. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + num_deconv_groups (list|tuple): Group number. + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + Default: None. + + - 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_deconv_layers=3, + num_deconv_filters=(144, 144, 144), + num_deconv_kernels=(4, 4, 4), + num_deconv_groups=(16, 16, 16), + extra=None, + in_index=0, + input_transform=None, + align_corners=False, + loss_keypoint=None, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.in_channels = in_channels + self.loss = build_loss(loss_keypoint) + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + self._init_inputs(in_channels, in_index, input_transform) + self.in_index = in_index + self.align_corners = align_corners + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, num_deconv_filters, num_deconv_kernels, + num_deconv_groups) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append( + build_norm_layer(dict(type='BN'), conv_channels)[1]) + layers.append(nn.ReLU(inplace=True)) + + layers.append( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding)) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + losses = dict() + + assert not isinstance(self.loss, nn.Sequential) + assert target.dim() == 4 and target_weight.dim() == 3 + losses['heatmap_loss'] = self.loss(output, target, target_weight) + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + accuracy = dict() + + if self.target_type.lower() == 'GaussianHeatmap'.lower(): + _, avg_acc, _ = pose_pck_accuracy( + output.detach().cpu().numpy(), + target.detach().cpu().numpy(), + target_weight.detach().cpu().numpy().squeeze(-1) > 0) + accuracy['acc_pose'] = float(avg_acc) + + return accuracy + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + x = self.deconv_layers(x) + x = self.final_layer(x) + return x + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_heatmap = flip_back( + output.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output.detach().cpu().numpy() + return output_heatmap + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform is not None, in_channels and in_index must be + list or tuple, with the same length. + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + + - 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor] | Tensor): multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(inputs, list): + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels, + num_groups): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + if num_layers != len(num_groups): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_groups({len(num_groups)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + groups = num_groups[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + groups=groups, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/voxelpose_head.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/voxelpose_head.py new file mode 100644 index 0000000..8799bdc --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/voxelpose_head.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------------ +# Copyright and License Information +# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models +# Original Licence: MIT License +# ------------------------------------------------------------------------------ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import HEADS + + +@HEADS.register_module() +class CuboidCenterHead(nn.Module): + """Get results from the 3D human center heatmap. In this module, human 3D + centers are local maximums obtained from the 3D heatmap via NMS (max- + pooling). + + Args: + space_size (list[3]): The size of the 3D space. + cube_size (list[3]): The size of the heatmap volume. + space_center (list[3]): The coordinate of space center. + max_num (int): Maximum of human center detections. + max_pool_kernel (int): Kernel size of the max-pool kernel in nms. + """ + + def __init__(self, + space_size, + space_center, + cube_size, + max_num=10, + max_pool_kernel=3): + super(CuboidCenterHead, self).__init__() + # use register_buffer + self.register_buffer('grid_size', torch.tensor(space_size)) + self.register_buffer('cube_size', torch.tensor(cube_size)) + self.register_buffer('grid_center', torch.tensor(space_center)) + + self.num_candidates = max_num + self.max_pool_kernel = max_pool_kernel + self.loss = nn.MSELoss() + + def _get_real_locations(self, indices): + """ + Args: + indices (torch.Tensor(NXP)): Indices of points in the 3D tensor + + Returns: + real_locations (torch.Tensor(NXPx3)): Locations of points + in the world coordinate system + """ + real_locations = indices.float() / ( + self.cube_size - 1) * self.grid_size + \ + self.grid_center - self.grid_size / 2.0 + return real_locations + + def _nms_by_max_pool(self, heatmap_volumes): + max_num = self.num_candidates + batch_size = heatmap_volumes.shape[0] + root_cubes_nms = self._max_pool(heatmap_volumes) + root_cubes_nms_reshape = root_cubes_nms.reshape(batch_size, -1) + topk_values, topk_index = root_cubes_nms_reshape.topk(max_num) + topk_unravel_index = self._get_3d_indices(topk_index, + heatmap_volumes[0].shape) + + return topk_values, topk_unravel_index + + def _max_pool(self, inputs): + kernel = self.max_pool_kernel + padding = (kernel - 1) // 2 + max = F.max_pool3d( + inputs, kernel_size=kernel, stride=1, padding=padding) + keep = (inputs == max).float() + return keep * inputs + + @staticmethod + def _get_3d_indices(indices, shape): + """Get indices in the 3-D tensor. + + Args: + indices (torch.Tensor(NXp)): Indices of points in the 1D tensor + shape (torch.Size(3)): The shape of the original 3D tensor + + Returns: + indices: Indices of points in the original 3D tensor + """ + batch_size = indices.shape[0] + num_people = indices.shape[1] + indices_x = (indices // + (shape[1] * shape[2])).reshape(batch_size, num_people, -1) + indices_y = ((indices % (shape[1] * shape[2])) // + shape[2]).reshape(batch_size, num_people, -1) + indices_z = (indices % shape[2]).reshape(batch_size, num_people, -1) + indices = torch.cat([indices_x, indices_y, indices_z], dim=2) + return indices + + def forward(self, heatmap_volumes): + """ + + Args: + heatmap_volumes (torch.Tensor(NXLXWXH)): + 3D human center heatmaps predicted by the network. + Returns: + human_centers (torch.Tensor(NXPX5)): + Coordinates of human centers. + """ + batch_size = heatmap_volumes.shape[0] + + topk_values, topk_unravel_index = self._nms_by_max_pool( + heatmap_volumes.detach()) + + topk_unravel_index = self._get_real_locations(topk_unravel_index) + + human_centers = torch.zeros( + batch_size, self.num_candidates, 5, device=heatmap_volumes.device) + human_centers[:, :, 0:3] = topk_unravel_index + human_centers[:, :, 4] = topk_values + + return human_centers + + def get_loss(self, pred_cubes, gt): + + return dict(loss_center=self.loss(pred_cubes, gt)) + + +@HEADS.register_module() +class CuboidPoseHead(nn.Module): + + def __init__(self, beta): + """Get results from the 3D human pose heatmap. Instead of obtaining + maximums on the heatmap, this module regresses the coordinates of + keypoints via integral pose regression. Refer to `paper. + + ` for more details. + + Args: + beta: Constant to adjust the magnification of soft-maxed heatmap. + """ + super(CuboidPoseHead, self).__init__() + self.beta = beta + self.loss = nn.L1Loss() + + def forward(self, heatmap_volumes, grid_coordinates): + """ + + Args: + heatmap_volumes (torch.Tensor(NxKxLxWxH)): + 3D human pose heatmaps predicted by the network. + grid_coordinates (torch.Tensor(Nx(LxWxH)x3)): + Coordinates of the grids in the heatmap volumes. + Returns: + human_poses (torch.Tensor(NxKx3)): Coordinates of human poses. + """ + batch_size = heatmap_volumes.size(0) + channel = heatmap_volumes.size(1) + x = heatmap_volumes.reshape(batch_size, channel, -1, 1) + x = F.softmax(self.beta * x, dim=2) + grid_coordinates = grid_coordinates.unsqueeze(1) + x = torch.mul(x, grid_coordinates) + human_poses = torch.sum(x, dim=2) + + return human_poses + + def get_loss(self, preds, targets, weights): + + return dict(loss_pose=self.loss(preds * weights, targets * weights)) diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/model_builder.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/model_builder.py new file mode 100644 index 0000000..724cfcb --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/model_builder.py @@ -0,0 +1,67 @@ +import torch + +# from configs.coco.ViTPose_base_coco_256x192 import model +from .heads.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead + +# import TopdownHeatmapSimpleHead +from .backbones import ViT + +# print(model) +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +from importlib import import_module + + +def build_model(model_name, checkpoint=None): + try: + path = ".configs.coco." + model_name + mod = import_module(path, package="src.vitpose_infer") + + model = getattr(mod, "model") + # from path import model + except: + raise ValueError("not a correct config") + + head = TopdownHeatmapSimpleHead( + in_channels=model["keypoint_head"]["in_channels"], + out_channels=model["keypoint_head"]["out_channels"], + num_deconv_filters=model["keypoint_head"]["num_deconv_filters"], + num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"], + num_deconv_layers=model["keypoint_head"]["num_deconv_layers"], + extra=model["keypoint_head"]["extra"], + ) + # print(head) + backbone = ViT( + img_size=model["backbone"]["img_size"], + patch_size=model["backbone"]["patch_size"], + embed_dim=model["backbone"]["embed_dim"], + depth=model["backbone"]["depth"], + num_heads=model["backbone"]["num_heads"], + ratio=model["backbone"]["ratio"], + mlp_ratio=model["backbone"]["mlp_ratio"], + qkv_bias=model["backbone"]["qkv_bias"], + drop_path_rate=model["backbone"]["drop_path_rate"], + ) + + class VitPoseModel(nn.Module): + def __init__(self, backbone, keypoint_head): + super(VitPoseModel, self).__init__() + self.backbone = backbone + self.keypoint_head = keypoint_head + + def forward(self, x): + x = self.backbone(x) + x = self.keypoint_head(x) + return x + + pose = VitPoseModel(backbone, head) + if checkpoint is not None: + check = torch.load(checkpoint) + + pose.load_state_dict(check["state_dict"]) + return pose + + +# pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b-multi-coco.pth') diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/model_builder.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/model_builder.py new file mode 100644 index 0000000..1fcb014 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/model_builder.py @@ -0,0 +1,167 @@ +import torch + +# from configs.coco.ViTPose_base_coco_256x192 import model +from .builder.heads.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead + +# import TopdownHeatmapSimpleHead +from .builder.backbones import ViT + +# print(model) +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +from importlib import import_module + +models = { + "ViTPose_huge_coco_256x192": dict( + type="TopDown", + pretrained=None, + backbone=dict( + type="ViT", + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ), + keypoint_head=dict( + type="TopdownHeatmapSimpleHead", + in_channels=1280, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict( + final_conv_kernel=1, + ), + out_channels=17, + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), + train_cfg=dict(), + test_cfg=dict(), + ), + "ViTPose_base_coco_256x192": dict( + type="TopDown", + pretrained=None, + backbone=dict( + type="ViT", + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type="TopdownHeatmapSimpleHead", + in_channels=768, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict( + final_conv_kernel=1, + ), + out_channels=17, + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), + train_cfg=dict(), + test_cfg=dict(), + ), + "ViTPose_base_simple_coco_256x192": dict( + type="TopDown", + pretrained=None, + backbone=dict( + type="ViT", + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type="TopdownHeatmapSimpleHead", + in_channels=768, + num_deconv_layers=0, + num_deconv_filters=[], + num_deconv_kernels=[], + upsample=4, + extra=dict( + final_conv_kernel=3, + ), + out_channels=17, + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process="default", + shift_heatmap=False, + target_type="GaussianHeatmap", + modulate_kernel=11, + use_udp=True, + ), + ), +} + + +def build_model(model_name, checkpoint=None): + try: + model = models[model_name] + except: + raise ValueError("not a correct config") + + head = TopdownHeatmapSimpleHead( + in_channels=model["keypoint_head"]["in_channels"], + out_channels=model["keypoint_head"]["out_channels"], + num_deconv_filters=model["keypoint_head"]["num_deconv_filters"], + num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"], + num_deconv_layers=model["keypoint_head"]["num_deconv_layers"], + extra=model["keypoint_head"]["extra"], + ) + # print(head) + backbone = ViT( + img_size=model["backbone"]["img_size"], + patch_size=model["backbone"]["patch_size"], + embed_dim=model["backbone"]["embed_dim"], + depth=model["backbone"]["depth"], + num_heads=model["backbone"]["num_heads"], + ratio=model["backbone"]["ratio"], + mlp_ratio=model["backbone"]["mlp_ratio"], + qkv_bias=model["backbone"]["qkv_bias"], + drop_path_rate=model["backbone"]["drop_path_rate"], + ) + + class VitPoseModel(nn.Module): + def __init__(self, backbone, keypoint_head): + super(VitPoseModel, self).__init__() + self.backbone = backbone + self.keypoint_head = keypoint_head + + def forward(self, x): + x = self.backbone(x) + x = self.keypoint_head(x) + return x + + pose = VitPoseModel(backbone, head) + if checkpoint is not None: + check = torch.load(checkpoint) + + pose.load_state_dict(check["state_dict"]) + return pose + + +# pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b-multi-coco.pth') diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/ViTPose_trt.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/ViTPose_trt.py new file mode 100644 index 0000000..c0dd6d6 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/ViTPose_trt.py @@ -0,0 +1,102 @@ +import tensorrt as trt +import torch.nn +from collections import OrderedDict, namedtuple +import numpy as np + +def torch_device_from_trt(device): + if device == trt.TensorLocation.DEVICE: + return torch.device("cuda") + elif device == trt.TensorLocation.HOST: + return torch.device("cpu") + else: + return TypeError("%s is not supported by torch" % device) +def torch_dtype_from_trt(dtype): + if dtype == trt.int8: + return torch.int8 + elif trt.__version__ >= '7.0' and dtype == trt.bool: + return torch.bool + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + else: + raise TypeError("%s is not supported by torch" % dtype) +class TRTModule_ViTPose(torch.nn.Module): + def __init__(self, engine=None, input_names=None, output_names=None, input_flattener=None, output_flattener=None,path=None,device=None): + super(TRTModule_ViTPose, self).__init__() + # self._register_state_dict_hook(TRTModule._on_state_dict) + # self.engine = engine + logger = trt.Logger(trt.Logger.INFO) + with open(path, 'rb') as f, trt.Runtime(logger) as runtime: + self.engine = runtime.deserialize_cuda_engine(f.read()) + if self.engine is not None: + self.context = self.engine.create_execution_context() + self.input_names = ['images'] + self.output_names = [] + self.input_flattener = input_flattener + self.output_flattener = output_flattener + Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + + # with open(path, 'rb') as f, trt.Runtime(logger) as runtime: + # self.model = runtime.deserialize_cuda_engine(f.read()) + # self.context = self.model.create_execution_context() + self.bindings = OrderedDict() + # self.output_names = [] + fp16 = False # default updated below + dynamic = False + for i in range(self.engine.num_bindings): + name = self.engine.get_binding_name(i) + dtype = trt.nptype(self.engine.get_binding_dtype(i)) + if self.engine.binding_is_input(i): + if -1 in tuple(self.engine.get_binding_shape(i)): # dynamic + dynamic = True + self.context.set_binding_shape(i, tuple(self.engine.get_profile_shape(0, i)[2])) + if dtype == np.float16: + fp16 = True + else: # output + self.output_names.append(name) + shape = tuple(self.context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) + self.batch_size = self.bindings['images'].shape[0] + + + + def forward(self, *inputs): + bindings = [None] * (len(self.input_names) + len(self.output_names)) + + if self.input_flattener is not None: + inputs = self.input_flattener.flatten(inputs) + + for i, input_name in enumerate(self.input_names): + idx = self.engine.get_binding_index(input_name) + shape = tuple(inputs[i].shape) + bindings[idx] = inputs[i].contiguous().data_ptr() + self.context.set_binding_shape(idx, shape) + + # create output tensors + outputs = [None] * len(self.output_names) + for i, output_name in enumerate(self.output_names): + idx = self.engine.get_binding_index(output_name) + dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + shape = tuple(self.context.get_binding_shape(idx)) + device = torch_device_from_trt(self.engine.get_location(idx)) + output = torch.empty(size=shape, dtype=dtype, device=device) + outputs[i] = output + bindings[idx] = output.data_ptr() + + self.context.execute_async_v2( + bindings, torch.cuda.current_stream().cuda_stream + ) + + if self.output_flattener is not None: + outputs = self.output_flattener.unflatten(outputs) + else: + outputs = tuple(outputs) + if len(outputs) == 1: + outputs = outputs[0] + + return outputs \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/__init__.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/convert_to_trt.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/convert_to_trt.py new file mode 100644 index 0000000..8af8bd3 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/convert_to_trt.py @@ -0,0 +1,9 @@ +from torch2trt import TRTModule,torch2trt +from builder import build_model +import torch +pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b.pth') +pose.cuda().eval() + +x = torch.ones(1,3,256,192).cuda() +net_trt = torch2trt(pose, [x],max_batch_size=10, fp16_mode=True) +torch.save(net_trt.state_dict(), 'vitpose_trt.pth') \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/general_utils.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/general_utils.py new file mode 100644 index 0000000..2dba4ac --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/general_utils.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jun 15 15:49:22 2022 + +@author: gpastal +""" +import numpy as np +import numpy.ma as ma +# import pika +import json +from collections import OrderedDict +from collections.abc import Iterable +from itertools import chain +import argparse + +def make_parser(): + parser = argparse.ArgumentParser("ByteTrack Demo!") + # exp file + # tracking args + parser.add_argument("--track_thresh", type=float, default=0.2, help="tracking confidence threshold") + parser.add_argument("--track_buffer", type=int, default=240, help="the frames for keep lost tracks") + parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking") + parser.add_argument( + "--aspect_ratio_thresh", type=float, default=1.6, + help="threshold for filtering out boxes of which aspect ratio are above the given value." + ) + parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes') + parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.") + return parser + +def jitter(tracking,temp,id1): + pass +def jitter2(tracking,temp,id1) : + pass + +def create_json_rabbitmq( FRAME_ID,pose): + pass + +def producer_rabbitmq(): + pass +def fix_head(xyz): + pass + +def flatten_lst(x): + if isinstance(x, Iterable): + return [a for i in x for a in flatten_lst(i)] + else: + return [x] + +def polys_from_pose(pts): + seg=[] + for ind, i in enumerate(pts): + list_=[] + list_sc=[] + # list1 = [i[0][1],i[0][0]] + + # list2 = [i[0][1],i[0][0]] + # print(i) + for j in i: + + temp_ = [j[1],j[0]] + + if j[2]>0.4: + + temp2_ = [1] + else: + temp2_ =[0] + list_.append(temp_) + list_sc.append(temp2_) + # print(list_sc) + # list2 = [i[6][1],i[6][0]] + # list3 = [i[11][1],i[11][0]] + # list4 = [i[12][1],i[12][0]] + + # list_ = flatten_lst(list_) + # print(list_) + list_=fix_list_order(list_,list_sc) + # print(list_) + # list_=list(list_) + # list_ = list_.to_list() + # print(list_) + # temp__=list(chain(*list_)) + seg.append(list_) # temp_ = list(chain(list1,list2,list3,list4,list1)) + return seg +def fix_list_order(list_,list2): + # for index,values in enumerate(list_): + myorder = [0, 2, 4, 6, 8,10,12,14,16,15,13,11,9,7,5,3,1] + cor_list = [list_[i] for i in myorder] + cor_list2 = [list2[i] for i in myorder] + # print(cor_list) + # result = list(set(map(tuple,cor_list)) & set(map(tuple,cor_list2))) + # arr = np.array([x for x in cor_list]) + # print(cor_list) + data = np.asarray(cor_list) + # print(data) + mask = np.column_stack((cor_list2, cor_list2)) + # masked = ma.masked_array(data, mask=np.column_stack((cor_list2, cor_list2)))#[cor_list2,cor_list2]) + # result = list(set(masked[~masked.mask])) + # print(result) + # print(data) + # print(mask) + result2 = [] + for inde,i in enumerate(data): + # print(mask[inde]) + if mask[inde].all()==1: + result2.append(i[0]) + result2.append(i[1]) + # result = [int(result[i] for i in result)] + return result2 + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/inference_test.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/inference_test.py new file mode 100644 index 0000000..fbb6220 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/inference_test.py @@ -0,0 +1,29 @@ +from builder import build_model +import torch +from ViTPose_trt import TRTModule_ViTPose +# pose = TRTModule_ViTPose(path='pose_higher_hrnet_w32_512.engine',device='cuda:0') +pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b.pth') +pose.cuda().eval() +if pose.training: + print('train') +else: + print('eval') +device = torch.device("cuda") +# pose.to(device) +dummy_input = torch.randn(10, 3,256,192, dtype=torch.float).to(device) +repetitions=100 +total_time = 0 +starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) +with torch.no_grad(): + for rep in range(repetitions): + # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + # for k in range(10): + _ = pose(dummy_input) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender)/1000 + total_time += curr_time +Throughput = repetitions*10/total_time +print('Final Throughput:',Throughput) +print('Total time',total_time) \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/logger_helper.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/logger_helper.py new file mode 100644 index 0000000..e58d240 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/logger_helper.py @@ -0,0 +1,23 @@ +import logging + +class CustomFormatter(logging.Formatter): + + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" + + FORMATS = { + logging.DEBUG: grey + format + reset, + logging.INFO: grey + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_utils.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_utils.py new file mode 100644 index 0000000..0d6cb06 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_utils.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jun 15 15:45:33 2022 + +@author: gpastal +""" +import torch +import torchvision +import torch.nn.functional as F +from torchvision import transforms as TR +import numpy as np +import cv2 +import logging +# from simpleHRNet.models_.hrnet import HRNet +# from torch2trt import torch2trt,TRTModule +logger = logging.getLogger("Tracker !") +from .timerr import Timer +from pathlib import Path +# import gdown +timer_det = Timer() +timer_track = Timer() +timer_pose = Timer() + +def pose_points_yolo5(detector,image,pose,tracker,tensorrt): + timer_det.tic() + # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + + transform = TR.Compose([ + TR.ToPILImage(), + # Padd(), + TR.Resize((256, 192)), # (height, width) + TR.ToTensor(), + TR.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + + detections = detector(image) + timer_det.toc() + logger.info('DET FPS -- %s',1./timer_det.average_time) + # print(detections.shape) + dets = detections.xyxy[0] + dets = dets[dets[:,5] == 0.] + # dets = dets[dets[:,4] > 0.3] + # logger.warning(len(dets)) + + # if len(dets)>0: + # image_gpu = torch.tensor(image).cuda()/255. + # print(image_gpu.size()) + timer_track.tic() + online_targets=tracker.update(dets,[image.shape[0],image.shape[1]],image.shape) + + online_tlwhs = [] + online_ids = [] + online_scores = [] + for t in online_targets: + tlwh = t.tlwh + tid = t.track_id + # vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_threshs + if tlwh[2] * tlwh[3] > 10 :#and not vertical: + online_tlwhs.append(tlwh) + online_ids.append(tid) + online_scores.append(t.score) + # tracker.update() + timer_track.toc() + logger.info('TRACKING FPS --%s',1./timer_track.average_time) + device='cuda' + nof_people = len(online_ids) if online_ids is not None else 0 + # nof_people=1 + # print(dets) + # print(nof_people) + boxes = torch.empty((nof_people, 4), dtype=torch.int32,device= 'cuda') + # boxes = [] + images = torch.empty((nof_people, 3, 256, 192)) # (height, width) + heatmaps = np.zeros((nof_people, 17, 64, 48), + dtype=np.float32) + # starter.record() + # print(online_tlwhs) + if len(online_tlwhs): + for i, (x1, y1, x2, y2) in enumerate(online_tlwhs): + # for i, (x1, y1, x2, y2) in enumerate(np.array([[55,399,424-55,479-399]])): + # if i<1: + x1 = x1.astype(np.int32) + x2 = x1+x2.astype(np.int32) + y1 = y1.astype(np.int32) + y2 = y1+ y2.astype(np.int32) + if x2>image.shape[1]:x2=image.shape[1]-1 + if y2>image.shape[0]:y2=image.shape[0]-1 + if y1<0: y1=0 + if x1<0 : x1=0 + # print([x1,x2,y1,y2]) + # image = cv2.rectangle(image, (x1,y1), (x2,y2), (0,0,0), 1) + # cv2.imwrite('saved.png',image) + # # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14) + correction_factor = 256 / 192 * (x2 - x1) / (y2 - y1) + if correction_factor > 1: + # increase y side + center = y1 + (y2 - y1) // 2 + length = int(round((y2 - y1) * correction_factor)) + y1_new = int( center - length // 2) + y2_new = int( center + length // 2) + image_crop = image[y1:y2, x1:x2, ::-1] + # print(y1,y2,x1,x2) + pad = (int(abs(y1_new-y1))), int(abs(y2_new-y2)) + image_crop = np.pad(image_crop,((pad), (0, 0), (0, 0))) + images[i] = transform(image_crop) + boxes[i]= torch.tensor([x1, y1_new, x2, y2_new]) + + elif correction_factor < 1: + # increase x side + center = x1 + (x2 - x1) // 2 + length = int(round((x2 - x1) * 1 / correction_factor)) + x1_new = int( center - length // 2) + x2_new = int( center + length // 2) + # images[i] = transform(image[y1:y2, x1:x2, ::-1]) + image_crop = image[y1:y2, x1:x2, ::-1] + pad = (abs(x1_new-x1)), int(abs(x2_new-x2)) + image_crop = np.pad(image_crop,((0, 0), (pad), (0, 0))) + images[i] = transform(image_crop) + boxes[i]= torch.tensor([x1_new, y1, x2_new, y2]) + + + if images.shape[0] > 0: + images = images.to(device) + + if tensorrt: + out = torch.zeros((images.shape[0],17,64,48),device=device) + with torch.no_grad(): + timer_pose.tic() + + for i in range(images.shape[0]): + # timer_pose.tic() + # print(images[i].size()) + + out[i] = pose(images[i].unsqueeze(0)) + timer_pose.toc() + logger.info('POSE FPS -- %s',1./timer_pose.average_time) + else: + with torch.no_grad(): + + timer_pose.tic() + + + + out = pose(images) + timer_pose.toc() + logger.info('POSE FPS -- %s',1./timer_pose.average_time) + + + pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32,device=device) + pts2 = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32) + + (b,indices)=torch.max(out,dim=2) + (b,indices)=torch.max(b,dim=2) + + (c,indicesc)=torch.max(out,dim=3) + (c,indicesc)=torch.max(c,dim=2) + dim1= torch.tensor(1. / 64,device=device) + dim2= torch.tensor(1. / 48,device=device) + + for i in range(0,out.shape[0]): + + pts[i, :, 0] = indicesc[i,:] * dim1 * (boxes[i][3] - boxes[i][1]) + boxes[i][1] + pts[i, :, 1] = indices[i,:] *dim2* (boxes[i][2] - boxes[i][0]) + boxes[i][0] + pts[i, :, 2] = c[i,:] + + pts=pts.cpu().numpy() + # print(pts) + else: + pts = np.empty((0, 0, 3), dtype=np.float32) + online_tlwhs = [] + online_ids = [] + online_scores=[] + res = list() + + res.append(pts) + + + if len(res) > 1: + return res,online_tlwhs,online_ids,online_scores#,pts2 + else: + return res[0],online_tlwhs,online_ids,online_scores#,pts2 + diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_viz.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_viz.py new file mode 100644 index 0000000..ba57581 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_viz.py @@ -0,0 +1,293 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision +import ffmpeg + + +def joints_dict(): + joints = { + "coco": { + "keypoints": { + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + 7: "left_elbow", + 8: "right_elbow", + 9: "left_wrist", + 10: "right_wrist", + 11: "left_hip", + 12: "right_hip", + 13: "left_knee", + 14: "right_knee", + 15: "left_ankle", + 16: "right_ankle" + }, + "skeleton": [ + # # [16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], + # # [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7] + # [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], + # [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6] + [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], + [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], # [3, 5], [4, 6] + [0, 5], [0, 6] + ] + }, + "mpii": { + "keypoints": { + 0: "right_ankle", + 1: "right_knee", + 2: "right_hip", + 3: "left_hip", + 4: "left_knee", + 5: "left_ankle", + 6: "pelvis", + 7: "thorax", + 8: "upper_neck", + 9: "head top", + 10: "right_wrist", + 11: "right_elbow", + 12: "right_shoulder", + 13: "left_shoulder", + 14: "left_elbow", + 15: "left_wrist" + }, + "skeleton": [ + # [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [13, 3], [12, 2], [13, 12], [13, 14], + # [12, 11], [14, 15], [11, 10], # [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7] + [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [3, 6], [2, 6], [6, 7], [7, 8], [8, 9], + [13, 7], [12, 7], [13, 14], [12, 11], [14, 15], [11, 10], + ] + }, + } + return joints + + +def draw_points(image, points, color_palette='tab20', palette_samples=16, confidence_threshold=0.5): + """ + Draws `points` on `image`. + + Args: + image: image in opencv format + points: list of points to be drawn. + Shape: (nof_points, 3) + Format: each point should contain (y, x, confidence) + color_palette: name of a matplotlib color palette + Default: 'tab20' + palette_samples: number of different colors sampled from the `color_palette` + Default: 16 + confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] + Default: 0.5 + + Returns: + A new image with overlaid points + + """ + try: + colors = np.round( + np.array(plt.get_cmap(color_palette).colors) * 255 + ).astype(np.uint8)[:, ::-1].tolist() + except AttributeError: # if palette has not pre-defined colors + colors = np.round( + np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255 + ).astype(np.uint8)[:, -2::-1].tolist() + + circle_size = max(1, min(image.shape[:2]) // 160) # ToDo Shape it taking into account the size of the detection + # circle_size = max(2, int(np.sqrt(np.max(np.max(points, axis=0) - np.min(points, axis=0)) // 16))) + + for i, pt in enumerate(points): + if pt[2] > confidence_threshold: + image = cv2.circle(image, (int(pt[1]), int(pt[0])), circle_size, tuple(colors[i % len(colors)]), -1) + + return image + + +def draw_skeleton(image, points, skeleton, color_palette='Set2', palette_samples=8, person_index=0, + confidence_threshold=0.5): + """ + Draws a `skeleton` on `image`. + + Args: + image: image in opencv format + points: list of points to be drawn. + Shape: (nof_points, 3) + Format: each point should contain (y, x, confidence) + skeleton: list of joints to be drawn + Shape: (nof_joints, 2) + Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points` + color_palette: name of a matplotlib color palette + Default: 'Set2' + palette_samples: number of different colors sampled from the `color_palette` + Default: 8 + person_index: index of the person in `image` + Default: 0 + confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] + Default: 0.5 + + Returns: + A new image with overlaid joints + + """ + try: + colors = np.round( + np.array(plt.get_cmap(color_palette).colors) * 255 + ).astype(np.uint8)[:, ::-1].tolist() + except AttributeError: # if palette has not pre-defined colors + colors = np.round( + np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255 + ).astype(np.uint8)[:, -2::-1].tolist() + + for i, joint in enumerate(skeleton): + pt1, pt2 = points[joint] + if pt1[2] > confidence_threshold and pt2[2] > confidence_threshold: + image = cv2.line( + image, (int(pt1[1]), int(pt1[0])), (int(pt2[1]), int(pt2[0])), + tuple(colors[person_index % len(colors)]), 2 + ) + + return image + + +def draw_points_and_skeleton(image, points, skeleton, points_color_palette='tab20', points_palette_samples=16, + skeleton_color_palette='Set2', skeleton_palette_samples=8, person_index=0, + confidence_threshold=0.5): + """ + Draws `points` and `skeleton` on `image`. + + Args: + image: image in opencv format + points: list of points to be drawn. + Shape: (nof_points, 3) + Format: each point should contain (y, x, confidence) + skeleton: list of joints to be drawn + Shape: (nof_joints, 2) + Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points` + points_color_palette: name of a matplotlib color palette + Default: 'tab20' + points_palette_samples: number of different colors sampled from the `color_palette` + Default: 16 + skeleton_color_palette: name of a matplotlib color palette + Default: 'Set2' + skeleton_palette_samples: number of different colors sampled from the `color_palette` + Default: 8 + person_index: index of the person in `image` + Default: 0 + confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] + Default: 0.5 + + Returns: + A new image with overlaid joints + + """ + image = draw_skeleton(image, points, skeleton, color_palette=skeleton_color_palette, + palette_samples=skeleton_palette_samples, person_index=person_index, + confidence_threshold=confidence_threshold) + image = draw_points(image, points, color_palette=points_color_palette, palette_samples=points_palette_samples, + confidence_threshold=confidence_threshold) + return image + + +def save_images(images, target, joint_target, output, joint_output, joint_visibility, summary_writer=None, step=0, + prefix=''): + """ + Creates a grid of images with gt joints and a grid with predicted joints. + This is a basic function for debugging purposes only. + + If summary_writer is not None, the grid will be written in that SummaryWriter with name "{prefix}_images" and + "{prefix}_predictions". + + Args: + images (torch.Tensor): a tensor of images with shape (batch x channels x height x width). + target (torch.Tensor): a tensor of gt heatmaps with shape (batch x channels x height x width). + joint_target (torch.Tensor): a tensor of gt joints with shape (batch x joints x 2). + output (torch.Tensor): a tensor of predicted heatmaps with shape (batch x channels x height x width). + joint_output (torch.Tensor): a tensor of predicted joints with shape (batch x joints x 2). + joint_visibility (torch.Tensor): a tensor of joint visibility with shape (batch x joints). + summary_writer (tb.SummaryWriter): a SummaryWriter where write the grids. + Default: None + step (int): summary_writer step. + Default: 0 + prefix (str): summary_writer name prefix. + Default: "" + + Returns: + A pair of images which are built from torchvision.utils.make_grid + """ + # Input images with gt + images_ok = images.detach().clone() + images_ok[:, 0].mul_(0.229).add_(0.485) + images_ok[:, 1].mul_(0.224).add_(0.456) + images_ok[:, 2].mul_(0.225).add_(0.406) + for i in range(images.shape[0]): + joints = joint_target[i] * 4. + joints_vis = joint_visibility[i] + + for joint, joint_vis in zip(joints, joints_vis): + if joint_vis[0]: + a = int(joint[1].item()) + b = int(joint[0].item()) + # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0]) + images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1 + images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0 + grid_gt = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False) + if summary_writer is not None: + summary_writer.add_image(prefix + 'images', grid_gt, global_step=step) + + # Input images with prediction + images_ok = images.detach().clone() + images_ok[:, 0].mul_(0.229).add_(0.485) + images_ok[:, 1].mul_(0.224).add_(0.456) + images_ok[:, 2].mul_(0.225).add_(0.406) + for i in range(images.shape[0]): + joints = joint_output[i] * 4. + joints_vis = joint_visibility[i] + + for joint, joint_vis in zip(joints, joints_vis): + if joint_vis[0]: + a = int(joint[1].item()) + b = int(joint[0].item()) + # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0]) + images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1 + images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0 + grid_pred = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False) + if summary_writer is not None: + summary_writer.add_image(prefix + 'predictions', grid_pred, global_step=step) + + # Heatmaps + # ToDo + # for h in range(0,17): + # heatmap = torchvision.utils.make_grid(output[h].detach(), nrow=int(np.sqrt(output.shape[0])), + # padding=2, normalize=True, range=(0, 1)) + # summary_writer.add_image('train_heatmap_%d' % h, heatmap, global_step=step + epoch*len_dl_train) + + return grid_gt, grid_pred + + +def check_video_rotation(filename): + # thanks to + # https://stackoverflow.com/questions/53097092/frame-from-video-is-upside-down-after-extracting/55747773#55747773 + + # this returns meta-data of the video file in form of a dictionary + meta_dict = ffmpeg.probe(filename) + + # from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key + # we are looking for + rotation_code = None + try: + if int(meta_dict['streams'][0]['tags']['rotate']) == 90: + rotation_code = cv2.ROTATE_90_CLOCKWISE + elif int(meta_dict['streams'][0]['tags']['rotate']) == 180: + rotation_code = cv2.ROTATE_180 + elif int(meta_dict['streams'][0]['tags']['rotate']) == 270: + rotation_code = cv2.ROTATE_90_COUNTERCLOCKWISE + else: + raise ValueError + except KeyError: + pass + + return rotation_code diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/timerr.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/timerr.py new file mode 100644 index 0000000..c9b15fb --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/timerr.py @@ -0,0 +1,37 @@ +import time + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + + self.duration = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + self.duration = self.average_time + else: + self.duration = self.diff + return self.duration + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. \ No newline at end of file diff --git a/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/visualizer.py b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/visualizer.py new file mode 100644 index 0000000..d2d2305 --- /dev/null +++ b/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/visualizer.py @@ -0,0 +1,162 @@ +import cv2 +import numpy as np + +__all__ = ["vis"] + + +def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None): + + for i in range(len(boxes)): + box = boxes[i] + cls_id = int(cls_ids[i]) + score = scores[i] + if score < conf: + continue + x0 = int(box[0]) + y0 = int(box[1]) + x1 = int(box[2]) + y1 = int(box[3]) + + color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist() + text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100) + txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255) + font = cv2.FONT_HERSHEY_SIMPLEX + + txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] + cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) + + txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist() + cv2.rectangle( + img, + (x0, y0 + 1), + (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])), + txt_bk_color, + -1 + ) + cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1) + + return img + + +def get_color(idx): + idx = idx * 3 + color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) + + return color + + +def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None): + im = np.ascontiguousarray(np.copy(image)) + im_h, im_w = im.shape[:2] + + top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255 + + #text_scale = max(1, image.shape[1] / 1600.) + #text_thickness = 2 + #line_thickness = max(1, int(image.shape[1] / 500.)) + text_scale = 2 + text_thickness = 2 + line_thickness = 3 + + radius = max(5, int(im_w/140.)) + cv2.putText(im, 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)), + (0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2) + + for i, tlwh in enumerate(tlwhs): + x1, y1, w, h = tlwh + intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) + obj_id = int(obj_ids[i]) + id_text = '{}'.format(int(obj_id)) + if ids2 is not None: + id_text = id_text + ', {}'.format(int(ids2[i])) + color = get_color(abs(obj_id)) + cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) + cv2.putText(im, id_text, (intbox[0], intbox[1]), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), + thickness=text_thickness) + return im + + +_COLORS = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857, + 0.000, 0.447, 0.741, + 0.314, 0.717, 0.741, + 0.50, 0.5, 0 + ] +).astype(np.float32).reshape(-1, 3) \ No newline at end of file diff --git a/hmr4d/utils/pylogger.py b/hmr4d/utils/pylogger.py new file mode 100644 index 0000000..b827475 --- /dev/null +++ b/hmr4d/utils/pylogger.py @@ -0,0 +1,76 @@ +from time import time +import logging +import torch +from colorlog import ColoredFormatter + + +def sync_time(): + torch.cuda.synchronize() + return time() + + +Log = logging.getLogger() +Log.time = time +Log.sync_time = sync_time + +# Set default +Log.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) +# Use colorlog +formatstring = "[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] %(message)s" +datefmt = "%m/%d %H:%M:%S" +ch.setFormatter(ColoredFormatter(formatstring, datefmt=datefmt)) + +Log.addHandler(ch) +# Log.info("Init-Logger") + + +def timer(sync_cuda=False, mem=False, loop=1): + """ + Args: + func: function + sync_cuda: bool, whether to synchronize cuda + mem: bool, whether to log memory + """ + + def decorator(func): + def wrapper(*args, **kwargs): + if mem: + start_mem = torch.cuda.memory_allocated() / 1024**2 + if sync_cuda: + torch.cuda.synchronize() + + start = Log.time() + for _ in range(loop): + result = func(*args, **kwargs) + + if sync_cuda: + torch.cuda.synchronize() + if loop == 1: + message = f"{func.__name__} took {Log.time() - start:.3f} s." + else: + message = f"{func.__name__} took {((Log.time() - start))/loop:.3f} s. (loop={loop})" + + if mem: + end_mem = torch.cuda.memory_allocated() / 1024**2 + end_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + message += f" Start_Mem {start_mem:.1f} Max {end_max_mem:.1f} MB" + Log.info(message) + + return result + + return wrapper + + return decorator + + +def timed(fn): + """example usage: timed(lambda: model(inp))""" + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = fn() + end.record() + torch.cuda.synchronize() + return result, start.elapsed_time(end) / 1000 diff --git a/hmr4d/utils/seq_utils.py b/hmr4d/utils/seq_utils.py new file mode 100644 index 0000000..790cf3b --- /dev/null +++ b/hmr4d/utils/seq_utils.py @@ -0,0 +1,184 @@ +import torch +import numpy as np + +# def get_frame_id_list_from_mask(mask): +# """ +# Args: +# mask (F,), bool. +# Return: +# frame_id_list: List of frame_ids. +# """ +# frame_id_list = [] +# i = 0 +# while i < len(mask): +# if not mask[i]: +# i += 1 +# else: +# j = i +# while j < len(mask) and mask[j]: +# j += 1 +# frame_id_list.append(torch.arange(i, j)) +# i = j + +# return frame_id_list + + +# From GPT +def get_frame_id_list_from_mask(mask): + # batch=64, 0.13s + """ + Vectorized approach to get frame id list from a boolean mask. + + Args: + mask (F,), bool tensor: Mask array where `True` indicates a frame to be processed. + + Returns: + frame_id_list: List of torch.Tensors, each tensor containing continuous indices where mask is True. + """ + # Find the indices where the mask changes from False to True and vice versa + padded_mask = torch.cat( + [torch.tensor([False], device=mask.device), mask, torch.tensor([False], device=mask.device)] + ) + diffs = torch.diff(padded_mask.int()) + starts = (diffs == 1).nonzero(as_tuple=False).squeeze() + ends = (diffs == -1).nonzero(as_tuple=False).squeeze() + if starts.numel() == 0: + return [] + if starts.numel() == 1: + starts = starts.reshape(-1) + ends = ends.reshape(-1) + + # Create list of ranges + frame_id_list = [torch.arange(start, end) for start, end in zip(starts, ends)] + return frame_id_list + + +def get_batch_frame_id_lists_from_mask_BLC(masks): + # batch=64, 0.10s + """ + 处理三维掩码数组,为每个批次和通道提取连续True区段的索引列表。 + + 参数: + masks (B, L, C), 布尔张量:每个元素代表一个掩码,True表示需要处理的帧。 + + 返回: + batch_frame_id_lists: 对应于每个批次和每个通道的帧id列表的嵌套列表。 + """ + B, L, C = masks.size() + # 在序列长度两端添加一个False + padded_masks = torch.cat( + [ + torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), + masks, + torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), + ], + dim=1, + ) + # 计算差分来找到True区段的起始和结束点 + diffs = torch.diff(padded_masks.int(), dim=1) + starts = (diffs == 1).nonzero(as_tuple=True) + ends = (diffs == -1).nonzero(as_tuple=True) + + # 初始化返回列表 + batch_frame_id_lists = [[[] for _ in range(C)] for _ in range(B)] + for b in range(B): + for c in range(C): + batch_start = starts[0][(starts[0] == b) & (starts[2] == c)] + batch_end = ends[0][(ends[0] == b) & (ends[2] == c)] + # 确保start和end都是1维张量 + batch_frame_id_lists[b][c] = [ + torch.arange(start.item(), end.item()) for start, end in zip(batch_start, batch_end) + ] + + return batch_frame_id_lists + + +def get_frame_id_list_from_frame_id(frame_id): + mask = torch.zeros(frame_id[-1] + 1, dtype=torch.bool) + mask[frame_id] = True + frame_id_list = get_frame_id_list_from_mask(mask) + return frame_id_list + + +def rearrange_by_mask(x, mask): + """ + x (L, *) + mask (M,), M >= L + """ + M = mask.size(0) + L = x.size(0) + if M == L: + return x + assert M > L + assert mask.sum() == L + x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) + x_rearranged[mask] = x + return x_rearranged + + +def frame_id_to_mask(frame_id, max_len): + mask = torch.zeros(max_len, dtype=torch.bool) + mask[frame_id] = True + return mask + + +def mask_to_frame_id(mask): + frame_id = torch.where(mask)[0] + return frame_id + + +def linear_interpolate_frame_ids(data, frame_id_list): + data = data.clone() + for i, invalid_frame_ids in enumerate(frame_id_list): + # interplate between prev, next + # if at beginning or end, use the same value + if invalid_frame_ids[0] - 1 < 0 or invalid_frame_ids[-1] + 1 >= len(data): + if invalid_frame_ids[0] - 1 < 0: + data[invalid_frame_ids] = data[invalid_frame_ids[-1] + 1].clone() + else: + data[invalid_frame_ids] = data[invalid_frame_ids[0] - 1].clone() + else: + prev = data[invalid_frame_ids[0] - 1] + next = data[invalid_frame_ids[-1] + 1] + data[invalid_frame_ids] = ( + torch.linspace(0, 1, len(invalid_frame_ids) + 2)[1:-1][:, None] * (next - prev)[None] + prev[None] + ) + return data + + +def linear_interpolate(data, N_middle_frames): + """ + Args: + data: (2, C) + Returns: + data_interpolated: (1+N+1, C) + """ + prev = data[0] + next = data[1] + middle = torch.linspace(0, 1, N_middle_frames + 2)[1:-1][:, None] * (next - prev)[None] + prev[None] # (N, C) + data_interpolated = torch.cat([data[0][None], middle, data[1][None]], dim=0) # (1+N+1, C) + return data_interpolated + + +def find_top_k_span(mask, k=3): + """ + Args: + mask: (L,) + Return: + topk_span: List of tuple, usage: [start, end) + """ + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + if mask.sum() == 0: + return [] + mask = mask.clone().float() + mask = torch.cat([mask.new([0]), mask, mask.new([0])]) + diff = mask[1:] - mask[:-1] + start = torch.where(diff == 1)[0] + end = torch.where(diff == -1)[0] + assert len(start) == len(end) + span_lengths = end - start + span_lengths, idx = span_lengths.sort(descending=True) + start = start[idx] + end = end[idx] + return list(zip(start.tolist(), end.tolist()))[:k] diff --git a/hmr4d/utils/smplx_utils.py b/hmr4d/utils/smplx_utils.py new file mode 100644 index 0000000..9c9bc31 --- /dev/null +++ b/hmr4d/utils/smplx_utils.py @@ -0,0 +1,442 @@ +import torch +import torch.nn.functional as F +import numpy as np +import smplx +import pickle +from smplx import SMPL, SMPLX, SMPLXLayer +from hmr4d.utils.body_model import BodyModelSMPLH, BodyModelSMPLX +from hmr4d.utils.body_model.smplx_lite import SmplxLiteCoco17, SmplxLiteV437Coco17, SmplxLiteSmplN24 +from hmr4d import PROJ_ROOT + +# fmt: off +SMPLH_PARENTS = torch.tensor([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, + 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, + 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50]) +# fmt: on + + +def make_smplx(type="neu_fullpose", **kwargs): + if type == "neu_fullpose": + model = smplx.create( + model_path="inputs/models/smplx/SMPLX_NEUTRAL.npz", use_pca=False, flat_hand_mean=True, **kwargs + ) + elif type == "supermotion": + # SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used + bm_kwargs = { + "model_type": "smplx", + "gender": "neutral", + "num_pca_comps": 12, + "flat_hand_mean": False, + } + bm_kwargs.update(kwargs) + model = BodyModelSMPLX(model_path=PROJ_ROOT / "inputs/checkpoints/body_models", **bm_kwargs) + elif type == "supermotion_EVAL3DPW": + # SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used + bm_kwargs = { + "model_type": "smplx", + "gender": "neutral", + "num_pca_comps": 12, + "flat_hand_mean": True, + } + bm_kwargs.update(kwargs) + model = BodyModelSMPLX(model_path="inputs/checkpoints/body_models", **bm_kwargs) + elif type == "supermotion_coco17": + # Fast but only predicts 17 joints + model = SmplxLiteCoco17() + elif type == "supermotion_v437coco17": + # Predicts 437 verts and 17 joints + model = SmplxLiteV437Coco17() + elif type == "supermotion_smpl24": + model = SmplxLiteSmplN24() + elif type == "rich-smplx": + # https://github.com/paulchhuang/rich_toolkit/blob/main/smplx2images.py + bm_kwargs = { + "model_type": "smplx", + "gender": kwargs.get("gender", "male"), + "num_pca_comps": 12, + "flat_hand_mean": False, + # create_expression=True, create_jaw_pose=Ture + } + # A /smplx folder should exist under the model_path + model = BodyModelSMPLX(model_path="inputs/checkpoints/body_models", **bm_kwargs) + elif type == "rich-smplh": + bm_kwargs = { + "model_type": "smplh", + "gender": kwargs.get("gender", "male"), + "use_pca": False, + "flat_hand_mean": True, + } + model = BodyModelSMPLH(model_path="inputs/checkpoints/body_models", **bm_kwargs) + + elif type in ["smplx-circle", "smplx-groundlink"]: + # don't use hand + bm_kwargs = { + "model_path": "inputs/checkpoints/body_models", + "model_type": "smplx", + "gender": kwargs.get("gender"), + "num_betas": 16, + "num_expression": 0, + } + model = BodyModelSMPLX(**bm_kwargs) + + elif type == "smplx-motionx": + layer_args = { + "create_global_orient": False, + "create_body_pose": False, + "create_left_hand_pose": False, + "create_right_hand_pose": False, + "create_jaw_pose": False, + "create_leye_pose": False, + "create_reye_pose": False, + "create_betas": False, + "create_expression": False, + "create_transl": False, + } + + bm_kwargs = { + "model_type": "smplx", + "model_path": "inputs/checkpoints/body_models", + "gender": "neutral", + "use_pca": False, + "use_face_contour": True, + **layer_args, + } + model = smplx.create(**bm_kwargs) + + elif type == "smplx-samp": + # don't use hand + bm_kwargs = { + "model_path": "inputs/checkpoints/body_models", + "model_type": "smplx", + "gender": kwargs.get("gender"), + "num_betas": 10, + "num_expression": 0, + } + model = BodyModelSMPLX(**bm_kwargs) + + elif type == "smplx-bedlam": + # don't use hand + bm_kwargs = { + "model_path": "inputs/checkpoints/body_models", + "model_type": "smplx", + "gender": kwargs.get("gender"), + "num_betas": 11, + "num_expression": 0, + } + model = BodyModelSMPLX(**bm_kwargs) + + elif type in ["smplx-layer", "smplx-fit3d"]: + # Use layer + if type == "smplx-fit3d": + assert ( + kwargs.get("gender") == "neutral" + ), "smplx-fit3d use neutral model: https://github.com/sminchisescu-research/imar_vision_datasets_tools/blob/e8c8f83ffac23cc36adf8ec8d0fd1c55679484ef/util/smplx_util.py#L15C34-L15C34" + + bm_kwargs = { + "model_path": "inputs/checkpoints/body_models/smplx", + "gender": kwargs.get("gender"), + "num_betas": 10, + "num_expression": 10, + } + model = SMPLXLayer(**bm_kwargs) + + elif type == "smpl": + bm_kwargs = { + "model_path": PROJ_ROOT / "inputs/checkpoints/body_models", + "model_type": "smpl", + "gender": "neutral", + "num_betas": 10, + "create_body_pose": False, + "create_betas": False, + "create_global_orient": False, + "create_transl": False, + } + bm_kwargs.update(kwargs) + # model = SMPL(**bm_kwargs) + model = BodyModelSMPLH(**bm_kwargs) + elif type == "smplh": + bm_kwargs = { + "model_type": "smplh", + "gender": kwargs.get("gender", "male"), + "use_pca": False, + "flat_hand_mean": False, + } + model = BodyModelSMPLH(model_path="inputs/checkpoints/body_models", **bm_kwargs) + + else: + raise NotImplementedError + + return model + + +def load_parents(npz_path="models/smplx/SMPLX_NEUTRAL.npz"): + smplx_struct = np.load("models/smplx/SMPLX_NEUTRAL.npz", allow_pickle=True) + parents = smplx_struct["kintree_table"][0].astype(np.long) + parents[0] = -1 + return parents + + +def load_smpl_faces(npz_path="models/smplh/SMPLH_FEMALE.pkl"): + with open(npz_path, "rb") as f: + smpl_model = pickle.load(f, encoding="latin1") + faces = np.array(smpl_model["f"].astype(np.int64)) + return faces + + +def decompose_fullpose(fullpose, model_type="smplx"): + assert model_type == "smplx" + + fullpose_dict = { + "global_orient": fullpose[..., :3], + "body_pose": fullpose[..., 3:66], + "jaw_pose": fullpose[..., 66:69], + "leye_pose": fullpose[..., 69:72], + "reye_pose": fullpose[..., 72:75], + "left_hand_pose": fullpose[..., 75:120], + "right_hand_pose": fullpose[..., 120:165], + } + + return fullpose_dict + + +def compose_fullpose(fullpose_dict, model_type="smplx"): + assert model_type == "smplx" + fullpose = torch.cat( + [ + fullpose_dict[k] + for k in [ + "global_orient", + "body_pose", + "jaw_pose", + "leye_pose", + "reye_pose", + "left_hand_pose", + "right_hand_pose", + ] + ], + dim=-1, + ) + return fullpose + + +def compute_R_from_kinetree(rot_mats, parents): + """operation of lbs/batch_rigid_transform, focus on 3x3 R only + Parameters + ---------- + rot_mats: torch.tensor BxNx3x3 + Tensor of rotation matrices + parents : torch.tensor BxN + The kinematic tree of each object + + Returns + ------- + R : torch.tensor BxNx3x3 + Tensor of rotation matrices + """ + rot_mat_chain = [rot_mats[:, 0]] + for i in range(1, parents.shape[0]): + curr_res = torch.matmul(rot_mat_chain[parents[i]], rot_mats[:, i]) + rot_mat_chain.append(curr_res) + + R = torch.stack(rot_mat_chain, dim=1) + return R + + +def compute_relR_from_kinetree(R, parents): + """Inverse operation of lbs/batch_rigid_transform, focus on 3x3 R only + Parameters + ---------- + R : torch.tensor BxNx4x4 or BxNx3x3 + Tensor of rotation matrices + parents : torch.tensor BxN + The kinematic tree of each object + + Returns + ------- + rot_mats: torch.tensor BxNx3x3 + Tensor of rotation matrices + """ + R = R[:, :, :3, :3] + + Rp = R[:, parents] # Rp[:, 0] is invalid + rot_mats = Rp.transpose(2, 3) @ R + rot_mats[:, 0] = R[:, 0] + + return rot_mats + + +def quat_mul(x, y): + """ + Performs quaternion multiplication on arrays of quaternions + + :param x: tensor of quaternions of shape (..., Nb of joints, 4) + :param y: tensor of quaternions of shape (..., Nb of joints, 4) + :return: The resulting quaternions + """ + x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4] + y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4] + + # res = np.concatenate( + # [ + # y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, + # y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, + # y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, + # y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0, + # ], + # axis=-1, + # ) + res = torch.cat( + [ + y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, + y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, + y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, + y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0, + ], + axis=-1, + ) + + return res + + +def quat_inv(q): + """ + Inverts a tensor of quaternions + + :param q: quaternion tensor + :return: tensor of inverted quaternions + """ + # res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q + res = torch.tensor([1, -1, -1, -1], device=q.device).float() * q + return res + + +def quat_mul_vec(q, x): + """ + Performs multiplication of an array of 3D vectors by an array of quaternions (rotation). + + :param q: tensor of quaternions of shape (..., Nb of joints, 4) + :param x: tensor of vectors of shape (..., Nb of joints, 3) + :return: the resulting array of rotated vectors + """ + # t = 2.0 * np.cross(q[..., 1:], x) + t = 2.0 * torch.cross(q[..., 1:], x) + # res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t) + res = x + q[..., 0][..., None] * t + torch.cross(q[..., 1:], t) + + return res + + +def inverse_kinematics_motion( + global_pos, + global_rot, + parents=SMPLH_PARENTS, +): + """ + Args: + global_pos : (B, T, J-1, 3) + global_rot (q) : (B, T, J-1, 4) + parents : SMPLH_PARENTS + Returns: + local_pos : (B, T, J-1, 3) + local_rot (q) : (B, T, J-1, 4) + """ + J = 22 + local_pos = quat_mul_vec( + quat_inv(global_rot[..., parents[1:J], :]), + global_pos - global_pos[..., parents[1:J], :], + ) + local_rot = (quat_mul(quat_inv(global_rot[..., parents[1:J], :]), global_rot),) + return local_pos, local_rot + + +def transform_mat(R, t): + """Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + """ + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + +def normalize_joints(joints): + """ + Args: + joints: (B, *, J, 3) + """ + LR_hips_xy = joints[..., 2, [0, 1]] - joints[..., 1, [0, 1]] + LR_shoulders_xy = joints[..., 17, [0, 1]] - joints[..., 16, [0, 1]] + LR_xy = (LR_hips_xy + LR_shoulders_xy) / 2 # (B, *, J, 2) + + x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, *, 3) + z_dir = torch.zeros_like(x_dir) # (B, *, 3) + z_dir[..., 2] = 1 + y_dir = torch.cross(z_dir, x_dir, dim=-1) + + joints_normalized = (joints - joints[..., [0], :]) @ torch.stack([x_dir, y_dir, z_dir], dim=-1) + return joints_normalized + + +@torch.no_grad() +def compute_Rt_af2az(joints, inverse=False): + """Assume z coord is upward + Args: + joints: (B, J, 3), in the start-frame + Returns: + R_af2az: (B, 3, 3) + t_af2az: (B, 3) + """ + t_af2az = joints[:, 0, :].detach().clone() + t_af2az[:, 2] = 0 # do not modify z + + LR_xy = joints[:, 2, [0, 1]] - joints[:, 1, [0, 1]] # (B, 2) + I_mask = LR_xy.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction + x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, 3) + z_dir = torch.zeros_like(x_dir) + z_dir[..., 2] = 1 + y_dir = torch.cross(z_dir, x_dir, dim=-1) + R_af2az = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3) + R_af2az[I_mask] = torch.eye(3).to(R_af2az) + + if inverse: + R_az2af = R_af2az.transpose(1, 2) + t_az2af = -(R_az2af @ t_af2az.unsqueeze(2)).squeeze(2) + return R_az2af, t_az2af + else: + return R_af2az, t_af2az + + +def finite_difference_forward(x, dim_t=1, dup_last=True): + if dim_t == 1: + v = x[:, 1:] - x[:, :-1] + if dup_last: + v = torch.cat([v, v[:, [-1]]], dim=1) + else: + raise NotImplementedError + + return v + + +def compute_joints_zero(betas, gender): + """ + Args: + betas: (16) + gender: 'male' or 'female' + Returns: + joints_zero: (22, 3) + """ + body_model = { + "male": make_smplx(type="humor", gender="male"), + "female": make_smplx(type="humor", gender="female"), + } + + smpl_params = { + "root_orient": torch.zeros((1, 3)), + "pose_body": torch.zeros((1, 63)), + "betas": betas[None], + "trans": torch.zeros(1, 3), + } + joints_zero = body_model[gender](**smpl_params).Jtr[0, :22] + return joints_zero diff --git a/hmr4d/utils/video_io_utils.py b/hmr4d/utils/video_io_utils.py new file mode 100644 index 0000000..9ccfeaf --- /dev/null +++ b/hmr4d/utils/video_io_utils.py @@ -0,0 +1,113 @@ +import imageio.v3 as iio +import numpy as np +import torch +from pathlib import Path +import shutil +import ffmpeg +from tqdm import tqdm +import cv2 + + +def get_video_lwh(video_path): + L, H, W, _ = iio.improps(video_path, plugin="pyav").shape + return L, W, H + + +def read_video_np(video_path, start_frame=0, end_frame=-1, scale=1.0): + """ + Args: + video_path: str + Returns: + frames: np.array, (N, H, W, 3) RGB, uint8 + """ + # If video path not exists, an error will be raised by ffmpegs + filter_args = [] + should_check_length = False + + # 1. Trim + if not (start_frame == 0 and end_frame == -1): + if end_frame == -1: + filter_args.append(("trim", f"start_frame={start_frame}")) + else: + should_check_length = True + filter_args.append(("trim", f"start_frame={start_frame}:end_frame={end_frame}")) + + # 2. Scale + if scale != 1.0: + filter_args.append(("scale", f"iw*{scale}:ih*{scale}")) + + # Excute then check + frames = iio.imread(video_path, plugin="pyav", filter_sequence=filter_args) + if should_check_length: + assert len(frames) == end_frame - start_frame + + return frames + + +def get_video_reader(video_path): + return iio.imiter(video_path, plugin="pyav") + + +def read_images_np(image_paths, verbose=False): + """ + Args: + image_paths: list of str + Returns: + images: np.array, (N, H, W, 3) RGB, uint8 + """ + if verbose: + images = [cv2.imread(str(img_path))[..., ::-1] for img_path in tqdm(image_paths)] + else: + images = [cv2.imread(str(img_path))[..., ::-1] for img_path in image_paths] + images = np.stack(images, axis=0) + return images + + +def save_video(images, video_path, fps=30, crf=17): + """ + Args: + images: (N, H, W, 3) RGB, uint8 + crf: 17 is visually lossless, 23 is default, +6 results in half the bitrate + 0 is lossless, https://trac.ffmpeg.org/wiki/Encode/H.264#crf + """ + if isinstance(images, torch.Tensor): + images = images.cpu().numpy().astype(np.uint8) + elif isinstance(images, list): + images = np.array(images).astype(np.uint8) + + with iio.imopen(video_path, "w", plugin="pyav") as writer: + writer.init_video_stream("libx264", fps=fps) + writer._video_stream.options = {"crf": str(crf)} + writer.write(images) + + +def get_writer(video_path, fps=30, crf=17): + """remember to .close()""" + writer = iio.imopen(video_path, "w", plugin="pyav") + writer.init_video_stream("libx264", fps=fps) + writer._video_stream.options = {"crf": str(crf)} + return writer + + +def copy_file(video_path, out_video_path, overwrite=True): + if not overwrite and Path(out_video_path).exists(): + return + shutil.copy(video_path, out_video_path) + + +def merge_videos_horizontal(in_video_paths: list, out_video_path: str): + if len(in_video_paths) < 2: + raise ValueError("At least two video paths are required for merging.") + inputs = [ffmpeg.input(path) for path in in_video_paths] + merged_video = ffmpeg.filter(inputs, "hstack", inputs=len(inputs)) + output = ffmpeg.output(merged_video, out_video_path) + ffmpeg.run(output, overwrite_output=True, quiet=True) + + +def merge_videos_vertical(in_video_paths: list, out_video_path: str): + if len(in_video_paths) < 2: + raise ValueError("At least two video paths are required for merging.") + inputs = [ffmpeg.input(path) for path in in_video_paths] + merged_video = ffmpeg.filter(inputs, "vstack", inputs=len(inputs)) + output = ffmpeg.output(merged_video, out_video_path) + ffmpeg.run(output, overwrite_output=True, quiet=True) diff --git a/hmr4d/utils/vis/README.md b/hmr4d/utils/vis/README.md new file mode 100644 index 0000000..9582abc --- /dev/null +++ b/hmr4d/utils/vis/README.md @@ -0,0 +1,20 @@ +## Pytorch3D Renderer + +Example: +```python +from hmr4d.utils.vis.renderer import Renderer +import imageio + +fps = 30 +focal_length = data["cam_int"][0][0, 0] +width, height = img_hw +faces = smplh[data["gender"]].bm.faces +renderer = Renderer(width, height, focal_length, "cuda", faces) +writer = imageio.get_writer("tmp_debug.mp4", fps=fps, mode="I", format="FFMPEG", macro_block_size=1) + +for i in tqdm(range(length)): + img = np.zeros((height, width, 3), dtype=np.uint8) + img = renderer.render_mesh(smplh_out.vertices[i].cuda(), img) + writer.append_data(img) +writer.close() +``` \ No newline at end of file diff --git a/hmr4d/utils/vis/cv2_utils.py b/hmr4d/utils/vis/cv2_utils.py new file mode 100644 index 0000000..77ca33a --- /dev/null +++ b/hmr4d/utils/vis/cv2_utils.py @@ -0,0 +1,144 @@ +import torch +import cv2 +import numpy as np +from hmr4d.utils.wis3d_utils import get_colors_by_conf + + +def to_numpy(x): + if isinstance(x, np.ndarray): + return x.copy() + elif isinstance(x, list): + return np.array(x) + return x.clone().cpu().numpy() + + +def draw_bbx_xys_on_image(bbx_xys, image, conf=True): + assert isinstance(bbx_xys, np.ndarray) + assert isinstance(image, np.ndarray) + image = image.copy() + lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int) + rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int) + color = (255, 178, 102) if conf == True else (128, 128, 128) # orange or gray + image = cv2.rectangle(image, lu_point, rd_point, color, 2) + return image + + +def draw_bbx_xys_on_image_batch(bbx_xys_batch, image_batch, conf=None): + """conf: if provided, list of bool""" + use_conf = conf is not None + bbx_xys_batch = to_numpy(bbx_xys_batch) + assert len(bbx_xys_batch) == len(image_batch) + image_batch_out = [] + for i in range(len(bbx_xys_batch)): + if use_conf: + image_batch_out.append(draw_bbx_xys_on_image(bbx_xys_batch[i], image_batch[i], conf[i])) + else: + image_batch_out.append(draw_bbx_xys_on_image(bbx_xys_batch[i], image_batch[i])) + return image_batch_out + + +def draw_bbx_xyxy_on_image(bbx_xys, image, conf=True): + bbx_xys = to_numpy(bbx_xys) + image = to_numpy(image) + color = (255, 178, 102) if conf == True else (128, 128, 128) # orange or gray + image = cv2.rectangle(image, (int(bbx_xys[0]), int(bbx_xys[1])), (int(bbx_xys[2]), int(bbx_xys[3])), color, 2) + return image + + +def draw_bbx_xyxy_on_image_batch(bbx_xyxy_batch, image_batch, mask=None, conf=None): + """ + Args: + conf: if provided, list of bool, mutually exclusive with mask + mask: whether to draw, historically used + """ + if mask is not None: + assert conf is None + if conf is not None: + assert mask is None + use_conf = conf is not None + bbx_xyxy_batch = to_numpy(bbx_xyxy_batch) + image_batch = to_numpy(image_batch) + assert len(bbx_xyxy_batch) == len(image_batch) + image_batch_out = [] + for i in range(len(bbx_xyxy_batch)): + if use_conf: + image_batch_out.append(draw_bbx_xyxy_on_image(bbx_xyxy_batch[i], image_batch[i], conf[i])) + else: + if mask is None or mask[i]: + image_batch_out.append(draw_bbx_xyxy_on_image(bbx_xyxy_batch[i], image_batch[i])) + else: + image_batch_out.append(image_batch[i]) + return image_batch_out + + +def draw_kpts(frame, keypoints, color=(0, 255, 0), thickness=2): + frame_ = frame.copy() + for x, y in keypoints: + cv2.circle(frame_, (int(x), int(y)), thickness, color, -1) + return frame_ + + +def draw_kpts_with_conf(frame, kp2d, conf, thickness=2): + """ + Args: + kp2d: (J, 2), + conf: (J,) + """ + frame_ = frame.copy() + conf = conf.reshape(-1) + colors = get_colors_by_conf(conf) # (J, 3) + colors = colors[:, [2, 1, 0]].int().numpy().tolist() + for j in range(kp2d.shape[0]): + x, y = kp2d[j, :2] + c = colors[j] + cv2.circle(frame_, (int(x), int(y)), thickness, c, -1) + return frame_ + + +def draw_kpts_with_conf_batch(frames, kp2d_batch, conf_batch, thickness=2): + """ + Args: + kp2d_batch: (B, J, 2), + conf_batch: (B, J) + """ + assert len(frames) == len(kp2d_batch) + assert len(frames) == len(conf_batch) + frames_ = [] + for i in range(len(frames)): + frames_.append(draw_kpts_with_conf(frames[i], kp2d_batch[i], conf_batch[i], thickness)) + return frames_ + + +def draw_coco17_skeleton(img, keypoints, conf_thr=0): + use_conf_thr = True if keypoints.shape[1] == 3 else False + img = img.copy() + # fmt:off + coco_skel = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] + # fmt:on + for bone in coco_skel: + if use_conf_thr: + kp1 = keypoints[bone[0]][:2].astype(int) + kp2 = keypoints[bone[1]][:2].astype(int) + kp1_c = keypoints[bone[0]][2] + kp2_c = keypoints[bone[1]][2] + if kp1_c > conf_thr and kp2_c > conf_thr: + img = cv2.line(img, (kp1[0], kp1[1]), (kp2[0], kp2[1]), (0, 255, 0), 4) + if kp1_c > conf_thr: + img = cv2.circle(img, (kp1[0], kp1[1]), 6, (0, 255, 0), -1) + if kp2_c > conf_thr: + img = cv2.circle(img, (kp2[0], kp2[1]), 6, (0, 255, 0), -1) + + else: + kp1 = keypoints[bone[0]][:2].astype(int) + kp2 = keypoints[bone[1]][:2].astype(int) + img = cv2.line(img, (kp1[0], kp1[1]), (kp2[0], kp2[1]), (0, 255, 0), 4) + return img + + +def draw_coco17_skeleton_batch(imgs, keypoints_batch, conf_thr=0): + assert len(imgs) == len(keypoints_batch) + keypoints_batch = to_numpy(keypoints_batch) + imgs_out = [] + for i in range(len(imgs)): + imgs_out.append(draw_coco17_skeleton(imgs[i], keypoints_batch[i], conf_thr)) + return imgs_out diff --git a/hmr4d/utils/vis/renderer.py b/hmr4d/utils/vis/renderer.py new file mode 100644 index 0000000..f4d6232 --- /dev/null +++ b/hmr4d/utils/vis/renderer.py @@ -0,0 +1,356 @@ +import cv2 +import torch +import numpy as np + +from pytorch3d.renderer import ( + PerspectiveCameras, + TexturesVertex, + PointLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, +) +from pytorch3d.structures import Meshes +from pytorch3d.structures.meshes import join_meshes_as_scene +from pytorch3d.renderer.cameras import look_at_rotation +from pytorch3d.transforms import axis_angle_to_matrix + +from .renderer_tools import get_colors, checkerboard_geometry + + +colors_str_map = { + "gray": [0.8, 0.8, 0.8], + "green": [39, 194, 128], +} + + +def overlay_image_onto_background(image, mask, bbox, background): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + + out_image = background.copy() + bbox = bbox[0].int().cpu().numpy().copy() + roi_image = out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] + + roi_image[mask] = image[mask] + out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] = roi_image + + return out_image + + +def update_intrinsics_from_bbox(K_org, bbox): + device, dtype = K_org.device, K_org.dtype + + K = torch.zeros((K_org.shape[0], 4, 4)).to(device=device, dtype=dtype) + K[:, :3, :3] = K_org.clone() + K[:, 2, 2] = 0 + K[:, 2, -1] = 1 + K[:, -1, 2] = 1 + + image_sizes = [] + for idx, bbox in enumerate(bbox): + left, upper, right, lower = bbox + cx, cy = K[idx, 0, 2], K[idx, 1, 2] + + new_cx = cx - left + new_cy = cy - upper + new_height = max(lower - upper, 1) + new_width = max(right - left, 1) + new_cx = new_width - new_cx + new_cy = new_height - new_cy + + K[idx, 0, 2] = new_cx + K[idx, 1, 2] = new_cy + image_sizes.append((int(new_height), int(new_width))) + + return K, image_sizes + + +def perspective_projection(x3d, K, R=None, T=None): + if R != None: + x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2) + if T != None: + x3d = x3d + T.transpose(1, 2) + + x2d = torch.div(x3d, x3d[..., 2:]) + x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] + return x2d + + +def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2): + left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w) + right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w) + top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h) + bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h) + + cx = (left + right) / 2 + cy = (top + bottom) / 2 + width = right - left + height = bottom - top + + new_left = torch.clamp(cx - width / 2 * scaleFactor, min=0, max=img_w - 1) + new_right = torch.clamp(cx + width / 2 * scaleFactor, min=1, max=img_w) + new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h - 1) + new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h) + + bbox = torch.stack((new_left.detach(), new_top.detach(), new_right.detach(), new_bottom.detach())).int().float().T + + return bbox + + +class Renderer: + def __init__(self, width, height, focal_length=None, device="cuda", faces=None, K=None, bin_size=None): + """set bin_size to 0 for no binning""" + self.width = width + self.height = height + self.bin_size = bin_size + assert (focal_length is not None) ^ (K is not None), "focal_length and K are mutually exclusive" + + self.device = device + if faces is not None: + if isinstance(faces, np.ndarray): + faces = torch.from_numpy((faces).astype("int")) + self.faces = faces.unsqueeze(0).to(self.device) + + self.initialize_camera_params(focal_length, K) + self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]]) + self.create_renderer() + + def create_renderer(self): + self.renderer = MeshRenderer( + rasterizer=MeshRasterizer( + raster_settings=RasterizationSettings( + image_size=self.image_sizes[0], blur_radius=1e-5, bin_size=self.bin_size + ), + ), + shader=SoftPhongShader( + device=self.device, + lights=self.lights, + ), + ) + + def create_camera(self, R=None, T=None): + if R is not None: + self.R = R.clone().view(1, 3, 3).to(self.device) + if T is not None: + self.T = T.clone().view(1, 3).to(self.device) + + return PerspectiveCameras( + device=self.device, R=self.R.mT, T=self.T, K=self.K_full, image_size=self.image_sizes, in_ndc=False + ) + + def initialize_camera_params(self, focal_length, K): + # Extrinsics + self.R = torch.diag(torch.tensor([1, 1, 1])).float().to(self.device).unsqueeze(0) + + self.T = torch.tensor([0, 0, 0]).unsqueeze(0).float().to(self.device) + + # Intrinsics + if K is not None: + self.K = K.float().reshape(1, 3, 3).to(self.device) + else: + assert focal_length is not None, "focal_length or K should be provided" + self.K = ( + torch.tensor([[focal_length, 0, self.width / 2], [0, focal_length, self.height / 2], [0, 0, 1]]) + .float() + .reshape(1, 3, 3) + .to(self.device) + ) + self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float() + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) + self.cameras = self.create_camera() + + def set_intrinsic(self, K): + self.K = K.reshape(1, 3, 3) + + def set_ground(self, length, center_x, center_z): + device = self.device + length, center_x, center_z = map(float, (length, center_x, center_z)) + v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y")) + v, f, vc = v.to(device), f.to(device), vc.to(device) + self.ground_geometry = [v, f, vc] + + def update_bbox(self, x3d, scale=2.0, mask=None): + """Update bbox of cameras from the given 3d points + + x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3) + """ + + if x3d.size(-1) != 3: + x2d = x3d.unsqueeze(0) + else: + x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1)) + + if mask is not None: + x2d = x2d[:, ~mask] + + bbox = compute_bbox_from_points(x2d, self.width, self.height, scale) + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def reset_bbox( + self, + ): + bbox = torch.zeros((1, 4)).float().to(self.device) + bbox[0, 2] = self.width + bbox[0, 3] = self.height + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def render_mesh(self, vertices, background=None, colors=[0.8, 0.8, 0.8], VI=50): + self.update_bbox(vertices[::VI], scale=1.2) + vertices = vertices.unsqueeze(0) + + if isinstance(colors, torch.Tensor): + # per-vertex color + verts_features = colors.to(device=vertices.device, dtype=vertices.dtype) + colors = [0.8, 0.8, 0.8] + else: + if colors[0] > 1: + colors = [c / 255.0 for c in colors] + verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype) + verts_features = verts_features.repeat(1, vertices.shape[1], 1) + textures = TexturesVertex(verts_features=verts_features) + + mesh = Meshes( + verts=vertices, + faces=self.faces, + textures=textures, + ) + + materials = Materials(device=self.device, specular_color=(colors,), shininess=0) + + results = torch.flip(self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), [1, 2]) + image = results[0, ..., :3] * 255 + mask = results[0, ..., -1] > 1e-3 + + if background is None: + background = np.ones((self.height, self.width, 3)).astype(np.uint8) * 255 + + image = overlay_image_onto_background(image, mask, self.bboxes, background.copy()) + self.reset_bbox() + return image + + def render_with_ground(self, verts, colors, cameras, lights, faces=None): + """ + :param verts (N, V, 3), potential multiple people + :param colors (N, 3) or (N, V, 3) + :param faces (N, F, 3), optional, otherwise self.faces is used will be used + """ + # Sanity check of input verts, colors and faces: (B, V, 3), (B, F, 3), (B, V, 3) + N, V, _ = verts.shape + if faces is None: + faces = self.faces.clone().expand(N, -1, -1) + else: + assert len(faces.shape) == 3, "faces should have shape of (N, F, 3)" + + assert len(colors.shape) in [2, 3] + if len(colors.shape) == 2: + assert len(colors) == N, "colors of shape 2 should be (N, 3)" + colors = colors[:, None] + colors = colors.expand(N, V, -1)[..., :3] + + # (V, 3), (F, 3), (V, 3) + gv, gf, gc = self.ground_geometry + verts = list(torch.unbind(verts, dim=0)) + [gv] + faces = list(torch.unbind(faces, dim=0)) + [gf] + colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]] + mesh = create_meshes(verts, faces, colors) + + materials = Materials(device=self.device, shininess=0) + + results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials) + image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + + return image + + +def create_meshes(verts, faces, colors): + """ + :param verts (B, V, 3) + :param faces (B, F, 3) + :param colors (B, V, 3) + """ + textures = TexturesVertex(verts_features=colors) + meshes = Meshes(verts=verts, faces=faces, textures=textures) + return join_meshes_as_scene(meshes) + + +def get_global_cameras(verts, device="cuda", distance=5, position=(-5.0, 5.0, 0.0)): + """This always put object at the center of view""" + positions = torch.tensor([position]).repeat(len(verts), 1) + targets = verts.mean(1) + + directions = targets - positions + directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance + positions = targets - directions + + rotation = look_at_rotation(positions, targets).mT + translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1) + + lights = PointLights(device=device, location=[position]) + return rotation, translation, lights + + +def get_global_cameras_static( + verts, beta=4.0, cam_height_degree=30, target_center_height=1.0, use_long_axis=False, vec_rot=45, device="cuda" +): + L, V, _ = verts.shape + + # Compute target trajectory, denote as center + scale + targets = verts.mean(1) # (L, 3) + targets[:, 1] = 0 # project to xz-plane + target_center = targets.mean(0) # (3,) + target_scale, target_idx = torch.norm(targets - target_center, dim=-1).max(0) + + # a 45 degree vec from longest axis + if use_long_axis: + long_vec = targets[target_idx] - target_center # (x, 0, z) + long_vec = long_vec / torch.norm(long_vec) + R = axis_angle_to_matrix(torch.tensor([0, np.pi / 4, 0])).to(long_vec) + vec = R @ long_vec + else: + vec_rad = vec_rot / 180 * np.pi + vec = torch.tensor([np.sin(vec_rad), 0, np.cos(vec_rad)]).float() + vec = vec / torch.norm(vec) + + # Compute camera position (center + scale * vec * beta) + y=4 + target_scale = max(target_scale, 1.0) * beta + position = target_center + vec * target_scale + position[1] = target_scale * np.tan(np.pi * cam_height_degree / 180) + target_center_height + + # Compute camera rotation and translation + positions = position.unsqueeze(0).repeat(L, 1) + target_centers = target_center.unsqueeze(0).repeat(L, 1) + target_centers[:, 1] = target_center_height + rotation = look_at_rotation(positions, target_centers).mT + translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1) + + lights = PointLights(device=device, location=[position.tolist()]) + return rotation, translation, lights + + +def get_ground_params_from_points(root_points, vert_points): + """xz-plane is the ground plane + Args: + root_points: (L, 3), to decide center + vert_points: (L, V, 3), to decide scale + """ + root_max = root_points.max(0)[0] # (3,) + root_min = root_points.min(0)[0] # (3,) + cx, _, cz = (root_max + root_min) / 2.0 + + vert_max = vert_points.reshape(-1, 3).max(0)[0] # (L, 3) + vert_min = vert_points.reshape(-1, 3).min(0)[0] # (L, 3) + scale = (vert_max - vert_min)[[0, 2]].max() + return float(scale), float(cx), float(cz) diff --git a/hmr4d/utils/vis/renderer_tools.py b/hmr4d/utils/vis/renderer_tools.py new file mode 100644 index 0000000..68107fc --- /dev/null +++ b/hmr4d/utils/vis/renderer_tools.py @@ -0,0 +1,804 @@ +import os +import cv2 +import numpy as np +import torch +from PIL import Image + + +def read_image(path, scale=1): + im = Image.open(path) + if scale == 1: + return np.array(im) + W, H = im.size + w, h = int(scale * W), int(scale * H) + return np.array(im.resize((w, h), Image.ANTIALIAS)) + + +def transform_torch3d(T_c2w): + """ + :param T_c2w (*, 4, 4) + returns (*, 3, 3), (*, 3) + """ + R1 = torch.tensor( + [ + [-1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 1.0], + ], + device=T_c2w.device, + ) + R2 = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, -1.0], + ], + device=T_c2w.device, + ) + cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3] + cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1) + cam_t = torch.einsum("ij,...j->...i", R2, cam_t) + return cam_R, cam_t + + +def transform_pyrender(T_c2w): + """ + :param T_c2w (*, 4, 4) + """ + T_vis = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=T_c2w.device, + ) + return torch.einsum("...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis) + + +def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None): + """ + :param verts (B, T, V, 3) + :param faces (F, 3) + :param vis_mask (optional) (B, T) visibility of each person + :param track_ids (optional) (B,) + returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3) + where B is different depending on the visibility of the people + """ + B, T = verts.shape[:2] + device = verts.device + + # (B, 3) + colors = track_to_colors(track_ids) if track_ids is not None else torch.ones(B, 3, device) * 0.5 + + # list T (B, V, 3), T (B, 3), T (F, 3) + return filter_visible_meshes(verts, colors, faces, vis_mask) + + +def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False): + """ + :param verts (B, T, V, 3) + :param colors (B, 3) + :param faces (F, 3) + :param vis_mask (optional tensor, default None) (B, T) ternary mask + -1 if not in frame + 0 if temporarily occluded + 1 if visible + :param vis_opacity (optional bool, default False) + if True, make occluded people alpha=0.5, otherwise alpha=1 + returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3) + """ + # import ipdb; ipdb.set_trace() + B, T = verts.shape[:2] + faces = [faces for t in range(T)] + if vis_mask is None: + verts = [verts[:, t] for t in range(T)] + colors = [colors for t in range(T)] + return verts, colors, faces + + # render occluded and visible, but not removed + vis_mask = vis_mask >= 0 + if vis_opacity: + alpha = 0.5 * (vis_mask[..., None] + 1) + else: + alpha = (vis_mask[..., None] >= 0).float() + vert_list = [verts[vis_mask[:, t], t] for t in range(T)] + colors = [torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1) for t in range(T)] + bounds = get_bboxes(verts, vis_mask) + return vert_list, colors, faces, bounds + + +def get_bboxes(verts, vis_mask): + """ + return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory + :param verts (B, T, V, 3) + :param vis_mask (B, T) + """ + B, T, *_ = verts.shape + bb_min, bb_max, mean = [], [], [] + for b in range(B): + v = verts[b, vis_mask[b, :T]] # (Tb, V, 3) + bb_min.append(v.amin(dim=(0, 1))) + bb_max.append(v.amax(dim=(0, 1))) + mean.append(v.mean(dim=(0, 1))) + bb_min = torch.stack(bb_min, dim=0) + bb_max = torch.stack(bb_max, dim=0) + mean = torch.stack(mean, dim=0) + # point to a track that's long and close to the camera + zs = mean[:, 2] + counts = vis_mask[:, :T].sum(dim=-1) # (B,) + mask = counts < 0.8 * T + zs[mask] = torch.inf + sel = torch.argmin(zs) + return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel] + + +def track_to_colors(track_ids): + """ + :param track_ids (B) + """ + color_map = torch.from_numpy(get_colors()).to(track_ids) + return color_map[track_ids] / 255 # (B, 3) + + +def get_colors(): + # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt")) + color_file = os.path.abspath(os.path.join(__file__, "../colors.txt")) + RGB_tuples = np.vstack( + [ + np.loadtxt(color_file, skiprows=0), + # np.loadtxt(color_file, skiprows=1), + np.random.uniform(0, 255, size=(10000, 3)), + [[0, 0, 0]], + ] + ) + b = np.where(RGB_tuples == 0) + RGB_tuples[b] = 1 + return RGB_tuples.astype(np.float32) + + +def checkerboard_geometry( + length=12.0, + color0=[0.8, 0.9, 0.9], + color1=[0.6, 0.7, 0.7], + tile_width=0.5, + alpha=1.0, + up="y", + c1=0.0, + c2=0.0, +): + assert up == "y" or up == "z" + color0 = np.array(color0 + [alpha]) + color1 = np.array(color1 + [alpha]) + num_rows = num_cols = max(2, int(length / tile_width)) + radius = float(num_rows * tile_width) / 2.0 + vertices = [] + vert_colors = [] + faces = [] + face_colors = [] + for i in range(num_rows): + for j in range(num_cols): + u0, v0 = j * tile_width - radius, i * tile_width - radius + us = np.array([u0, u0, u0 + tile_width, u0 + tile_width]) + vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0]) + zs = np.zeros(4) + if up == "y": + cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 2] += c2 + else: + cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 1] += c2 + + cur_faces = np.array([[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64) + cur_faces += 4 * (i * num_cols + j) # the number of previously added verts + use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1) + cur_color = color0 if use_color0 else color1 + cur_colors = np.array([cur_color, cur_color, cur_color, cur_color]) + + vertices.append(cur_verts) + faces.append(cur_faces) + vert_colors.append(cur_colors) + face_colors.append(cur_colors) + + vertices = np.concatenate(vertices, axis=0).astype(np.float32) + vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32) + faces = np.concatenate(faces, axis=0).astype(np.float32) + face_colors = np.concatenate(face_colors, axis=0).astype(np.float32) + + return vertices, faces, vert_colors, face_colors + + +def camera_marker_geometry(radius, height, up): + assert up == "y" or up == "z" + if up == "y": + vertices = np.array( + [ + [-radius, -radius, 0], + [radius, -radius, 0], + [radius, radius, 0], + [-radius, radius, 0], + [0, 0, height], + ] + ) + else: + vertices = np.array( + [ + [-radius, 0, -radius], + [radius, 0, -radius], + [radius, 0, radius], + [-radius, 0, radius], + [0, -height, 0], + ] + ) + + faces = np.array( + [ + [0, 3, 1], + [1, 3, 2], + [0, 1, 4], + [1, 2, 4], + [2, 3, 4], + [3, 0, 4], + ] + ) + + face_colors = np.array( + [ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + return vertices, faces, face_colors + + +def vis_keypoints( + keypts_list, + img_size, + radius=6, + thickness=3, + kpt_score_thr=0.3, + dataset="TopDownCocoDataset", +): + """ + Visualize keypoints + From ViTPose/mmpose/apis/inference.py + """ + palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ] + ) + + if dataset in ( + "TopDownCocoDataset", + "BottomUpCocoDataset", + "TopDownOCHumanDataset", + "AnimalMacaqueDataset", + ): + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + ] + + pose_link_color = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]] + pose_kpt_color = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]] + + elif dataset == "TopDownCocoWholeBodyDataset": + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [15, 17], + [15, 18], + [15, 19], + [16, 20], + [16, 21], + [16, 22], + [91, 92], + [92, 93], + [93, 94], + [94, 95], + [91, 96], + [96, 97], + [97, 98], + [98, 99], + [91, 100], + [100, 101], + [101, 102], + [102, 103], + [91, 104], + [104, 105], + [105, 106], + [106, 107], + [91, 108], + [108, 109], + [109, 110], + [110, 111], + [112, 113], + [113, 114], + [114, 115], + [115, 116], + [112, 117], + [117, 118], + [118, 119], + [119, 120], + [112, 121], + [121, 122], + [122, 123], + [123, 124], + [112, 125], + [125, 126], + [126, 127], + [127, 128], + [112, 129], + [129, 130], + [130, 131], + [131, 132], + ] + + pose_link_color = palette[ + [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16] + + [16, 16, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + ] + pose_kpt_color = palette[ + [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [19] * (68 + 42) + ] + + elif dataset == "TopDownAicDataset": + skeleton = [ + [2, 1], + [1, 0], + [0, 13], + [13, 3], + [3, 4], + [4, 5], + [8, 7], + [7, 6], + [6, 9], + [9, 10], + [10, 11], + [12, 13], + [0, 6], + [3, 9], + ] + + pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]] + pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]] + + elif dataset == "TopDownMpiiDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 6], + [6, 3], + [3, 4], + [4, 5], + [6, 7], + [7, 8], + [8, 9], + [8, 12], + [12, 11], + [11, 10], + [8, 13], + [13, 14], + [14, 15], + ] + + pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]] + pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]] + + elif dataset == "TopDownMpiiTrbDataset": + skeleton = [ + [12, 13], + [13, 0], + [13, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [0, 6], + [1, 7], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [14, 15], + [16, 17], + [18, 19], + [20, 21], + [22, 23], + [24, 25], + [26, 27], + [28, 29], + [30, 31], + [32, 33], + [34, 35], + [36, 37], + [38, 39], + ] + + pose_link_color = palette[[16] * 14 + [19] * 13] + pose_kpt_color = palette[[16] * 14 + [0] * 26] + + elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"): + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]] + pose_kpt_color = palette[[0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]] + + elif dataset == "InterHand2DDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [4, 5], + [5, 6], + [6, 7], + [8, 9], + [9, 10], + [10, 11], + [12, 13], + [13, 14], + [14, 15], + [16, 17], + [17, 18], + [18, 19], + [3, 20], + [7, 20], + [11, 20], + [15, 20], + [19, 20], + ] + + pose_link_color = palette[[0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]] + pose_kpt_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]] + + elif dataset == "Face300WDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 68] + kpt_score_thr = 0 + + elif dataset == "FaceAFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 19] + kpt_score_thr = 0 + + elif dataset == "FaceCOFWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 29] + kpt_score_thr = 0 + + elif dataset == "FaceWFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 98] + kpt_score_thr = 0 + + elif dataset == "AnimalHorse10Dataset": + skeleton = [ + [0, 1], + [1, 12], + [12, 16], + [16, 21], + [21, 17], + [17, 11], + [11, 10], + [10, 8], + [8, 9], + [9, 12], + [2, 3], + [3, 4], + [5, 6], + [6, 7], + [13, 14], + [14, 15], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2] + pose_kpt_color = palette[[4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]] + + elif dataset == "AnimalFlyDataset": + skeleton = [ + [1, 0], + [2, 0], + [3, 0], + [4, 3], + [5, 4], + [7, 6], + [8, 7], + [9, 8], + [11, 10], + [12, 11], + [13, 12], + [15, 14], + [16, 15], + [17, 16], + [19, 18], + [20, 19], + [21, 20], + [23, 22], + [24, 23], + [25, 24], + [27, 26], + [28, 27], + [29, 28], + [30, 3], + [31, 3], + ] + + pose_link_color = palette[[0] * 25] + pose_kpt_color = palette[[0] * 32] + + elif dataset == "AnimalLocustDataset": + skeleton = [ + [1, 0], + [2, 1], + [3, 2], + [4, 3], + [6, 5], + [7, 6], + [9, 8], + [10, 9], + [11, 10], + [13, 12], + [14, 13], + [15, 14], + [17, 16], + [18, 17], + [19, 18], + [21, 20], + [22, 21], + [24, 23], + [25, 24], + [26, 25], + [28, 27], + [29, 28], + [30, 29], + [32, 31], + [33, 32], + [34, 33], + ] + + pose_link_color = palette[[0] * 26] + pose_kpt_color = palette[[0] * 35] + + elif dataset == "AnimalZebraDataset": + skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]] + + pose_link_color = palette[[0] * 8] + pose_kpt_color = palette[[0] * 9] + + elif dataset in "AnimalPoseDataset": + skeleton = [ + [0, 1], + [0, 2], + [1, 3], + [0, 4], + [1, 4], + [4, 5], + [5, 7], + [6, 7], + [5, 8], + [8, 12], + [12, 16], + [5, 9], + [9, 13], + [13, 17], + [6, 10], + [10, 14], + [14, 18], + [6, 11], + [11, 15], + [15, 19], + ] + + pose_link_color = palette[[0] * 20] + pose_kpt_color = palette[[0] * 20] + else: + NotImplementedError() + + img_w, img_h = img_size + img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8) + img = imshow_keypoints( + img, + keypts_list, + skeleton, + kpt_score_thr, + pose_kpt_color, + pose_link_color, + radius, + thickness, + ) + alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8) + return np.concatenate([img, alpha], axis=-1) + + +def imshow_keypoints( + img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, +): + """Draw keypoints and links on an image. + From ViTPose/mmpose/core/visualization/image.py + + Args: + img (H, W, 3) array + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + show_keypoint_weight (bool): If True, opacity indicates keypoint score + """ + img_h, img_w, _ = img.shape + idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11] + for kpts in pose_result: + kpts = np.array(kpts, copy=False)[idcs] + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if ( + pos1[0] > 0 + and pos1[0] < img_w + and pos1[1] > 0 + and pos1[1] < img_h + and pos2[0] > 0 + and pos2[0] < img_w + and pos2[1] > 0 + and pos2[1] < img_h + and kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr + ): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), + int(angle), + 0, + 360, + 1, + ) + cv2.fillConvexPoly(img_copy, polygon, color) + transparency = max(0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img diff --git a/hmr4d/utils/vis/renderer_utils.py b/hmr4d/utils/vis/renderer_utils.py new file mode 100644 index 0000000..634c84b --- /dev/null +++ b/hmr4d/utils/vis/renderer_utils.py @@ -0,0 +1,38 @@ +from hmr4d.utils.vis.renderer import Renderer +from tqdm import tqdm +import numpy as np + + +def simple_render_mesh(render_dict): + """Render an camera-space mesh, blank background""" + width, height, focal_length = render_dict["whf"] + faces = render_dict["faces"] + verts = render_dict["verts"] + + renderer = Renderer(width, height, focal_length, device="cuda", faces=faces) + outputs = [] + for i in tqdm(range(len(verts)), desc=f"Rendering"): + img = renderer.render_mesh(verts[i].cuda(), colors=[0.8, 0.8, 0.8]) + outputs.append(img) + outputs = np.stack(outputs, axis=0) + return outputs + + +def simple_render_mesh_background(render_dict, VI=50, colors=[0.8, 0.8, 0.8]): + """Render an camera-space mesh, blank background""" + K = render_dict["K"] + faces = render_dict["faces"] + verts = render_dict["verts"] + background = render_dict["background"] + N_frames = len(verts) + if len(background.shape) == 3: + background = [background] * N_frames + height, width = background[0].shape[:2] + + renderer = Renderer(width, height, device="cuda", faces=faces, K=K) + outputs = [] + for i in tqdm(range(len(verts)), desc=f"Rendering"): + img = renderer.render_mesh(verts[i].cuda(), colors=colors, background=background[i], VI=VI) + outputs.append(img) + outputs = np.stack(outputs, axis=0) + return outputs diff --git a/hmr4d/utils/vis/rich_logger.py b/hmr4d/utils/vis/rich_logger.py new file mode 100644 index 0000000..80b5513 --- /dev/null +++ b/hmr4d/utils/vis/rich_logger.py @@ -0,0 +1,36 @@ +from pytorch_lightning.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf +import rich +import rich.tree +import rich.syntax +from hmr4d.utils.pylogger import Log + + +@rank_zero_only +def print_cfg(cfg: DictConfig, use_rich: bool = False): + if use_rich: + print_order = ("data", "model", "callbacks", "logger", "pl_trainer") + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + # add fields from `print_order` to queue + # add all the other fields to queue (not specified in `print_order`) + queue = [] + for field in print_order: + queue.append(field) if field in cfg else Log.warn(f"Field '{field}' not found in config. Skipping.") + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=False) + else: + branch_content = str(config_group) + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + rich.print(tree) + else: + Log.info(OmegaConf.to_yaml(cfg, resolve=False)) diff --git a/hmr4d/utils/wis3d_utils.py b/hmr4d/utils/wis3d_utils.py new file mode 100644 index 0000000..df54d8b --- /dev/null +++ b/hmr4d/utils/wis3d_utils.py @@ -0,0 +1,403 @@ +from wis3d import Wis3D +from pathlib import Path +from datetime import datetime +import torch +import numpy as np +from einops import einsum +from pytorch3d.transforms import axis_angle_to_matrix + + +def make_wis3d(output_dir="outputs/wis3d", name="debug", time_postfix=False): + """ + Make a Wis3D instance. e.g.: + from hmr4d.utils.wis3d_utils import make_wis3d + wis3d = make_wis3d(time_postfix=True) + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if time_postfix: + time_str = datetime.now().strftime("%m%d-%H%M-%S") + name = f"{name}_{time_str}" + print(f"Creating Wis3D {name}") + wis3d = Wis3D(output_dir.absolute(), name) + return wis3d + + +color_schemes = { + "red": ([255, 168, 154], [153, 17, 1]), + "green": ([183, 255, 191], [0, 171, 8]), + "blue": ([183, 255, 255], [0, 0, 255]), + "cyan": ([183, 255, 255], [0, 255, 255]), + "magenta": ([255, 183, 255], [255, 0, 255]), + "black": ([0, 0, 0], [0, 0, 0]), + "orange": ([255, 183, 0], [255, 128, 0]), + "grey": ([203, 203, 203], [203, 203, 203]), +} + + +def get_gradient_colors(scheme="red", num_points=120, alpha=1.0): + """ + Return a list of colors that are gradient from start to end. + """ + start_rgba = torch.tensor(color_schemes[scheme][0] + [255 * alpha]) / 255 + end_rgba = torch.tensor(color_schemes[scheme][1] + [255 * alpha]) / 255 + colors = torch.stack([torch.linspace(s, e, steps=num_points) for s, e in zip(start_rgba, end_rgba)], dim=-1) + return colors + + +def get_const_colors(name="red", partial_shape=(120, 5), alpha=1.0): + """ + Return colors (partial_shape, 4) + """ + rgba = torch.tensor(color_schemes[name][1] + [255 * alpha]) / 255 + partial_shape = tuple(partial_shape) + colors = rgba[None].repeat(*partial_shape, 1) + return colors + + +def get_colors_by_conf(conf, low="red", high="green"): + colors = torch.stack([conf] * 3, dim=-1) + colors = colors * torch.tensor(color_schemes[high][1]) + (1 - colors) * torch.tensor(color_schemes[low][1]) + return colors + + +# ================== Colored Motion Sequence ================== # + + +KINEMATIC_CHAINS = { + "smpl22": [ + [0, 2, 5, 8, 11], # right-leg + [0, 1, 4, 7, 10], # left-leg + [0, 3, 6, 9, 12, 15], # spine + [9, 14, 17, 19, 21], # right-arm + [9, 13, 16, 18, 20], # left-arm + ], + "h36m17": [ + [0, 1, 2, 3], # right-leg + [0, 4, 5, 6], # left-leg + [0, 7, 8, 9, 10], # spine + [8, 14, 15, 16], # right-arm + [8, 11, 12, 13], # left-arm + ], + "coco17": [ + [12, 14, 16], # right-leg + [11, 13, 15], # left-leg + [4, 2, 0, 1, 3], # replace spine with head + [6, 8, 10], # right-arm + [5, 7, 9], # left-arm + ], +} + + +def convert_motion_as_line_mesh(motion, skeleton_type="smpl22", const_color=None): + if isinstance(motion, np.ndarray): + motion = torch.from_numpy(motion) + motion = motion.detach().cpu() + kinematic_chain = KINEMATIC_CHAINS[skeleton_type] + color_names = ["red", "green", "blue", "cyan", "magenta"] + s_points = [] + e_points = [] + m_colors = [] + length = motion.shape[0] + device = motion.device + for chain, color_name in zip(kinematic_chain, color_names): + num_line = len(chain) - 1 + s_points.append(motion[:, chain[:-1]]) + e_points.append(motion[:, chain[1:]]) + if const_color is not None: + color_name = const_color + color_ = get_const_colors(color_name, partial_shape=(length, num_line), alpha=1.0).to(device) # (L, 4, 4) + m_colors.append(color_[..., :3] * 255) # (L, 4, 3) + + s_points = torch.cat(s_points, dim=1) # (L, ?, 3) + e_points = torch.cat(e_points, dim=1) + m_colors = torch.cat(m_colors, dim=1) + + vertices = [] + for f in range(length): + vertices_, faces, vertex_colors = create_skeleton_mesh(s_points[f], e_points[f], radius=0.02, color=m_colors[f]) + vertices.append(vertices_) + vertices = torch.stack(vertices, dim=0) + return vertices, faces, vertex_colors + + +def add_motion_as_lines(motion, wis3d, name="joints22", skeleton_type="smpl22", const_color=None, offset=0): + """ + Args: + motion (tensor): (L, J, 3) + """ + vertices, faces, vertex_colors = convert_motion_as_line_mesh( + motion, skeleton_type=skeleton_type, const_color=const_color + ) + for f in range(len(vertices)): + wis3d.set_scene_id(f + offset) + wis3d.add_mesh(vertices[f], faces, vertex_colors, name=name) # Add skeleton as cylinders + # Old way to add lines, this may cause problems when the number of lines is large + # wis3d.add_lines(s_points[f], e_points[f], m_colors[f], name=name) + + +def add_prog_motion_as_lines(motion, wis3d, name="joints22", skeleton_type="smpl22"): + """ + Args: + motion (tensor): (P, L, J, 3) + """ + if isinstance(motion, np.ndarray): + motion = torch.from_numpy(motion) + P, L, J, _ = motion.shape + device = motion.device + + kinematic_chain = KINEMATIC_CHAINS[skeleton_type] + color_names = ["red", "green", "blue", "cyan", "magenta"] + s_points = [] + e_points = [] + m_colors = [] + for chain, color_name in zip(kinematic_chain, color_names): + num_line = len(chain) - 1 + s_points.append(motion[:, :, chain[:-1]]) + e_points.append(motion[:, :, chain[1:]]) + color_ = get_gradient_colors(color_name, L, alpha=1.0).to(device) # (L, 4) + color_ = color_[None, :, None, :].repeat(P, 1, num_line, 1) # (P, L, num_line, 4) + m_colors.append(color_[..., :3] * 255) # (P, L, num_line, 3) + s_points = torch.cat(s_points, dim=-2) # (L, ?, 3) + e_points = torch.cat(e_points, dim=-2) + m_colors = torch.cat(m_colors, dim=-2) + + s_points = s_points.reshape(P, -1, 3) + e_points = e_points.reshape(P, -1, 3) + m_colors = m_colors.reshape(P, -1, 3) + + for p in range(P): + wis3d.set_scene_id(p) + wis3d.add_lines(s_points[p], e_points[p], m_colors[p], name=name) + + +def add_joints_motion_as_spheres(joints, wis3d, radius=0.05, name="joints", label_each_joint=False): + """Visualize skeleton as spheres to explore the skeleton. + Args: + joints: (NF, NJ, 3) + wis3d + radius: radius of the spheres + name + label_each_joint: if True, each joints will have a label in wis3d (then you can interact with it, but it's slower) + """ + colors = torch.zeros_like(joints).float() + n_frames = joints.shape[0] + n_joints = joints.shape[1] + for i in range(n_joints): + colors[:, i, 1] = 255 / n_joints * i + colors[:, i, 2] = 255 / n_joints * (n_joints - i) + for f in range(n_frames): + wis3d.set_scene_id(f) + if label_each_joint: + for i in range(n_joints): + wis3d.add_spheres( + joints[f, i].float(), + radius=radius, + colors=colors[f, i], + name=f"{name}-j{i}", + ) + else: + wis3d.add_spheres( + joints[f].float(), + radius=radius, + colors=colors[f], + name=f"{name}", + ) + + +def create_skeleton_mesh(p1, p2, radius, color, resolution=4, return_merged=True): + """ + Create mesh between p1 and p2. + Args: + p1 (torch.Tensor): (N, 3), + p2 (torch.Tensor): (N, 3), + radius (float): radius, + color (torch.Tensor): (N, 3) + resolution (int): number of vertices in one circle, denoted as Q + Returns: + vertices (torch.Tensor): (N * 2Q, 3), if return_merged is False (N, 2Q, 3) + faces (torch.Tensor): (M', 3), if return_merged is False (N, M, 3) + vertex_colors (torch.Tensor): (N * 2Q, 3), if return_merged is False (N, 2Q, 3) + """ + N = p1.shape[0] + + # Calculate segment direction + seg_dir = p2 - p1 # (N, 3) + unit_seg_dir = seg_dir / seg_dir.norm(dim=-1, keepdim=True) # (N, 3) + + # Compute an orthogonal vector + x_vec = torch.tensor([1, 0, 0], device=p1.device).float().unsqueeze(0).repeat(N, 1) # (N, 3) + y_vec = torch.tensor([0, 1, 0], device=p1.device).float().unsqueeze(0).repeat(N, 1) + ortho_vec = torch.cross(unit_seg_dir, x_vec, dim=-1) # (N, 3) + ortho_vec_ = torch.cross(unit_seg_dir, y_vec, dim=-1) # (N, 3) backup + ortho_vec = torch.where(ortho_vec.norm(dim=-1, keepdim=True) > 1e-3, ortho_vec, ortho_vec_) + + # Get circle points on two ends + unit_ortho_vec = ortho_vec / ortho_vec.norm(dim=-1, keepdim=True) # (N, 3) + theta = torch.linspace(0, 2 * np.pi, resolution, device=p1.device) + rotation_matrix = axis_angle_to_matrix(unit_seg_dir[:, None] * theta[None, :, None]) # (N, Q, 3, 3) + rotated_points = einsum(rotation_matrix, unit_ortho_vec, "n q i j, n i -> n q j") * radius # (N, Q, 3) + bottom_points = rotated_points + p1.unsqueeze(1) # (N, Q, 3) + top_points = rotated_points + p2.unsqueeze(1) # (N, Q, 3) + + # Combine bottom and top points + vertices = torch.cat([bottom_points, top_points], dim=1) # (N, 2Q, 3) + + # Generate face + indices = torch.arange(0, resolution, device=p1.device) + bottom_indices = indices + top_indices = indices + resolution + + # outside face + face_bottom = torch.stack([bottom_indices[:-2], bottom_indices[1:-1], bottom_indices[-1].repeat(resolution - 2)], 1) + face_top = torch.stack([top_indices[1:-1], top_indices[:-2], top_indices[-1].repeat(resolution - 2)], 1) + faces = torch.cat( + [ + torch.stack([bottom_indices[1:], bottom_indices[:-1], top_indices[:-1]], 1), # out face + torch.stack([bottom_indices[1:], top_indices[:-1], top_indices[1:]], 1), # out face + face_bottom, + face_top, + ] + ) + faces = faces.unsqueeze(0).repeat(p1.shape[0], 1, 1) # (N, M, 3) + + # Assign colors + vertex_colors = color.unsqueeze(1).repeat(1, resolution * 2, 1) + + if return_merged: + # manully adjust face ids + N, V = vertices.shape[:2] + faces = faces + torch.arange(0, N, device=p1.device).unsqueeze(1).unsqueeze(1) * V + faces = faces.reshape(-1, 3) + vertices = vertices.reshape(-1, 3) + vertex_colors = vertex_colors.reshape(-1, 3) + + return vertices, faces, vertex_colors + + +def get_lines_of_my_frustum(frustum_points): + """ + frustum_points: (B, 8, 3), in (near {lu ru rd ld}, far {lu ru rd ld}) + """ + start_points = frustum_points[:, [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7]].cpu().numpy() + end_points = frustum_points[:, [4, 5, 6, 7, 1, 2, 3, 0, 5, 6, 7, 4]].cpu().numpy() + return start_points, end_points + + +def draw_colored_vec(wis3d, vec, name, radius=0.02, colors="r", starts=None, l=1.0): + """ + Args: + vec: (3) or (L, 3), should be the same length as colors, like 'rgb' + """ + if len(vec.shape) == 1: + vec = vec[None] + else: + assert len(vec.shape) == 2 + + assert len(vec) == len(colors) + # split colors, 'rgb' to 'r', 'g', 'b' + color_tensor = torch.zeros((len(colors), 3)) + c2rgb = { + "r": torch.tensor([1, 0, 0]).float(), + "g": torch.tensor([0, 1, 0]).float(), + "b": torch.tensor([0, 0, 1]).float(), + } + for i, c in enumerate(colors): + color_tensor[i] = c2rgb[c] + + if starts is None: + starts = torch.zeros_like(vec) + ends = starts + vec * l + vertices, faces, vertex_colors = create_skeleton_mesh(starts, ends, radius, color_tensor, resolution=10) + wis3d.add_mesh(vertices, faces, vertex_colors, name=name) + + +def draw_T_w2c(wis3d, T_w2c, name, radius=0.01, all_in_one=True, l=0.1): + """ + Draw a camera trajectory in world coordinate. + Args: + T_w2c: (L, 4, 4) + """ + color_tensor = torch.eye(3) + if all_in_one: + starts = -T_w2c[:, :3, :3].mT @ T_w2c[:, :3, [3]] # (L, 3, 1) + starts = starts[:, None, :, 0].expand(-1, 3, -1).reshape(-1, 3) # (L*3, 3) + vec = T_w2c[:, :3, :3].reshape(-1, 3) # (L * 3, 3) + ends = starts + vec * l + color_tensor = color_tensor[None].expand(T_w2c.size(0), -1, -1).reshape(-1, 3) + + vertices, faces, vertex_colors = create_skeleton_mesh(starts, ends, radius, color_tensor, resolution=10) + else: + raise NotImplementedError + wis3d.add_mesh(vertices, faces, vertex_colors, name=name) + + +def create_checkerboard_mesh(y=0.0, grid_size=1.0, bounds=((-3, -3), (3, 3))): + """ + example usage: + vertices, faces, vertex_colors = create_checkerboard_mesh() + wis3d.add_mesh(vertices=vertices, faces=faces, vertex_colors=vertex_colors, name="one") + """ + color1 = np.array([236, 240, 241], np.uint8) # light + color2 = np.array([120, 120, 120], np.uint8) # dark + + # 扩大范围 + min_x, min_z = bounds[0] + max_x, max_z = bounds[1] + min_x = grid_size * np.floor(min_x / grid_size) + min_z = grid_size * np.floor(min_z / grid_size) + max_x = grid_size * np.ceil(max_x / grid_size) + max_z = grid_size * np.ceil(max_z / grid_size) + + vertices = [] + faces = [] + vertex_colors = [] + eps = 1e-4 # HACK: disable smooth color & double-side color artifacts of wis3d + + for i, x in enumerate(np.arange(min_x, max_x, grid_size)): + for j, z in enumerate(np.arange(min_z, max_z, grid_size)): + + # Right-hand rule for normal direction + x += ((i % 2 * 2) - 1) * eps + z += ((j % 2 * 2) - 1) * eps + v1 = np.array([x, y, z]) + v2 = np.array([x, y, z + grid_size]) + v3 = np.array([x + grid_size, y, z + grid_size]) + v4 = np.array([x + grid_size, y, z]) + offset = np.array([0, -eps, 0]) # For visualizing the down-side of the mesh + + vertices.extend([v1, v2, v3, v4, v1 + offset, v2 + offset, v3 + offset, v4 + offset]) + idx = len(vertices) - 8 + faces.extend( + [ + [idx, idx + 1, idx + 2], + [idx + 2, idx + 3, idx], + [idx + 4, idx + 7, idx + 6], # double-sided + [idx + 6, idx + 5, idx + 4], # double-sided + ] + ) + vertex_color = color1 if (i + j) % 2 == 0 else color2 + vertex_colors.extend([vertex_color] * 8) + + # To numpy.array and the shape should be (n, 3) + vertices = np.array(vertices) + faces = np.array(faces) + vertex_colors = np.array(vertex_colors) + assert len(vertices.shape) == 2 and vertices.shape[1] == 3 + assert len(faces.shape) == 2 and faces.shape[1] == 3 + assert len(vertex_colors.shape) == 2 and vertex_colors.shape[1] == 3 and vertex_colors.dtype == np.uint8 + + return vertices, faces, vertex_colors + + +def add_a_trimesh(mesh, wis3d, name): + mesh.apply_transform(wis3d.three_to_world) + + # filename = wis3d.__get_export_file_name("mesh", name) + export_dir = Path(wis3d.out_folder) / wis3d.sequence_name / f"{wis3d.scene_id:05d}" / "meshes" + export_dir.mkdir(parents=True, exist_ok=True) + assert name is not None + filename = export_dir / f"{name}.ply" + wis3d.counters["mesh"] += 1 + + mesh.export(filename) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ff35939 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.black] +line-length = 120 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..8e63feb --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "exclude": ["./inputs", "./outputs" ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0039ea2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,46 @@ +# PyTorch +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.3.0+cu121 +torchvision==0.18.0+cu121 +timm==0.9.12 # For HMR2.0a feature extraction + +# Lightning + Hydra +lightning==2.3.0 +hydra-core==1.3 +hydra-zen +hydra_colorlog +rich + +# Common utilities +numpy==1.23.5 +jupyter +matplotlib +ipdb +setuptools>=68.0 +black +tensorboardX +opencv-python +ffmpeg-python +scikit-image +termcolor +einops +imageio==2.34.1 +av # imageio[pyav], improved performance over imageio[ffmpeg] +joblib + +# Diffusion +# diffusers[torch]==0.19.3 +# transformers==4.31.0 + +# 3D-Vision +pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt230/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl +trimesh +chumpy +smplx +# open3d==0.17.0 +wis3d + +# 2D-Pose +ultralytics==8.2.42 # YOLO +cython_bbox +lapx \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..24a2868 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + + +setup( + name="gvhmr", + version="1.0.0", + packages=find_packages(), + author="Zehong Shen", + description=["GVHMR training and inference"], + url="https://github.com/zju3dv/GVHMR", +) diff --git a/tools/demo/demo.py b/tools/demo/demo.py new file mode 100644 index 0000000..472f51b --- /dev/null +++ b/tools/demo/demo.py @@ -0,0 +1,310 @@ +import cv2 +import torch +import pytorch_lightning as pl +import numpy as np +import argparse +from hmr4d.utils.pylogger import Log +import hydra +from hydra import initialize_config_module, compose +from pathlib import Path +from pytorch3d.transforms import quaternion_to_matrix + +from hmr4d.configs import register_store_gvhmr +from hmr4d.utils.video_io_utils import ( + get_video_lwh, + read_video_np, + save_video, + merge_videos_horizontal, + get_writer, + get_video_reader, +) +from hmr4d.utils.vis.cv2_utils import draw_bbx_xyxy_on_image_batch, draw_coco17_skeleton_batch + +from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SLAMModel + +from hmr4d.utils.geo.hmr_cam import get_bbx_xys_from_xyxy, estimate_K, convert_K_to_K4, create_camera_sensor +from hmr4d.utils.geo_transform import compute_cam_angvel +from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL +from hmr4d.utils.net_utils import detach_to_cpu, to_cuda +from hmr4d.utils.smplx_utils import make_smplx +from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points +from tqdm import tqdm +from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay +from einops import einsum, rearrange + + +CRF = 23 # 17 is lossless, every +6 halves the mp4 size + + +def parse_args_to_cfg(): + # Put all args to cfg + parser = argparse.ArgumentParser() + parser.add_argument("--video", type=str, default="inputs/demo/dance_3.mp4") + parser.add_argument("--output_root", type=str, default=None, help="by default to outputs/demo") + parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO") + parser.add_argument("--verbose", action="store_true", help="If true, draw intermediate results") + args = parser.parse_args() + + # Input + video_path = Path(args.video) + assert video_path.exists(), f"Video not found at {video_path}" + length, width, height = get_video_lwh(video_path) + Log.info(f"[Input]: {video_path}") + Log.info(f"(L, W, H) = ({length}, {width}, {height})") + # Cfg + with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"): + overrides = [ + f"video_name={video_path.stem}", + f"static_cam={args.static_cam}", + f"verbose={args.verbose}", + ] + + # Allow to change output root + if args.output_root is not None: + overrides.append(f"output_root={args.output_root}") + register_store_gvhmr() + cfg = compose(config_name="demo", overrides=overrides) + + # Output + Log.info(f"[Output Dir]: {cfg.output_dir}") + Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) + Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True) + + # Copy raw-input-video to video_path + Log.info(f"[Copy Video] {video_path} -> {cfg.video_path}") + if not Path(cfg.video_path).exists() or get_video_lwh(video_path)[0] != get_video_lwh(cfg.video_path)[0]: + reader = get_video_reader(video_path) + writer = get_writer(cfg.video_path, fps=30, crf=CRF) + for img in tqdm(reader, total=get_video_lwh(video_path)[0], desc=f"Copy"): + writer.write_frame(img) + writer.close() + reader.close() + + return cfg + + +@torch.no_grad() +def run_preprocess(cfg): + Log.info(f"[Preprocess] Start!") + tic = Log.time() + video_path = cfg.video_path + paths = cfg.paths + static_cam = cfg.static_cam + verbose = cfg.verbose + + # Get bbx tracking result + if not Path(paths.bbx).exists(): + tracker = Tracker() + bbx_xyxy = tracker.get_one_track(video_path).float() # (L, 4) + bbx_xys = get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2).float() # (L, 3) apply aspect ratio and enlarge + torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx) + del tracker + else: + bbx_xys = torch.load(paths.bbx)["bbx_xys"] + Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}") + if verbose: + video = read_video_np(video_path) + bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"] + video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video) + save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay) + + # Get VitPose + if not Path(paths.vitpose).exists(): + vitpose_extractor = VitPoseExtractor() + vitpose = vitpose_extractor.extract(video_path, bbx_xys) + torch.save(vitpose, paths.vitpose) + del vitpose_extractor + else: + vitpose = torch.load(paths.vitpose) + Log.info(f"[Preprocess] vitpose from {paths.vitpose}") + if verbose: + video = read_video_np(video_path) + video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5) + save_video(video_overlay, paths.vitpose_video_overlay) + + # Get vit features + if not Path(paths.vit_features).exists(): + extractor = Extractor() + vit_features = extractor.extract_video_features(video_path, bbx_xys) + torch.save(vit_features, paths.vit_features) + del extractor + else: + Log.info(f"[Preprocess] vit_features from {paths.vit_features}") + + # Get DPVO results + if not static_cam: # use slam to get cam rotation + if not Path(paths.slam).exists(): + length, width, height = get_video_lwh(cfg.video_path) + K_fullimg = estimate_K(width, height) + intrinsics = convert_K_to_K4(K_fullimg) + slam = SLAMModel(video_path, width, height, intrinsics, buffer=4000, resize=0.5) + bar = tqdm(total=length, desc="DPVO") + while True: + ret = slam.track() + if ret: + bar.update() + else: + break + slam_results = slam.process() # (L, 7), numpy + torch.save(slam_results, paths.slam) + else: + Log.info(f"[Preprocess] slam results from {paths.slam}") + + Log.info(f"[Preprocess] End. Time elapsed: {Log.time()-tic:.2f}s") + + +def load_data_dict(cfg): + paths = cfg.paths + length, width, height = get_video_lwh(cfg.video_path) + if cfg.static_cam: + R_w2c = torch.eye(3).repeat(length, 1, 1) + else: + traj = torch.load(cfg.paths.slam) + traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]]) + R_w2c = quaternion_to_matrix(traj_quat).mT + K_fullimg = estimate_K(width, height).repeat(length, 1, 1) + # K_fullimg = create_camera_sensor(width, height, 26)[2].repeat(length, 1, 1) + + data = { + "length": torch.tensor(length), + "bbx_xys": torch.load(paths.bbx)["bbx_xys"], + "kp2d": torch.load(paths.vitpose), + "K_fullimg": K_fullimg, + "cam_angvel": compute_cam_angvel(R_w2c), + "f_imgseq": torch.load(paths.vit_features), + } + return data + + +def render_incam(cfg): + incam_video_path = Path(cfg.paths.incam_video) + if incam_video_path.exists(): + Log.info(f"[Render Incam] Video already exists at {incam_video_path}") + return + + pred = torch.load(cfg.paths.hmr4d_results) + smplx = make_smplx("supermotion").cuda() + smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").cuda() + faces_smpl = make_smplx("smpl").faces + + # smpl + smplx_out = smplx(**to_cuda(pred["smpl_params_incam"])) + pred_c_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]) + + # -- rendering code -- # + video_path = cfg.video_path + length, width, height = get_video_lwh(video_path) + K = pred["K_fullimg"][0] + + # renderer + renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) + reader = get_video_reader(video_path) # (F, H, W, 3), uint8, numpy + bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"] + + # -- render mesh -- # + verts_incam = pred_c_verts + writer = get_writer(incam_video_path, fps=30, crf=CRF) + for i, img_raw in tqdm(enumerate(reader), total=get_video_lwh(video_path)[0], desc=f"Rendering Incam"): + img = renderer.render_mesh(verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8]) + + # # bbx + # bbx_xys_ = bbx_xys_render[i].cpu().numpy() + # lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int) + # rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int) + # img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2) + + writer.write_frame(img) + writer.close() + reader.close() + + +def render_global(cfg): + global_video_path = Path(cfg.paths.global_video) + if global_video_path.exists(): + Log.info(f"[Render Global] Video already exists at {global_video_path}") + return + + debug_cam = False + pred = torch.load(cfg.paths.hmr4d_results) + smplx = make_smplx("supermotion").cuda() + smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").cuda() + faces_smpl = make_smplx("smpl").faces + J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt").cuda() + + # smpl + smplx_out = smplx(**to_cuda(pred["smpl_params_global"])) + pred_ay_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]) + + def move_to_start_point_face_z(verts): + "XZ to origin, Start from the ground, Face-Z" + # position + verts = verts.clone() # (L, V, 3) + offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] # (3) + offset[1] = verts[:, :, [1]].min() + verts = verts - offset + # face direction + T_ay2ayfz = compute_T_ayfz2ay(einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True) + verts = apply_T_on_points(verts, T_ay2ayfz) + return verts + + verts_glob = move_to_start_point_face_z(pred_ay_verts) + joints_glob = einsum(J_regressor, verts_glob, "j v, l v i -> l j i") # (L, J, 3) + global_R, global_T, global_lights = get_global_cameras_static( + verts_glob.cpu(), + beta=2.0, + cam_height_degree=20, + target_center_height=1.0, + ) + + # -- rendering code -- # + video_path = cfg.video_path + length, width, height = get_video_lwh(video_path) + _, _, K = create_camera_sensor(width, height, 24) # render as 24mm lens + + # renderer + renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) + # renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K, bin_size=0) + + # -- render mesh -- # + scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob) + renderer.set_ground(scale * 1.5, cx, cz) + color = torch.ones(3).float().cuda() * 0.8 + + render_length = length if not debug_cam else 8 + writer = get_writer(global_video_path, fps=30, crf=CRF) + for i in tqdm(range(render_length), desc=f"Rendering Global"): + cameras = renderer.create_camera(global_R[i], global_T[i]) + img = renderer.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights) + writer.write_frame(img) + writer.close() + + +if __name__ == "__main__": + cfg = parse_args_to_cfg() + paths = cfg.paths + Log.info(f"[GPU]: {torch.cuda.get_device_name()}") + Log.info(f'[GPU]: {torch.cuda.get_device_properties("cuda")}') + + # ===== Preprocess and save to disk ===== # + run_preprocess(cfg) + data = load_data_dict(cfg) + + # ===== HMR4D ===== # + if not Path(paths.hmr4d_results).exists(): + Log.info("[HMR4D] Predicting") + model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False) + model.load_pretrained_model(cfg.ckpt_path) + model = model.eval().cuda() + tic = Log.sync_time() + pred = model.predict(data, static_cam=cfg.static_cam) + pred = detach_to_cpu(pred) + data_time = data["length"] / 30 + Log.info(f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s") + torch.save(pred, paths.hmr4d_results) + + # ===== Render ===== # + render_incam(cfg) + render_global(cfg) + if not Path(paths.incam_global_horiz_video).exists(): + Log.info("[Merge Videos]") + merge_videos_horizontal([paths.incam_video, paths.global_video], paths.incam_global_horiz_video) diff --git a/tools/demo/demo_folder.py b/tools/demo/demo_folder.py new file mode 100644 index 0000000..9dff08d --- /dev/null +++ b/tools/demo/demo_folder.py @@ -0,0 +1,29 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +from hmr4d.utils.pylogger import Log +import subprocess +import os + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--folder", type=str) + parser.add_argument("-d", "--output_root", type=str, default=None) + parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO") + args = parser.parse_args() + + folder = Path(args.folder) + output_root = args.output_root + + # Run demo.py for each .mp4 file + mp4_paths = sorted(list(folder.glob("*.mp4")) + list(folder.glob("*.MP4"))) + Log.info(f"Found {len(mp4_paths)} .mp4 files in {folder}") + for mp4_path in tqdm(mp4_paths): + command = ["python", "tools/demo/demo.py", "--video", str(mp4_path)] + if output_root is not None: + command += ["--output_root", output_root] + if args.static_cam: + command += ["-s"] + Log.info(f"Running: {' '.join(command)}") + subprocess.run(command, env=dict(os.environ), check=True) diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000..7ec8e9b --- /dev/null +++ b/tools/train.py @@ -0,0 +1,87 @@ +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.callbacks.checkpoint import Checkpoint + +from hmr4d.utils.pylogger import Log +from hmr4d.configs import register_store_gvhmr +from hmr4d.utils.vis.rich_logger import print_cfg +from hmr4d.utils.net_utils import load_pretrained_model, get_resume_ckpt_path + + +def get_callbacks(cfg: DictConfig) -> list: + """Parse and instantiate all the callbacks in the config.""" + if not hasattr(cfg, "callbacks") or cfg.callbacks is None: + return None + # Handle special callbacks + enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True) + # Instantiate all the callbacks + callbacks = [] + for callback in cfg.callbacks.values(): + if callback is not None: + cb = hydra.utils.instantiate(callback, _recursive_=False) + # skip when disable checkpointing and the callback is Checkpoint + if not enable_checkpointing and isinstance(cb, Checkpoint): + continue + else: + callbacks.append(cb) + return callbacks + + +def train(cfg: DictConfig) -> None: + """Train/Test""" + Log.info(f"[Exp Name]: {cfg.exp_name}") + if cfg.task == "fit": + Log.info(f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}") + pl.seed_everything(cfg.seed) + + # preparation + datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.data, _recursive_=False) + model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) + if cfg.ckpt_path is not None: + load_pretrained_model(model, cfg.ckpt_path) + + # PL callbacks and logger + callbacks = get_callbacks(cfg) + has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks]) + if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True): + Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.") + cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False} + logger = hydra.utils.instantiate(cfg.logger, _recursive_=False) + + # PL-Trainer + if cfg.task == "test": + Log.info("Test mode forces full-precision.") + cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32} + trainer = pl.Trainer( + accelerator="gpu", + logger=logger if logger is not None else False, + callbacks=callbacks, + **cfg.pl_trainer, + ) + + if cfg.task == "fit": + resume_path = None + if cfg.resume_mode is not None: + resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=cfg.callbacks.model_checkpoint.dirpath) + Log.info(f"Resume training from {resume_path}") + Log.info("Start Fitiing...") + trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader(), ckpt_path=resume_path) + elif cfg.task == "test": + Log.info("Start Testing...") + trainer.test(model, datamodule.test_dataloader()) + else: + raise ValueError(f"Unknown task: {cfg.task}") + + Log.info("End of script.") + + +@hydra.main(version_base="1.3", config_path="../hmr4d/configs", config_name="train") +def main(cfg) -> None: + print_cfg(cfg, use_rich=True) + train(cfg) + + +if __name__ == "__main__": + register_store_gvhmr() + main() diff --git a/tools/unitest/make_hydra_cfg.py b/tools/unitest/make_hydra_cfg.py new file mode 100644 index 0000000..2ecfbab --- /dev/null +++ b/tools/unitest/make_hydra_cfg.py @@ -0,0 +1,7 @@ +from hmr4d.configs import parse_args_to_cfg, register_store_gvhmr +from hmr4d.utils.vis.rich_logger import print_cfg + +if __name__ == "__main__": + register_store_gvhmr() + cfg = parse_args_to_cfg() + print_cfg(cfg, use_rich=True) diff --git a/tools/unitest/run_dataset.py b/tools/unitest/run_dataset.py new file mode 100644 index 0000000..0e1f8af --- /dev/null +++ b/tools/unitest/run_dataset.py @@ -0,0 +1,41 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def get_dataset(DATA_TYPE): + if DATA_TYPE == "BEDLAM_V2": + from hmr4d.dataset.bedlam.bedlam import BedlamDatasetV2 + + return BedlamDatasetV2() + + if DATA_TYPE == "3DPW_TRAIN": + from hmr4d.dataset.threedpw.threedpw_motion_train import ThreedpwSmplDataset + + return ThreedpwSmplDataset() + +if __name__ == "__main__": + DATA_TYPE = "3DPW_TRAIN" + dataset = get_dataset(DATA_TYPE) + print(len(dataset)) + + data = dataset[0] + + from hmr4d.datamodule.mocap_trainX_testY import collate_fn + + loader = DataLoader( + dataset, + shuffle=False, + num_workers=0, + persistent_workers=False, + pin_memory=False, + batch_size=1, + collate_fn=collate_fn, + ) + i = 0 + for batch in tqdm(loader): + i += 1 + # if i == 20: + # raise AssertionError + # time.sleep(0.2) + pass diff --git a/tools/video/merge_folder.py b/tools/video/merge_folder.py new file mode 100644 index 0000000..e159bd6 --- /dev/null +++ b/tools/video/merge_folder.py @@ -0,0 +1,42 @@ +"""This script will glob two folder, check the mp4 files are one-to-one match precisely, then call merge_horizontal.py to merge them one by one""" + +import os +import argparse +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir1", type=str) + parser.add_argument("input_dir2", type=str) + parser.add_argument("output_dir", type=str) + parser.add_argument("--vertical", action="store_true") # By default use horizontal + args = parser.parse_args() + + # Check input + input_dir1 = Path(args.input_dir1) + input_dir2 = Path(args.input_dir2) + assert input_dir1.exists() + assert input_dir2.exists() + video_paths1 = sorted(input_dir1.glob("*.mp4")) + video_paths2 = sorted(input_dir2.glob("*.mp4")) + assert len(video_paths1) == len(video_paths2) + for path1, path2 in zip(video_paths1, video_paths2): + assert path1.stem == path2.stem + + # Merge to output + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for path1, path2 in zip(video_paths1, video_paths2): + out_path = output_dir / f"{path1.stem}.mp4" + in_paths = [str(path1), str(path2)] + print(f"Merging {in_paths} to {out_path}") + if args.vertical: + os.system(f"python tools/video/merge_vertical.py {' '.join(in_paths)} -o {out_path}") + else: + os.system(f"python tools/video/merge_horizontal.py {' '.join(in_paths)} -o {out_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/video/merge_horizontal.py b/tools/video/merge_horizontal.py new file mode 100644 index 0000000..f06da6d --- /dev/null +++ b/tools/video/merge_horizontal.py @@ -0,0 +1,15 @@ +import argparse +from hmr4d.utils.video_io_utils import merge_videos_horizontal + + +def parse_args(): + """python tools/video/merge_horizontal.py a.mp4 b.mp4 c.mp4 -o out.mp4""" + parser = argparse.ArgumentParser() + parser.add_argument("input_videos", nargs="+", help="Input video paths") + parser.add_argument("-o", "--output", type=str, required=True, help="Output video path") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + merge_videos_horizontal(args.input_videos, args.output) diff --git a/tools/video/merge_vertical.py b/tools/video/merge_vertical.py new file mode 100644 index 0000000..8617ec1 --- /dev/null +++ b/tools/video/merge_vertical.py @@ -0,0 +1,15 @@ +import argparse +from hmr4d.utils.video_io_utils import merge_videos_vertical + + +def parse_args(): + """python tools/video/merge_vertical.py a.mp4 b.mp4 c.mp4 -o out.mp4""" + parser = argparse.ArgumentParser() + parser.add_argument("input_videos", nargs="+", help="Input video paths") + parser.add_argument("-o", "--output", type=str, required=True, help="Output video path") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + merge_videos_vertical(args.input_videos, args.output)