KEMBAR78
Slowdown due to thread specific model caching · Issue #27902 · pytorch/pytorch · GitHub
Skip to content

Slowdown due to thread specific model caching #27902

@CoderHam

Description

@CoderHam

🐛 Bug

PyTorch appears to have an issue with threads that causes it to not properly warm up models if executed by different threads in succession. This is why when using multi-instances of the model on the same GPU, it does not provide a speedup and in some cases is even slower than a single instance.

To Reproduce

Steps to reproduce the behavior:

I have a simple libtorch code snippet that reproduces this issue. In short all it does is run the same model (Resnet50 model from the torchvision that was traced to produce a torchscript version of the same). The code was run on a Titian V with a batch size of 1.

  1. Create model in main thread
  2. Launch thread that runs model N times in a loop, reporting the runtime each time
  3. Launch N threads that each run the model once, reporting the runtime each time
#include <torch/script.h>

#include <pthread.h>
#include <unistd.h>
#include <chrono>
#include <iostream>

struct inferStruct {
  at::Tensor output_;
  torch::jit::script::Module module_;
  std::vector<torch::jit::IValue> inputs_;
};

// Run model 5 times
void*
NForwardPass(void* run)
{
  inferStruct* run_tmp = reinterpret_cast<inferStruct*>(run);
  for (int i = 0; i < 5; i++) {
    auto t1 = std::chrono::high_resolution_clock::now();
    run_tmp->output_ = run_tmp->module_.forward(run_tmp->inputs_).toTensor();
    auto t2 = std::chrono::high_resolution_clock::now();
    auto duration =
        std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
    std::cout << duration / 1000.0
              << " ms to run model in process thread (NForwardPass)"
              << std::endl;
  }
}

// Run model 1 time
void*
ForwardPass(void* run)
{
  inferStruct* run_tmp = reinterpret_cast<inferStruct*>(run);
  auto t1 = std::chrono::high_resolution_clock::now();
  run_tmp->output_ = run_tmp->module_.forward(run_tmp->inputs_).toTensor();
  auto t2 = std::chrono::high_resolution_clock::now();
  auto duration =
      std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
  std::cout << duration / 1000.0
            << " ms to run model in process thread (ForwardPass)" << std::endl;
}

int
main()
{
  // Model loaded in main thread
  torch::Device device_ = torch::Device(torch::kCUDA, 0);
  std::string model_path = "resnet50_libtorch.pt";
  torch::jit::script::Module module1 = torch::jit::load(model_path, device_);
  std::vector<torch::jit::IValue> inputs_(1);
  inputs_[0] = torch::zeros({1, 3, 224, 224}).to(device_);
  // pre-warm model
  at::Tensor output = module1.forward(inputs_).toTensor();

  inferStruct run1;
  pthread_t thread_id1;  //, thread_id2;

  run1.module_ = module1;
  run1.inputs_ = inputs_;

  // Run model in main thread N times and report runtime each time
  for (int i = 0; i < 5; i++) {
    auto t1 = std::chrono::high_resolution_clock::now();
    output = module1.forward(inputs_).toTensor();
    auto t2 = std::chrono::high_resolution_clock::now();
    auto duration =
        std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
    std::cout << duration / 1000.0 << " ms to run model in main thread"
              << std::endl;
  }
  sleep(1);

  // Run N threads and in each run model once and report runtime each time
  std::cout << std::endl;
  for (int i = 0; i < 5; i++) {
    pthread_create(&thread_id1, NULL, ForwardPass, &run1);
    pthread_join(thread_id1, NULL);
  }
  sleep(1);

  // Run 1 thread and in it run the model N and report runtime each time
  std::cout << std::endl;
  pthread_create(&thread_id1, NULL, NForwardPass, &run1);
  pthread_join(thread_id1, NULL);

  return 0;
}

Expected behavior

Environment

Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.14.0

Python version: 3.6
Is CUDA available: No
CUDA runtime version: 10.1.243
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.3

Versions of relevant libraries:
[pip] msgpack-numpy==0.4.3.2
[pip] numpy==1.16.4
[pip] torch==1.2.0a0+afb7a16
[pip] torchtext==0.4.0
[pip] torchvision==0.3.0a0
[conda] magma-cuda100 2.5.0 1 local
[conda] mkl 2019.1 144
[conda] mkl-include 2019.1 144
[conda] nomkl 3.0 0
[conda] torch 1.2.0a0+afb7a16 pypi_0 pypi
[conda] torchtext 0.4.0 pypi_0 pypi
[conda] torchvision 0.3.0a0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @suo

Metadata

Metadata

Assignees

Labels

high priorityoncall: jitAdd this issue/PR to JIT oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions