Skip to content

Commit

Permalink
Added optika.transforms module and corresponding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie committed Aug 1, 2023
1 parent 4cde209 commit d381637
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 0 deletions.
1 change: 1 addition & 0 deletions optika/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import transforms
Empty file added optika/_tests/__init__.py
Empty file.
117 changes: 117 additions & 0 deletions optika/_tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pytest
import astropy.units as u
import numpy as np
import named_arrays as na
import optika


class AbstractTestAbstractTransform:

def test_matrix(self, transform: optika.transforms.AbstractTransform):
assert isinstance(transform.matrix, na.AbstractCartesian3dMatrixArray)
assert isinstance(transform.matrix.to(u.dimensionless_unscaled), na.AbstractCartesian3dMatrixArray)

def test_vector(self, transform: optika.transforms.AbstractTransform):
assert isinstance(transform.vector, na.AbstractCartesian3dVectorArray)
assert isinstance(transform.vector.to(u.mm), na.AbstractCartesian3dVectorArray)

def test__call__(self, transform: optika.transforms.AbstractTransform):
x = na.Cartesian3dVectorArray(x=1, y=-2, z=3) * u.m
y = transform(x)
assert isinstance(y, na.AbstractCartesian3dVectorArray)

def test_inverse(self, transform: optika.transforms.AbstractTransform):
x = na.Cartesian3dVectorArray(x=1, y=-2, z=3) * u.m
y = transform(x)
z = transform.inverse(y)
assert np.allclose(x, z)


@pytest.mark.parametrize(
argnames="transform",
argvalues=[
optika.transforms.Translation(na.Cartesian3dVectorArray() * u.mm),
optika.transforms.Translation(na.Cartesian3dVectorArray(1, -2, 3) * u.mm)
]
)
class TestTranslation(
AbstractTestAbstractTransform,
):
pass


class AbstractTestAbstractRotation(
AbstractTestAbstractTransform,
):
pass


@pytest.mark.parametrize(
argnames="transform",
argvalues=[
optika.transforms.RotationX(0 * u.deg),
optika.transforms.RotationX(45 * u.deg),
optika.transforms.RotationX(90 * u.deg),
optika.transforms.RotationX(223 * u.deg),
]
)
class TestRotationX(
AbstractTestAbstractRotation
):
pass


@pytest.mark.parametrize(
argnames="transform",
argvalues=[
optika.transforms.RotationY(0 * u.deg),
optika.transforms.RotationY(45 * u.deg),
optika.transforms.RotationY(90 * u.deg),
optika.transforms.RotationY(223 * u.deg),
]
)
class TestRotationY(
AbstractTestAbstractRotation
):
pass


@pytest.mark.parametrize(
argnames="transform",
argvalues=[
optika.transforms.RotationZ(0 * u.deg),
optika.transforms.RotationZ(45 * u.deg),
optika.transforms.RotationZ(90 * u.deg),
optika.transforms.RotationZ(223 * u.deg),
]
)
class TestRotationZ(
AbstractTestAbstractRotation
):
pass


@pytest.mark.parametrize(
argnames='transform',
argvalues=[
optika.transforms.TransformList([
optika.transforms.Translation(na.Cartesian3dVectorArray(x=2) * u.m),
optika.transforms.RotationZ(90 * u.deg),
optika.transforms.Translation(na.Cartesian3dVectorArray(x=2) * u.m),
optika.transforms.RotationY(90 * u.deg),
optika.transforms.Translation(na.Cartesian3dVectorArray(x=2) * u.m),
])
]
)
class TestTransformList(
AbstractTestAbstractTransform,
):

def test__call__(self, transform: optika.transforms.TransformList): # type: ignore[override]
super().test__call__(transform=transform)
x = na.Cartesian3dVectorArray() * u.m
b = transform(x)
c = x
for t in transform.transforms:
c = t(c)
assert b == c
154 changes: 154 additions & 0 deletions optika/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import Iterator
from typing_extensions import Self
import abc
import dataclasses
import copy
import astropy.units as u
import named_arrays as na
import optika.mixins

__all__ = [
'AbstractTransform',
'Translation',
'AbstractRotation',
'RotationX',
'RotationY',
'RotationZ',
'TransformList',
'Transformable',
]


class AbstractTransform(
abc.ABC,
):

@property
def matrix(self) -> na.AbstractCartesian3dMatrixArray:
return na.Cartesian3dIdentityMatrixArray()

@property
def vector(self) -> na.AbstractCartesian3dVectorArray:
return na.Cartesian3dVectorArray() * u.mm

def __call__(
self,
value: na.AbstractCartesian3dVectorArray,
rotate: bool = True,
translate: bool = True,
) -> na.AbstractCartesian3dVectorArray:
if rotate:
value = self.matrix @ value
if translate:
value = value + self.vector
return value

@abc.abstractmethod
def __invert__(self: Self) -> Self:
pass

@property
def inverse(self: Self) -> Self:
return self.__invert__()


@dataclasses.dataclass
class Translation(AbstractTransform):
displacement: na.Cartesian3dVectorArray = dataclasses.MISSING

@property
def vector(self) -> na.Cartesian3dVectorArray:
return self.displacement

def __invert__(self: Self) -> Self:
return type(self)(displacement=-self.displacement)


@dataclasses.dataclass
class AbstractRotation(AbstractTransform):
angle: na.ScalarLike

def __invert__(self: Self) -> Self:
return type(self)(angle=-self.angle)


@dataclasses.dataclass
class RotationX(AbstractRotation):

@property
def matrix(self) -> na.Cartesian3dXRotationMatrixArray:
return na.Cartesian3dXRotationMatrixArray(self.angle)


@dataclasses.dataclass
class RotationY(AbstractRotation):

@property
def matrix(self) -> na.Cartesian3dYRotationMatrixArray:
return na.Cartesian3dYRotationMatrixArray(self.angle)


@dataclasses.dataclass
class RotationZ(AbstractRotation):

@property
def matrix(self) -> na.Cartesian3dZRotationMatrixArray:
return na.Cartesian3dZRotationMatrixArray(self.angle)


@dataclasses.dataclass
class TransformList(
AbstractTransform,
optika.mixins.DataclassList,
):

intrinsic: bool = True

@property
def extrinsic(self) -> bool:
return not self.intrinsic

@property
def transforms(self) -> Iterator[AbstractTransform]:
if self.intrinsic:
return reversed(list(self))
else:
return iter(self)

@property
def matrix(self) -> na.Cartesian3dMatrixArray:
rotation = na.Cartesian3dIdentityMatrixArray()

for transform in reversed(list(self.transforms)):
if transform is not None:
rotation = rotation @ transform.matrix

return rotation

@property
def vector(self) -> na.Cartesian3dVectorArray:
rotation = na.Cartesian3dIdentityMatrixArray()
translation = 0

for transform in reversed(list(self.transforms)):
if transform is not None:
rotation = rotation @ transform.matrix
translation = rotation @ transform.vector + translation

return translation

def __invert__(self: Self) -> Self:
other = copy.copy(self)
other.data = []
for transform in self:
if transform is not None:
transform = transform.__invert__()
other.append(transform)
other.reverse()
return other


@dataclasses.dataclass
class Transformable:
transform: TransformList = dataclasses.field(default_factory=TransformList)

0 comments on commit d381637

Please sign in to comment.