-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathtest_transpose.py
119 lines (93 loc) · 4.48 KB
/
test_transpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestTranspose(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
def create_model(self, dim0, dim1, op_type):
class swapaxes(torch.nn.Module):
def __init__(self, dim0, dim1):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.swapaxes(x, self.dim0, self.dim1)
class aten_transpose(torch.nn.Module):
def __init__(self, dim0, dim1, op_type):
super(aten_transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
op_types = {"transpose": self.forward_transpose, "swapaxes": self.forward_swapaxes}
self.swapaxes = swapaxes(dim0, dim1)
self.forward = op_types.get(op_type)
def forward_transpose(self, x):
return torch.transpose(x, self.dim0, self.dim1)
def forward_swapaxes(self, x: torch.Tensor) -> torch.Tensor:
# To reproduce aten::swapaxes in graph, swapaxes need to be in separate graph and tracing need to be used.
return self.swapaxes(x)
ref_net = None
return aten_transpose(dim0, dim1, op_type), ref_net, f"aten::{op_type}"
@pytest.mark.parametrize("dim0", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.parametrize("op_type", ["transpose", "swapaxes"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_transpose(self, dim0, dim1, op_type, ie_device, precision, ir_version):
self._test(*self.create_model(dim0, dim1, op_type), ie_device, precision, ir_version, trace_model=True)
class TestMoveDim(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
def create_model(self, dim0, dim1):
class aten_move_dim(torch.nn.Module):
def __init__(self, dim0, dim1):
super(aten_move_dim, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return torch.movedim(x, self.dim0, self.dim1)
ref_net = None
return aten_move_dim(dim0, dim1), ref_net, f"aten::movedim"
@pytest.mark.parametrize(("dim0", "dim1"), [[0, 1], [-1, 0], [2, -2], [3, 1], [3, 3], [[1, 2], [3, 0]], [[-4, 1], [1, -1]], [[1, 3, 2], [0, 1, 2 ]]])
@pytest.mark.nightly
@pytest.mark.precommit
def test_move_dim(self, dim0, dim1, ie_device, precision, ir_version):
self._test(*self.create_model(dim0, dim1), ie_device, precision, ir_version, trace_model=True)
class TestTSmall(PytorchLayerTest):
def _prepare_input(self, num_dims=2, input_dtype="float32"):
shape = (2, 3)
if num_dims == 0:
return (np.array(num_dims).astype(input_dtype),)
return (np.random.randn(*shape[:num_dims]).astype(input_dtype),)
def create_model(self, mode):
class aten_transpose(torch.nn.Module):
def __init__(self, mode):
super(aten_transpose, self).__init__()
if mode == "inplace":
self.forward = self.forward_inplace
elif mode == "numpy":
self.forward = self.forward_numpy_t
def forward(self, x):
return x.t(), x
def forward_inplace(self, x):
return x.t_(), x
def forward_numpy_t(self, x):
return x.T, x
ref_net = None
return aten_transpose(mode), ref_net, "aten::t_" if mode == "inplace" else ("aten::numpy_T" if mode == "numpy" else "aten::t")
@pytest.mark.parametrize("num_dims", [0, 1, 2])
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
@pytest.mark.parametrize("mode", [None, "inplace", "numpy"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_t_small(self, num_dims, input_dtype, mode, ie_device, precision, ir_version):
self._test(
*self.create_model(mode),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype},
use_convert_model=True,
)