|
1 |
| -# Copyright (C) 2018-2025 Intel Corporation |
2 |
| -# SPDX-License-Identifier: Apache-2.0 |
3 |
| - |
4 | 1 | import pytest
|
5 | 2 | import torch
|
6 | 3 | import numpy as np
|
7 | 4 | from pytorch_layer_test_class import PytorchLayerTest, flattenize_inputs
|
8 |
| -from copy import deepcopy |
9 | 5 |
|
10 | 6 | class TestRandperm(PytorchLayerTest):
|
11 | 7 | def _prepare_input(self):
|
12 |
| - return () |
| 8 | + return (np.array([self.n], dtype=np.int64),) |
13 | 9 |
|
14 |
| - def create_model(self, n): |
15 |
| - class AtenRandperm(torch.nn.Module): |
16 |
| - def __init__(self, n): |
| 10 | + def create_model(self, n, num_inputs, dtype_value=None): |
| 11 | + class aten_randperm(torch.nn.Module): |
| 12 | + def __init__(self, n, num_inputs, dtype_value): |
17 | 13 | super().__init__()
|
18 |
| - self.n = n |
19 |
| - |
20 |
| - def forward(self): |
21 |
| - return torch.randperm(self.n, dtype=torch.int64) |
22 |
| - |
23 |
| - return AtenRandperm(n), None, "aten::randperm" |
24 |
| - |
25 |
| - def is_valid_permutation(self, output, n): |
26 |
| - if hasattr(output, 'detach'): |
27 |
| - arr = output.detach().cpu().numpy().astype(np.int64) |
28 |
| - else: |
29 |
| - arr = np.array(output, dtype=np.int64) |
30 |
| - sorted_arr = np.sort(arr.flatten()) |
31 |
| - expected = np.arange(n, dtype=np.int64) |
32 |
| - return np.array_equal(sorted_arr, expected) |
33 |
| - |
34 |
| - @pytest.mark.parametrize("n", [1, 5, 10]) |
| 14 | + self.n = torch.tensor(n, dtype=torch.int64) |
| 15 | + self.num_inputs = num_inputs |
| 16 | + self.dtype = torch.int64 if dtype_value == 4 else None |
| 17 | + |
| 18 | + def forward(self, x): |
| 19 | + if self.num_inputs == 1: |
| 20 | + return torch.randperm(self.n) |
| 21 | + elif self.num_inputs == 2: |
| 22 | + return torch.randperm(self.n, dtype=self.dtype) |
| 23 | + elif self.num_inputs == 5: |
| 24 | + return torch.randperm(self.n, dtype=self.dtype, layout=torch.strided, |
| 25 | + device=x.device, pin_memory=False) |
| 26 | + raise ValueError("Invalid num_inputs") |
| 27 | + |
| 28 | + return aten_randperm(n, num_inputs, dtype_value), None, "aten::randperm" |
| 29 | + |
| 30 | + @pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [ |
| 31 | + (0, 1, None), |
| 32 | + (1, 1, None), |
| 33 | + (5, 1, None), |
| 34 | + (5, 2, 4), |
| 35 | + (5, 5, 4), |
| 36 | + ]) |
35 | 37 | @pytest.mark.nightly
|
36 | 38 | @pytest.mark.precommit
|
37 |
| - def test_randperm_custom(self, n, ie_device, precision, ir_version): |
38 |
| - model, ref_net, op = self.create_model(n) |
| 39 | + def test_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version): |
| 40 | + self.n = n |
| 41 | + model, ref_net, op = self.create_model(n, num_inputs, dtype_value) |
39 | 42 | inputs = self._prepare_input()
|
40 | 43 | torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs]
|
41 | 44 | ov_inputs = flattenize_inputs(inputs)
|
42 |
| - trace_model = True |
43 |
| - dynamic_shapes = True |
44 |
| - freeze_model = True |
45 |
| - |
46 |
| - with torch.no_grad(): |
47 |
| - smodel, converted_model = self.convert_directly_via_frontend( |
48 |
| - model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model |
49 |
| - ) |
50 |
| - |
51 |
| - from openvino import Core |
52 |
| - core = Core() |
53 |
| - compiled_model = core.compile_model(converted_model, ie_device). |
54 |
| - ov_output_dict = compiled_model(()) |
55 |
| - ov_output_tensor = list(ov_output_dict.values())[0] |
56 |
| - |
57 |
| - assert ov_output_tensor.shape[0] == n, f"Output shape {ov_output_tensor.shape} does not match expected ({n},)" |
58 |
| - assert self.is_valid_permutation(ov_output_tensor, n), ( |
59 |
| - f"Output {ov_output_tensor} is not a valid permutation of [0, 1, ..., {n-1}]" |
| 45 | + smodel, converted_model = self.convert_directly_via_frontend( |
| 46 | + model, torch_inputs, trace_model=True, dynamic_shapes=False, ov_inputs=ov_inputs, freeze_model=True |
60 | 47 | )
|
61 |
| - |
62 |
| - @pytest.mark.xfail(reason="OpenVINO doesn't support empty tensors for randperm") |
63 |
| - def test_randperm_zero(self, ie_device, precision, ir_version): |
64 |
| - model, ref_net, op = self.create_model(0) |
65 |
| - inputs = self._prepare_input() |
66 |
| - torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] |
67 |
| - ov_inputs = flattenize_inputs(inputs) |
68 |
| - trace_model = True |
69 |
| - dynamic_shapes = True |
70 |
| - freeze_model = True |
71 |
| - |
72 |
| - with torch.no_grad(): |
73 |
| - smodel, converted_model = self.convert_directly_via_frontend( |
74 |
| - model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model |
75 |
| - ) |
76 | 48 | from openvino import Core
|
77 | 49 | core = Core()
|
78 | 50 | compiled_model = core.compile_model(converted_model, ie_device)
|
79 |
| - _ = compiled_model(()) |
| 51 | + |
| 52 | + ov_output = compiled_model(ov_inputs)[0] |
| 53 | + if n > 0: |
| 54 | + assert ov_output.shape[0] == n, f"Output shape {ov_output.shape} does not match expected ({n},)" |
| 55 | + assert np.array_equal(np.sort(ov_output), np.arange(n)), f"Output is not a valid permutation of [0, ..., {n-1}]" |
| 56 | + else: |
| 57 | + assert ov_output.shape[0] == 0, f"Output shape for n=0 should be (0,), got {ov_output.shape}" |
0 commit comments