Skip to content

Commit

Permalink
[feat] util for torch to numpy pose conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
pdell-kitware committed Sep 7, 2021
1 parent 51f8ed8 commit d71af5b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
47 changes: 46 additions & 1 deletion slam/common/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional, Type

import torch
from torchvision.transforms.functional import to_tensor
Expand All @@ -7,6 +7,8 @@

import numpy as np

from slam.common.utils import assert_debug, check_sizes


def custom_to_tensor(data: Union[torch.Tensor, np.ndarray, dict],
device: Union[str, torch.device] = "cuda",
Expand Down Expand Up @@ -79,6 +81,49 @@ def send_to_device(data: Union[dict, torch.Tensor, np.ndarray],
return data


def convert_pose_transform(pose: Union[torch.Tensor, np.ndarray],
dest: type = torch.Tensor,
device: Optional[torch.device] = None,
dtype: Optional[Union[torch.dtype, np.number, Type]] = None):
"""Converts a [4, 4] pose tensor to the desired type
Returns a tensor (either a numpy.ndarray or torch.Tensor depending on dest type)
>>> check_sizes(convert_pose_transform(torch.eye(4).reshape(4, 4), np.ndarray), [4, 4])
>>> check_sizes(convert_pose_transform(torch.eye(4).reshape(1, 4, 4), np.ndarray, dtype=np.float32), [4, 4])
>>> check_sizes(convert_pose_transform(torch.eye(4).reshape(1, 4, 4), torch.Tensor, dtype=torch.float32), [1, 4, 4])
>>> check_sizes(convert_pose_transform(torch.eye(4).reshape(4, 4), torch.Tensor, dtype=torch.float32), [4, 4])
>>> check_sizes(convert_pose_transform(np.eye(4).reshape(4, 4), torch.Tensor, dtype=torch.float32), [4, 4])
>>> check_sizes(convert_pose_transform(np.eye(4).reshape(4, 4), np.ndarray, dtype=np.float32), [4, 4])
>>> check_sizes(convert_pose_transform(np.eye(4).reshape(4, 4), np.ndarray), [4, 4])
"""
# Check size
if isinstance(pose, torch.Tensor):
assert_debug(list(pose.shape) == [1, 4, 4] or list(pose.shape) == [4, 4],
f"Wrong tensor shape, expected [(1), 4, 4], got {pose.shape}")
if dest == torch.Tensor:
assert_debug(isinstance(dtype, torch.dtype), f"The dtype {dtype} is not a torch.dtype")
return pose.to(device=device if device is not None else pose.device,
dtype=dtype if dtype is not None else pose.dtype)
else:
assert_debug(dest == np.ndarray, "Only numpy.ndarray and torch.Tensor are supported as destination tensor")
np_array = pose.detach().cpu().numpy()
if dtype is not None:
assert_debug(issubclass(dtype, np.number), f"Expected a numpy.dtype, got {dtype}")
np_array = np_array.astype(dtype)
return np_array.reshape(4, 4)
else:
assert_debug(isinstance(pose, np.ndarray), f"Only numpy.ndarray and torch.Tensor are supported. Got {pose}.")
check_sizes(pose, [4, 4])
if dest == torch.Tensor:
tensor = torch.from_numpy(pose).to(dtype=dtype, device=device)
return tensor
if dtype is not None:
assert_debug(issubclass(dtype, np.number), f"Expected numpy.dtype, got {dtype}")
new_pose = pose.astype(dtype)
return new_pose
return pose


# ----------------------------------------------------------------------------------------------------------------------
# Collate Function
def collate_fun(batch) -> object:
Expand Down
6 changes: 5 additions & 1 deletion slam/odometry/icp_odometry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Project Imports
from typing import Optional

import numpy as np

from slam.common.geometry import projection_map_to_points, mask_not_null
from slam.common.pose import Pose
from slam.common.projection import Projector
from slam.common.torch_utils import convert_pose_transform
from slam.common.utils import check_sizes, remove_nan, modify_nan_pmap
from slam.common.modules import _with_viz3d
from slam.dataset import DatasetLoader
Expand Down Expand Up @@ -186,7 +189,8 @@ def do_process_next_frame(self, data_dict: dict):
if self._has_window:
# Add Ground truth poses (mainly for visualization purposes)
if DatasetLoader.absolute_gt_key() in data_dict:
pose_gt = data_dict[DatasetLoader.absolute_gt_key()].reshape(1, 4, 4).cpu().numpy()
pose_gt = convert_pose_transform(data_dict[DatasetLoader.absolute_gt_key()],
np.ndarray, dtype=np.float64)
self.gt_poses = pose_gt if self.gt_poses is None else np.concatenate(
[self.gt_poses, pose_gt], axis=0)

Expand Down

0 comments on commit d71af5b

Please sign in to comment.