2023-06-21
Highlight
There are two main changes in this release: Linear
and IrrepsArray
.
Change in Linear
Now the classes e3nn.flax.Linear
and e3nn.haiku.Linear
by default discard the output irrep that are not reachable from the input.
linear = Linear("2x0e + 1o + 2e")
x = e3nn.normal("0e + 1o") # an input without 2e
w = linear.init(jax.random.PRNGKey(0), x)
linear.apply(w, x).irreps
# 2x0e+1x1o, no 2e because it is not reachable from the input
Change in IrrepsArray
Before this release
IrrepsArray
had its data stored twice, both in .array
and in .list
.
x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
x.list # a list with the scalars and the vectors in two separate arrays
x.array # an array with all the data contiguous
the motivation was to allow chunks to be stored as None
, to be strictly zero
e3nn.IrrepsArray.from_list("0e + 1o", [jnp.array([[1.0]]), None], ())
this was leading to confusing situations with the jax.tree_util
module
jax.tree_util.tree_leaves(x)
# a list of 3 arrays with repeated data
Because of that, jax.vmap
was not possible with negative axis
jax.vmap(lambda x: x, in_axes=-2)(x) # ERROR!!
And the gradient was only propagated through one of the two attributes
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array # is zero
g.list # is not zero
In this release
We refactored IrrepsArray
. It has now only .array
as a data attribute. The .list
attribute is gone.
Therefore jax.tree_util.tree_leaves
have just one array.
x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
jax.tree_util.tree_leaves(x) # [jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])]
Instead, we have a new attribute .zero_flags
which is a list of booleans indicating whether the corresponding chunk is zero or not.
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks # [jnp.array([[1.0]]), None]
y.zero_flags # [False, True]
.chunks
is the new attribute that replaces .list
(now deprecated).
It has a better name because we already have .slice_by_chunk
.
x.chunks # list of the two chunks
x.slice_by_chunk[:1] # get the first chunk
jax.vmap
can be used with negative axis
jax.vmap(lambda x: x, in_axes=-2)(x)
And the gradient behaves as expected
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array # expected value
g.chunks # expected value
To avoid any trouble that .zero_flags
might induce in all jax
transformations we drop it when using a transformation.
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags) # (False, True)
z = jax.jit(lambda x: x)(y)
print(z.zero_flags) # (False, False)
z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags) # (False, False)
z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags) # (False, False)
Changelog
Changed
- [BREAKING]
e3nn.flax.Linear
ande3nn.haiku.Linear
now don't output the impossible irreps anymore. To force the output of all irreps, useforce_irreps_out = True
. For instancee3nn.flax.Linear("0e + 1o")("0e")
will now return"0e"
instead of"0e + 1o"
. - [BREAKING]
e3nn.utils.assert_equivariant
has the same signature ase3nn.utils.equivariance_test
- [BREAKING] Move
as_irreps_array
,zeros
andzeros_like
frome3nn.IrrepsArray
toe3nn
- [BREAKING] Move
IrrepsArray.from_list
toe3nn.from_chunks
- [BREAKING] Rename
IrrepsArray.list
intoIrrepsArray.chunks
- [BREAKING] Rename
IrrepsArray.remove_nones
intoIrrepsArray.remove_zero_chunks
e3nn.IrrepsArray
has now only.array
as data attribute.
Added
e3nn.IrrepsArray.rechunk
e3nn.IrrepsArray.zero_flags
a tuple of bools that indicates which chunks are zero