View source on GitHub
|
Keeps track of the object created by invoking trackable_factory_callable.
tft.make_and_track_object(
trackable_factory_callable: Callable[[], base.Trackable],
name: Optional[str] = None
) -> base.Trackable
This API is only for use when Transform APIs are run with TF2 behaviors
enabled and tft_beam.Context.force_tf_compat_v1 is set to False.
Use this API to track TF Trackable objects created in the preprocessing_fn
such as tf.hub modules, tf.data.Dataset etc. This ensures they are serialized
correctly when exporting to SavedModel.
Example:
def preprocessing_fn(inputs):dataset = tft.make_and_track_object(lambda: tf.data.Dataset.from_tensor_slices([1, 2, 3]))with tf.init_scope():dataset_list = list(dataset.as_numpy_iterator())return {'x_0': dataset_list[0] + inputs['x']}raw_data = [dict(x=1), dict(x=2), dict(x=3)]feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64))raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec)with tft_beam.Context(temp_dir=tempfile.mkdtemp(),force_tf_compat_v1=False):transformed_dataset, transform_fn = ((raw_data, raw_data_metadata)| tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))transformed_data, transformed_metadata = transformed_datasettransformed_data[{'x_0': 2}, {'x_0': 3}, {'x_0': 4}]
Returns | |
|---|---|
The object returned when trackable_factory_callable is invoked. The object
creation is lifted out to the eager context using tf.init_scope.
|
View source on GitHub