View source on GitHub
|
Wraps a python function into a TensorFlow op that executes it eagerly.
tf.py_function(
func=None, inp=None, Tout=None, name=None
)
Used in the notebooks
| Used in the guide | Used in the tutorials |
|---|---|
Using tf.py_function inside a tf.function allows you to run a python
function using eager execution, inside the tf.function's graph.
This has two main effects:
- This allows you to use nofunc=None, inp=None, Tout=None tensorflow code
inside your
tf.function. - It allows you to run python control logic in a
tf.functionwithout relying ontf.autographto convert the code to use tensorflow control logic (tf.cond, tf.while_loop).
Both of these features can be useful for debugging.
Since tf.py_function operates on Tensors it is still
differentiable (once).
There are two ways to use this function:
As a decorator
Use tf.py_function as a decorator to ensure the function always runs
eagerly.
When using tf.py_function as a decorator:
- you must set
Tout - you may set
name - you must not set
funcorinp
For example, you might use tf.py_function to
implement the log huber function.
@tf.py_function(Tout=tf.float32)def py_log_huber(x, m):print('Running with eager execution.')if tf.abs(x) <= m:return x**2else:return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
Under eager execution the function operates normally:
x = tf.constant(1.0)m = tf.constant(2.0)print(py_log_huber(x,m).numpy())Running with eager execution.1.0
Inside a tf.function the tf.py_function is not converted to a tf.Graph.:
@tf.functiondef tf_wrapper(x):print('Tracing.')m = tf.constant(2.0)return py_log_huber(x,m)
The tf.py_function only executes eagerly, and only when the tf.function
is called:
print(tf_wrapper(x).numpy())Tracing.Running with eager execution.1.0print(tf_wrapper(x).numpy())Running with eager execution.1.0
Gradients work as expected:
with tf.GradientTape() as t:t.watch(x)y = tf_wrapper(x)Running with eager execution.t.gradient(y, x).numpy()2.0
Inplace
You can also skip the decorator and use tf.py_function in-place.
This form is a useful shortcut if you don't control the function's source,
but it is harder to read.
# No decoratordef log_huber(x, m):if tf.abs(x) <= m:return x**2else:return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))x = tf.constant(1.0)m = tf.constant(2.0)tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32).numpy()1.0
More info
You can also use tf.py_function to debug your models at runtime
using Python tools, i.e., you can isolate portions of your code that
you want to debug, wrap them in Python functions and insert pdb tracepoints
or print statements as desired, and wrap those functions in
tf.py_function.
For more information on eager execution, see the Eager guide.
tf.py_function is similar in spirit to tf.numpy_function, but unlike
the latter, the former lets you use TensorFlow operations in the wrapped
Python function. In particular, while tf.compat.v1.py_func only runs on CPUs
and wraps functions that take NumPy arrays as inputs and return NumPy arrays
as outputs, tf.py_function can be placed on GPUs and wraps functions
that take Tensors as inputs, execute TensorFlow operations in their bodies,
and return Tensors as outputs.
Calling
tf.py_functionwill acquire the Python Global Interpreter Lock (GIL) that allows only one thread to run at any point in time. This will preclude efficient parallelization and distribution of the execution of the program.The body of the function (i.e.
func) will not be serialized in aGraphDef. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment.The operation must run in the same address space as the Python program that calls
tf.py_function(). If you are using distributed TensorFlow, you must run atf.distribute.Serverin the same process as the program that callstf.py_function()and you must pin the created operation to a device in that server (e.g. usingwith tf.device():).Currently
tf.py_functionis not compatible with XLA. Callingtf.py_functioninsidetf.function(jit_compile=True)will raise an error.
Args | |
|---|---|
func
|
A Python function that accepts inp as arguments, and returns a value
(or list of values) whose type is described by Tout. Do not set func
when using tf.py_function as a decorator.
|
inp
|
Input arguments for func. A list whose elements are Tensors or
CompositeTensors (such as tf.RaggedTensor); or a single Tensor or
CompositeTensor. Do not set inp when using tf.py_function as a
decorator.
|
Tout
|
The type(s) of the value(s) returned by func. One of the following.
|
name
|
A name for the operation (optional). |
Returns | |
|---|---|
|
View source on GitHub