KEMBAR78
Add optimized TBE training forward by sryap · Pull Request #1641 · pytorch/FBGEMM · GitHub
Skip to content

Conversation

@sryap
Copy link
Contributor

@sryap sryap commented Mar 13, 2023

Summary:
This diff adds an optimized implementation of TBE training forward,
namely
split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel.
The implementation currently supports only a subset of usecases of TBE
including:

  • Split TBE (SplitTableBatchedEmbeddingBagsCodegen)
  • Pooled TBE (pooling_mode: PoolingMode.SUM, PoolingMode.MEAN)
  • Weighted and unweighted TBE (per_sample_weights: Tensor, None)
  • FP32 and FP16 weight types (weights_precision: SparseType.FP32,
    SparseType.FP16)
  • FP32 and FP16 output types (output_dtype: SparseType.FP32,
    SparseType.FP16)
  • Device, manged, managed caching embedding locations
    (EmbeddingLocation: EmbeddingLocation.DEVICE,
    EmbeddingLocation.MANAGED,
    EmbeddingLocation.MANAGED_CACHING)

Cases that the new implementation does NOT support:

  • Dense TBE (DenseTableBatchedEmbeddingBagsCodegen)
  • Sequence TBE (pooling_mode: PoolingMode.NONE)
  • FP8, INT8, INT4, INT2, and BF16 weight types (weights_precision:
    SparseType.FP8, SparseType.INT8, SparseType.INT4,
    SparseType.INT2, SparseType.BF16)
  • FP8, INT8, INT4, INT2, and BF16 output types (weights_precision:
    SparseType.FP8, SparseType.INT8, SparseType.INT4,
    SparseType.INT2, SparseType.BF16)
  • Host embedding locations (EmbeddingLocation:
    EmbeddingLocation.HOST)

The IS_EXPERIMENTAL environment variable flag is added for
enabling/disabling the new implementation at runtime. If
IS_EXPERIMENTAL is not set, TBE will use the orignal implementation.
If IS_EXPERIMENTAL=1, TBE will use the new implementation. If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation. By default,
IS_EXPERIMENTAL is not set.

The new implementation contains the following optimizations:

  • Use multiple warps per bag for D > 128 to maintain a constant
    number of registers per thread
  • Use subwarps to process subsets of input rows in a bag if D < 128
  • Cooperatively compute weight pointers and store them in shared
    memory
  • Save state variables in shared memory instead of registers to free
    registers for compiler optimizations
  • Use the upper bound number of warps for all tables to avoid complex
    warp offset computation
  • Process multiple samples (up to kWarpSize samples) in a warp for
    small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

@netlify
Copy link

netlify bot commented Mar 13, 2023

Deploy Preview for pytorch-fbgemm-docs canceled.

Name Link
🔨 Latest commit ce68548
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/6481811cca0ba100086b5f29

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Mar 15, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 036d8d984fc4a0fbc4ac9e5a7fb746bf783dd80f
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

@sryap sryap force-pushed the export-D43634651 branch from e0422bb to 9450a20 Compare March 15, 2023 03:56
sryap added a commit to sryap/FBGEMM that referenced this pull request Mar 15, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 9991218d631aac96f7c54f8ecadf8e46de402a66
@sryap sryap force-pushed the export-D43634651 branch from 9450a20 to 1b88367 Compare March 15, 2023 04:00
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 17, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: dbecfc183d4d93bb186a9db64c7cee81775c73aa
@sryap sryap force-pushed the export-D43634651 branch from 1b88367 to 68c23c9 Compare May 17, 2023 23:58
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 3ec1ef98e408bb8d06c44fc74a098f6b483833b2
@sryap sryap force-pushed the export-D43634651 branch from 68c23c9 to 03cb369 Compare May 19, 2023 00:50
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 42b8c5b853dd30df9bb3b2f808668d1ebf0db9a7
@sryap sryap force-pushed the export-D43634651 branch from 03cb369 to b163a5e Compare May 19, 2023 17:17
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 615518b61a305ae63e5cb6e9010ba4e9f7b689b9
@sryap sryap force-pushed the export-D43634651 branch from b163a5e to fcd6a27 Compare May 19, 2023 17:28
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: ec11b9e8c665207200f0a8699a414757d4bd005e
@sryap sryap force-pushed the export-D43634651 branch from fcd6a27 to 19c296d Compare May 19, 2023 17:34
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 2ede3008140d8f0d1c33d34867d0c3aaf3c98ce0
@sryap sryap force-pushed the export-D43634651 branch from 19c296d to 7f51b6a Compare May 19, 2023 23:33
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request May 19, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: b36d4982c245f22187d27f150faaa93f541f7a5a
@sryap sryap force-pushed the export-D43634651 branch from 7f51b6a to 77ef167 Compare May 19, 2023 23:38
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 6, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 0f34b1e02932a71d225e19a44c45f18f29fc5e7c
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 6, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.

**Usage**

The frontend changes are in D44479772

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation.  If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation.  If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation.  By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.

This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.

```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=...,
    ...,
    use_experimental_tbe=True,
)
```

**Optimization**

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Reviewed By: jianyuh

Differential Revision: D43634651

fbshipit-source-id: 96ad56f0e5567959fd28c72a649f862e1f5dd307
@sryap sryap force-pushed the export-D43634651 branch from 4ad7dc0 to 75318b0 Compare June 6, 2023 08:38
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 6, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 42fea6790c0fef1e60bae3d57c247ca61da46ec0
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 6, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.

**Usage**

The frontend changes are in D44479772

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation.  If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation.  If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation.  By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.

This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.

```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=...,
    ...,
    use_experimental_tbe=True,
)
```

**Optimization**

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Reviewed By: jianyuh

Differential Revision: D43634651

fbshipit-source-id: 64d0d0752fc2689dae75ea1064a7c80551d3a15f
@sryap sryap force-pushed the export-D43634651 branch from 75318b0 to 670a498 Compare June 6, 2023 08:44
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 7, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 8be8753cba190ffde4c3be3a9e016cf09a99b5d4
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 7, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 3a3c3ce39c6a1deb1e217581e3717b98e7629e04
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 7, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: d03445efde04e978e8f5bb8853452a5c85ed9236
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 7, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 27876f466229cf1fd6a0aeb66e3d35bd6b43f930
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 7, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.

**Usage**

The frontend changes are in D44479772

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation.  If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation.  If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation.  By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.

This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.

```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=...,
    ...,
    use_experimental_tbe=True,
)
```

**Optimization**

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Reviewed By: jianyuh

Differential Revision: D43634651

fbshipit-source-id: 30f9fc00c306515400e89d2f7c78063b75630722
@sryap sryap force-pushed the export-D43634651 branch from 670a498 to 12fd5ce Compare June 7, 2023 17:50
sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 8, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.

**Usage**

The frontend changes are in D44479772

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation.  If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation.  If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation.  By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.

This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.

```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=...,
    ...,
    use_experimental_tbe=True,
)
```

**Optimization**

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Reviewed By: jianyuh

Differential Revision: D43634651

fbshipit-source-id: 3d5c90de057af284014a4a916f8aac1e0361750b
@sryap sryap force-pushed the export-D43634651 branch from 12fd5ce to c6d7a7f Compare June 8, 2023 07:12
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 8, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 1c07c1e62f9fabde9ca4a5b166d666d8d01b1cf3
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.

**Usage**

The frontend changes are in D44479772

The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation.  If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation.  If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation.  By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.

This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.

```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=...,
    ...,
    use_experimental_tbe=True,
)
```

**Optimization**

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Reviewed By: jianyuh

Differential Revision: D43634651

fbshipit-source-id: 0e72b4809d2a7e26a8db88d8639c3d329ddd34ec
@sryap sryap force-pushed the export-D43634651 branch from c6d7a7f to ce68548 Compare June 8, 2023 07:19
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43634651

sryap added a commit to sryap/FBGEMM that referenced this pull request Jun 8, 2023
Summary:
Pull Request resolved: pytorch#1641

This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:

- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
  `SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
  `SparseType.FP16`)
- Device, manged, managed caching embedding locations
  (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
  `EmbeddingLocation.MANAGED`,
  `EmbeddingLocation.MANAGED_CACHING`)

Cases that the new implementation does **NOT** support:

- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
  `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
  `SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
  `EmbeddingLocation.HOST`)

The `IS_EXPERIMENTAL` environment variable flag is added for
enabling/disabling the new implementation at runtime.  If
`IS_EXPERIMENTAL` is not set, TBE will use the orignal implementation.
If `IS_EXPERIMENTAL=1`, TBE will use the new implementation.  If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation.  By default,
`IS_EXPERIMENTAL` is not set.

The new implementation contains the following optimizations:

- Use multiple warps per bag for D > 128 to maintain a constant
  number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
  memory
- Save state variables in shared memory instead of registers to free
  registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
  warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
  small Ls

Note: D = embedding dimension, L = pooling factor

Differential Revision: D43634651

fbshipit-source-id: 6953f7f8c9fd3a415d1ea5ed2af771ea85eb1d84
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in d1c4a6f.

@sryap
Copy link
Contributor Author

sryap commented Jun 13, 2023

@liligwu FYI, we currently disable this functionality on ROCm due to various compilation errors. This is the optimized table batched embedding implementation. Currently it is not used by default but this might change in the future. We are considering replacing the old implementation with the new one.

@liligwu
Copy link
Contributor

liligwu commented Jun 13, 2023

Hi @sryap , thank you for letting us know the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants