Skip to content

Commit

Permalink
Text fun
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 27, 2024
1 parent a55bce2 commit d05224e
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions tests/tests_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Tests for the direct.utils module."""

import pathlib
Expand All @@ -9,7 +9,7 @@
import pytest
import torch

from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds
from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds, reshape_array_to_shape
from direct.utils.asserts import assert_complex
from direct.utils.bbox import crop_to_largest
from direct.utils.dataset import get_filenames_for_datasets_from_config
Expand Down Expand Up @@ -126,3 +126,20 @@ def test_normalize_image(shape, eps):
img = np.random.randn(*shape)
normalized_img = normalize_image(img, eps)
assert normalized_img.min() >= 0.0 and normalized_img.max() <= 1.0


@pytest.mark.parametrize(
"array, requested_shape, expected_shape",
[
(np.random.rand(4, 5), (4, 5, 1), (4, 5, 1)),
(np.random.rand(4, 5), (1, 4, 5, 1), (1, 4, 5, 1)),
(np.random.rand(2, 4, 5), (2, 4, 5, 1), (2, 4, 5, 1)),
(np.random.rand(3, 3), (1, 3, 1, 3, 1), (1, 3, 1, 3, 1)),
(np.random.rand(2, 3), (2, 1, 3), (2, 1, 3)),
(np.random.rand(4), (1, 1, 4, 1), (1, 1, 4, 1)),
(np.random.rand(6), (1, 6, 1), (1, 6, 1)),
]
)
def test_reshape_array_to_shape(array, requested_shape, expected_shape):
result = reshape_array_to_shape(array, requested_shape)
assert result.shape == expected_shape

0 comments on commit d05224e

Please sign in to comment.