Skip to content

Commit f00ac41

Browse files
authored
[JAX FE]: Support jax.lax.eq and jax.lax.ne operation for JAX (openvinotoolkit#26719)
### Details: - support jax.lax.eq and jax.lax.ne operation - enable unit tests for new operations ### Tickets: - [None](openvinotoolkit#26571)
1 parent 94749db commit f00ac41

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/frontends/jax/src/op/binary_op.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
//
44

55
#include "openvino/frontend/jax/node_context.hpp"
6+
#include "openvino/op/equal.hpp"
67
#include "openvino/op/greater.hpp"
78
#include "openvino/op/greater_eq.hpp"
9+
#include "openvino/op/not_equal.hpp"
810
#include "utils.hpp"
911

1012
using namespace std;
@@ -25,8 +27,10 @@ OutputVector translate_binary_op(const NodeContext& context) {
2527
return {binary_op};
2628
}
2729

30+
template OutputVector translate_binary_op<v1::Equal>(const NodeContext& context);
2831
template OutputVector translate_binary_op<v1::GreaterEqual>(const NodeContext& context);
2932
template OutputVector translate_binary_op<v1::Greater>(const NodeContext& context);
33+
template OutputVector translate_binary_op<v1::NotEqual>(const NodeContext& context);
3034

3135
} // namespace op
3236
} // namespace jax

src/frontends/jax/src/op_table.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
#include "openvino/op/add.hpp"
88
#include "openvino/op/divide.hpp"
9+
#include "openvino/op/equal.hpp"
910
#include "openvino/op/erf.hpp"
1011
#include "openvino/op/exp.hpp"
1112
#include "openvino/op/greater.hpp"
1213
#include "openvino/op/greater_eq.hpp"
1314
#include "openvino/op/maximum.hpp"
1415
#include "openvino/op/multiply.hpp"
16+
#include "openvino/op/not_equal.hpp"
1517
#include "openvino/op/reduce_max.hpp"
1618
#include "openvino/op/reduce_sum.hpp"
1719
#include "openvino/op/sqrt.hpp"
@@ -63,13 +65,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
6365
{"device_put", op::skip_node},
6466
{"div", op::translate_1to1_match_2_inputs<v1::Divide>},
6567
{"dot_general", op::translate_dot_general},
68+
{"eq", op::translate_binary_op<v1::Equal>},
6669
{"erf", op::translate_1to1_match_1_input<v0::Erf>},
6770
{"exp", op::translate_1to1_match_1_input<v0::Exp>},
6871
{"ge", op::translate_binary_op<v1::GreaterEqual>},
6972
{"gt", op::translate_binary_op<v1::Greater>},
7073
{"integer_pow", op::translate_integer_pow},
7174
{"max", op::translate_1to1_match_2_inputs<v1::Maximum>},
7275
{"mul", op::translate_1to1_match_2_inputs<v1::Multiply>},
76+
{"ne", op::translate_binary_op<v1::NotEqual>},
7377
{"reduce_max", op::translate_reduce_op<v1::ReduceMax>},
7478
{"reduce_sum", op::translate_reduce_op<v1::ReduceSum>},
7579
{"reduce_window_max", op::translate_reduce_window_max},

tests/layer_tests/jax_tests/test_binary_comparison.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def _prepare_input(self):
2727

2828
def create_model(self, input_shapes, binary_op, input_type):
2929
reduce_map = {
30+
'eq': lax.eq,
3031
'ge': lax.ge,
31-
'gt': lax.gt
32+
'gt': lax.gt,
33+
'ne': lax.ne
3234
}
3335

3436
self.input_shapes = input_shapes
@@ -42,7 +44,7 @@ def jax_binary(x, y):
4244
@pytest.mark.parametrize('input_shapes', [[[5], [1]], [[1], [5]], [[2, 2, 4], [1, 1, 4]],
4345
[[5, 10], [5, 10]], [[2, 4, 6], [1, 4, 6]],
4446
[[5, 8, 10, 128], [5, 1, 10, 128]]])
45-
@pytest.mark.parametrize('binary_op', ['ge', 'gt'])
47+
@pytest.mark.parametrize('binary_op', ['eq', 'ge', 'gt', 'ne'])
4648
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
4749
np.int32, np.uint32, np.int64, np.uint64,
4850
np.float16, np.float32, np.float64])

0 commit comments

Comments
 (0)