View source on GitHub
|
Returns the single element of the dataset as a nested structure of tensors. (deprecated)
tf.data.experimental.get_single_element(
dataset
)
The function enables you to use a tf.data.Dataset in a stateless
"tensor-in tensor-out" expression, without creating an iterator.
This facilitates the ease of data transformation on tensors using the
optimized tf.data.Dataset abstraction on top of them.
For example, lets consider a preprocessing_fn which would take as an
input the raw features and returns the processed feature along with
it's label.
def preprocessing_fn(raw_feature):
# ... the raw_feature is preprocessed as per the use-case
return feature
raw_features = ... # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
.batch(BATCH_SIZE))
processed_features = tf.data.experimental.get_single_element(dataset)
In the above example, the raw_features tensor of length=BATCH_SIZE
was converted to a tf.data.Dataset. Next, each of the raw_feature was
mapped using the preprocessing_fn and the processed features were
grouped into a single batch. The final dataset contains only one element
which is a batch of all the processed features.
Now, instead of creating an iterator for the dataset and retrieving the
batch of features, the tf.data.experimental.get_single_element() function
is used to skip the iterator creation process and directly output the batch
of features.
This can be particularly useful when your tensor transformations are
expressed as tf.data.Dataset operations, and you want to use those
transformations while serving your model.
Keras
model = ... # A pre-built or custom model
class PreprocessingModel(tf.keras.Model):
def __init__(self, model):
super().__init__(self)
self.model = model
@tf.function(input_signature=[...])
def serving_fn(self, data):
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
ds = ds.batch(batch_size=BATCH_SIZE)
return tf.argmax(
self.model(tf.data.experimental.get_single_element(ds)),
axis=-1
)
preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
signatures={'serving_default': preprocessing_model.serving_fn})
Args | |
|---|---|
dataset
|
A tf.data.Dataset object containing a single element.
|
Returns | |
|---|---|
A nested structure of tf.Tensor objects, corresponding to the single
element of dataset.
|
Raises | |
|---|---|
TypeError
|
if dataset is not a tf.data.Dataset object.
|
InvalidArgumentError
|
(at runtime) if dataset does not contain exactly
one element.
|
View source on GitHub