github google/flax v0.12.1

4 hours ago

Deprecations

Variable.value

Variable.value is now deprecated. Consider the following example:

import jax.numpy as jnp
import jax
from flax import nnx

my_param = nnx.Param({'a': 0.0})

@nnx.jit
def f(m):
    m.value['a'] = 1.0
    return m

Running f(my_param) produces Param(value={'a': 0.0}), not Param(value={'a': 1.0}) as before. This is because getting the value parameter new returns a copy of the pytree values (like dict / list). Instead, use the __setitem__ method to update the value:

@nnx.jit
def f(m):
    m['a'] = 1.0
    return m

nnx.Data and nnx.Static

nnx.Data and nnx.Static annotations are now deprecated. To create nnx.Pytree or nnx.Module dataclasses use the new nnx.dataclass with nnx.data and nnx.static as field descriptors.

# old
@dataclasses.dataclass
class Foo(nnx.Pytree):
  a: nnx.Data[int]
  b: nnx.Static[str]

# new
@nnx.dataclass
class Foo(nnx.Pytree):
  a: int = nnx.data()
  b: str = nnx.static()

Pull Requests

New Contributors

Full Changelog: v0.12.0...v0.12.1

Don't miss a new flax release

NewReleases is sending notifications on new releases.