Skip to content

Commit 3380b3c

Browse files
authored
Fix tf-addons for upcoming keras 3 default. (#2858)
Keras 3.0 will become default in TF 2.16 (and is currently default in tf-nightly). This breaks this tf-addons package. Here we make minimal changes to return functionality in a backward-compatible way.
1 parent 062a7aa commit 3380b3c

14 files changed

+757
-38
lines changed

tensorflow_addons/image/tests/distort_image_ops_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_adjust_random_hue_in_yiq(shape, style, dtype):
9494
y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
9595
y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
9696
test_utils.assert_allclose_according_to_type(
97-
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8
97+
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=1.1
9898
)
9999

100100

@@ -121,11 +121,11 @@ def test_invalid_channels_hsv():
121121

122122
def test_adjust_hsv_in_yiq_unknown_shape():
123123
fn = tf.function(distort_image_ops.adjust_hsv_in_yiq).get_concrete_function(
124-
tf.TensorSpec(shape=None, dtype=tf.float64)
124+
tf.TensorSpec(shape=None, dtype=tf.float32)
125125
)
126126
for shape in (2, 3, 3), (4, 2, 3, 3):
127127
image_np = np.random.rand(*shape) * 255.0
128-
image_tf = tf.constant(image_np)
128+
image_tf = tf.constant(image_np, dtype=tf.float32)
129129
np.testing.assert_allclose(
130130
_adjust_hue_in_yiq_np(image_np, 0), fn(image_tf), rtol=2e-4, atol=1e-4
131131
)

tensorflow_addons/optimizers/discriminative_layer_training.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,20 @@
2222
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2323
from typeguard import typechecked
2424

25-
if Version(tf.__version__).release >= Version("2.13").release:
26-
# New versions of Keras require importing from `keras.src` when
27-
# importing internal symbols.
25+
if Version(tf.__version__).release >= Version("2.16").release:
26+
# Determine if loading keras 2 or 3.
27+
if (
28+
hasattr(tf.keras, "version")
29+
and Version(tf.keras.version()).release >= Version("3.0").release
30+
):
31+
# New versions of Keras require importing from `keras.src` when
32+
# importing internal symbols.
33+
from keras.src import backend
34+
from keras.src.utils import tf_utils
35+
else:
36+
from tf_keras.src import backend
37+
from tf_keras.src.utils import tf_utils
38+
elif Version(tf.__version__).release >= Version("2.13").release:
2839
from keras.src import backend
2940
from keras.src.utils import tf_utils
3041
else:

tensorflow_addons/optimizers/lazy_adam.py

+3
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,6 @@ def _resource_scatter_operate(self, resource, indices, update, resource_scatter_
149149
}
150150

151151
return resource_scatter_op(**resource_update_kwargs)
152+
153+
def get_config(self):
154+
return super().get_config()
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Base class for RNN cells.
16+
17+
Adapted from legacy github.com/keras-team/tf-keras.
18+
"""
19+
20+
import tensorflow as tf
21+
22+
23+
def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
24+
if inputs is not None:
25+
batch_size = tf.shape(inputs)[0]
26+
dtype = inputs.dtype
27+
return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
28+
29+
30+
def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
31+
"""Generate a zero filled tensor with shape [batch_size, state_size]."""
32+
if batch_size_tensor is None or dtype is None:
33+
raise ValueError(
34+
"batch_size and dtype cannot be None while constructing initial state: "
35+
"batch_size={}, dtype={}".format(batch_size_tensor, dtype)
36+
)
37+
38+
def create_zeros(unnested_state_size):
39+
flat_dims = tf.TensorShape(unnested_state_size).as_list()
40+
init_state_size = [batch_size_tensor] + flat_dims
41+
return tf.zeros(init_state_size, dtype=dtype)
42+
43+
if tf.nest.is_nested(state_size):
44+
return tf.nest.map_structure(create_zeros, state_size)
45+
else:
46+
return create_zeros(state_size)
47+
48+
49+
class AbstractRNNCell(tf.keras.layers.Layer):
50+
"""Abstract object representing an RNN cell.
51+
52+
This is a base class for implementing RNN cells with custom behavior.
53+
54+
Every `RNNCell` must have the properties below and implement `call` with
55+
the signature `(output, next_state) = call(input, state)`.
56+
57+
Examples:
58+
59+
```python
60+
class MinimalRNNCell(AbstractRNNCell):
61+
62+
def __init__(self, units, **kwargs):
63+
self.units = units
64+
super(MinimalRNNCell, self).__init__(**kwargs)
65+
66+
@property
67+
def state_size(self):
68+
return self.units
69+
70+
def build(self, input_shape):
71+
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
72+
initializer='uniform',
73+
name='kernel')
74+
self.recurrent_kernel = self.add_weight(
75+
shape=(self.units, self.units),
76+
initializer='uniform',
77+
name='recurrent_kernel')
78+
self.built = True
79+
80+
def call(self, inputs, states):
81+
prev_output = states[0]
82+
h = backend.dot(inputs, self.kernel)
83+
output = h + backend.dot(prev_output, self.recurrent_kernel)
84+
return output, output
85+
```
86+
87+
This definition of cell differs from the definition used in the literature.
88+
In the literature, 'cell' refers to an object with a single scalar output.
89+
This definition refers to a horizontal array of such units.
90+
91+
An RNN cell, in the most abstract setting, is anything that has
92+
a state and performs some operation that takes a matrix of inputs.
93+
This operation results in an output matrix with `self.output_size` columns.
94+
If `self.state_size` is an integer, this operation also results in a new
95+
state matrix with `self.state_size` columns. If `self.state_size` is a
96+
(possibly nested tuple of) TensorShape object(s), then it should return a
97+
matching structure of Tensors having shape `[batch_size].concatenate(s)`
98+
for each `s` in `self.batch_size`.
99+
"""
100+
101+
def call(self, inputs, states):
102+
"""The function that contains the logic for one RNN step calculation.
103+
104+
Args:
105+
inputs: the input tensor, which is a slide from the overall RNN input by
106+
the time dimension (usually the second dimension).
107+
states: the state tensor from previous step, which has the same shape
108+
as `(batch, state_size)`. In the case of timestep 0, it will be the
109+
initial state user specified, or zero filled tensor otherwise.
110+
111+
Returns:
112+
A tuple of two tensors:
113+
1. output tensor for the current timestep, with size `output_size`.
114+
2. state tensor for next step, which has the shape of `state_size`.
115+
"""
116+
raise NotImplementedError("Abstract method")
117+
118+
@property
119+
def state_size(self):
120+
"""size(s) of state(s) used by this cell.
121+
122+
It can be represented by an Integer, a TensorShape or a tuple of Integers
123+
or TensorShapes.
124+
"""
125+
raise NotImplementedError("Abstract method")
126+
127+
@property
128+
def output_size(self):
129+
"""Integer or TensorShape: size of outputs produced by this cell."""
130+
raise NotImplementedError("Abstract method")
131+
132+
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
133+
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)

tensorflow_addons/rnn/esn_cell.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
"""Implements ESN Cell."""
1616

1717
import tensorflow as tf
18-
import tensorflow.keras as keras
1918
from typeguard import typechecked
2019

20+
from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
2121
from tensorflow_addons.utils.types import (
2222
Activation,
2323
Initializer,
2424
)
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package="Addons")
28-
class ESNCell(keras.layers.AbstractRNNCell):
28+
class ESNCell(AbstractRNNCell):
2929
"""Echo State recurrent Network (ESN) cell.
3030
This implements the recurrent cell from the paper:
3131
H. Jaeger

tensorflow_addons/rnn/nas_cell.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
"""Implements NAS Cell."""
1616

1717
import tensorflow as tf
18-
import tensorflow.keras as keras
1918
from typeguard import typechecked
2019

20+
from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
2121
from tensorflow_addons.utils.types import (
2222
FloatTensorLike,
2323
TensorLike,
@@ -27,7 +27,7 @@
2727

2828

2929
@tf.keras.utils.register_keras_serializable(package="Addons")
30-
class NASCell(keras.layers.AbstractRNNCell):
30+
class NASCell(AbstractRNNCell):
3131
"""Neural Architecture Search (NAS) recurrent network cell.
3232
3333
This implements the recurrent cell from the paper:

tensorflow_addons/seq2seq/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ py_library(
1010
"//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so",
1111
],
1212
deps = [
13+
"//tensorflow_addons/rnn",
1314
"//tensorflow_addons/testing",
1415
"//tensorflow_addons/utils",
1516
],

tensorflow_addons/seq2seq/attention_wrapper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import tensorflow as tf
2525

26+
from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
2627
from tensorflow_addons.utils import keras_utils
2728
from tensorflow_addons.utils.types import (
2829
AcceptableDTypes,
@@ -1577,7 +1578,7 @@ def _compute_attention(
15771578
return attention, alignments, next_attention_state
15781579

15791580

1580-
class AttentionWrapper(tf.keras.layers.AbstractRNNCell):
1581+
class AttentionWrapper(AbstractRNNCell):
15811582
"""Wraps another RNN cell with attention.
15821583
15831584
Example:

tensorflow_addons/text/BUILD

+7-9
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@ package(default_visibility = ["//visibility:public"])
77
py_library(
88
name = "text",
99
srcs = glob(["*.py"]),
10-
data = select({
11-
"//tensorflow_addons:windows": [
12-
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
13-
"//tensorflow_addons/testing",
14-
"//tensorflow_addons/utils",
15-
],
10+
data = [
11+
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
12+
"//tensorflow_addons/rnn",
13+
"//tensorflow_addons/testing",
14+
"//tensorflow_addons/utils",
15+
] + select({
16+
"//tensorflow_addons:windows": [],
1617
"//conditions:default": [
1718
"//tensorflow_addons/custom_ops/text:_parse_time_op.so",
18-
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
19-
"//tensorflow_addons/testing",
20-
"//tensorflow_addons/utils",
2119
],
2220
}),
2321
)

tensorflow_addons/text/crf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import tensorflow as tf
1919

20+
from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
2021
from tensorflow_addons.utils.types import TensorLike
2122
from typeguard import typechecked
2223
from typing import Optional, Tuple
@@ -403,7 +404,7 @@ def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tenso
403404
return viterbi, viterbi_score
404405

405406

406-
class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
407+
class CrfDecodeForwardRnnCell(AbstractRNNCell):
407408
"""Computes the forward decoding in a linear-chain CRF."""
408409

409410
@typechecked

tensorflow_addons/utils/test_utils.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,10 @@
2222
import pytest
2323
import tensorflow as tf
2424

25-
from packaging.version import Version
2625
from tensorflow_addons import options
2726
from tensorflow_addons.utils import resource_loader
2827

29-
if Version(tf.__version__).release >= Version("2.13").release:
30-
# New versions of Keras require importing from `keras.src` when
31-
# importing internal symbols.
32-
from keras.src.testing_infra.test_utils import layer_test # noqa: F401
33-
elif Version(tf.__version__) >= Version("2.9"):
34-
from keras.testing_infra.test_utils import layer_test # noqa: F401
35-
else:
36-
from keras.testing_utils import layer_test # noqa: F401
28+
from tensorflow_addons.utils.tf_test_utils import layer_test # noqa
3729

3830
NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))
3931
WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2])

0 commit comments

Comments
 (0)