import pytest import numpy as np from numpy.testing import assert_array_equal def test_matrix_transpose_raises_error_for_1d(): msg = "matrix transpose with ndim < 2 is undefined" arr = np.arange(48) with pytest.raises(ValueError, match=msg): arr.mT def test_matrix_transpose_equals_transpose_2d(): arr = np.arange(48).reshape((6, 8)) assert_array_equal(arr.T, arr.mT) ARRAY_SHAPES_TO_TEST = ( (5, 2), (5, 2, 3), (5, 2, 3, 4), ) @pytest.mark.parametrize("shape", ARRAY_SHAPES_TO_TEST) def test_matrix_transpose_equals_swapaxes(shape): num_of_axes = len(shape) vec = np.arange(shape[-1]) arr = np.broadcast_to(vec, shape) tgt = np.swapaxes(arr, num_of_axes - 2, num_of_axes - 1) mT = arr.mT assert_array_equal(tgt, mT)