Skip to content

Commit

Permalink
refactor: remove torch dependency, only install toch for image metrics (
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlao committed Apr 14, 2024
1 parent 070a42a commit 593242b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 125 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,21 @@ cd camtools
# Installation mode, if you want to use camtools only.
pip install .

# Dev mode, if you want to modify camtools on the fly.
# Editable mode, if you want to modify camtools on the fly.
pip install -e .

# Dev mode and dev dependencies, if you want to modify camtools and run tests.
# Editable mode and dev dependencies.
pip install -e .[dev]

# Help VSCode resolve imports when installed with editable mode.
# https://stackoverflow.com/a/76897706/1255535
pip install -e .[dev] --config-settings editable_mode=strict

# Enable torch-related features (e.g. computing image metrics)
pip install camtools[torch]

# Enable torch-related features in editable mode
pip install -e .[torch]
```

## Camera coordinate system
Expand Down
80 changes: 22 additions & 58 deletions camtools/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import cv2
import numpy as np
import torch

from . import sanity
from . import convert
Expand Down Expand Up @@ -29,26 +28,15 @@ def pad_0001(array):
f"Expected array of shape (3, 4) or (N, 3, 4), but got {array.shape}."
)

if torch.is_tensor(array):
if array.ndim == 2:
bottom = torch.tensor([0, 0, 0, 1], dtype=array.dtype, device=array.device)
return torch.cat([array, bottom[None, :]], dim=0)
elif array.ndim == 3:
bottom_single = torch.tensor(
[0, 0, 0, 1], dtype=array.dtype, device=array.device
)
bottom = bottom_single[None, None, :].expand(array.shape[0], 1, 4)
return torch.cat([array, bottom], dim=-2)
if array.ndim == 2:
bottom = np.array([0, 0, 0, 1], dtype=array.dtype)
return np.concatenate([array, bottom[None, :]], axis=0)
elif array.ndim == 3:
bottom_single = np.array([0, 0, 0, 1], dtype=array.dtype)
bottom = np.broadcast_to(bottom_single, (array.shape[0], 1, 4))
return np.concatenate([array, bottom], axis=-2)
else:
if array.ndim == 2:
bottom = np.array([0, 0, 0, 1], dtype=array.dtype)
return np.concatenate([array, bottom[None, :]], axis=0)
elif array.ndim == 3:
bottom_single = np.array([0, 0, 0, 1], dtype=array.dtype)
bottom = np.broadcast_to(bottom_single, (array.shape[0], 1, 4))
return np.concatenate([array, bottom], axis=-2)
else:
raise ValueError("Should not reach here.")
raise ValueError("Should not reach here.")


def rm_pad_0001(array, check_vals=False):
Expand Down Expand Up @@ -78,42 +66,21 @@ def rm_pad_0001(array, check_vals=False):

# Check vals.
if check_vals:
if torch.is_tensor(array):
if array.ndim == 2:
bottom = array[3, :]
if not torch.allclose(
bottom, torch.tensor([0, 0, 0, 1], dtype=array.dtype)
):
raise ValueError(
f"Expected bottom row to be [0, 0, 0, 1], but got {bottom}."
)
elif array.ndim == 3:
bottom = array[:, 3:4, :]
expected_bottom = torch.tensor([0, 0, 0, 1], dtype=array.dtype).expand(
array.shape[0], 1, 4
if array.ndim == 2:
bottom = array[3, :]
if not np.allclose(bottom, [0, 0, 0, 1]):
raise ValueError(
f"Expected bottom row to be [0, 0, 0, 1], but got {bottom}."
)
elif array.ndim == 3:
bottom = array[:, 3:4, :]
expected_bottom = np.broadcast_to([0, 0, 0, 1], (array.shape[0], 1, 4))
if not np.allclose(bottom, expected_bottom):
raise ValueError(
f"Expected bottom row to be {expected_bottom}, but got {bottom}."
)
if not torch.allclose(bottom, expected_bottom):
raise ValueError(
f"Expected bottom row to be {expected_bottom}, but got {bottom}."
)
else:
raise ValueError("Should not reach here.")
else:
if array.ndim == 2:
bottom = array[3, :]
if not np.allclose(bottom, [0, 0, 0, 1]):
raise ValueError(
f"Expected bottom row to be [0, 0, 0, 1], but got {bottom}."
)
elif array.ndim == 3:
bottom = array[:, 3:4, :]
expected_bottom = np.broadcast_to([0, 0, 0, 1], (array.shape[0], 1, 4))
if not np.allclose(bottom, expected_bottom):
raise ValueError(
f"Expected bottom row to be {expected_bottom}, but got {bottom}."
)
else:
raise ValueError("Should not reach here.")
raise ValueError("Should not reach here.")

return array[..., :3, :]

Expand Down Expand Up @@ -363,10 +330,7 @@ def roll_pitch_yaw_to_R(roll, pitch, yaw):

def R_t_to_T(R, t):
sanity.assert_same_device(R, t)
if torch.is_tensor(R):
T = torch.eye(4, device=R.device, dtype=R.dtype)
else:
T = np.eye(4)
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
return T
Expand Down
5 changes: 3 additions & 2 deletions camtools/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity
import torch
import lpips
from pathlib import Path
from typing import Tuple

Expand Down Expand Up @@ -96,6 +94,9 @@ def image_lpips(
Returns:
LPIPS value in float.
"""
import torch
import lpips

if im_mask is None:
h, w = im_pd.shape[:2]
im_mask = np.ones((h, w), dtype=np.float32)
Expand Down
50 changes: 2 additions & 48 deletions camtools/sanity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import torch


def assert_numpy(x, name=None):
Expand All @@ -8,12 +7,6 @@ def assert_numpy(x, name=None):
raise ValueError(f"Expected{maybe_name} to be numpy array, but got {type(x)}.")


def assert_torch(x, name=None):
if not torch.is_tensor(x):
maybe_name = f" {name}" if name is not None else ""
raise ValueError(f"Expected{maybe_name} to be torch tensor, but got {type(x)}.")


def assert_K(K):
if K.shape != (3, 3):
raise ValueError(f"K must has shape (3, 3), but got {K} of shape {K.shape}.")
Expand All @@ -22,12 +15,7 @@ def assert_K(K):
def assert_T(T):
if T.shape != (4, 4):
raise ValueError(f"T must has shape (4, 4), but got {T} of shape {T.shape}.")
if torch.is_tensor(T):
is_valid = torch.allclose(
T[3, :], torch.tensor([0, 0, 0, 1], dtype=T.dtype, device=T.device)
)
else:
is_valid = np.allclose(T[3, :], np.array([0, 0, 0, 1]))
is_valid = np.allclose(T[3, :], np.array([0, 0, 0, 1]))
if not is_valid:
raise ValueError(f"T must has [0, 0, 0, 1] the bottom row, but got {T}.")

Expand All @@ -37,12 +25,7 @@ def assert_pose(pose):
raise ValueError(
f"pose must has shape (4, 4), but got {pose} of shape {pose.shape}."
)
if torch.is_tensor(pose):
is_valid = torch.allclose(
pose[3, :], torch.tensor([0, 0, 0, 1], dtype=pose.dtype, device=pose.device)
)
else:
is_valid = np.allclose(pose[3, :], np.array([0, 0, 0, 1]))
is_valid = np.allclose(pose[3, :], np.array([0, 0, 0, 1]))
if not is_valid:
raise ValueError(f"pose must has [0, 0, 0, 1] the bottom row, but got {pose}.")

Expand Down Expand Up @@ -101,32 +84,3 @@ def assert_shape_3x3(x, name=None):

def assert_shape_3(x, name=None):
assert_shape(x, (3,), name=name)


def assert_same_device(*tensors):
"""
Args:
tensors: list of tensors
"""
if not isinstance(tensors, tuple):
raise ValueError(f"Unknown input type: {type(tensors)}.")
if len(tensors) == 0:
return
if len(tensors) == 1:
if torch.is_tensor(tensors[0]) or isinstance(tensors[0], np.ndarray):
return
else:
raise ValueError(f"Unknown input type: {type(tensors)}.")

all_are_torch = all(torch.is_tensor(t) for t in tensors)
all_are_numpy = all(isinstance(t, np.ndarray) for t in tensors)

if not all_are_torch and not all_are_numpy:
raise ValueError(f"All tensors must be torch tensors or numpy arrays.")

if all_are_torch:
devices = [t.device for t in tensors]
if not all(devices[0] == d for d in devices):
raise ValueError(
f"All tensors must be on the same device, bui got {devices}."
)
19 changes: 6 additions & 13 deletions camtools/solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import torch

from camtools import sanity

Expand Down Expand Up @@ -115,15 +114,9 @@ def closest_points_of_line_pairs(src_os, src_ds, dst_os, dst_ds):
sanity.assert_shape_nx3(dst_os, "dst_os")
sanity.assert_shape_nx3(dst_ds, "dst_ds")

is_torch = torch.is_tensor(src_ds) and torch.is_tensor(dst_ds)
cross = torch.cross if is_torch else np.cross
norm = torch.linalg.norm if is_torch else np.linalg.norm
solve = torch.linalg.solve if is_torch else np.linalg.solve
stack = torch.stack if is_torch else np.stack

# Normalize direction vectors.
src_ds = src_ds / norm(src_ds, axis=1, keepdims=True)
dst_ds = dst_ds / norm(dst_ds, axis=1, keepdims=True)
src_ds = src_ds / np.linalg.norm(src_ds, axis=1, keepdims=True)
dst_ds = dst_ds / np.linalg.norm(dst_ds, axis=1, keepdims=True)

# Find the closest points of the two lines.
# - src_p = src_o + src_t * src_d is the closest point in src line.
Expand All @@ -139,12 +132,12 @@ def closest_points_of_line_pairs(src_os, src_ds, dst_os, dst_ds):
# │src_d -dst_d mid_d│ │ dst_t │ = │ dst_o │ - │ src_o │
# │ │ │ │ │ │ mid_t │ │ │ │ │ │ │
# └ ┘ └ ┘ └ ┘ └ ┘
mid_ds = cross(src_ds, dst_ds)
mid_ds = mid_ds / norm(mid_ds, axis=1, keepdims=True)
mid_ds = np.cross(src_ds, dst_ds)
mid_ds = mid_ds / np.linalg.norm(mid_ds, axis=1, keepdims=True)

lhs = stack((src_ds, -dst_ds, mid_ds), axis=-1)
lhs = np.stack((src_ds, -dst_ds, mid_ds), axis=-1)
rhs = dst_os - src_os
results = solve(lhs, rhs)
results = np.linalg.solve(lhs, rhs)
src_ts, dst_ts, mid_ts = results[:, 0], results[:, 1], results[:, 2]
src_ps = src_os + src_ts.reshape((-1, 1)) * src_ds
dst_ps = dst_os + dst_ts.reshape((-1, 1)) * dst_ds
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"numpy>=1.15.0",
"open3d>=0.16.0",
"opencv-python>=4.5.1.48",
"matplotlib>=3.3.4",
"scikit-image>=0.16.2",
"torch>=1.8.0",
"lpips>=0.1.4",
"tqdm>=4.60.0",
]
description = "CamTools: Camera Tools for Computer Vision."
Expand All @@ -37,6 +36,10 @@ dev = [
"pytest>=6.2.2",
"ipdb",
]
torch = [
"torch>=1.8.0",
"lpips>=0.1.4",
]

[tool.setuptools]
packages = ["camtools", "camtools.tools"]

0 comments on commit 593242b

Please sign in to comment.