Skip to content
Snippets Groups Projects
Commit 87ff3e9f authored by wusize's avatar wusize
Browse files

add comments

parent 05ad4b21
No related branches found
No related tags found
No related merge requests found
......@@ -130,11 +130,12 @@ class ProjectLayer(nn.Module):
@POSENETS.register_module()
class DetectAndRegress(BasePose):
"""VoxelPose Please refer to the `paper <https://arxiv.org/abs/2004.06239>`
for details.
"""DetectAndRegress approach for multiview human pose detection.
Args:
backbone (ConfigDict): Dictionary to construct the 2D pose detector
human_detector (ConfigDict): dictionary to construct human detector
pose_regressor (ConfigDict): dictionary to construct pose regressor
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
pretrained (str): Path to the pretrained 2D model. Default: None.
......@@ -437,6 +438,19 @@ class DetectAndRegress(BasePose):
@POSENETS.register_module()
class VoxelSinglePose(BasePose):
"""VoxelPose Please refer to the `paper <https://arxiv.org/abs/2004.06239>`
for details.
Args:
image_size (list): input size of the 2D model.
heatmap_size (list): output size of the 2D model.
sub_space_size (list): Size of the cuboid human proposal.
sub_cube_size (list): Size of the input volume to the pose net.
pose_net (ConfigDict): Dictionary to construct the pose net.
pose_head (ConfigDict): Dictionary to construct the pose head.
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
"""
def __init__(
self,
......@@ -483,6 +497,8 @@ class VoxelSinglePose(BasePose):
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
......@@ -492,9 +508,6 @@ class VoxelSinglePose(BasePose):
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
Returns:
dict: losses.
"""
if return_loss:
return self.forward_train(img, img_metas, feature_maps,
......@@ -510,8 +523,7 @@ class VoxelSinglePose(BasePose):
human_candidates=None,
return_preds=False,
**kwargs):
"""Defines the computation performed at training."""
"""
"""Defines the computation performed at training.
Note:
batch_size: N
num_keypoints: K
......@@ -525,14 +537,15 @@ class VoxelSinglePose(BasePose):
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
human_candidates (torch.Tensor[NxPx5]):
Human candidates.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
return_preds (bool): Whether to return prediction results
Returns:
dict: losses.
......@@ -611,8 +624,7 @@ class VoxelSinglePose(BasePose):
feature_maps=None,
human_candidates=None,
**kwargs):
"""Defines the computation performed at training."""
"""
"""Defines the computation performed at training.
Note:
batch_size: N
num_keypoints: K
......@@ -626,14 +638,14 @@ class VoxelSinglePose(BasePose):
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
human_candidates (torch.Tensor[NxPx5]):
Human candidates.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
Returns:
dict: predicted poses, human centers and sample_id
......@@ -677,6 +689,21 @@ class VoxelSinglePose(BasePose):
@POSENETS.register_module()
class VoxelCenterDetector(BasePose):
"""Detect human center by 3D CNN on voxels.
Please refer to the
`paper <https://arxiv.org/abs/2004.06239>` for details.
Args:
image_size (list): input size of the 2D model.
heatmap_size (list): output size of the 2D model.
space_size (list): Size of the 3D space.
cube_size (list): Size of the input volume to the 3D CNN.
space_center (list): Coordinate of the center of the 3D space.
center_net (ConfigDict): Dictionary to construct the center net.
center_head (ConfigDict): Dictionary to construct the center head.
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
"""
def __init__(
self,
......@@ -729,6 +756,30 @@ class VoxelCenterDetector(BasePose):
return_loss=True,
feature_maps=None,
targets_3d=None):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
Returns:
dict: if 'return_loss' is true, then return losses.
Otherwise, return predicted poses
"""
if return_loss:
return self.forward_train(img, img_metas, feature_maps, targets_3d)
else:
......@@ -740,6 +791,29 @@ class VoxelCenterDetector(BasePose):
feature_maps=None,
targets_3d=None,
return_preds=False):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
return_preds (bool): Whether to return prediction results
Returns:
dict: if 'return_pred' is true, then return losses
and human centers. Otherwise, return losses only
"""
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
self.space_size,
[self.space_center],
......@@ -771,6 +845,25 @@ class VoxelCenterDetector(BasePose):
return losses
def forward_test(self, img, img_metas, feature_maps=None):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
Returns:
human centers
"""
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
self.space_size,
[self.space_center],
......
......@@ -33,12 +33,9 @@ class CuboidCenterHead(nn.Module):
max_pool_kernel=3):
super(CuboidCenterHead, self).__init__()
# use register_buffer
self.register_buffer(
'grid_size', torch.tensor(space_size), persistent=False)
self.register_buffer(
'cube_size', torch.tensor(cube_size), persistent=False)
self.register_buffer(
'grid_center', torch.tensor(space_center), persistent=False)
self.register_buffer('grid_size', torch.tensor(space_size))
self.register_buffer('cube_size', torch.tensor(cube_size))
self.register_buffer('grid_center', torch.tensor(space_center))
self.num_candidates = max_num
self.max_pool_kernel = max_pool_kernel
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment