View source on GitHub
|
An ExtensionType that can be batched and unbatched.
Inherits From: ExtensionType
tf.experimental.BatchableExtensionType(
*args, **kwargs
)
Used in the notebooks
| Used in the guide |
|---|
BatchableExtensionTypes can be used with APIs that require batching or
unbatching, including Keras, tf.data.Dataset, and tf.map_fn. E.g.:
class Vehicle(tf.experimental.BatchableExtensionType):top_speed: tf.Tensormpg: tf.Tensorbatch = Vehicle([120, 150, 80], [30, 40, 12])tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,fn_output_signature=tf.int32).numpy()array([3600, 6000, 960], dtype=int32)
An ExtensionTypeBatchEncoder is used by these APIs to encode ExtensionType
values. The default encoder assumes that values can be stacked, unstacked, or
concatenated by simply stacking, unstacking, or concatenating every nested
Tensor, ExtensionType, CompositeTensor, or TensorShape field.
Extension types where this is not the case will need to override
__batch_encoder__ with a custom ExtensionTypeBatchEncoder. See
tf.experimental.ExtensionTypeBatchEncoder for more details.
Methods
__eq__
__eq__(
other
)
Return self==value.
__ne__
__ne__(
other
)
Return self!=value.
View source on GitHub