View source on GitHub
|
Implements standard functionality on top of the AbstractTrainer API.
Inherits From: AbstractTrainer
orbit.StandardTrainer(
train_dataset,
options: Optional[orbit.StandardTrainerOptions] = None
)
This class structures the training "inner loop" roughly as follows:
train_loop_begin()
for _ in range(num_steps):
train_step(train_iterator)
return train_loop_end()
Calls to train_loop_begin and train_loop_end are always done in eager
mode, while the loop/train_step may be implemented using tf.while and/or
tf.function, as determined by the options passed to __init__.
Args | |
|---|---|
train_dataset
|
A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
|
options
|
An orbit.StandardTrainerOptions instance.
|
Attributes | |
|---|---|
name
|
Returns the name of this module as passed or determined in the ctor. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
train_dataset
|
The current training dataset. |
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
create_train_loop_fn
create_train_loop_fn()
Creates a training loop from the current step function and options.
| Returns | |
|---|---|
| The train loop function, i.e. wrapper of multiple train steps. |
train
train(
num_steps: tf.Tensor
) -> Optional[runner.Output]
Implements num_steps steps of training.
| Args | |
|---|---|
num_steps
|
The number of training steps to run. This corresponds directly
to the number of calls made to train_step.
|
| Returns | |
|---|---|
The output of train_loop_end.
|
train_loop_begin
train_loop_begin()
Called once at the beginning of the training loop.
This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.
Note that this method is called before dataset iterator creation.
train_loop_end
train_loop_end() -> Optional[runner.Output]
Called once at the end of the training loop.
This method is always called in eager mode, and is a good place to get
metric results. The value returned from this function will be returned as-is
from the train method implementation provided by StandardTrainer.
| Returns | |
|---|---|
The function may return a dictionary of Tensors, which will be
written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
|
train_step
@abc.abstractmethodtrain_step( iterator )
Implements one step of training.
What a "step" consists of is up to the implementer. When using distribution
strategies, the call to this method takes place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to strategy.run.
Note that if use_tf_function=True, all the code inside train_step should
be compatible with tf.function tracing (and in particular, any state
modifications involving self should be avoided). In some cases, non-
tf.function compatible code can be moved to train_loop_begin or
train_loop_end, which always execute eagerly.
| Args | |
|---|---|
iterator
|
A tf.nest-compatible structure of tf.data.Iterator or
DistributedIterator. The structure of this input matches the structure
of train_dataset as passed to __init__.
|
with_name_scope
@classmethodwith_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):@tf.Module.with_name_scopedef __call__(self, x):if not hasattr(self, 'w'):self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))return tf.matmul(x, self.w)
Using the above module would produce tf.Variables and tf.Tensors whose
names included the module name:
mod = MyModule()mod(tf.ones([1, 2]))<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>mod.w<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,numpy=..., dtype=float32)>
| Args | |
|---|---|
method
|
The method to wrap. |
| Returns | |
|---|---|
| The original method wrapped such that it enters the module's name scope. |
View source on GitHub