KEMBAR78
Move all Stream and Event Python implementation to C++ by mrshenli · Pull Request #15937 · pytorch/pytorch · GitHub
Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Jan 10, 2019

  1. Added torch/csrc/cuda/Event.h and torch/csrc/cuda/Event.cpp to bind Python Event class to C++ implementation.
  2. Move all CUDA runtime invocations from torch/cuda/streams.py to C++
  3. Added tests to cover Stream and Event APIs. (event IPC handle tests is introduced in Adding CUDA event IPC test #15974)

@mrshenli mrshenli force-pushed the event branch 4 times, most recently from f8f2a6a to dc2417c Compare January 13, 2019 20:40
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

2. raise error if calling cudaDeviceGetStreamPriorityRange or
cudaStreamGetPriority from Rocm
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli mrshenli changed the title [WIP][Don't Review Yet] Move all Stream and Event Python implementation to C++ Move all Stream and Event Python implementation to C++ Jan 13, 2019
@mrshenli mrshenli requested a review from colesbury January 14, 2019 01:37
Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks great.

I'm a bit concerned about potential bugs in CUDAEvent due to the lazy initialization. Some of these issues predate this PR. I've commented about the ones I saw in new functions.

self.assertFalse(s1.query())

with torch.cuda.device(d1):
# delibrately using a different device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: delibrately -> deliberately

PyTypeObject *type, PyObject *args, PyObject *kwargs) {
HANDLE_TH_ERRORS

int current_device;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines don't look like they're used

CUDAEvent(const cudaIpcEventHandle_t* handle)
: CUDAEvent(getCurrentCUDAStream(), handle) { }

CUDAEvent(const CUDAStream& stream, const cudaIpcEventHandle_t* handle) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to just take a at::Device or at::DeviceIndex since you don't need a stream, just a device.


bool happened() const {
return (was_recorded_ && cudaEventQuery(event_) == cudaSuccess);
if (was_recorded_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remove happened() since it is the same as query(), but is incorrect if the Event was opened from an IPC handle.


float elapsed_time(const CUDAEvent& other) const {
float time_ms = 0;
// raise cudaErrorNotReady if either event is recorded but not yet completed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to check that this and other were created. Otherwise, the event_ fields may have undefined values.

bool THCPEvent_init(PyObject *module) {
THCPEventClass = (PyObject*)&THCPEventType;
if (PyType_Ready(&THCPEventType) < 0)
return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. throw python_error() instead of return false

if (PyType_Ready(&THCPEventType) < 0)
return false;
Py_INCREF(&THCPEventType);
PyModule_AddObject(module, "_CudaEventBase", (PyObject *)&THCPEventType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also check the return value and throw python_error() if necessary here. PyModule_AddObject returns -1 on failure.

self->cuda_stream.priority_range();

PyTuple_SET_ITEM(tuple.get(), 0, PyLong_FromLong(least_priority));
PyTuple_SET_ITEM(tuple.get(), 1, PyLong_FromLong(greatest_priority));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLong_FromLong can also potentially fail. Instead of checking PyLong_FromLong and PyTuple_New call, you can combine everything with one Py_BuildValue which will simplify the code:

return Py_BuildValue("(ii)", least_priority, greatest_priority);

() means tuple
i means int

https://docs.python.org/3/c-api/arg.html?highlight=buildvalue#c.Py_BuildValue

CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
: flags_{flags} { }

CUDAEvent(const cudaIpcEventHandle_t* handle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ambivalent about this constructor. I think just keeping the one below which is more explicit about the device is better.

If you do keep this one, it should be marked explicit to avoid allowing implicit conversions in C++. The above CUDAEvent(unsigned int flags) should also be marked explicit.

// by the corresponding THCPEvent python object.
// see https://docs.python.org/3/c-api/arg.html#strings-and-buffers
new (&self->cuda_event) at::cuda::CUDAEvent(
(const cudaIpcEventHandle_t *) handle_bytes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this cast violates the strict aliasing rule in C++. It might not because handle_bytes is a char*, but I'm not sure. It may still violate the rule because of the underlying type of the object in the Python API.

To be safe, I think it's slightly preferable to copy into a cudaIpcEventHandle_t here. Something like:

cudaIpcEventHandle_t handle;
std::mempcy(&handle, handle_bytes, sizeof(handle));
new (&self->cuda_event) at::cuda::CUDAEvent(&handle);

AT_CHECK(is_created_,
"Events must be recorded before creating IPC handles.");

AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this also needs to happen on the event's device. From what I've seen it may segfault if called from the wrong device.

(Also, add a test for this please)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do, thanks!

@mrshenli
Copy link
Contributor Author

@colesbury I found some interesting behavior:

  1. When using a single gpu, reconstruct event from handle in a different process will block until the device wakes up. This is probably because the gpu was busy spinning (not sure, and looks weird if it needs to access gpu to reconstruct an event).

  2. The following behavior is even weirder. When using two gpus, the even reconstruction no longer blocks. However, the sub-process join() blocks even though the sub-process did not try to synchronize on the event.

    def _test_event_handle_consumer(handle):
        d1 = torch.device('cuda:1')
        with torch.cuda.device(d1):
            e1 = torch.cuda.Event(_handle=handle)

    def test_event_handle(self):
        d0 = torch.device('cuda:0')
        with torch.cuda.device(d0):
            e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
            self.assertTrue(e0.query())

            torch.cuda._sleep(5000000000)  # spin for about 5 s
            e0.record()

            ctx = mp.get_context('spawn')
            p = ctx.Process(target=TestMultiprocessing._test_event_handle_consumer,
                            args=(e0.ipc_handle(),))
            p.start()

            self.assertFalse(e0.query())
            p.join()         # blocks until d0 finish spinning
            self.assertTrue(e0.query())
        """
  1. The third situation could be either a bug in CUDA/PyTorch or need an doc update. In the following code, the exporting event synchronize() will block forever if the importing event recorded some activity and then destructed. CUDA docs says:

Performing operations on the imported event after the exported event has been freed with cudaEventDestroy will result in undefined behavior.

But this experiment seems to indicate that performing operations on exported event after imported event has been freed will also result in undefined behavior.

    def _test_event_multi_gpu_consumer(handle):
        d1 = torch.device('cuda:1')
        with torch.cuda.device(d1):
            stream = torch.cuda.Stream()
            with torch.cuda.stream(stream):
                e1 = torch.cuda.Event(_handle=handle)
                torch.cuda._sleep(5000000000)
                e1.record()

    def test_event_multi_gpu(self):
        d0 = torch.device('cuda:0')

        with torch.cuda.device(d0):
            e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
            manager = mp.Manager()
            ret_list = manager.list()

            ctx = mp.get_context('spawn')
            p = ctx.Process(
                target=TestMultiprocessing._test_event_multi_gpu_consumer,
                args=(e0.ipc_handle(), ))
            p.start()
            p.join()

            self.assertFalse(e0.query())
            e0.synchronize()       # blocks forever
            self.assertTrue(e0.query())

@colesbury
Copy link
Member

Thanks for investigating this.

For (1) and (2) I think there are synchronizations when CUDA initializes a device in a process and when it destructs. You can avoid (1) in the tests by using a mp.SimpleQueue to pass the event after CUDA is initialized in the child process. (2) is probably unavoidable.

I think (3) is likely because e0 is treated as if it's on device 0 in the parent process and on device 1 in the child process. My guess is that this isn't allowed, even if it sort-of works in some situations. Hopefully, Mike Ruberry gets some more info on this, but I think we should prevent it.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. There are some API changes that I think should be made to CUDAEvent to bring it inline with the C++ Stream API.

optional<Device> device() const; -- should assume kCUDA device type
DeviceIndex device_index() const; getter for device_index_

The Python Event API should also have a read-only device property. It can be None if the event is not created and doesn't have a device.

The API documentation could use some clarification. I'll add inline comments as I go through it.

// activities. Users need to make sure the last recording event (either
// original or reconstructed) must not be destructed when synchronize() is
// called. Otherwise, the behavior is undefined.
explicit CUDAEvent(const cudaIpcEventHandle_t* handle) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature should be:

CUDAEvent(Device device, const cudaIpcEventHandle_t* handle);

It's better to be explicit about the device (instead of assuming the current device) because it's important that the event be reconstructed on the same device that it was created on.

(I'm ambivalent about the choice of Device vs. DeviceIndex, but I think Device will work better with our other APIs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hesitated a bit on this API change. It feels a little inconsistent with Event.record() where it implicitly getting the device from the current stream. If we make the new API explicit, do we also want to enforce existing API to be explicit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that the device here is required information. The exported event already has a device, it's just not encoded in a visible way in cudaIpcEventHandle_t. The full information to import an event is both the device and the cudaIpcEventHandle_t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I am not sure if this is just luck or guaranteed behavior. The test_event_multi_gpu test here seems to suggest that an event can be reconstructed from handle on any device (does not have to be the same as the exported event). Let me run some more test.

// Note: the original event and the reconstructed event now share recorded
// activities. Users need to make sure the last recording event (either
// original or reconstructed) must not be destructed when synchronize() is
// called. Otherwise, the behavior is undefined.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by the description here. I don't see what's special about synchronize(). In general, it's undefined behavior to call any function on an imported event after the exported event has been freed.

Perhaps it's better to paraphrase the CUDA documentation. Something like:

Opens an interprocess event handle exported from another process.

The imported event behaves like a locally created event with the cudaEventDisableTiming flag
specified. Performing operations on the imported event after the exported event has been freed
will result in undefined behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to clean up this doc before that push. I updated the doc in subsequent push. I think it is a little bit more complicated than CUDA's official doc. For example, the official doc did not state whether we can all methods on the exported event when the imported event is destructed. From what I see, I think the dependency is symmetric. Whenever a method (except record?) is called, that last recording event (can be either exported or imported event) must not be destructed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the official doc did not state whether we can all methods on the exported event when the imported event is destructed.

This can be assumed and need not be stated.

From what I see, I think the dependency is symmetric. Whenever a method (except record?) is called, that last recording event (can be either exported or imported event) must not be destructed.

I checked and I don't think this is the case. You can call synchronize(), wait(), and query() on the original event even if it's only recorded (and deleted) in another process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just verified, you are right. I might have misread some test results yesterday.

CUDAGuard guard(static_cast<int16_t>(stream.device_index()));

if (is_created_) {
AT_ASSERT(device_index_ == stream.device_index());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be an AT_CHECK with a useful error message instead of AT_ASSERT

unsigned int flags_ = DEFAULT_FLAGS;
bool is_created_ = false;
bool was_recorded_ = false;
int64_t device_index_ = -1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make device_index_ a DeviceIndex instead of int64_t and remove all the casts to int16_t in things like CUDAGuard guard(static_cast<int16_t>(device_index_));

blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
interprocess (bool): if ``True``, the event can be shared between processes
(default: ``False``)
_handle (bytes-like object, optional): acquired by calling Event.ipc_handle()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should document _handle. It's intended as a private API. The preferred way to pass an even to another process is to pass the event itself (not to call ipc_handle() and Event(_handle=...))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this a bit, and am a little inclined to think that using handle might actually be better in some cases. Suppose we have a parent process spawning a bunch of children process, with each child working on a different device. We have at least two options to synchronize parent and children:

Option 1. Create all events in the parent process, and pass an event (associated with the correct device) to each child. In this case, the parent process need to do many ctx switching, and will need to use a loop synchronize with all children. Moreover, if children also need to synchronize between each other, every child need to keep all other children's event object, and will also need to use a loop to synchronize.

Option 2. Create an event on master, and then pass the same handle to each child. In this case, every process has only a single event, and looks more concise. (actually, let me add a test for creating two events from the same handle).

I don't have a strong preference, and open to any discussion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can implement both cases in user code without ever calling the Event(_handle=...) constructor or explicitly calling ipc_handle().

Something like:

event = torch.cuda.Event()
p1 = mp.Process(target=target1, args=(event,))
p2 = mp.Process(target=target1, args=(event,))
p3 = mp.Process(target=target1, args=(event,))
p1.start(); p2.start(); p3.start()

It's better to pass the event instead of the handle because then we can encode the device and ensure that it's opened on the correct device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might miss something. I thought if p1, p2, and p3 operate on different devices, they cannot use that event to record, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have an unverified assumption regarding whether an event can be reconstructed from handle on any device. Let me run some more tests.

enable_timing (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool): if ``True``, :meth:`wait` will be blocking (default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This predates your change, but the description is wrong. cudaEventBlockingSync affects the type of synchronization used in synchronize() (busy-wait vs. blocking). It doesn't affect wait().

* device is acquired from the first recording stream. However, if constructed
* from a handle or ipc_handle() is called before it is ever recorded, the device
* will be acquired from current stream. Later streams that record to the event
* must share this device, but streams on any device can query and wait on the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the added bit about "query". It's not a stream operation.

* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
* device is acquired from the first recording stream. However, if constructed
* from a handle or ipc_handle() is called before it is ever recorded, the device
* will be acquired from current stream. Later streams that record to the event
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

record to -> record

class Event(object):
r"""Wrapper around CUDA event.
class Event(torch._C._CudaEventBase):
r"""Wrapper around CUDA event. Every event is associated with a device index, which
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

r"""Wrapper around a CUDA event.

CUDA events are synchronization markers that can be used to monitor the device's progress,
to accurately measure timing, and to synchronize CUDA streams.

The underlying CUDA events are lazily initialized when the event is first recorded or exported to another process. After creation, only streams on the same device may record the event. However, streams on any device can wait on the event.

   .. _CUDA documentation:
   https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html

"""


def record(self, stream=None):
r"""Records the event in a given stream."""
r"""Records the event in a given stream. Use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r"""Records the event in a given stream.

Uses ``torch.cuda.current_stream()`` if no stream is specified. The stream's device must match the event's device."""


def query(self):
r"""Checks if the event has been recorded.
r"""Checks if all work currently captured by event has completed. This
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove "This can be called on any device.". That's notable for CUDA's cudaEventQuery(), but we make sure that Python methods can be called from any device by setting the current device if necessary.


def wait(self, stream=None):
r"""Makes a given stream wait for the event."""
r"""Makes a given stream wait for the event. Use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_error(cudart().cudaEventElapsedTime(
ctypes.byref(time_ms), self, end_event))
return time_ms.value
r"""Returns the time elapsed in milliseconds after the event was
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The events can be recorded on different devices. Suggestion:

r"""Returns the elapsed time in milliseconds between two events.

Both events must be created with `enable_timing=True` and have already be recorded and completed (that is `event.query()` must be True).

NOTE: If either event was last recorded in a non-NULL stream, the resulting time may be greater than expected (even if both used the same stream).

        .. note:: This is a wrapper around ``cudaEventElapsedTime()``: see `CUDA
           documentation`_ for more info.

        .. _CUDA documentation:
           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I tests two events on different devices, and it hit CUDA error. Let me double check.

Copy link
Contributor Author

@mrshenli mrshenli Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code run into RuntimeError: CUDA error: invalid resource handle error when calling e0.elapsed_time(e1). But I can still remove that from doc as CUDA does not explicitly say that.

        d0 = torch.device('cuda:0')
        d1 = torch.device('cuda:1')
    
        with torch.cuda.device(d0):
            s0 = torch.cuda.current_stream()
            e0 = torch.cuda.Event(enable_timing=True)
            torch.cuda._sleep(10)  # spin for about 50 ms on device1
            s0.record_event(e0)
    
        with torch.cuda.device(d1):
            s1 = torch.cuda.current_stream()
            e1 = torch.cuda.Event(enable_timing=True)
            torch.cuda._sleep(50000000)  # spin for about 50 ms on device1
            s1.record_event(e1)
    
        e0.synchronize()
        e1.synchronize()
        e0.elapsed_time(e1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm --- yeah I'm seeing the same thing now. My original test might have been wrong.

def synchronize(self):
r"""Synchronizes with the event."""
check_error(cudart().cudaEventSynchronize(self))
r"""Synchronizes with the event on the event's device."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "on the event's device" doesn't make sense in this context. Suggestion:

r"""Waits for the event to complete. 

Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes.

         .. note:: This is a wrapper around ``cudaEventSynchronize()``: see `CUDA
           documentation`_ for more info.

        .. _CUDA documentation:
           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html

@mrshenli mrshenli force-pushed the event branch 2 times, most recently from 07c6b46 to 4ecdf19 Compare January 15, 2019 23:24
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mrshenli
Copy link
Contributor Author

API change breaks existing use case. I will revert and only apply the change to ipc_handle()

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

was_recorded_ = true;
}

// Note: cudaStreamWaitEvent must be called on the same device as the event.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't accurate. (The original statement below is accurate)

}

// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t * handle, DeviceIndex device_index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't take in DeviceIndex here. It unnecessarily complicates the signature with redundant information (i.e. it now allows invalid calls).

void ipc_handle(cudaIpcEventHandle_t * handle);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we handle the case when ipc_handle is called before the event is recorded? Throw an error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an event on the the current device

check_error(cudart().cudaEventElapsedTime(
ctypes.byref(time_ms), self, end_event))
return time_ms.value
r"""Returns the time elapsed in milliseconds after the event was
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm --- yeah I'm seeing the same thing now. My original test might have been wrong.

"""
super(Event, self).synchronize()

def ipc_handle(self, device=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to take in a device here.

**kwargs):
with torch.cuda.device(device):
return super(Event, cls).__new__(
cls, device=torch.cuda.current_device(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this always ignores the function argument device

if hasattr(self, '_as_parameter_'):
check_error(self._cudart.cudaEventDestroy(self._as_parameter_))
del self._as_parameter_
def __new__(cls, device=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to take the device as an argument when the event is lazily constructed, so it may just end up getting ignored. It makes sense if your importing an even through a ipc_handle. Since this constructor is getting complicated -- not all arguments make sense together -- I think it's worth splitting out the construction of an event from an IPC handle into a separate factory function:

Something like

def from_ipc_handle(device, ipc_handle)

(Probably implemented in C++)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This will require changing rebuild_event and reduce_event in torch/multiprocessing/reductions.py)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

static PyObject * THCPEvent_from_ipc_handle(
PyTypeObject *type, PyObject *args) {
HANDLE_TH_ERRORS
int64_t device_index = -1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be declared long long to match L in PyArg_ParseTuple. You can cast it later to int64_t at the use-site (at::cuda::CUDAEvent(device_index, &handle);)


def reduce_event(event):
return (rebuild_event, (event.ipc_handle(),))
return (rebuild_event, (torch.cuda.current_device(), event.ipc_handle()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the event's device, not the current device. You'll also want to call event.ipc_handle() before accessing the event's device to ensure that the event is created (and has a device).

# create handle on different device from recorded event
e1.ipc_handle()

def _test_event_handle_importer_consumer(handle, p2c, c2p):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you specifically want to test some functionality related to handle it's better to pass the event itself to the subprocess, since that will test the reductions code and better matches expected use cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a test for that. I would prefer to keep these two test just to make sure that reconstructed event works as expected.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, assuming all tests pass

Some future work (not necessary in this PR):

  1. Stream.device and Event.device should probably return a torch.device instead of a number to match Tensor API
  2. Many of the Stream and Event calls should probably release the GIL before making potentially long-running calls

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 17, 2019
Summary:
1. Added `torch/csrc/cuda/Event.h` and `torch/csrc/cuda/Event.cpp` to bind Python Event class to C++ implementation.
2. Move all CUDA runtime invocations from `torch/cuda/streams.py` to C++
3. Added tests to cover Stream and Event APIs. ~(event IPC handle tests is introduced in #15974)~
Pull Request resolved: pytorch/pytorch#15937

Differential Revision: D13649001

Pulled By: mrshenli

fbshipit-source-id: 84ca58f35f6ba679a4ba33150ceba678d760d240
@mrshenli
Copy link
Contributor Author

@colesbury thanks a lot for the thorough review and suggestions. I will address the two future work items in separate PRs.

facebook-github-bot pushed a commit that referenced this pull request Jan 20, 2019
Summary:
Addresses one future work item in #15937
Pull Request resolved: #16150

Differential Revision: D13732299

Pulled By: mrshenli

fbshipit-source-id: 4d0b35df573a3bf92dea6e2e7eb42fe8bac77b18
facebook-github-bot pushed a commit that referenced this pull request Jan 22, 2019
Summary:
address the second future work item in #15937
Pull Request resolved: #16182

Differential Revision: D13744972

Pulled By: mrshenli

fbshipit-source-id: e9812e3fd4a5623e99b639d9f334bfc2d1827d92
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants