View source on GitHub
|
Distribution for data parallelism.
tf.keras.distribution.DataParallel(
device_mesh=None, devices=None
)
You can choose to create this instance by either specifying
the device_mesh or devices arguments (but not both).
The device_mesh argument is expected to be a DeviceMesh instance,
and is expected to be 1D only. In case that the mesh has multiple axes,
then the first axis will be treated as the data parallel dimension
(and a warning will be raised).
When a list of devices are provided, they will be used to construct a
1D mesh.
When both mesh and devices are absent, then list_devices()
will be used to detect any available devices and create a 1D mesh from
them.
Args | |
|---|---|
device_mesh
|
Optional DeviceMesh instance.
|
devices
|
Optional list of devices. |
Attributes | |
|---|---|
device_mesh
|
|
Methods
distribute_dataset
distribute_dataset(
dataset
)
Create a distributed dataset instance from the original user dataset.
| Args | |
|---|---|
dataset
|
the original global dataset instance. Only
tf.data.Dataset is supported at the moment.
|
| Returns | |
|---|---|
a sharded tf.data.Dataset instance, which will produce data for
the current local worker/process.
|
get_data_layout
get_data_layout(
data_shape
)
Retrieve the TensorLayout for the input data.
| Args | |
|---|---|
data_shape
|
shape for the input data in list or tuple format. |
| Returns | |
|---|---|
The TensorLayout for the data, which can be used by
backend.distribute_value() to redistribute a input data.
|
get_tensor_layout
get_tensor_layout(
path
)
Retrieve the TensorLayout for the intermediate tensor.
| Args | |
|---|---|
path
|
a string path for the corresponding tensor. |
return:
The TensorLayout for the intermediate tensor, which can be used
by backend.relayout() to reshard the tensor. Could also return
None.
get_variable_layout
get_variable_layout(
variable
)
Retrieve the TensorLayout for the variable.
| Args | |
|---|---|
variable
|
A KerasVariable instance.
|
return:
The TensorLayout for the variable, which can be used by
backend.distribute_value() to redistribute a variable.
scope
@contextlib.contextmanagerscope()
Context manager to make the Distribution current.
View source on GitHub