Linear#
NNX linear layer classes.
- class flax.nnx.Conv(self, in_features, out_features, kernel_size, strides=1, *, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, promote_dtype=<function promote_dtype>, preferred_element_type=None, rngs, kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
Convolution Module wrapping
lax.conv_general_dilated.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.shape (3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.shape (3, 3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- Parameters:
in_features – int or tuple with number of input features.
out_features – int or tuple with number of output features.
kernel_size – shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
strides – an integer or a sequence of
nintegers, representing the inter-window strides (default: 1).padding – either the string
'SAME', the string'VALID', the string'CIRCULAR'(periodic boundary conditions), the string ‘REFLECT’ (reflection across the padding boundary), or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides.'CAUSAL'padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.input_dilation – an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension ofinputs(default: 1). Convolution with input dilationdis equivalent to transposed convolution with strided.kernel_dilation – an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.feature_group_count – integer, default 1. If specified divides the input features into groups.
use_bias – whether to add a bias to the output (default: True).
mask – Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see
jax.lax.Precisionfor details.kernel_init – initializer for the convolutional kernel.
bias_init – initializer for the bias.
promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.preferred_element_type – Optional parameter controls the data type output by the convolution. This argument is passed to
conv_general_dilatedfunction. Seejax.lax.conv_general_dilatedfor details.rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the weight matrix.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
- __call__(inputs)[source]#
Applies a (potentially unshared) convolution to the inputs.
- Parameters:
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used bylax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns:
The convolved data.
Methods
- class flax.nnx.ConvTranspose(self, in_features, out_features, kernel_size, strides=None, *, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, promote_dtype=<function promote_dtype>, preferred_element_type=None, rngs, kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
Convolution Module wrapping
lax.conv_transpose.Note: The padding argument behaves differently from PyTorch; see the argument description below.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.shape (3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.shape (6, 6, 4, 3) >>> layer.bias.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- Parameters:
in_features – int or tuple with number of input features.
out_features – int or tuple with number of output features.
kernel_size – shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
strides – an integer or a sequence of
nintegers, representing the inter-window strides (default: 1).padding –
either a string indicating a specialized padding mode, or a sequence of
n(low, high)integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and a single int in a sequence causes the same padding to be used on both sides.Note that this behavior is different from PyTorch. In PyTorch, the padding argument effectively adds
dilation * (kernel_size - 1) - paddingamount of zero padding to the input instead. This is set so that whentorch.Conv2dandtorch.ConvTranspose2dare initialized with the same parameters, they are inverses of each other in regard to the input and output shapes.nnx.Convandnnx.ConvTransposedo not have this behavior; if you want annx.ConvTransposelayer to invert the shape change produced by annx.Convlayer with a given padding and dilation, you should explicitly passdilation * (kernel_size - 1) - paddingas the padding argument to thennx.ConvTransposelayer.Strings for specifying padding modes can be one of the following:
VALIDaddsdilation * (kernel_size - 1)padding to all dimensions. This is set so that annx.Convlayer withVALIDpadding would produce the inverse shape transformation.SAMEpads the input so that the output shape is the same as the input shape.CIRCULARpads the input with periodic boundary conditions.CAUSALpadding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.
kernel_dilation – an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.use_bias – whether to add a bias to the output (default: True).
mask – Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see
jax.lax.Precisionfor details.kernel_init – initializer for the convolutional kernel.
bias_init – initializer for the bias.
transpose_kernel – if
Trueflips spatial axes and swaps the input/output channel axes of the kernel.promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.preferred_element_type – Optional parameter controls the data type output by the transposed convolution. This argument is passed to
jax.lax.conv_transposefunction. Seejax.lax.conv_transposefor details.rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the weight matrix.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
- __call__(inputs)[source]#
Applies a transposed convolution to the inputs.
Behaviour mirrors of
jax.lax.conv_transpose.- Parameters:
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns:
The convolved data.
Methods
- class flax.nnx.Embed(self, num_embeddings, features, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, promote_dtype=<function promote_dtype>, rngs, embedding_metadata=mappingproxy({}))[source]#
Embedding Module.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': Param( # 15 (60 B) value=Array([[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135], [-1.0917838 , -0.7441663 , 0.07713798], [-0.66570747, 0.13815777, 1.007365 ]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135]], [[-0.66570747, 0.13815777, 1.007365 ], [-1.0917838 , -0.7441663 , 0.07713798], [ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32)
A parameterized function from integers [0,
num_embeddings) tofeatures-dimensional vectors. ThisModulewill create anembeddingmatrix with shape(num_embeddings, features). When calling this layer, the input values will be used to 0-index into theembeddingmatrix. Indexing on a value greater than or equal tonum_embeddingswill result innanvalues. Whennum_embeddingsequals to 1, it will broadcast theembeddingmatrix to input shape withfeaturesdimension appended.- Parameters:
num_embeddings – number of embeddings / vocab size.
features – number of feature dimensions for each embedding.
dtype – the dtype of the embedding vectors (default: same as embedding).
param_dtype – the dtype passed to parameter initializers (default: float32).
embedding_init – embedding initializer.
promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(embedding,)during__call__or(query, embedding)duringattend, and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – rng key.
embedding_metadata – Optional metadata dictionary to set when initializing the embedding matrix.
- __call__(inputs)[source]#
Embeds the inputs along the last dimension.
- Parameters:
inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.
- Returns:
Output which is embedded input data. The output shape follows the input, with an additional
featuresdimension appended.
- attend(query)[source]#
Attend over the embedding using a query array.
- Parameters:
query – array with last dimension equal the feature depth
featuresof the embedding.- Returns:
An array with final dim
num_embeddingscorresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.
Methods
attend(query)Attend over the embedding using a query array.
- class flax.nnx.Linear(self, in_features, out_features, *, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=<function dot_general>, promote_dtype=<function promote_dtype>, preferred_element_type=None, rngs, kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
A linear transformation applied over the last dimension of the input.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(4,) ), 'kernel': Param( value=(3, 4) ) })
- Parameters:
in_features – the number of input features.
out_features – the number of output features.
use_bias – whether to add a bias to the output (default: True).
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see
jax.lax.Precisionfor details.kernel_init – initializer function for the weight matrix.
bias_init – initializer function for the bias.
dot_general – dot product function.
promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.preferred_element_type – Optional parameter controls the data type output by the dot product. This argument is passed to
dot_generalfunction. Seejax.lax.dotfor details.rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the weight matrix.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters:
inputs – The nd-array to be transformed.
- Returns:
The transformed input.
Methods
- class flax.nnx.LinearGeneral(self, in_features, out_features, *, axis=-1, batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, promote_dtype=<function promote_dtype>, dot_general=None, dot_general_cls=None, preferred_element_type=None, rngs, kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
A linear transformation with flexible axes.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 4, 5) >>> layer.bias.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 3, 4, 5) >>> layer.bias.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- Parameters:
in_features – int or tuple with number of input features.
out_features – int or tuple with number of output features.
axis – int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.
batch_axis – mapping of batch axis indices to axis size.
use_bias – whether to add a bias to the output (default: True).
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
kernel_init – initializer function for the weight matrix.
bias_init – initializer function for the bias.
precision – numerical precision of the computation see
jax.lax.Precisionfor details.promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.preferred_element_type – Optional parameter controls the data type output by the dot product. This argument is passed to
dot_generalfunction. Seejax.lax.dotfor details.rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the weight matrix.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along multiple dimensions.
- Parameters:
inputs – The nd-array to be transformed.
- Returns:
The transformed input.
Methods
- class flax.nnx.Einsum(self, einsum_str, kernel_shape, bias_shape=None, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, einsum_op=<function einsum>, preferred_element_type=None, rngs, kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
An einsum transformation with learnable kernel and bias.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (8, 2, 4) >>> layer.bias.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- Parameters:
einsum_str – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of
einsum_strin the constructor argument and call argument must be not None, while the other must be None.kernel_shape – the shape of the kernel.
bias_shape – the shape of the bias. If this is None, a bias won’t be used.
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see
jax.lax.Precisionfor details.kernel_init – initializer function for the weight matrix.
bias_init – initializer function for the bias.
promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.einsum_op – An injectable alternative of jnp.einsum to do the computation. Should support same signature as jnp.einsum.
preferred_element_type – Optional parameter controls the data type output by the dot product. This argument is passed to
dot_generalfunction. Seejax.lax.dotfor details.rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the weight matrix.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
- __call__(inputs, einsum_str=None)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters:
inputs – The nd-array to be transformed.
einsum_str – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of
einsum_strin the constructor argument and call argument must be not None, while the other must be None.
- Returns:
The transformed input.
Methods