View source on GitHub
|
Gathers and serializes a checkpoint view.
tf.train.CheckpointView(
save_path
)
This is for loading specific portions of a module from a checkpoint, and be able to compare two modules by matching components.
Example usage:
class SimpleModule(tf.Module):def __init__(self, name=None):super().__init__(name=name)self.a_var = tf.Variable(5.0)self.b_var = tf.Variable(4.0)self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
root = SimpleModule(name="root")root.leaf = SimpleModule(name="leaf")ckpt = tf.train.Checkpoint(root)save_path = ckpt.save('/tmp/tf_ckpts')checkpoint_view = tf.train.CheckpointView(save_path)
Pass node_id=0 to tf.train.CheckpointView.children() to get the dictionary
of all children directly linked to the checkpoint root.
for name, node_id in checkpoint_view.children(0).items():print(f"- name: '{name}', node_id: {node_id}")- name: 'a_var', node_id: 1- name: 'b_var', node_id: 2- name: 'vars', node_id: 3- name: 'leaf', node_id: 4- name: 'root', node_id: 0- name: 'save_counter', node_id: 5
Args | |
|---|---|
save_path
|
The path to the checkpoint. |
Raises | |
|---|---|
ValueError
|
If the save_path does not lead to a TF2 checkpoint. |
Methods
children
children(
node_id
)
Returns all child trackables attached to obj.
| Args | |
|---|---|
node_id
|
Id of the node to return its children. |
| Returns | |
|---|---|
| Dictionary of all children attached to the object with name to node_id. |
descendants
descendants()
Returns a list of trackables by node_id attached to obj.
diff
diff(
obj
)
Returns diff between CheckpointView and Trackable.
This method is intended to be used to compare the object stored in a
checkpoint vs a live model in Python. For example, if checkpoint
restoration fails the assert_consumed() or
assert_existing_objects_matched() checks, you can use this to list out
the objects/checkpoint nodes which were not restored.
Example Usage:
class SimpleModule(tf.Module):def __init__(self, name=None):super().__init__(name=name)self.a_var = tf.Variable(5.0)self.b_var = tf.Variable(4.0)self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
root = SimpleModule(name="root")leaf = root.leaf = SimpleModule(name="leaf")leaf.leaf3 = tf.Variable(6.0, name="leaf3")leaf.leaf4 = tf.Variable(7.0, name="leaf4")ckpt = tf.train.Checkpoint(root)save_path = ckpt.save('/tmp/tf_ckpts')checkpoint_view = tf.train.CheckpointView(save_path)
root2 = SimpleModule(name="root")leaf2 = root2.leaf2 = SimpleModule(name="leaf2")leaf2.leaf3 = tf.Variable(6.0)leaf2.leaf4 = tf.Variable(7.0)
Pass node_id=0 to tf.train.CheckpointView.children() to get the
dictionary of all children directly linked to the checkpoint root.
checkpoint_view_diff = checkpoint_view.diff(root2)checkpoint_view_match = checkpoint_view_diff[0].items()for item in checkpoint_view_match:print(item)(0, ...)(1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)(2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)(3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))(6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)(7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
only_in_checkpoint_view = checkpoint_view_diff[1]print(only_in_checkpoint_view)[4, 5, 8, 9, 10, 11, 12, 13, 14]
only_in_trackable = checkpoint_view_diff[2]print(only_in_trackable)[..., <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>,ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]),<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=6.0>,<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>,<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]
| Args | |
|---|---|
obj
|
Trackable root.
|
| Returns | |
|---|---|
Tuple of (
|
match
match(
obj
)
Returns all matching trackables between CheckpointView and Trackable.
Matching trackables represents trackables with the same name and position in graph.
| Args | |
|---|---|
obj
|
Trackable root.
|
| Returns | |
|---|---|
Dictionary containing all overlapping trackables that maps node_id to
Trackable.
|
Example usage:
class SimpleModule(tf.Module):def __init__(self, name=None):super().__init__(name=name)self.a_var = tf.Variable(5.0)self.b_var = tf.Variable(4.0)self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
root = SimpleModule(name="root")leaf = root.leaf = SimpleModule(name="leaf")leaf.leaf3 = tf.Variable(6.0, name="leaf3")leaf.leaf4 = tf.Variable(7.0, name="leaf4")ckpt = tf.train.Checkpoint(root)save_path = ckpt.save('/tmp/tf_ckpts')checkpoint_view = tf.train.CheckpointView(save_path)
root2 = SimpleModule(name="root")leaf2 = root2.leaf2 = SimpleModule(name="leaf2")leaf2.leaf3 = tf.Variable(6.0)leaf2.leaf4 = tf.Variable(7.0)
Pass node_id=0 to tf.train.CheckpointView.children() to get the
dictionary of all children directly linked to the checkpoint root.
checkpoint_view_match = checkpoint_view.match(root2).items()for item in checkpoint_view_match:print(item)(0, ...)(1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)(2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)(3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))(6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)(7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
View source on GitHub