{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "MhoQ0WE77laV"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2024-08-16T09:05:47.809834Z",
"iopub.status.busy": "2024-08-16T09:05:47.809423Z",
"iopub.status.idle": "2024-08-16T09:05:47.813099Z",
"shell.execute_reply": "2024-08-16T09:05:47.812505Z"
},
"id": "_ckMIh7O7s6D"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jYysdyb-CaWM"
},
"source": [
"# Distributed Input"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S5Uhzt6vVIB2"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FbVhjPpzn6BM"
},
"source": [
"The [tf.distribute](https://www.tensorflow.org/guide/distributed_training) APIs provide an easy way for users to scale their training from a single machine to multiple machines. When scaling their model, users also have to distribute their input across multiple devices. `tf.distribute` provides APIs using which you can automatically distribute your input across devices.\n",
"\n",
"This guide will show you the different ways in which you can create distributed dataset and iterators using `tf.distribute` APIs. Additionally, the following topics will be covered:\n",
"- Usage, sharding and batching options when using `tf.distribute.Strategy.experimental_distribute_dataset` and `tf.distribute.Strategy.distribute_datasets_from_function`.\n",
"- Different ways in which you can iterate over the distributed dataset.\n",
"- Differences between `tf.distribute.Strategy.experimental_distribute_dataset`/`tf.distribute.Strategy.distribute_datasets_from_function` APIs and `tf.data` APIs as well as any limitations that users may come across in their usage.\n",
"\n",
"This guide does not cover usage of distributed input with Keras APIs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MM6W__qraV55"
},
"source": [
"## Distributed datasets"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lNy9GxjSlMKQ"
},
"source": [
"To use `tf.distribute` APIs to scale, use `tf.data.Dataset` to represent their input. `tf.distribute` works efficiently with `tf.data.Dataset`—for example, via automatic prefetching onto each accelerator device and regular performance updates. If you have a use case for using something other than `tf.data.Dataset`, please refer to the [Tensor inputs section](#tensorinputs) in this guide.\n",
"In a non-distributed training loop, first create a `tf.data.Dataset` instance and then iterate over the elements. For example:\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:47.816839Z",
"iopub.status.busy": "2024-08-16T09:05:47.816404Z",
"iopub.status.idle": "2024-08-16T09:05:50.183211Z",
"shell.execute_reply": "2024-08-16T09:05:50.182525Z"
},
"id": "pCu2Jj-21AEf"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-16 09:05:48.074269: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-16 09:05:48.095393: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-16 09:05:48.101792: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.17.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"# Helper libraries\n",
"import numpy as np\n",
"import os\n",
"\n",
"print(tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:50.186466Z",
"iopub.status.busy": "2024-08-16T09:05:50.186068Z",
"iopub.status.idle": "2024-08-16T09:05:50.698171Z",
"shell.execute_reply": "2024-08-16T09:05:50.697321Z"
},
"id": "6cnilUtmKwpa"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1723799150.647793 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.651622 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.654832 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.658024 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.669227 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.672853 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.675654 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.678590 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.681440 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.685030 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799150.687948 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0000 00:00:1723799150.690847 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
}
],
"source": [
"# Simulate multiple CPUs with virtual devices\n",
"N_VIRTUAL_DEVICES = 2\n",
"physical_devices = tf.config.list_physical_devices(\"CPU\")\n",
"tf.config.set_logical_device_configuration(\n",
" physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:50.702696Z",
"iopub.status.busy": "2024-08-16T09:05:50.702446Z",
"iopub.status.idle": "2024-08-16T09:05:52.021674Z",
"shell.execute_reply": "2024-08-16T09:05:52.020959Z"
},
"id": "zd4l1ySeLRk1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available devices:\n",
"0) LogicalDevice(name='/device:CPU:0', device_type='CPU')\n",
"1) LogicalDevice(name='/device:CPU:1', device_type='CPU')\n",
"2) LogicalDevice(name='/device:GPU:0', device_type='GPU')\n",
"3) LogicalDevice(name='/device:GPU:1', device_type='GPU')\n",
"4) LogicalDevice(name='/device:GPU:2', device_type='GPU')\n",
"5) LogicalDevice(name='/device:GPU:3', device_type='GPU')\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0000 00:00:1723799151.928305 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.930331 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.932408 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.934455 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.936495 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.938464 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.940475 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.942418 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.944365 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.946251 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.948231 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.950166 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.988870 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.990833 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.993544 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.995544 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.997527 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799151.999410 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.001394 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.003347 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.005311 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.007714 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.010177 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723799152.012651 616320 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
}
],
"source": [
"print(\"Available devices:\")\n",
"for i, device in enumerate(tf.config.list_logical_devices()):\n",
" print(\"%d) %s\" % (i, device))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:52.026782Z",
"iopub.status.busy": "2024-08-16T09:05:52.026191Z",
"iopub.status.idle": "2024-08-16T09:05:52.451986Z",
"shell.execute_reply": "2024-08-16T09:05:52.450930Z"
},
"id": "dzLKpmZICaWN"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(16, 1), dtype=float32)\n",
"tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n"
]
}
],
"source": [
"global_batch_size = 16\n",
"# Create a tf.data.Dataset object.\n",
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
"\n",
"@tf.function\n",
"def train_step(inputs):\n",
" features, labels = inputs\n",
" return labels - 0.3 * features\n",
"\n",
"# Iterate over the dataset using the for..in construct.\n",
"for inputs in dataset:\n",
" print(train_step(inputs))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ihrhYDYRrVLH"
},
"source": [
"To allow users to use `tf.distribute` strategy with minimal changes to a user’s existing code, two APIs were introduced which would distribute a `tf.data.Dataset` instance and return a distributed dataset object. A user could then iterate over this distributed dataset instance and train their model as before. Let us now look at the two APIs - `tf.distribute.Strategy.experimental_distribute_dataset` and `tf.distribute.Strategy.distribute_datasets_from_function` in more detail:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4AXoHhrsbdF3"
},
"source": [
"### `tf.distribute.Strategy.experimental_distribute_dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5mVuLZhbem8d"
},
"source": [
"#### Usage\n",
"\n",
"This API takes a `tf.data.Dataset` instance as input and returns a `tf.distribute.DistributedDataset` instance. You should batch the input dataset with a value that is equal to the global batch size. This global batch size is the number of samples that you want to process across all devices in 1 step. You can iterate over this distributed dataset in a Pythonic fashion or create an iterator using `iter`. The returned object is not a `tf.data.Dataset` instance and does not support any other APIs that transform or inspect the dataset in any way.\n",
"This is the recommended API if you don’t have specific ways in which you want to shard your input over different replicas.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:52.456326Z",
"iopub.status.busy": "2024-08-16T09:05:52.455671Z",
"iopub.status.idle": "2024-08-16T09:05:53.708579Z",
"shell.execute_reply": "2024-08-16T09:05:53.707900Z"
},
"id": "F2VeZUWUj5S4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0:
,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"}, PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
}
],
"source": [
"global_batch_size = 16\n",
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
"# Distribute input using the `experimental_distribute_dataset`.\n",
"dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
"# 1 global batch of data fed to the model in 1 step.\n",
"print(next(iter(dist_dataset)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QPceDmRht54F"
},
"source": [
"#### Properties"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Qb6nDgxiN_n"
},
"source": [
"##### Batching\n",
"\n",
"`tf.distribute` rebatches the input `tf.data.Dataset` instance with a new batch size that is equal to the global batch size divided by the number of replicas in sync. The number of replicas in sync is equal to the number of devices that are taking part in the gradient allreduce during training. When a user calls `next` on the distributed iterator, a per replica batch size of data is returned on each replica. The rebatched dataset cardinality will always be a multiple of the number of replicas. Here are a couple of\n",
"examples:\n",
"* `tf.data.Dataset.range(6).batch(4, drop_remainder=False)`\n",
" * Without distribution:\n",
" * Batch 1: [0, 1, 2, 3]\n",
" * Batch 2: [4, 5]\n",
" * With distribution over 2 replicas.\n",
" The last batch ([4, 5]) is split between 2 replicas.\n",
"\n",
" * Batch 1:\n",
" * Replica 1:[0, 1]\n",
" * Replica 2:[2, 3]\n",
" * Batch 2:\n",
" * Replica 1: [4]\n",
" * Replica 2: [5]\n",
"\n",
"\n",
"\n",
"* `tf.data.Dataset.range(4).batch(4)`\n",
" * Without distribution:\n",
" * Batch 1: [0, 1, 2, 3]\n",
" * With distribution over 5 replicas:\n",
" * Batch 1:\n",
" * Replica 1: [0]\n",
" * Replica 2: [1]\n",
" * Replica 3: [2]\n",
" * Replica 4: [3]\n",
" * Replica 5: []\n",
"\n",
"* `tf.data.Dataset.range(8).batch(4)`\n",
" * Without distribution:\n",
" * Batch 1: [0, 1, 2, 3]\n",
" * Batch 2: [4, 5, 6, 7]\n",
" * With distribution over 3 replicas:\n",
" * Batch 1:\n",
" * Replica 1: [0, 1]\n",
" * Replica 2: [2, 3]\n",
" * Replica 3: []\n",
" * Batch 2:\n",
" * Replica 1: [4, 5]\n",
" * Replica 2: [6, 7]\n",
" * Replica 3: []\n",
"\n",
"Note: The above examples only illustrate how a global batch is split on different replicas. It is not advisable to depend on the actual values that might end up on each replica as it can change depending on the implementation.\n",
"\n",
"Rebatching the dataset has a space complexity that increases linearly with the number of replicas. This means that for the multi-worker training use case the input pipeline can run into OOM errors. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IszBuubdtydp"
},
"source": [
"##### Sharding\n",
"\n",
"`tf.distribute` also autoshards the input dataset in multi-worker training with `MultiWorkerMirroredStrategy` and `TPUStrategy`. Each dataset is created on the CPU device of the worker. Autosharding a dataset over a set of workers means that each worker is assigned a subset of the entire dataset (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to ensure that at each step, a global batch size of non-overlapping dataset elements will be processed by each worker. Autosharding has a couple of different options that can be specified using `tf.data.experimental.DistributeOptions`. Note that there is no autosharding in multi-worker training with `ParameterServerStrategy`, and more information on dataset creation with this strategy can be found in the [ParameterServerStrategy tutorial](parameter_server_training.ipynb). "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:53.712364Z",
"iopub.status.busy": "2024-08-16T09:05:53.712109Z",
"iopub.status.idle": "2024-08-16T09:05:53.721737Z",
"shell.execute_reply": "2024-08-16T09:05:53.721119Z"
},
"id": "jwJtsCQhHK-E"
},
"outputs": [],
"source": [
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n",
"options = tf.data.Options()\n",
"options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA\n",
"dataset = dataset.with_options(options)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J7fj3GskHC8g"
},
"source": [
"There are three different options that you can set for the `tf.data.experimental.AutoShardPolicy`:\n",
"\n",
"* AUTO: This is the default option which means an attempt will be made to shard by FILE. The attempt to shard by FILE fails if a file-based dataset is not detected. `tf.distribute` will then fall back to sharding by DATA. Note that if the input dataset is file-based but the number of files is less than the number of workers, an `InvalidArgumentError` will be raised. If this happens, explicitly set the policy to `AutoShardPolicy.DATA`, or split your input source into smaller files such that number of files is greater than number of workers.\n",
"* FILE: This is the option if you want to shard the input files over all the workers. You should use this option if the number of input files is much larger than the number of workers and the data in the files is evenly distributed. The downside of this option is having idle workers if the data in the files is not evenly distributed. If the number of files is less than the number of workers, an `InvalidArgumentError` will be raised. If this happens, explicitly set the policy to `AutoShardPolicy.DATA`.\n",
"For example, let us distribute 2 files over 2 workers with 1 replica each. File 1 contains [0, 1, 2, 3, 4, 5] and\n",
"File 2 contains [6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2 and global batch size be 4.\n",
"\n",
" * Worker 0:\n",
" * Batch 1 = Replica 1: [0, 1]\n",
" * Batch 2 = Replica 1: [2, 3]\n",
" * Batch 3 = Replica 1: [4]\n",
" * Batch 4 = Replica 1: [5]\n",
" * Worker 1:\n",
" * Batch 1 = Replica 2: [6, 7]\n",
" * Batch 2 = Replica 2: [8, 9]\n",
" * Batch 3 = Replica 2: [10]\n",
" * Batch 4 = Replica 2: [11]\n",
"\n",
"* DATA: This will autoshard the elements across all the workers. Each of the workers will read the entire dataset and only process the shard assigned to it. All other shards will be discarded. This is generally used if the number of input files is less than the number of workers and you want better sharding of data across all workers. The downside is that the entire dataset will be read on each worker.\n",
"For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2.\n",
"\n",
" * Worker 0:\n",
" * Batch 1 = Replica 1: [0, 1]\n",
" * Batch 2 = Replica 1: [4, 5]\n",
" * Batch 3 = Replica 1: [8, 9]\n",
" * Worker 1:\n",
" * Batch 1 = Replica 2: [2, 3]\n",
" * Batch 2 = Replica 2: [6, 7]\n",
" * Batch 3 = Replica 2: [10, 11]\n",
"\n",
"* OFF: If you turn off autosharding, each worker will process all the data.\n",
"For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2. Then each worker will see the following distribution:\n",
"\n",
" * Worker 0:\n",
" * Batch 1 = Replica 1: [0, 1]\n",
" * Batch 2 = Replica 1: [2, 3]\n",
" * Batch 3 = Replica 1: [4, 5]\n",
" * Batch 4 = Replica 1: [6, 7]\n",
" * Batch 5 = Replica 1: [8, 9]\n",
" * Batch 6 = Replica 1: [10, 11]\n",
"\n",
" * Worker 1:\n",
" * Batch 1 = Replica 2: [0, 1]\n",
" * Batch 2 = Replica 2: [2, 3]\n",
" * Batch 3 = Replica 2: [4, 5]\n",
" * Batch 4 = Replica 2: [6, 7]\n",
" * Batch 5 = Replica 2: [8, 9]\n",
" * Batch 6 = Replica 2: [10, 11] "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OK46ZJGPH5H2"
},
"source": [
"##### Prefetching\n",
"\n",
"By default, `tf.distribute` adds a prefetch transformation at the end of the user provided `tf.data.Dataset` instance. The argument to the prefetch transformation which is `buffer_size` is equal to the number of replicas in sync."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PjiGSY3gtr6_"
},
"source": [
"### `tf.distribute.Strategy.distribute_datasets_from_function`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bAXAo_wWbWSb"
},
"source": [
"#### Usage\n",
"\n",
"This API takes an input function and returns a `tf.distribute.DistributedDataset` instance. The input function that users pass in has a `tf.distribute.InputContext` argument and should return a `tf.data.Dataset` instance. With this API, `tf.distribute` does not make any further changes to the user’s `tf.data.Dataset` instance returned from the input function. It is the responsibility of the user to batch and shard the dataset. `tf.distribute` calls the input function on the CPU device of each of the workers. Apart from allowing users to specify their own batching and sharding logic, this API also demonstrates better scalability and performance compared to `tf.distribute.Strategy.experimental_distribute_dataset` when used for multi-worker training."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:53.725458Z",
"iopub.status.busy": "2024-08-16T09:05:53.725188Z",
"iopub.status.idle": "2024-08-16T09:05:53.740593Z",
"shell.execute_reply": "2024-08-16T09:05:53.739983Z"
},
"id": "9ODch-OFCaW4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
]
}
],
"source": [
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"def dataset_fn(input_context):\n",
" batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
" dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n",
" dataset = dataset.shard(\n",
" input_context.num_input_pipelines, input_context.input_pipeline_id)\n",
" dataset = dataset.batch(batch_size)\n",
" dataset = dataset.prefetch(2) # This prefetches 2 batches per device.\n",
" return dataset\n",
"\n",
"dist_dataset = mirrored_strategy.distribute_datasets_from_function(dataset_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M1bpzPYzt_R7"
},
"source": [
"#### Properties"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7cgzhwiiuBvO"
},
"source": [
"##### Batching\n",
"\n",
"The `tf.data.Dataset` instance that is the return value of the input function should be batched using the per replica batch size. The per replica batch size is the global batch size divided by the number of replicas that are taking part in sync training. This is because `tf.distribute` calls the input function on the CPU device of each of the workers. The dataset that is created on a given worker should be ready to use by all the replicas on that worker. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e-wlFFZbP33n"
},
"source": [
"##### Sharding\n",
"\n",
"The `tf.distribute.InputContext` object that is implicitly passed as an argument to the user’s input function is created by `tf.distribute` under the hood. It has information about the number of workers, current worker ID etc. This input function can handle sharding as per policies set by the user using these properties that are part of the `tf.distribute.InputContext` object.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7TGwnDM-ICHf"
},
"source": [
"##### Prefetching\n",
"\n",
"`tf.distribute` does not add a prefetch transformation at the end of the `tf.data.Dataset` returned by the user-provided input function, so you explicitly call `Dataset.prefetch` in the example above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iOMsf8kyZZpv"
},
"source": [
"Note:\n",
"Both `tf.distribute.Strategy.experimental_distribute_dataset` and `tf.distribute.Strategy.distribute_datasets_from_function` return **`tf.distribute.DistributedDataset` instances that are not of type `tf.data.Dataset`**. You can iterate over these instances (as shown in the Distributed Iterators section) and use the `element_spec`\n",
"property. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dL3XbI1gzEjO"
},
"source": [
"## Distributed iterators"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w8y54-o9T2Ni"
},
"source": [
"Similar to non-distributed `tf.data.Dataset` instances, you will need to create an iterator on the `tf.distribute.DistributedDataset` instances to iterate over it and access the elements in the `tf.distribute.DistributedDataset`.\n",
"The following are the ways in which you can create a `tf.distribute.DistributedIterator` and use it to train your model:\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FlKh8NV0uOtZ"
},
"source": [
"### Usages"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eSZz6EqOuSlB"
},
"source": [
"#### Use a Pythonic for loop construct\n",
"\n",
"You can use a user friendly Pythonic loop to iterate over the `tf.distribute.DistributedDataset`. The elements returned from the `tf.distribute.DistributedIterator` can be a single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains a value per replica. Placing the loop inside a `tf.function` will give a performance boost. However, `break` and `return` are currently not supported for a loop over a `tf.distribute.DistributedDataset` that is placed inside of a `tf.function`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:53.744795Z",
"iopub.status.busy": "2024-08-16T09:05:53.744556Z",
"iopub.status.idle": "2024-08-16T09:05:54.230393Z",
"shell.execute_reply": "2024-08-16T09:05:54.229656Z"
},
"id": "zt3AHb46Tr3w"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
" 1: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
" 2: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
" 3: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32)\n",
"}\n"
]
}
],
"source": [
"global_batch_size = 16\n",
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
"dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
"\n",
"@tf.function\n",
"def train_step(inputs):\n",
" features, labels = inputs\n",
" return labels - 0.3 * features\n",
"\n",
"for x in dist_dataset:\n",
" # train_step trains the model using the dataset elements\n",
" loss = mirrored_strategy.run(train_step, args=(x,))\n",
" print(\"Loss is \", loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NchPwTEiuSqb"
},
"source": [
"#### Use `iter` to create an explicit iterator\n",
"To iterate over the elements in a `tf.distribute.DistributedDataset` instance, you can create a `tf.distribute.DistributedIterator` using the `iter` API on it. With an explicit iterator, you can iterate for a fixed number of steps. In order to get the next element from an `tf.distribute.DistributedIterator` instance `dist_iterator`, you can call `next(dist_iterator)`, `dist_iterator.get_next()`, or `dist_iterator.get_next_as_optional()`. The former two are essentially the same:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:54.234338Z",
"iopub.status.busy": "2024-08-16T09:05:54.233633Z",
"iopub.status.idle": "2024-08-16T09:05:57.405457Z",
"shell.execute_reply": "2024-08-16T09:05:57.404767Z"
},
"id": "OrMmakq5EqeQ"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n",
"Loss is PerReplica:{\n",
" 0: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 1: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 2: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32),\n",
" 3: tf.Tensor(\n",
"[[0.7]\n",
" [0.7]\n",
" [0.7]\n",
" [0.7]], shape=(4, 1), dtype=float32)\n",
"}\n"
]
}
],
"source": [
"num_epochs = 10\n",
"steps_per_epoch = 5\n",
"for epoch in range(num_epochs):\n",
" dist_iterator = iter(dist_dataset)\n",
" for step in range(steps_per_epoch):\n",
" # train_step trains the model using the dataset elements\n",
" loss = mirrored_strategy.run(train_step, args=(next(dist_iterator),))\n",
" # which is the same as\n",
" # loss = mirrored_strategy.run(train_step, args=(dist_iterator.get_next(),))\n",
" print(\"Loss is \", loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UpJXIlxjqPYg"
},
"source": [
"With `next` or `tf.distribute.DistributedIterator.get_next`, if the `tf.distribute.DistributedIterator` has reached its end, an OutOfRange error will be thrown. The client can catch the error on python side and continue doing other work such as checkpointing and evaluation. However, this will not work if you are using a host training loop (i.e., run multiple steps per `tf.function`), which looks like:\n",
"\n",
"```\n",
"@tf.function\n",
"def train_fn(iterator):\n",
" for _ in tf.range(steps_per_loop):\n",
" strategy.run(step_fn, args=(next(iterator),))\n",
"```\n",
"\n",
"This example `train_fn` contains multiple steps by wrapping the step body inside a `tf.range`. In this case, different iterations in the loop with no dependency could start in parallel, so an OutOfRange error can be triggered in later iterations before the computation of previous iterations finishes. Once an OutOfRange error is thrown, all the ops in the function will be terminated right away. If this is some case that you would like to avoid, an alternative that does not throw an OutOfRange error is `tf.distribute.DistributedIterator.get_next_as_optional`. `get_next_as_optional` returns a `tf.experimental.Optional` which contains the next element or no value if the `tf.distribute.DistributedIterator` has reached an end."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:57.409179Z",
"iopub.status.busy": "2024-08-16T09:05:57.408695Z",
"iopub.status.idle": "2024-08-16T09:05:58.028547Z",
"shell.execute_reply": "2024-08-16T09:05:58.027808Z"
},
"id": "Iyjao96Vqwyz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"([0], [1], [2], [3])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"([4], [5], [6], [7])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"([8], [], [], [])\n"
]
}
],
"source": [
"# You can break the loop with `get_next_as_optional` by checking if the `Optional` contains a value\n",
"global_batch_size = 4\n",
"steps_per_loop = 5\n",
"strategy = tf.distribute.MirroredStrategy()\n",
"\n",
"dataset = tf.data.Dataset.range(9).batch(global_batch_size)\n",
"distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n",
"\n",
"@tf.function\n",
"def train_fn(distributed_iterator):\n",
" for _ in tf.range(steps_per_loop):\n",
" optional_data = distributed_iterator.get_next_as_optional()\n",
" if not optional_data.has_value():\n",
" break\n",
" per_replica_results = strategy.run(lambda x: x, args=(optional_data.get_value(),))\n",
" tf.print(strategy.experimental_local_results(per_replica_results))\n",
"train_fn(distributed_iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LaclbKnqzLjf"
},
"source": [
"## Using the `element_spec` property"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z1YvXqOpwy08"
},
"source": [
"If you pass the elements of a distributed dataset to a `tf.function` and want a `tf.TypeSpec` guarantee, you can specify the `input_signature` argument of the `tf.function`. The output of a distributed dataset is `tf.distribute.DistributedValues` which can represent the input to a single device or multiple devices. To get the `tf.TypeSpec` corresponding to this distributed value, you can use `tf.distribute.DistributedDataset.element_spec` or `tf.distribute.DistributedIterator.element_spec`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-16T09:05:58.032268Z",
"iopub.status.busy": "2024-08-16T09:05:58.031841Z",
"iopub.status.idle": "2024-08-16T09:05:59.971322Z",
"shell.execute_reply": "2024-08-16T09:05:59.970640Z"
},
"id": "pg3B-Cw_cn3a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"})\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(PerReplica:{\n",
" 0: ,\n",
" 1: ,\n",
" 2: ,\n",
" 3: \n",
"},\n",
" PerReplica:{\n",
" 0: ,\n",
" 1: