View source on GitHub
|
Common utilities for TF-Agents.
Classes
class AggregatedLosses: AggregatedLosses(total_loss, weighted, regularization)
class Checkpointer: Checkpoints training state, policy state, and replay_buffer state.
class EagerPeriodically: EagerPeriodically performs the ops defined in body.
class OUProcess: A zero-mean Ornstein-Uhlenbeck process.
class Periodically: Periodically performs the ops defined in body.
Functions
aggregate_losses(...): Aggregates and scales per example loss and regularization losses.
assert_members_are_not_overridden(...): Asserts public members of base_cls are not overridden in instance.
check_matching_networks(...): Check that two networks have matching input specs and variables.
check_no_shared_variables(...): Checks that there are no shared trainable variables in the two networks.
check_tf1_allowed(...): Raises an error if running in TF1 (non-eager) mode and this is disabled.
clip_to_spec(...): Clips value to a given bounded tensor spec.
compute_returns(...): Compute the return from each index in an episode.
convert_q_logits_to_values(...): Converts a set of Q-value logits into Q-values using the provided support.
create_variable(...): Create a variable.
deduped_network_variables(...): Returns a list of variables in net1 that are not in any other nets.
discounted_future_sum(...): Discounted future sum of batch-major values.
discounted_future_sum_masked(...): Discounted future sum of batch-major values.
element_wise_squared_loss(...)
entropy(...): Computes total entropy of distribution.
extract_shared_variables(...): Separates shared variables from the given collections.
function(...): Wrapper for tf.function with TF Agents-specific customizations.
function_in_tf1(...): Wrapper that returns common.function if using TF1.
generate_tensor_summaries(...): Generates various summaries of tensor such as histogram, max, min, etc.
get_contiguous_sub_episodes(...): Computes mask on sub-episodes which includes only contiguous components.
get_episode_mask(...): Create a mask that is 0.0 for all final steps, 1.0 elsewhere.
has_eager_been_enabled(...): Returns true iff in TF2 or in TF1 with eager execution enabled.
index_with_actions(...): Index into q_values using actions.
initialize_uninitialized_variables(...): Initialize any pending variables that are uninitialized.
join_scope(...): Joins a parent and child scope using /, checking for empty/none.
load_spec(...): Loads a data spec from a file.
log_probability(...): Computes log probability of actions given distribution.
maybe_copy_target_network_with_checks(...): Copies the network into target if None and checks for shared variables.
ornstein_uhlenbeck_process(...): An op for generating noise from a zero-mean Ornstein-Uhlenbeck process.
periodically(...): Periodically performs the tensorflow op in body.
replicate(...): Replicates a tensor so as to match the given outer shape.
resource_variables_enabled(...)
safe_has_state(...): Safely checks state not in (None, (), []).
save_spec(...): Saves the given spec nest as a StructProto.
scale_to_spec(...): Shapes and scales a batch into the given spec bounds.
set_default_tf_function_parameters(...): Generates a decorator that sets default parameters for tf.function.
shift_values(...): Shifts batch-major values in time by some amount.
soft_device_placement(...): Context manager for soft device placement, allowing summaries on CPU.
soft_variables_update(...): Performs a soft/hard update of variables from the source to the target.
spec_means_and_magnitudes(...): Get the center and magnitude of the ranges in action spec.
summarize_tensor_dict(...): Generates summaries of all tensors in tensor_dict.
transpose_batch_time(...): Transposes the batch and time dimensions of a Tensor.
View source on GitHub