View source on GitHub
|
Decorator that overrides the default implementation for a TensorFlow API.
tf.experimental.dispatch_for_api(
api, *signatures
)
Used in the notebooks
| Used in the guide |
|---|
The decorated function (known as the "dispatch target") will override the
default implementation for the API when the API is called with parameters that
match a specified type signature. Signatures are specified using dictionaries
that map parameter names to type annotations. E.g., in the following example,
masked_add will be called for tf.add if both x and y are
MaskedTensors:
class MaskedTensor(tf.experimental.ExtensionType):values: tf.Tensormask: tf.Tensor
@dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})def masked_add(x, y, name=None):return MaskedTensor(x.values + y.values, x.mask & y.mask)
mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")values=[11 12], mask=[ True False]
If multiple type signatures are specified, then the dispatch target will be
called if any of the signatures match. For example, the following code
registers masked_add to be called if x is a MaskedTensor or y is
a MaskedTensor.
@dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})def masked_add(x, y):x_values = x.values if isinstance(x, MaskedTensor) else xx_mask = x.mask if isinstance(x, MaskedTensor) else Truey_values = y.values if isinstance(y, MaskedTensor) else yy_mask = y.mask if isinstance(y, MaskedTensor) else Truereturn MaskedTensor(x_values + y_values, x_mask & y_mask)
The type annotations in type signatures may be type objects (e.g.,
MaskedTensor), typing.List values, or typing.Union values. For
example, the following will register masked_concat to be called if values
is a list of MaskedTensor values:
@dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})def masked_concat(values, axis):return MaskedTensor(tf.concat([v.values for v in values], axis),tf.concat([v.mask for v in values], axis))
Each type signature must contain at least one subclass of tf.CompositeTensor
(which includes subclasses of tf.ExtensionType), and dispatch will only be
triggered if at least one type-annotated parameter contains a
CompositeTensor value. This rule avoids invoking dispatch in degenerate
cases, such as the following examples:
@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]}): Will not dispatch to the decorated dispatch target when the user callstf.concat([]).@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y': Union[MaskedTensor, Tensor]}): Will not dispatch to the decorated dispatch target when the user callstf.add(tf.constant(1), tf.constant(2)).
The dispatch target's signature must match the signature of the API that is
being overridden. In particular, parameters must have the same names, and
must occur in the same order. The dispatch target may optionally elide the
"name" parameter, in which case it will be wrapped with a call to
tf.name_scope when appropraite.
Returns | |
|---|---|
A decorator that overrides the default implementation for api.
|
Registered APIs
The TensorFlow APIs that may be overridden by @dispatch_for_api are:
<
View source on GitHub