KEMBAR78
[ROCm] BF16 Context Parallelism MI300X Not Numerically Accurate · Issue #156012 · pytorch/pytorch · GitHub
Skip to content

[ROCm] BF16 Context Parallelism MI300X Not Numerically Accurate #156012

@functionstackx

Description

@functionstackx

🐛 Describe the bug

I am trying out Context Parallelism on MI300X with rocm/pytorch-nightly:2025-06-12-rocm6.4 and running into the following atol difference

Max abs diff: 0.05859375, atol threshold: 0.004

On H100 with nvcr.io/nvidia/pytorch:25.05-py3, the atol is 10x lower

Max abs diff: 0.00390625, atol threshold: 0.004

The reprod script & atol threshold is from the torch docs https://docs.pytorch.org/tutorials/prototype/context_parallel.html#enable-context-parallel

where atol should be 1e-03 * world_size

@hliuca can you take a look when you get the chance?

Reprod Command

torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py

Reprod Script

import os

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.nn.attention import sdpa_kernel, SDPBackend


def context_parallel_sdpa_example(world_size: int, rank: int):
    assert torch.cuda.is_available()
    assert dist.is_nccl_available()
    torch.cuda.set_device(f"cuda:{rank}")
    torch.cuda.manual_seed(0)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    device_mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
    )

    batch = 8
    nheads = 8
    qkv_len = 64
    dim = 32
    backend = SDPBackend.FLASH_ATTENTION
    dtype = (
        torch.bfloat16
        if backend == SDPBackend.FLASH_ATTENTION
        or backend == SDPBackend.CUDNN_ATTENTION
        else torch.float32
    )

    qkv = [
        torch.rand(
            (batch, nheads, qkv_len, dim),
            dtype=dtype,
            requires_grad=True,
            device='cuda',
        )
        for _ in range(3)
    ]
    # specify the SDPBackend to use
    with sdpa_kernel(backend):
        out = F.scaled_dot_product_attention(*qkv, is_causal=True)

    # make a clean copy of QKV for output comparison
    cp_qkv = [t.detach().clone() for t in qkv]

    with sdpa_kernel(backend):
        # This `context_parallel()` performs two actions:
        # 1. Shard the tensor objects in `buffers` in-place along the dimension
        #    specified in `buffer_seq_dims`, the tensors in `buffers` and their
        #    sharding dims in `buffer_seq_dims` are organized in the same order.
        # 2. Replace the execution of `F.scaled_dot_product_attention` with a
        #    context-paralleled-enabled Ring Attention.
        with context_parallel(
            device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
        ):
            cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)

        # The output `cp_out` is still sharded in the same way as QKV
        # the `context_parallel_unshard` API allows users to easily
        # unshard to gain the full tensor.
        (cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])

    atol = 1e-08 if dtype == torch.float32 else 1e-03 * world_size
    abs_diff = (cp_out - out).abs().max()
    print(f"Max abs diff: {abs_diff.item()}, atol threshold: {atol}")
    print("allclose =", torch.allclose(cp_out, out, atol=atol))


if __name__ == "__main__":
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    try:
        context_parallel_sdpa_example(world_size, rank)
    finally:
        dist.barrier()
        dist.destroy_process_group()

Versions

nightly docker image rocm/pytorch-nightly:2025-06-12-rocm6.4

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

Labels

module: rocmAMD GPU support for Pytorchoncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions