diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2dc53ca
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
diff --git a/.gitmodules b/.gitmodules
index c6109a3..7be36f9 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
[submodule "models/detectors/yolo"]
- path = models/detectors/yolo
+ path = models_/detectors/yolo
url = https://github.com/eriklindernoren/PyTorch-YOLOv3
diff --git a/README.md b/README.md
index 4ca78dc..36aac97 100644
--- a/README.md
+++ b/README.md
@@ -15,17 +15,20 @@ This repository provides:
- A simple ``HRNet`` implementation in PyTorch (>=1.0) - compatible with official weights (``pose_hrnet_*``).
- A simple class (``SimpleHRNet``) that loads the HRNet network for the human pose estimation, loads the pre-trained weights,
and make human predictions on a single image or a batch of images.
-- **NEW** Support for "SimpleBaselines" model based on ResNet - compatible with official weights (``pose_resnet_*``).
-- **NEW** Support for multi-GPU inference.
-- **NEW** Add option for using YOLOv3-tiny (faster, but less accurate person detection).
-- **NEW** Add options for retrieving yolo bounding boxes and HRNet heatmaps.
-- Multi-person support with
+- Support for "SimpleBaselines" model based on ResNet - compatible with official weights (``pose_resnet_*``).
+- Support for multi-GPU inference.
+- Add options for retrieving yolo bounding boxes and HRNet heatmaps.
+- **NEW** Multi-person support with
[YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3/tree/47b7c912877ca69db35b8af3a38d6522681b3bb3)
- (enabled by default).
+ (enabled by default), YOLOv3-tiny, or [YOLOv5](https://github.com/ultralytics/yolov5) by Ultralytics.
- A reference code that runs a live demo reading frames from a webcam or a video file.
- A relatively-simple code for training and testing the HRNet network.
- A specific script for training the network on the COCO dataset.
-- **NEW** A [Google Colab notebook](https://github.com/stefanopini/simple-HRNet/issues/84#issuecomment-908199736) showcasing how to use this repository - Sincere thanks to [@basicvisual](https://github.com/basicvisual) and [@wuyenlin](https://github.com/wuyenlin).
+- **NEW** An updated [Jupyter Notebook](https://github.com/stefanopini/simple-HRNet/blob/master/SimpleHRNet_notebook.ipynb) compatible with Google Colab showcasing how to use this repository.
+ -
[Click here](https://colab.research.google.com/github/stefanopini/simple-HRNet/blob/master/SimpleHRNet_notebook.ipynb) to open the notebook on Colab!
+ - Thanks to [@basicvisual](https://github.com/basicvisual) and [@wuyenlin](https://github.com/wuyenlin) for the initial notebook.
+- **NEW** Support for TensorRT (thanks to [@gpastal24](https://github.com/gpastal24), see [#99](https://github.com/stefanopini/simple-HRNet/pull/99) and [#100](https://github.com/stefanopini/simple-HRNet/pull/100)).
+
If you are interested in **HigherHRNet**, please look at [*simple-HigherHRNet*](https://github.com/stefanopini/simple-HigherHRNet)
@@ -113,6 +116,24 @@ For help:
python scripts/extract-keypoints.py --help
```
+### Converting the model to TensorRT:
+
+Warning: require the installation of TensorRT (see Nvidia website) and onnx.
+On some platforms, they can be installed with
+```
+pip install tensorrt onnx
+```
+
+Converting in FP16:
+```
+python scripts/export-tensorrt-model.py --device 0 --half
+```
+
+For help:
+```
+python scripts/export-tensorrt-model.py --help
+```
+
### Running the training script
```
diff --git a/SimpleHRNet.py b/SimpleHRNet.py
index bd4e407..cdca44f 100644
--- a/SimpleHRNet.py
+++ b/SimpleHRNet.py
@@ -1,11 +1,12 @@
+import os
+
import cv2
import numpy as np
import torch
from torchvision.transforms import transforms
-from models.hrnet import HRNet
-from models.poseresnet import PoseResNet
-# from models.detectors.YOLOv3 import YOLOv3 # import only when multi-person is enabled
+from models_.hrnet import HRNet
+from models_.poseresnet import PoseResNet
class SimpleHRNet:
@@ -28,10 +29,12 @@ def __init__(self,
return_heatmaps=False,
return_bounding_boxes=False,
max_batch_size=32,
- yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
- yolo_class_path="./models/detectors/yolo/data/coco.names",
- yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
- device=torch.device("cpu")):
+ yolo_version='v3',
+ yolo_model_def="./models_/detectors/yolo/config/yolov3.cfg",
+ yolo_class_path="./models_/detectors/yolo/data/coco.names",
+ yolo_weights_path="./models_/detectors/yolo/weights/yolov3.weights",
+ device=torch.device("cpu"),
+ enable_tensorrt=False):
"""
Initializes a new SimpleHRNet object.
HRNet (and YOLOv3) are initialized on the torch.device("device") and
@@ -59,14 +62,23 @@ def __init__(self,
max_batch_size (int): maximum batch size used in hrnet inference.
Useless without multiperson=True.
Default: 16
- yolo_model_def (str): path to yolo model definition file.
- Default: "./models/detectors/yolo/config/yolov3.cfg"
- yolo_class_path (str): path to yolo class definition file.
- Default: "./models/detectors/yolo/data/coco.names"
- yolo_weights_path (str): path to yolo pretrained weights file.
- Default: "./models/detectors/yolo/weights/yolov3.weights.cfg"
+ yolo_version (str): version of YOLO. Supported versions: `v3`, `v5`. Used when multiperson is True.
+ Default: "v3"
+ yolo_model_def (str): path to yolo model definition file. Recommended values:
+ - `./models_/detectors/yolo/config/yolov3.cfg` if yolo_version is 'v3'
+ - `./models_/detectors/yolo/config/yolov3-tiny.cfg` if yolo_version is 'v3', to use tiny yolo
+ - yolov5 model name if yolo_version is 'v5', e.g. `yolov5m` (medium), `yolov5n` (nano)
+ - `yolov5m.engine` if yolo_version is 'v5', custom version (e.g. tensorrt model)
+ Default: "./models_/detectors/yolo/config/yolov3.cfg"
+ yolo_class_path (str): path to yolov3 class definition file.
+ Default: "./models_/detectors/yolo/data/coco.names"
+ yolo_weights_path (str): path to yolov3 pretrained weights file.
+ Default: "./models_/detectors/yolo/weights/yolov3.weights.cfg"
device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
Default: torch.device("cpu")
+ enable_tensorrt (bool): Enables tensorrt inference for HRnet.
+ If enabled, a `.engine` file is expected as `checkpoint_path`.
+ Default: False
"""
self.c = c
@@ -79,13 +91,20 @@ def __init__(self,
self.return_heatmaps = return_heatmaps
self.return_bounding_boxes = return_bounding_boxes
self.max_batch_size = max_batch_size
+ self.yolo_version = yolo_version
self.yolo_model_def = yolo_model_def
self.yolo_class_path = yolo_class_path
self.yolo_weights_path = yolo_weights_path
self.device = device
+ self.enable_tensorrt = enable_tensorrt
if self.multiperson:
- from models.detectors.YOLOv3 import YOLOv3
+ if self.yolo_version == 'v3':
+ from models_.detectors.YOLOv3 import YOLOv3
+ elif self.yolo_version == 'v5':
+ from models_.detectors.YOLOv5 import YOLOv5
+ else:
+ raise ValueError('Unsopported YOLO version.')
if model_name in ('HRNet', 'hrnet'):
self.model = HRNet(c=c, nof_joints=nof_joints)
@@ -94,32 +113,38 @@ def __init__(self,
else:
raise ValueError('Wrong model name.')
- checkpoint = torch.load(checkpoint_path, map_location=self.device)
- if 'model' in checkpoint:
- self.model.load_state_dict(checkpoint['model'])
- else:
- self.model.load_state_dict(checkpoint)
+ if not self.enable_tensorrt:
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
+ if 'model' in checkpoint:
+ self.model.load_state_dict(checkpoint['model'])
+ else:
+ self.model.load_state_dict(checkpoint)
- if 'cuda' in str(self.device):
- print("device: 'cuda' - ", end="")
+ if 'cuda' in str(self.device):
+ print("device: 'cuda' - ", end="")
- if 'cuda' == str(self.device):
- # if device is set to 'cuda', all available GPUs will be used
- print("%d GPU(s) will be used" % torch.cuda.device_count())
- device_ids = None
+ if 'cuda' == str(self.device):
+ # if device is set to 'cuda', all available GPUs will be used
+ print("%d GPU(s) will be used" % torch.cuda.device_count())
+ device_ids = None
+ else:
+ # if device is set to 'cuda:IDS', only that/those device(s) will be used
+ print("GPU(s) '%s' will be used" % str(self.device))
+ device_ids = [int(x) for x in str(self.device)[5:].split(',')]
+
+ self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
+ elif 'cpu' == str(self.device):
+ print("device: 'cpu'")
else:
- # if device is set to 'cuda:IDS', only that/those device(s) will be used
- print("GPU(s) '%s' will be used" % str(self.device))
- device_ids = [int(x) for x in str(self.device)[5:].split(',')]
+ raise ValueError('Wrong device name.')
- self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
- elif 'cpu' == str(self.device):
- print("device: 'cpu'")
+ self.model = self.model.to(device)
+ self.model.eval()
else:
- raise ValueError('Wrong device name.')
-
- self.model = self.model.to(device)
- self.model.eval()
+ from torch2trt import TRTModule
+ self.model = TRTModule()
+ self.model.load_state_dict(torch.load(checkpoint_path))
+ self.model.cuda().eval()
if not self.multiperson:
self.transform = transforms.Compose([
@@ -128,12 +153,17 @@ def __init__(self,
])
else:
- self.detector = YOLOv3(model_def=yolo_model_def,
- class_path=yolo_class_path,
- weights_path=yolo_weights_path,
- classes=('person',),
- max_batch_size=self.max_batch_size,
- device=device)
+ if self.yolo_version == 'v3':
+ self.detector = YOLOv3(model_def=yolo_model_def,
+ class_path=yolo_class_path,
+ weights_path=yolo_weights_path,
+ classes=('person',),
+ max_batch_size=self.max_batch_size,
+ device=device)
+ else:
+ self.detector = YOLOv5(model_def=yolo_model_def,
+ device=device)
+
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((self.resolution[0], self.resolution[1])), # (height, width)
@@ -196,10 +226,10 @@ def _predict_single(self, image):
else:
detections = self.detector.predict_single(image)
-
nof_people = len(detections) if detections is not None else 0
boxes = np.empty((nof_people, 4), dtype=np.int32)
- images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width)
+ # boxes = torch.empty((nof_people, 4),device=self.device)
+ images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1]), device=self.device) # (height, width)
heatmaps = np.zeros((nof_people, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
dtype=np.float32)
@@ -212,21 +242,41 @@ def _predict_single(self, image):
# Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
+
+ # Using padding instead of bbox enlargement, this should reduce cross-person keypoint detection
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
length = int(round((y2 - y1) * correction_factor))
- y1 = max(0, center - length // 2)
- y2 = min(image.shape[0], center + length // 2)
+ x1_new = x1
+ x2_new = x2
+ y1_new = int(center - length // 2)
+ y2_new = int(center + length // 2)
+ pad = (int(abs(y1_new - y1))), int(abs(y2_new - y2))
+ pad_tuple = (pad, (0, 0), (0, 0))
+
elif correction_factor < 1:
- # increase x side
center = x1 + (x2 - x1) // 2
length = int(round((x2 - x1) * 1 / correction_factor))
- x1 = max(0, center - length // 2)
- x2 = min(image.shape[1], center + length // 2)
-
- boxes[i] = [x1, y1, x2, y2]
- images[i] = self.transform(image[y1:y2, x1:x2, ::-1])
+ x1_new = int(center - length // 2)
+ x2_new = int(center + length // 2)
+ y1_new = y1
+ y2_new = y2
+ pad = (abs(x1_new - x1)), int(abs(x2_new - x2))
+ pad_tuple = ((0, 0), pad, (0, 0))
+ else:
+ x1_new = x1
+ x2_new = x2
+ y1_new = y1
+ y2_new = y2
+ pad_tuple = None
+
+ image_crop = image[y1:y2, x1:x2, ::-1]
+ if pad_tuple is not None:
+ image_crop = np.pad(image_crop, pad_tuple)
+ images[i] = self.transform(image_crop)
+ boxes[i] = [x1_new, y1_new, x2_new, y2_new]
+ # boxes[i] = torch.tensor([x1_new, y1_new, x2_new, y2_new])
if images.shape[0] > 0:
images = images.to(self.device)
@@ -257,6 +307,26 @@ def _predict_single(self, image):
pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, j, 2] = joint[pt]
+ # # Torch alternative, it could be faster
+ # pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32,device=self.device)
+ # # For each human, for each joint: y, x, confidence
+ # (b, indices) = torch.max(out, dim=2)
+ # (b, indices) = torch.max(b, dim=2)
+ #
+ # (c, indicesc) = torch.max(out, dim=3)
+ # (c, indicesc) = torch.max(c, dim=2)
+ # dims = (self.resolution[0]//4, self.resolution[1]//4)
+ # dim1 = torch.tensor(1. / dims[0], device=self.device)
+ # dim2 = torch.tensor(1. / dims[1], device=self.device)
+ #
+ # for i in range(0, out.shape[0]):
+ # pts[i, :, 0] = indicesc[i, :] * dim1 * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
+ # pts[i, :, 1] = indices[i, :] * dim2 * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
+ # pts[i, :, 2] = c[i, :]
+ #
+ # pts = pts.cpu().numpy()
+ # boxes = boxes.cpu().numpy()
+
else:
pts = np.empty((0, 0, 3), dtype=np.float32)
@@ -321,6 +391,8 @@ def _predict_batch(self, images):
# Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
+
+ # TODO Use padding instead of bbox enlargement here too
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
diff --git a/SimpleHRNet_notebook.ipynb b/SimpleHRNet_notebook.ipynb
new file mode 100644
index 0000000..6833c63
--- /dev/null
+++ b/SimpleHRNet_notebook.ipynb
@@ -0,0 +1,537 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": [
+ "x1P13nZeR3Xj",
+ "HqHg_VATg6CO",
+ "ZWUN1C5RgGYS"
+ ]
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Simple HRNet\n",
+ "This is a light Google Colab notebook showing how to use the [simple-HRNet](https://github.com/stefanopini/simple-HRNet) repository.\n",
+ "\n",
+ "It includes the conversion to TensorRT and a test of the converted model.\n",
+ "Please skip the section \"TensorRT\" if not interested.\n",
+ "\n",
+ "Initial idea of running on Google Colab by @basicvisual, initial implementation by @wuyenlin (see [issue #84](https://github.com/stefanopini/simple-HRNet/issues/84))."
+ ],
+ "metadata": {
+ "id": "xZqqnmmNfX1d"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Pytorch"
+ ],
+ "metadata": {
+ "id": "ZFihjwzqhA04"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Clone the repo and install the dependencies"
+ ],
+ "metadata": {
+ "id": "X_ugGAxdd6Hu"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# clone the repo\n",
+ "!git clone https://github.com/stefanopini/simple-HRNet.git"
+ ],
+ "metadata": {
+ "id": "FIecXpzEY7IJ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd simple-HRNet\n",
+ "!pwd"
+ ],
+ "metadata": {
+ "id": "JDNRl8a8dl7Z"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# install requirements\n",
+ "!pip install -r requirements.txt"
+ ],
+ "metadata": {
+ "id": "FGsHqGPNdbHt"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# install vlc to get video codecs\n",
+ "!apt install vlc"
+ ],
+ "metadata": {
+ "id": "qMynH2IPebr8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Add yolov3\n",
+ "Clone yolov3 for multiprocessing support. This can be skipped for single-person applications or if you plan to use YOLO v5 by Ultralytics."
+ ],
+ "metadata": {
+ "id": "x1P13nZeR3Xj"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# download git submodules\n",
+ "!git submodule update --init --recursive"
+ ],
+ "metadata": {
+ "id": "yqf7BRGWRtUV"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/simple-HRNet/models_/detectors/yolo\n",
+ "!pip install -q -r requirements.txt\n",
+ "\n",
+ "%cd /content/simple-HRNet"
+ ],
+ "metadata": {
+ "id": "vS9cz49gSJeG"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/simple-HRNet/models_/detectors/yolo/weights\n",
+ "!sh download_weights.sh\n",
+ "\n",
+ "%cd /content/simple-HRNet"
+ ],
+ "metadata": {
+ "id": "8v-RpWGwSM7V"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Download HRNet pre-trained weights and test video\n",
+ "\n",
+ "Download any of the supported official weights listed [here](https://github.com/stefanopini/simple-HRNet/#installation-instructions).\n",
+ "\n",
+ "In the following, we download the weights `pose_hrnet_w48_384x288.pth` from the official Drive link.\n",
+ "Download of other weights (e.g. `pose_hrnet_w32_256x192.pth`) as well as weights from private Drives is supported too."
+ ],
+ "metadata": {
+ "id": "HqHg_VATg6CO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade --no-cache-dir gdown"
+ ],
+ "metadata": {
+ "id": "pKFdWLLUXyZu"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# download weights\n",
+ "\n",
+ "# create weights folder\n",
+ "%cd /content/simple-HRNet\n",
+ "!mkdir weights\n",
+ "%cd /content/simple-HRNet/weights\n",
+ "\n",
+ "# download weights pose_hrnet_w48_384x288.pth\n",
+ "!gdown 1UoJhTtjHNByZSm96W3yFTfU5upJnsKiS\n",
+ "\n",
+ "# download weights pose_hrnet_w32_256x192.pth\n",
+ "!gdown 1zYC7go9EV0XaSlSBjMaiyE_4TcHc_S38\n",
+ "\n",
+ "# download weights pose_hrnet_w32_256x256.pth\n",
+ "!gdown 1_wn2ifmoQprBrFvUCDedjPON4Y6jsN-v\n",
+ "\n",
+ "# # download weights from your own Google Drive\n",
+ "# from glob import glob\n",
+ "# from google.colab import drive\n",
+ "# drive.mount('/content/drive')\n",
+ "# w_list = glob(\"/content/drive//*.pth\")\n",
+ "# if not w_list:\n",
+ "# raise FileNotFoundError(\"You haven't downloaded any pre-trained weights!\")\n",
+ "\n",
+ "%cd /content/simple-HRNet"
+ ],
+ "metadata": {
+ "id": "3LURZ12cfCcU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# download a publicly available video (or just get your own)\n",
+ "!wget https://commondatastorage.googleapis.com/gtv-videos-bucket/sample/WeAreGoingOnBullrun.mp4"
+ ],
+ "metadata": {
+ "id": "OLIrIc14eUPM"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Test the API\n"
+ ],
+ "metadata": {
+ "id": "vcv0B2P7UTxT"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import cv2\n",
+ "import requests\n",
+ "import matplotlib.pyplot as plt\n",
+ "import torch\n",
+ "from skimage import io\n",
+ "from PIL import Image\n",
+ "from SimpleHRNet import SimpleHRNet\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "# # singleperson, COCO weights\n",
+ "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", multiperson=False, device=device)\n",
+ "\n",
+ "# # multiperson w/ YOLOv3, COCO weights\n",
+ "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", device=device)\n",
+ "\n",
+ "# # multiperson w/ YOLOv3, COCO weights, small model\n",
+ "# model = SimpleHRNet(32, 17, \"./weights/pose_hrnet_w32_256x192.pth\", device=device)\n",
+ "\n",
+ "# # multiperson w/ YOLOv3, MPII weights\n",
+ "# model = SimpleHRNet(32, 16, \"./weights/pose_hrnet_w32_256x256.pth\", device=device)\n",
+ "\n",
+ "# # multiperson w/ YOLOv5 (medium), COCO weights\n",
+ "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", yolo_version='v5', yolo_model_def='yolov5m', device=device)\n",
+ "\n",
+ "# multiperson w/ YOLOv5 nano, COCO weights, small model\n",
+ "model = SimpleHRNet(32, 17, \"./weights/pose_hrnet_w32_256x192.pth\", yolo_version='v5', yolo_model_def='yolov5n', device=device)\n",
+ "\n",
+ "url = 'http://images.cocodataset.org/val2017/000000097278.jpg'\n",
+ "im = Image.open(requests.get(url, stream=True).raw)\n",
+ "image = io.imread(url)\n",
+ "\n",
+ "joints = model.predict(image)"
+ ],
+ "metadata": {
+ "id": "xCXrjhfJUR5C"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline\n",
+ "from misc.visualization import joints_dict\n",
+ "\n",
+ "def plot_joints(ax, output):\n",
+ " bones = joints_dict()[\"coco\"][\"skeleton\"]\n",
+ " # bones = joints_dict()[\"mpii\"][\"skeleton\"]\n",
+ "\n",
+ " for bone in bones:\n",
+ " xS = [output[:,bone[0],1], output[:,bone[1],1]]\n",
+ " yS = [output[:,bone[0],0], output[:,bone[1],0]]\n",
+ " ax.plot(xS, yS, linewidth=3, c=(0,0.3,0.7))\n",
+ " ax.scatter(joints[:,:,1],joints[:,:,0], s=20, c='r')\n",
+ "\n",
+ "fig = plt.figure(figsize=(60/2.54, 30/2.54))\n",
+ "ax = fig.add_subplot(121)\n",
+ "ax.imshow(Image.open(requests.get(url, stream=True).raw))\n",
+ "ax = fig.add_subplot(122)\n",
+ "ax.imshow(Image.open(requests.get(url, stream=True).raw))\n",
+ "plot_joints(ax, joints)\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "id": "aYNkSzCGUqMF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Test the live script\n",
+ "This step can be skipped if interested in the TensorRT conversion."
+ ],
+ "metadata": {
+ "id": "ZWUN1C5RgGYS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# # test the live script with default params (multiperson with yolo v3)\n",
+ "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video\n",
+ "\n",
+ "# # test the live script with tiny yolo (v3)\n",
+ "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --use_tiny_yolo\n",
+ "\n",
+ "# # test the live script with yolo v5\n",
+ "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --yolo_version v5\n",
+ "\n",
+ "# test the live script with tiny yolo v5 (tensorrt yolo v5)\n",
+ "!python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --yolo_version v5 --use_tiny_yolo"
+ ],
+ "metadata": {
+ "id": "VEPfVe2bg1dS"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now check out the video output.avi\n"
+ ],
+ "metadata": {
+ "id": "RsTTv7A5gGvF"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## TensorRT\n",
+ "This section install TensorRT 8.5, converts the model to TensorRT (.engine) and tests the converted model.\n",
+ "\n",
+ "Tested with TensorRT 8.5.1-1+cuda11.8 and python package tensorrt 8.5.1.7 ."
+ ],
+ "metadata": {
+ "id": "YHj3FQEyf1yD"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Install TensorRT\n",
+ "A GPU is needed for this step. Please change the runtime type to \"GPU\".\n"
+ ],
+ "metadata": {
+ "id": "VsFWYxaNc-gl"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LsAlxRGVXhrt"
+ },
+ "outputs": [],
+ "source": [
+ "# check a GPU runtime is selected\n",
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%bash\n",
+ "wget https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/nvidia-machine-learning-repo-ubuntu1804_1.0.0-1_amd64.deb\n",
+ "\n",
+ "dpkg -i nvidia-machine-learning-repo-*.deb\n",
+ "apt-get update\n",
+ "\n",
+ "sudo apt-get install libnvinfer8 python3-libnvinfer"
+ ],
+ "metadata": {
+ "id": "9vZ35qN5XkHE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# check TensorRT version\n",
+ "print(\"TensorRT version: \")\n",
+ "!dpkg -l | grep nvinfer"
+ ],
+ "metadata": {
+ "id": "GlGh_J2WYH8u"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# install TensorRT for python\n",
+ "!pip install tensorrt"
+ ],
+ "metadata": {
+ "id": "nhzVoykoYAWJ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# clone the converion tool torch2trt\n",
+ "%cd /content\n",
+ "!git clone https://github.com/NVIDIA-AI-IOT/torch2trt"
+ ],
+ "metadata": {
+ "id": "NUR0P_HklFbz"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# install torch2trt\n",
+ "%cd /content/torch2trt\n",
+ "!python setup.py install"
+ ],
+ "metadata": {
+ "id": "Y97nln2AX35c"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/simple-HRNet"
+ ],
+ "metadata": {
+ "id": "UC-xqiy5X5vk"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Export the model with tensorrt"
+ ],
+ "metadata": {
+ "id": "I2u6Xn72eEBE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Convert the smaller HRNet model to TensorRT - it may take a while...\n",
+ "!python scripts/export-tensorrt-model.py --half \\\n",
+ " --weights \"./weights/pose_hrnet_w32_256x192.pth\" --hrnet_c 32 --image_resolution '(256, 192)'"
+ ],
+ "metadata": {
+ "id": "S57JsLacdnoF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### [Optional] Export yolov5 with TensorRT"
+ ],
+ "metadata": {
+ "id": "ckdDXNJzmxt_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Optional - Convert yolov5 (nano) to tensorrt too\n",
+ "!python /root/.cache/torch/hub/ultralytics_yolov5_master/export.py --weights yolov5n.pt --include engine --device 0 --half"
+ ],
+ "metadata": {
+ "id": "3Hls1HlCl44F"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Test the tensorrt model"
+ ],
+ "metadata": {
+ "id": "npgGj4cGemZd"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Run inference with the converted TensorRT model\n",
+ "!python scripts/live-demo.py --enable_tensorrt --filename=WeAreGoingOnBullrun.mp4 --hrnet_weights='weights/hrnet_trt.engine' \\\n",
+ " --hrnet_c 32 --image_resolution \"(256, 192)\" --yolo_version v5 --use_tiny_yolo --save_video\n"
+ ],
+ "metadata": {
+ "id": "LnIpbqV0fVps"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now check out the video output.avi\n"
+ ],
+ "metadata": {
+ "id": "WbQk0PeNnM5-"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/datasets/LiveCamera.py b/datasets/LiveCamera.py
index fe8daed..3c2dc03 100644
--- a/datasets/LiveCamera.py
+++ b/datasets/LiveCamera.py
@@ -3,7 +3,7 @@
import torch
from torch.utils.data import Dataset
from torchvision import transforms
-from models.detectors.YOLOv3 import YOLOv3
+from models_.detectors.YOLOv3 import YOLOv3
class LiveCameraDataset(Dataset):
@@ -27,9 +27,9 @@ def __init__(self, camera_id=0, epoch_length=1, resolution=(384, 288), interpola
])
else:
- self.detector = YOLOv3(model_def="./models/detectors/yolo/config/yolov3.cfg",
- class_path="./models/detectors/yolo/data/coco.names",
- weights_path="./models/detectors/yolo/weights/yolov3.weights",
+ self.detector = YOLOv3(model_def="./models_/detectors/yolo/config/yolov3.cfg",
+ class_path="./models_/detectors/yolo/data/coco.names",
+ weights_path="./models_/detectors/yolo/weights/yolov3.weights",
classes=('person',), device=device)
self.transform = transforms.Compose([
diff --git a/misc/utils.py b/misc/utils.py
index b425da9..08fff48 100644
--- a/misc/utils.py
+++ b/misc/utils.py
@@ -360,7 +360,11 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
+
+ e = e[e <=2^32 -1]
+
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
+
return ious
diff --git a/models/__init__.py b/models_/__init__.py
similarity index 100%
rename from models/__init__.py
rename to models_/__init__.py
diff --git a/models/detectors/YOLOv3.py b/models_/detectors/YOLOv3.py
similarity index 97%
rename from models/detectors/YOLOv3.py
rename to models_/detectors/YOLOv3.py
index 63380f3..62d7959 100644
--- a/models/detectors/YOLOv3.py
+++ b/models_/detectors/YOLOv3.py
@@ -7,7 +7,7 @@
import torch
from torchvision.transforms import transforms
-sys.path.append(os.path.join(os.getcwd(), 'models', 'detectors', 'yolo'))
+sys.path.append(os.path.join(os.getcwd(), 'models_', 'detectors', 'yolo'))
from .yolo.models import Darknet
from .yolo.utils.utils import load_classes, non_max_suppression
@@ -29,10 +29,10 @@ def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'):
ratio = max(new_shape) / max(shape) # ratio = new / old
new_unpad = (int(round(shape[1] * ratio)), int(round(shape[0] * ratio)))
- if mode is 'auto': # minimum rectangle
+ if mode == 'auto': # minimum rectangle
dw = np.mod(new_shape - new_unpad[0], 32) / 2 # width padding
dh = np.mod(new_shape - new_unpad[1], 32) / 2 # height padding
- elif mode is 'square': # square
+ elif mode == 'square': # square
dw = (new_shape - new_unpad[0]) / 2 # width padding
dh = (new_shape - new_unpad[1]) / 2 # height padding
else:
diff --git a/models_/detectors/YOLOv5.py b/models_/detectors/YOLOv5.py
new file mode 100644
index 0000000..68e13f0
--- /dev/null
+++ b/models_/detectors/YOLOv5.py
@@ -0,0 +1,103 @@
+import os
+
+import cv2
+import numpy as np
+import torch
+
+
+# from https://github.com/ultralytics/yolov5
+def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
+ elif scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im, ratio, (dw, dh)
+
+
+class YOLOv5:
+ def __init__(self,
+ model_def='',
+ model_folder='./models_/detectors/yolov5',
+ image_resolution=(640, 640),
+ conf_thres=0.3,
+ device=torch.device('cpu')):
+
+ self.model_def = model_def
+ self.model_folder = model_folder
+ self.image_resolution = image_resolution
+ self.conf_thres = conf_thres
+ self.device = device
+ self.trt_model = self.model_def.endswith('.engine')
+
+ # Set up model
+ if self.trt_model:
+ # if the yolo model ends with 'engine', it is loaded as a custom YOLOv5 pre-trained model
+ print(f"Loading custom yolov5 model {self.model_def}")
+ self.model = torch.hub.load('ultralytics/yolov5', 'custom', self.model_def)
+ else:
+ # load the pre-trained YOLOv5 in a pre-defined folder
+ if not os.path.exists(self.model_folder):
+ os.makedirs(self.model_folder)
+ self.model = torch.hub.load('ultralytics/yolov5', self.model_def, pretrained=True)
+
+ self.model = self.model.to(self.device)
+ self.model.eval() # Set in evaluation mode
+
+ def predict_single(self, image, color_mode='BGR'):
+ image = image.copy()
+ if self.trt_model:
+ # when running with TensorRT, the image must have fixed size
+ image, (ratiow, ratioh), (dw, dh) = letterbox(image, self.image_resolution, stride=self.model.stride,
+ auto=False, scaleFill=False) # padded resize
+
+ if color_mode == 'BGR':
+ # all YOLO models expect RGB
+ # See https://github.com/ultralytics/yolov5/issues/9913#issuecomment-1290736061 and
+ # https://github.com/ultralytics/yolov5/blob/8ca182613499c323a411f559b7b5ea072122c897/models/common.py#L662
+ image = image[..., ::-1]
+
+ with torch.no_grad():
+ detections = self.model(image)
+
+ detections = detections.xyxy[0]
+ detections = detections[detections[:, 4] >= self.conf_thres]
+
+ detections = detections[detections[:, 5] == 0.] # person
+
+ # adding a fake class confidence to maintain compatibility with YOLOv3
+ detections = torch.cat((detections[:, :5], detections[:, 4:5], detections[:, 5:]), dim=1)
+
+ if self.trt_model:
+ # account for the image resize fixing the xyxy locations
+ detections[:, [0, 2]] = (detections[:, [0, 2]] - dw) / ratiow
+ detections[:, [1, 3]] = (detections[:, [1, 3]] - dh) / ratioh
+
+ return detections
+
+ def predict(self, images, color_mode='BGR'):
+ raise NotImplementedError("Not currently supported.")
diff --git a/models/detectors/yolo b/models_/detectors/yolo
similarity index 100%
rename from models/detectors/yolo
rename to models_/detectors/yolo
diff --git a/models/hrnet.py b/models_/hrnet.py
similarity index 99%
rename from models/hrnet.py
rename to models_/hrnet.py
index 3a079ce..830992d 100644
--- a/models/hrnet.py
+++ b/models_/hrnet.py
@@ -1,6 +1,6 @@
import torch
from torch import nn
-from models.modules import BasicBlock, Bottleneck
+from models_.modules import BasicBlock, Bottleneck
class StageModule(nn.Module):
diff --git a/models/modules.py b/models_/modules.py
similarity index 100%
rename from models/modules.py
rename to models_/modules.py
diff --git a/models/poseresnet.py b/models_/poseresnet.py
similarity index 98%
rename from models/poseresnet.py
rename to models_/poseresnet.py
index 96b04c5..6e4dd70 100644
--- a/models/poseresnet.py
+++ b/models_/poseresnet.py
@@ -1,6 +1,6 @@
import torch
from torch import nn
-from models.modules import BasicBlock, Bottleneck
+from models_.modules import BasicBlock, Bottleneck
resnet_spec = {
diff --git a/scripts/export-tensorrt-model.py b/scripts/export-tensorrt-model.py
new file mode 100644
index 0000000..e2d68f5
--- /dev/null
+++ b/scripts/export-tensorrt-model.py
@@ -0,0 +1,53 @@
+import argparse
+import ast
+import os
+import sys
+
+import torch
+from torch2trt import torch2trt, TRTModule
+
+sys.path.insert(1, os.getcwd())
+from models_.hrnet import HRNet
+
+
+def convert_to_trt(args):
+ """
+ TensorRT conversion function for the HRNet models using torch2trt.
+ Requires the definition of the image resolution and the max batch size, supports FP16 mode (half precision).
+ """
+ pose = HRNet(args.hrnet_c, 17)
+
+ pose.load_state_dict(torch.load(args.weights))
+ pose.cuda().eval()
+
+ image_resolution = ast.literal_eval(args.image_resolution)
+ x = torch.ones(1, 3, image_resolution[0], image_resolution[1]).cuda()
+ print("Starting conversion to TensorRT with torch2trt...")
+ net_trt = torch2trt(pose, [x], max_batch_size=args.batch_size, fp16_mode=args.half)
+ torch.save(net_trt.state_dict(), args.output_path)
+ print(f"Conversion to TensorRT completed! Model saved at {args.output_path}")
+
+
+def parse_opt():
+ """Parses the arguments for the trt conversion."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--weights", "-w", help="the model weights file", type=str,
+ default='./weights/pose_hrnet_w48_384x288.pth')
+ parser.add_argument("--hrnet_c", "-c", help="HRNet channels, either 32 or 48 (default)", type=int, default=48)
+ parser.add_argument("--hrnet_j", "-j", help="HRNet number of joints, 17 (default)", type=int, default=17)
+ parser.add_argument("--image_resolution", "-r", help="image resolution, 256x192 or 384x288 (default)", type=str,
+ default="(384, 288)")
+ parser.add_argument("--batch_size", "-b", help="maximum batch size for trt", type=int, default=16)
+ parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
+ parser.add_argument("--output_path", help="output path, default ./weights/hrnet_trt.engine", type=str,
+ default="./weights/hrnet_trt.engine")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_opt()
+ convert_to_trt(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/extract-keypoints.py b/scripts/extract-keypoints.py
index 2a4a04c..0cd5361 100644
--- a/scripts/extract-keypoints.py
+++ b/scripts/extract-keypoints.py
@@ -13,8 +13,9 @@
from misc.visualization import check_video_rotation
-def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resolution, single_person, use_tiny_yolo,
- max_batch_size, csv_output_filename, csv_delimiter, json_output_filename, device):
+def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resolution, single_person, yolo_version,
+ use_tiny_yolo, max_batch_size, csv_output_filename, csv_delimiter, json_output_filename, device,
+ enable_tensorrt):
if device is not None:
device = torch.device(device)
else:
@@ -43,14 +44,28 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol
fd = open(json_output_filename, 'wt')
json_data = {}
- if use_tiny_yolo:
- yolo_model_def = "./models/detectors/yolo/config/yolov3-tiny.cfg"
- yolo_class_path = "./models/detectors/yolo/data/coco.names"
- yolo_weights_path = "./models/detectors/yolo/weights/yolov3-tiny.weights"
+ if yolo_version == 'v3':
+ if use_tiny_yolo:
+ yolo_model_def = "./models_/detectors/yolo/config/yolov3-tiny.cfg"
+ yolo_weights_path = "./models_/detectors/yolo/weights/yolov3-tiny.weights"
+ else:
+ yolo_model_def = "./models_/detectors/yolo/config/yolov3.cfg"
+ yolo_weights_path = "./models_/detectors/yolo/weights/yolov3.weights"
+ yolo_class_path = "./models_/detectors/yolo/data/coco.names"
+ elif yolo_version == 'v5':
+ # YOLOv5 comes in different sizes: n(ano), s(mall), m(edium), l(arge), x(large)
+ if use_tiny_yolo:
+ yolo_model_def = "yolov5n" # this is the nano version
+ else:
+ yolo_model_def = "yolov5m" # this is the medium version
+ if enable_tensorrt:
+ yolo_trt_filename = yolo_model_def + ".engine"
+ if os.path.exists(yolo_trt_filename):
+ yolo_model_def = yolo_trt_filename
+ yolo_class_path = ""
+ yolo_weights_path = ""
else:
- yolo_model_def = "./models/detectors/yolo/config/yolov3.cfg"
- yolo_class_path = "./models/detectors/yolo/data/coco.names"
- yolo_weights_path = "./models/detectors/yolo/weights/yolov3.weights"
+ raise ValueError('Unsopported YOLO version.')
model = SimpleHRNet(
hrnet_c,
@@ -60,18 +75,23 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol
resolution=image_resolution,
multiperson=not single_person,
max_batch_size=max_batch_size,
+ yolo_version=yolo_version,
yolo_model_def=yolo_model_def,
yolo_class_path=yolo_class_path,
yolo_weights_path=yolo_weights_path,
- device=device
+ device=device,
+ enable_tensorrt=enable_tensorrt
)
index = 0
+ t_start = time.time()
while True:
t = time.time()
ret, frame = video.read()
if not ret:
+ t_end = time.time()
+ print("\n Total Time: ", t_end - t_start)
break
if rotation_code is not None:
frame = cv2.rotate(frame, rotation_code)
@@ -123,7 +143,7 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol
type=str, default=None)
parser.add_argument("--filename", "-f", help="open the specified video",
type=str, default=None)
- parser.add_argument("--hrnet_m", "-m", help="network model - HRNet or PoseResNet", type=str, default='HRNet')
+ parser.add_argument("--hrnet_m", "-m", help="network model - 'HRNet' or 'PoseResNet'", type=str, default='HRNet')
parser.add_argument("--hrnet_c", "-c", help="hrnet parameters - number of channels (if model is HRNet), "
"resnet size (if model is PoseResNet)", type=int, default=48)
parser.add_argument("--hrnet_j", "-j", help="hrnet parameters - number of joints", type=int, default=17)
@@ -134,8 +154,13 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol
help="disable the multiperson detection (YOLOv3 or an equivalen detector is required for"
"multiperson detection)",
action="store_true")
+ parser.add_argument("--yolo_version",
+ help="Use the specified version of YOLO. Supported versions: `v3` (default), `v5`.",
+ type=str, default="v3")
parser.add_argument("--use_tiny_yolo",
- help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection). Ignored if --single_person",
+ help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection) if `yolo_version` is `v3`."
+ "Use YOLOv5n(ano) in place of YOLOv5m(edium) if `yolo_version` is `v5`."
+ "Ignored if --single_person",
action="store_true")
parser.add_argument("--max_batch_size", help="maximum batch size used for inference", type=int, default=16)
parser.add_argument("--csv_output_filename", help="filename of the csv that will be written.",
@@ -148,5 +173,11 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol
"set to `cuda:IDS` to use one or more specific GPUs "
"(e.g. `cuda:0` `cuda:1,2`); "
"set to `cpu` to run on cpu.", type=str, default=None)
+ parser.add_argument("--enable_tensorrt",
+ help="Enables tensorrt inference for HRnet. If enabled, a `.engine` file is expected as "
+ "weights (`--hrnet_weights`). This option should be used only after the HRNet engine "
+ "file has been generated using the script `scripts/export-tensorrt-model.py`.",
+ action='store_true')
+
args = parser.parse_args()
main(**args.__dict__)
diff --git a/scripts/live-demo.py b/scripts/live-demo.py
index da4bca1..6bc6c34 100644
--- a/scripts/live-demo.py
+++ b/scripts/live-demo.py
@@ -13,9 +13,10 @@
from misc.visualization import draw_points, draw_skeleton, draw_points_and_skeleton, joints_dict, check_video_rotation
from misc.utils import find_person_id_associations
+
def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_joints_set, image_resolution,
- single_person, use_tiny_yolo, disable_tracking, max_batch_size, disable_vidgear, save_video, video_format,
- video_framerate, device):
+ single_person, yolo_version, use_tiny_yolo, disable_tracking, max_batch_size, disable_vidgear, save_video,
+ video_format, video_framerate, device, enable_tensorrt):
if device is not None:
device = torch.device(device)
else:
@@ -43,14 +44,28 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
else:
video = CamGear(camera_id).start()
- if use_tiny_yolo:
- yolo_model_def="./models/detectors/yolo/config/yolov3-tiny.cfg"
- yolo_class_path="./models/detectors/yolo/data/coco.names"
- yolo_weights_path="./models/detectors/yolo/weights/yolov3-tiny.weights"
+ if yolo_version == 'v3':
+ if use_tiny_yolo:
+ yolo_model_def = "./models_/detectors/yolo/config/yolov3-tiny.cfg"
+ yolo_weights_path = "./models_/detectors/yolo/weights/yolov3-tiny.weights"
+ else:
+ yolo_model_def = "./models_/detectors/yolo/config/yolov3.cfg"
+ yolo_weights_path = "./models_/detectors/yolo/weights/yolov3.weights"
+ yolo_class_path = "./models_/detectors/yolo/data/coco.names"
+ elif yolo_version == 'v5':
+ # YOLOv5 comes in different sizes: n(ano), s(mall), m(edium), l(arge), x(large)
+ if use_tiny_yolo:
+ yolo_model_def = "yolov5n" # this is the nano version
+ else:
+ yolo_model_def = "yolov5m" # this is the medium version
+ if enable_tensorrt:
+ yolo_trt_filename = yolo_model_def + ".engine"
+ if os.path.exists(yolo_trt_filename):
+ yolo_model_def = yolo_trt_filename
+ yolo_class_path = ""
+ yolo_weights_path = ""
else:
- yolo_model_def="./models/detectors/yolo/config/yolov3.cfg"
- yolo_class_path="./models/detectors/yolo/data/coco.names"
- yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights"
+ raise ValueError('Unsopported YOLO version.')
model = SimpleHRNet(
hrnet_c,
@@ -61,10 +76,12 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
multiperson=not single_person,
return_bounding_boxes=not disable_tracking,
max_batch_size=max_batch_size,
+ yolo_version=yolo_version,
yolo_model_def=yolo_model_def,
yolo_class_path=yolo_class_path,
yolo_weights_path=yolo_weights_path,
- device=device
+ device=device,
+ enable_tensorrt=enable_tensorrt
)
if not disable_tracking:
@@ -72,13 +89,15 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
prev_pts = None
prev_person_ids = None
next_person_id = 0
-
+ t_start = time.time()
while True:
t = time.time()
if filename is not None or disable_vidgear:
ret, frame = video.read()
if not ret:
+ t_end = time.time()
+ print("\n Total Time: ", t_end - t_start)
break
if rotation_code is not None:
frame = cv2.rotate(frame, rotation_code)
@@ -118,8 +137,11 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
points_color_palette='gist_rainbow', skeleton_color_palette='jet',
points_palette_samples=10)
+ # for box in boxes:
+ # cv2.rectangle(frame,(box[0],box[1]),(box[2],box[3]),(255,255,255),2)
+
fps = 1. / (time.time() - t)
- print('\rframerate: %f fps' % fps, end='')
+ print('\rframerate: %f fps, for %d person(s) ' % (fps,len(pts)), end='')
if has_display:
cv2.imshow('frame.png', frame)
@@ -162,8 +184,13 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
help="disable the multiperson detection (YOLOv3 or an equivalen detector is required for"
"multiperson detection)",
action="store_true")
+ parser.add_argument("--yolo_version",
+ help="Use the specified version of YOLO. Supported versions: `v3` (default), `v5`.",
+ type=str, default="v3")
parser.add_argument("--use_tiny_yolo",
- help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection). Ignored if --single_person",
+ help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection) if `yolo_version` is `v3`."
+ "Use YOLOv5n(ano) in place of YOLOv5m(edium) if `yolo_version` is `v5`."
+ "Ignored if --single_person",
action="store_true")
parser.add_argument("--disable_tracking",
help="disable the skeleton tracking and temporal smoothing functionality",
@@ -174,12 +201,18 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
action="store_true") # see https://pypi.org/project/vidgear/
parser.add_argument("--save_video", help="save output frames into a video.", action="store_true")
parser.add_argument("--video_format", help="fourcc video format. Common formats: `MJPG`, `XVID`, `X264`."
- "See http://www.fourcc.org/codecs.php", type=str, default='MJPG')
+ "See http://www.fourcc.org/codecs.php", type=str, default='MJPG')
parser.add_argument("--video_framerate", help="video framerate", type=float, default=30)
parser.add_argument("--device", help="device to be used (default: cuda, if available)."
"Set to `cuda` to use all available GPUs (default); "
"set to `cuda:IDS` to use one or more specific GPUs "
"(e.g. `cuda:0` `cuda:1,2`); "
"set to `cpu` to run on cpu.", type=str, default=None)
+ parser.add_argument("--enable_tensorrt",
+ help="Enables tensorrt inference for HRnet. If enabled, a `.engine` file is expected as "
+ "weights (`--hrnet_weights`). This option should be used only after the HRNet engine "
+ "file has been generated using the script `scripts/export-tensorrt-model.py`.",
+ action='store_true')
+
args = parser.parse_args()
main(**args.__dict__)
diff --git a/testing/Test.py b/testing/Test.py
index 263e3b5..6d3ad60 100644
--- a/testing/Test.py
+++ b/testing/Test.py
@@ -10,7 +10,7 @@
from misc.checkpoint import load_checkpoint
from misc.utils import flip_tensor, flip_back
from misc.visualization import save_images
-from models.hrnet import HRNet
+from models_.hrnet import HRNet
class Test(object):
diff --git a/training/Train.py b/training/Train.py
index badbd03..2c4ea00 100644
--- a/training/Train.py
+++ b/training/Train.py
@@ -13,7 +13,7 @@
from misc.checkpoint import save_checkpoint, load_checkpoint
from misc.utils import flip_tensor, flip_back
from misc.visualization import save_images
-from models.hrnet import HRNet
+from models_.hrnet import HRNet
class Train(object):