Skip to content

Commit 362f073

Browse files
11happyrkazants
andauthored
[JAX]: Support jax.lax.select_n operation for JAX (#28025)
**Overview:** This pull request fixes #26570. **Testing:** - Tested the updated code. - Verified that other functionalities remain unaffected. ![Screenshot from 2024-12-12 00-09-17](https://github.com/user-attachments/assets/ae118efa-2047-4bac-a90d-8318396e5f63) **Dependencies:** - No dependencies on other pull requests. **CC:** - @rkazants --------- Signed-off-by: 11happy <soni5happy@gmail.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
1 parent ca501ca commit 362f073

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

src/frontends/jax/src/op/select_n.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/frontend/jax/node_context.hpp"
6+
#include "openvino/op/concat.hpp"
7+
#include "openvino/op/constant.hpp"
8+
#include "openvino/op/convert.hpp"
9+
#include "openvino/op/gather_elements.hpp"
10+
#include "openvino/op/unsqueeze.hpp"
11+
#include "utils.hpp"
12+
13+
using namespace ov::op;
14+
15+
namespace ov {
16+
namespace frontend {
17+
namespace jax {
18+
namespace op {
19+
20+
OutputVector translate_select_n(const NodeContext& context) {
21+
num_inputs_check(context, 2);
22+
auto num_inputs = static_cast<int>(context.get_input_size());
23+
Output<Node> which = context.get_input(0);
24+
if (which.get_element_type() == element::boolean) {
25+
which = std::make_shared<v0::Convert>(which, element::i32);
26+
}
27+
auto const_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0});
28+
OutputVector unsqueezed_cases(num_inputs - 1);
29+
unsqueezed_cases.reserve(num_inputs - 1);
30+
for (int ind = 1; ind < num_inputs; ++ind) {
31+
auto case_input = context.get_input(ind);
32+
auto unsqueeze = std::make_shared<v0::Unsqueeze>(case_input, const_axis);
33+
unsqueezed_cases[ind - 1] = unsqueeze;
34+
}
35+
Output<Node> cases = std::make_shared<v0::Concat>(unsqueezed_cases, 0);
36+
which =
37+
std::make_shared<v0::Unsqueeze>(which,
38+
ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{0}));
39+
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0);
40+
return {result};
41+
};
42+
43+
} // namespace op
44+
} // namespace jax
45+
} // namespace frontend
46+
} // namespace ov

src/frontends/jax/src/op_table.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ OP_CONVERTER(translate_reduce_window_max);
5252
OP_CONVERTER(translate_reduce_window_sum);
5353
OP_CONVERTER(translate_reshape);
5454
OP_CONVERTER(translate_rsqrt);
55+
OP_CONVERTER(translate_select_n);
5556
OP_CONVERTER(translate_slice);
5657
OP_CONVERTER(translate_square);
5758
OP_CONVERTER(translate_squeeze);
@@ -92,6 +93,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
9293
{"transpose", op::translate_transpose},
9394
{"rsqrt", op::translate_rsqrt},
9495
{"reshape", op::translate_reshape},
96+
{"select_n", op::translate_select_n},
9597
{"slice", op::translate_slice},
9698
{"square", op::translate_square},
9799
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import jax
5+
import numpy as np
6+
import pytest
7+
from jax import numpy as jnp
8+
9+
from jax_layer_test_class import JaxLayerTest
10+
11+
rng = np.random.default_rng(5402)
12+
13+
14+
class TestSelectN(JaxLayerTest):
15+
def _prepare_input(self):
16+
cases = []
17+
if (self.case_num == 2):
18+
which = rng.choice([True, False], self.input_shape)
19+
else:
20+
which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type)
21+
which = np.array(which)
22+
for i in range(self.case_num):
23+
cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type)))
24+
cases = np.array(cases)
25+
return (which, cases)
26+
27+
def create_model(self, input_shape, input_type, case_num):
28+
self.input_shape = input_shape
29+
self.input_type = input_type
30+
self.case_num = case_num
31+
32+
def jax_select_n(which, cases):
33+
return jax.lax.select_n(which, *cases)
34+
35+
return jax_select_n, None, 'select_n'
36+
37+
@pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]])
38+
@pytest.mark.parametrize("input_type", [np.int32, np.int64])
39+
@pytest.mark.parametrize("case_num", [2, 3, 4])
40+
@pytest.mark.nightly
41+
@pytest.mark.precommit_jax_fe
42+
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num):
43+
self._test(*self.create_model(input_shape, input_type, case_num),
44+
ie_device, precision,
45+
ir_version)

0 commit comments

Comments
 (0)