diff --git a/dmsuite/poly_diff.py b/dmsuite/poly_diff.py index 7c9ec1e..a30fc41 100644 --- a/dmsuite/poly_diff.py +++ b/dmsuite/poly_diff.py @@ -23,7 +23,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from functools import cached_property +from functools import cached_property, lru_cache import numpy as np from numpy.typing import NDArray @@ -335,3 +335,27 @@ def _dmat(self) -> GeneralPoly: def at_order(self, order: int) -> NDArray: return self.scale**order * self._dmat.at_order(order) + + +@dataclass(frozen=True) +class DiffMatOnDomain(DiffMatrices): + """Differentiation matrices stretched and shifted to a different domain. + + The stretching and shifting is done linearly between xmin and xmax. + """ + + xmin: float + xmax: float + dmat: DiffMatrices + + @cached_property + def stretching(self) -> NDArray: + return (self.dmat.nodes[-1] - self.dmat.nodes[0]) / (self.xmax - self.xmin) + + @cached_property + def nodes(self) -> NDArray: + return (self.dmat.nodes - self.dmat.nodes[0]) / self.stretching + self.xmin + + @lru_cache + def at_order(self, order: int) -> NDArray: + return self.stretching**order * self.dmat.at_order(order) diff --git a/tests/test_chebdif.py b/tests/test_chebdif.py index 5b87b2f..f8d0795 100644 --- a/tests/test_chebdif.py +++ b/tests/test_chebdif.py @@ -1,6 +1,6 @@ import numpy as np -from dmsuite.poly_diff import Chebyshev +from dmsuite.poly_diff import Chebyshev, DiffMatOnDomain def test_chebdif4() -> None: @@ -12,3 +12,19 @@ def test_chebdif4() -> None: computed[order - 1] = cheb.at_order(order) assert np.allclose(cheb.nodes, expected[0]) assert np.allclose(computed, expected[1]) + + +def test_cheb_scaled() -> None: + dmat = DiffMatOnDomain(xmin=1.0, xmax=5.0, dmat=Chebyshev(degree=64)) + nodes = dmat.nodes + assert np.allclose(nodes[0], dmat.xmin) + assert np.allclose(nodes[-1], dmat.xmax) + func = nodes**2 + dfunc = 2 * nodes + d2func = 2.0 + d1_cheb = dmat.at_order(1) @ func + d2_cheb = dmat.at_order(2) @ func + d3_cheb = dmat.at_order(3) @ func + assert np.allclose(d1_cheb, dfunc) + assert np.allclose(d2_cheb, d2func) + assert np.allclose(d3_cheb, 0.0, atol=1e-6)