We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6a19a37 commit 757a10cCopy full SHA for 757a10c
tests/link/jax/test_pad.py
@@ -1,5 +1,6 @@
1
import numpy as np
2
import pytest
3
+from packaging import version
4
5
import pytensor.tensor as pt
6
from pytensor import config
@@ -16,7 +17,14 @@
16
17
"mode, kwargs",
18
[
19
("constant", {"constant_values": 0}),
- ("constant", {"constant_values": (1, 2)}),
20
+ pytest.param(
21
+ "constant",
22
+ {"constant_values": (1, 2)},
23
+ marks=pytest.mark.skipif(
24
+ version.parse(jax.__version__) > version.parse("0.4.35"),
25
+ reason="Bug in JAX: https://github.com/jax-ml/jax/issues/26888",
26
+ ),
27
28
("edge", {}),
29
("linear_ramp", {"end_values": 0}),
30
("linear_ramp", {"end_values": (1, 2)}),
0 commit comments