View source on GitHub
|
Keras Layer that wraps a JAX model.
Inherits From: Layer, Operation
tf.keras.layers.JaxLayer(
call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)
This layer enables the use of JAX components within Keras when using JAX as the backend for Keras.
Model function
This layer accepts JAX models in the form of a function, call_fn, which
must take the following arguments with these exact names:
params: trainable parameters of the model.state(optional): non-trainable state of the model. Can be omitted if the model has no non-trainable state.rng(optional): ajax.random.PRNGKeyinstance. Can be omitted if the model does not need RNGs, neither during training nor during inference.inputs: inputs to the model, a JAX array or aPyTreeof arrays.training(optional): an argument specifying if we're in training mode or inference mode,Trueis passed in training mode. Can be omitted if the model behaves the same in training mode and inference mode.
The inputs argument is mandatory. Inputs to the model must be provided via
a single argument. If the JAX model takes multiple inputs as separate
arguments, they must be combined into a single structure, for instance in a
tuple or a dict.
Model weights initialization
The initialization of the params and state of the model can be handled
by this layer, in which case the init_fn argument must be provided. This
allows the model to be initialized dynamically with the right shape.
Alternatively, and if the shape is known, the params argument and
optionally the state argument can be used to create an already initialized
model.
The init_fn function, if provided, must take the following arguments with
these exact names:
rng: ajax.random.PRNGKeyinstance.inputs: a JAX array or aPyTreeof arrays with placeholder values to provide the shape of the inputs.training(optional): an argument specifying if we're in training mode or inference mode.Trueis always passed toinit_fn. Can be omitted regardless of whethercall_fnhas atrainingargument.
Models with non-trainable state
For JAX models that have non-trainable state:
call_fnmust have astateargumentcall_fnmust return atuplecontaining the outputs of the model and the new non-trainable state of the modelinit_fnmust return atuplecontaining the initial trainable params of the model and the initial non-trainable state of the model.
This code shows a possible combination of call_fn and init_fn signatures
for a model with non-trainable state. In this example, the model has a
training argument and an rng argument in call_fn.
def stateful_call(params, state, rng, inputs, training):
outputs = ...
new_state = ...
return outputs, new_state
def stateful_init(rng, inputs):
initial_params = ...
initial_state = ...
return initial_params, initial_state
Models without non-trainable state
For JAX models with no non-trainable state:
call_fnmust not have astateargumentcall_fnmust return only the outputs of the modelinit_fnmust return only the initial trainable params of the model.
This code shows a possible combination of call_fn and init_fn signatures
for a model without non-trainable state. In this example, the model does not
have a training argument and does not have an rng argument in call_fn.
def stateless_call(params, inputs):
outputs = ...
return outputs
def stateless_init(rng, inputs):
initial_params = ...
return initial_params
Conforming to the required signature
If a model has a different signature than the one required by JaxLayer,
one can easily write a wrapper method to adapt the arguments. This example
shows a model that has multiple inputs as separate arguments, expects
multiple RNGs in a dict, and has a deterministic argument with the
opposite meaning of training. To conform, the inputs are combined in a
single structure using a tuple, the RNG is split and used the populate the
expected dict, and the Boolean flag is negated:
def my_model_fn(params, rngs, input1, input2, deterministic):
...
if not deterministic:
dropout_rng = rngs["dropout"]
keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
x = jax.numpy.where(keep, x / dropout_rate, 0)
...
...
return outputs
def my_model_wrapper_fn(params, rng, inputs, training):
input1, input2 = inputs
rng1, rng2 = jax.random.split(rng)
rngs = {"dropout": rng1, "preprocessing": rng2}
deterministic = not training
return my_model_fn(params, rngs, input1, input2, deterministic)
keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)
Usage with Haiku modules
JaxLayer enables the use of Haiku
components in the form of
haiku.Module.
This is achieved by transforming the module per the Haiku pattern and then
passing module.apply in the call_fn parameter and module.init in the
init_fn parameter if needed.
If the model has non-trainable state, it should be transformed with
haiku.transform_with_state.
If the model has no non-trainable state, it should be transformed with
haiku.transform.
Additionally, and optionally, if the module does not use RNGs in "apply", it
can be transformed with
haiku.without_apply_rng.
The following example shows how to create a JaxLayer from a Haiku module
that uses random number generators via hk.next_rng_key() and takes a
training positional argument:
class MyHaikuModule(hk.Module):
def __call__(self, x, training):
x = hk.Conv2D(32, (3, 3))(x)
x = jax.nn.relu(x)
x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
x = hk.Flatten()(x)
x = hk.Linear(200)(x)
if training:
x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
x = jax.nn.relu(x)
x = hk.Linear(10)(x)
x = jax.nn.softmax(x)
return x
def my_haiku_module_fn(inputs, training):
module = MyHaikuModule()
return module(inputs, training)
transformed_module = hk.transform(my_haiku_module_fn)
keras_layer = JaxLayer(
call_fn=transformed_module.apply,
init_fn=transformed_module.init,
)
Methods
from_config
@classmethodfrom_config( config )
Creates a layer from its config.
This method is the reverse of get_config,
capable of instantiating the same layer from the config
dictionary. It does not handle layer connectivity
(handled by Network), nor weights (handled by set_weights).
| Args | |
|---|---|
config
|
A Python dictionary, typically the output of get_config. |
| Returns | |
|---|---|
| A layer instance. |
symbolic_call
symbolic_call(
*args, **kwargs
)
View source on GitHub