Skip to content

2023-06-21

Compare
Choose a tag to compare
@mariogeiger mariogeiger released this 22 Jun 02:47
· 187 commits to main since this release

Highlight

There are two main changes in this release: Linear and IrrepsArray.

Change in Linear

Now the classes e3nn.flax.Linear and e3nn.haiku.Linear by default discard the output irrep that are not reachable from the input.

linear = Linear("2x0e + 1o + 2e")
x = e3nn.normal("0e + 1o")  # an input without 2e
w = linear.init(jax.random.PRNGKey(0), x)
linear.apply(w, x).irreps
# 2x0e+1x1o, no 2e because it is not reachable from the input

Change in IrrepsArray

Before this release

IrrepsArray had its data stored twice, both in .array and in .list.

x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
x.list  # a list with the scalars and the vectors in two separate arrays
x.array  # an array with all the data contiguous

the motivation was to allow chunks to be stored as None, to be strictly zero

e3nn.IrrepsArray.from_list("0e + 1o", [jnp.array([[1.0]]), None], ())

this was leading to confusing situations with the jax.tree_util module

jax.tree_util.tree_leaves(x)
# a list of 3 arrays with repeated data

Because of that, jax.vmap was not possible with negative axis

jax.vmap(lambda x: x, in_axes=-2)(x)  # ERROR!!

And the gradient was only propagated through one of the two attributes

g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array  # is zero
g.list  # is not zero

In this release

We refactored IrrepsArray. It has now only .array as a data attribute. The .list attribute is gone.
Therefore jax.tree_util.tree_leaves have just one array.

x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
jax.tree_util.tree_leaves(x)  # [jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])]

Instead, we have a new attribute .zero_flags which is a list of booleans indicating whether the corresponding chunk is zero or not.

y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks  # [jnp.array([[1.0]]), None]
y.zero_flags  # [False, True]

.chunks is the new attribute that replaces .list (now deprecated).
It has a better name because we already have .slice_by_chunk.

x.chunks  # list of the two chunks
x.slice_by_chunk[:1]  # get the first chunk

jax.vmap can be used with negative axis

jax.vmap(lambda x: x, in_axes=-2)(x)

And the gradient behaves as expected

g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array  # expected value
g.chunks  # expected value

To avoid any trouble that .zero_flags might induce in all jax transformations we drop it when using a transformation.

y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags)  # (False, True)

z = jax.jit(lambda x: x)(y)
print(z.zero_flags)  # (False, False)

z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags)  # (False, False)

z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags)  # (False, False)

Changelog

Changed

  • [BREAKING] e3nn.flax.Linear and e3nn.haiku.Linear now don't output the impossible irreps anymore. To force the output of all irreps, use force_irreps_out = True. For instance e3nn.flax.Linear("0e + 1o")("0e") will now return "0e" instead of "0e + 1o".
  • [BREAKING] e3nn.utils.assert_equivariant has the same signature as e3nn.utils.equivariance_test
  • [BREAKING] Move as_irreps_array, zeros and zeros_like from e3nn.IrrepsArray to e3nn
  • [BREAKING] Move IrrepsArray.from_list to e3nn.from_chunks
  • [BREAKING] Rename IrrepsArray.list into IrrepsArray.chunks
  • [BREAKING] Rename IrrepsArray.remove_nones into IrrepsArray.remove_zero_chunks
  • e3nn.IrrepsArray has now only .array as data attribute.

Added

  • e3nn.IrrepsArray.rechunk
  • e3nn.IrrepsArray.zero_flags a tuple of bools that indicates which chunks are zero