View source on GitHub
|
A Trimmer that allocates a length budget to segments via round robin.
Inherits From: Trimmer
text.RoundRobinTrimmer(
max_seq_length, axis=-1
)
Used in the notebooks
| Used in the guide |
|---|
A Trimmer that allocates a length budget to segments using a round robin
strategy, then drops elements outside of the segment's allocated budget.
See generate_mask() for more details.
Args | |
|---|---|
max_seq_length
|
a scalar Tensor int32 that describes the number max
number of elements allowed in a batch.
|
axis
|
Axis to apply trimming on. |
Methods
generate_mask
generate_mask(
segments
)
Calculates a truncation mask given a per-batch budget.
Calculate a truncation mask given a budget of the max number of items for each or all batch row. The allocation of the budget is done using a 'round robin' algorithm. This algorithm allocates quota in each bucket, left-to-right repeatedly until all the buckets are filled.
For example if the budget of [5] and we have segments of size [3, 4, 2], the truncate budget will be allocated as [2, 2, 1].
| Args | |
|---|---|
segments
|
A list of RaggedTensors each with a shape of [num_batch,
(num_items)].
|
| Returns | |
|---|---|
A list with len(segments) of RaggedTensors, see superclass for details.
|
trim
trim(
segments
)
Truncate the list of segments.
Truncate the list of segments using the 'round-robin' strategy which
allocates quota in each bucket, left-to-right repeatedly until all buckets
are filled.
For example if the budget of [5] and we have segments of size [3, 4, 2], the truncate budget will be allocated as [2, 2, 1].
| Args | |
|---|---|
segments
|
A list of RaggedTensors w/ shape [num_batch, (num_items)].
|
| Returns | |
|---|---|
A list with len(segments) of RaggedTensors, see superclass for details.
|
View source on GitHub