KEMBAR78
[Opt] Using filter and kernel level pipeline to optimize lookup kernels by jiashuy · Pull Request #136 · NVIDIA-Merlin/HierarchicalKV · GitHub
Skip to content

Conversation

jiashuy
Copy link
Collaborator

@jiashuy jiashuy commented Jun 4, 2023

On pure HBM mode

  1. Using digests(some bits of hashed keys) as a filter to reduce memory traffic.
  2. Using kernel level pipeline to overlap memory accesses to hide latency.
  3. Unit test of the look kernels using filter and pipeline.
  4. Make dim which lookup kernel with pipeline support Configurable.
  5. Put common kernels into the core_kernels folder, and modify the BUILD file used for bazel build.
  6. Change the way addressing digests
  7. When init hash table, check the bucket_max_size to make keys and scores meet cache line size.

@jiashuy jiashuy requested a review from rhdong June 4, 2023 22:35
@github-actions
Copy link

github-actions bot commented Jun 4, 2023

README.md Outdated
## Benchmark & Performance(W.I.P)

* GPU: 1 x NVIDIA A100 80GB PCIe: 8.0
* GPU: 1 x NVIDIA A100-SXM4-80GB: 8.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to keep PCIE as our benchmark baseline.

// Only bucket_size = 128
// On A100, the maximum dim which Pipeline support is 224 floats
if (options_.max_bucket_size == 128 &&
value_size <= (224 * sizeof(float))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better avoid the magic number, and make it a private member of HashTable or better form. If the 224 depends on the GPU hardware setting, we need to calculate it at initialize phrase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will make arch infomation as tempate type to select the kernel related config at compile time.

for (size_t i = 0; i < bucket_max_size; i++)
new (buckets[start + tid].keys(i))
AtomicKey<K>{static_cast<K>(EMPTY_KEY)};
K hashed_key = Murmur3HashDevice(static_cast<K>(EMPTY_KEY));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const K

new (buckets[start + tid].keys(i))
AtomicKey<K>{static_cast<K>(EMPTY_KEY)};
K hashed_key = Murmur3HashDevice(static_cast<K>(EMPTY_KEY));
uint8_t digest = static_cast<uint8_t>(hashed_key >> 32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const uint8_t

new (buckets[start + tid].keys(i))
AtomicKey<K>{static_cast<K>(EMPTY_KEY)};
K hashed_key = Murmur3HashDevice(static_cast<K>(EMPTY_KEY));
uint8_t digest = static_cast<uint8_t>(hashed_key >> 32);
Copy link
Member

@rhdong rhdong Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you confirm which header we should use for uint8_t?
sys/types.h
or <stdint.h>
or https://nvidia.github.io/cutlass/structcutlass_1_1TypeTraits_3_01uint8__t_01_4.html

I mean if we need to add a special header explicitly for uint8_t

Copy link
Collaborator Author

@jiashuy jiashuy Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its <cstdint> or <stdint.h> which is already included in our code in header types.cuh

local_size = buckets_size[new_bkt_idx];
if (rank == src_lane) {
K hashed_key = Murmur3HashDevice(key);
uint8_t target_digest = static_cast<uint8_t>(hashed_key >> 32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and I will find all of this and add const, thanks for your review.

if (rank == 0) {
K hashed_key = Murmur3HashDevice(static_cast<K>(EMPTY_KEY));
uint8_t target_digest = static_cast<uint8_t>(hashed_key >> 32);
bucket->digests[key_idx] = target_digest;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For EMPTY_KEY, we'd better define a separate macro for its relative digest.

if (g.thread_rank() == src_lane) {
const int key_pos =
(start_idx + tile_offset + src_lane) & (bucket_max_size - 1);
K hashed_key = Murmur3HashDevice(static_cast<K>(EMPTY_KEY));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EMPTY_DIGEST

#include "../utils.cuh"

// if i % 2 == 0, select buffer 0, else buffer 1
#define SAME_BUF(i) (((i)&0x01) ^ 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is used to select buffer in pipeline kernel. For example:
V* v_src = sm_vector[SAME_BUF(i)][groupID]; in kernel lookup_kernel_with_io_pipeline_v2.

Copy link
Member

@rhdong rhdong Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't expand the lookup_kernels.cuh.
So here is a potential issue, the macro naming is too common that may dirty the end-users name scope, if no performance loss, can we change them to a __forced_ inline__ __device__ func(..) ?
Or at least, #undef them after the last reference in this file.
Or special prefix like 'MERLIN_xxx'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I get


__forceinline__ __device__ static S lgs(S* src) { return src[0]; }

__forceinline__ __device__ static void stg(S* dst, S score_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stg(const S* dst, const S score_)

__pipeline_memcpy_async(dst, src, sizeof(S));
}

__forceinline__ __device__ static S lgs(S* src) { return src[0]; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using const maybe help the compiler optimize code.


using namespace cooperative_groups;
namespace cg = cooperative_groups;
#include "core_kernels/kernel_utils.cuh"
Copy link
Member

@rhdong rhdong Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify for Bazel build file at the same time, its location is ./include/merlin/BUILD, and please try to build with Bazel after done(no CI cases for it currently).

int idx_block = groupID * GROUP_SIZE + rank;
K target_key = keys[key_idx_base + rank];
sm_target_keys[idx_block] = target_key;
K hashed_key = Murmur3HashDevice(target_key);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use const as possible.

@rhdong rhdong requested review from minseokl and zehuanw June 5, 2023 07:52
template <typename K = uint64_t, typename V = float, typename S = uint64_t,
typename CopyScore = CopyScoreEmpty<S, K, 128>,
typename CopyValue = CopyValueTwoGroup<float, float4, 32>>
__global__ void lookup_kernel_with_io_pipeline_v1(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @minseokl, could you help review here for @jiashuy? It's too complex for me to hold it. :(

@rhdong
Copy link
Member

rhdong commented Jun 5, 2023

/blossom-ci

int find_number = __popc(find_result);
int group_base = 0;
if (find_number > 0) {
group_base = atomicAdd(sm_counts + key_idx_block, find_number);
Copy link
Member

@rhdong rhdong Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the atomicAdd_block is enough here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with you. __atomicAdd_block is more proper.

Copy link
Collaborator Author

@jiashuy jiashuy Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, when I use __atomicAdd_block, I got the error:identifier "__atomicAdd_block" is undefined.
I think its related to CMakeLists.txt : set_target_properties(xxx PROPERTIES CUDA_ARCHITECTURES OFF).
And, accorrding to the CUDA Doc, atomicAdd support shared memory, so I think use atomicAdd is the cheapest way at present.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomicAdd works for both global and shared memory

CUDA_CHECK(cudaMalloc(&((*table)->buckets[i].keys_), bucket_memory_size));
(*table)->buckets[i].scores_ = reinterpret_cast<AtomicScore<S>*>(
(*table)->buckets[i].keys_ + bucket_max_size);
(*table)->buckets[i].digests = reinterpret_cast<uint8_t*>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to bring the digests to be ahead of keys_, the find should always read the digests first.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, its reasonable

Copy link
Collaborator Author

@jiashuy jiashuy Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I think we can just store the key and value's address, as the address of scores and digests can be infered by keys_ and bucket_max_size

S score_ = CopyScore::lgs(sm_target_scores + key_idx_block);
CopyValue::lds_stg(rank, v_dst, v_src, dim);
founds[key_idx_grid] = true;
CopyScore::stg(scores + key_idx_grid, score_);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the result of score and found information is stored in shared memory, to write them to the global memory in the end of the kernel, for coalesing memory access, targeting reduce memory traffic.

/// TODO: compute the pointer of scores and digests using bucket_max_size
AtomicScore<S>* scores_;
/// @brief not visible to users
uint8_t* digests;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inspired by a discussion with you, potentially, we can reduce the memory consumption of Bucket struct by canceling the separate pointers for keys_, scores_, and digests_, because we just need only 1 start pointer for these three.
So could you switch the digests to a function like this?

  __forceinline__ __device__ uint8_t* digests(int index) const {
    return digests_ + index;
  }

This will benefit the future refactoring in the future I said.

Comment on lines 32 to 36
constexpr int GROUP_SIZE = 32;
constexpr int RESERVE = 16;
constexpr int DIM_BUF = 224;
constexpr int BLOCK_SIZE = 128;
constexpr int BUCKET_SIZE = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiashuy Are they configurable? How do you decide their values?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present:

  1. BUCKET_SIZE is fixed to 128. This is commonly used by users which can be confirmed by @rhdong .
  2. I think BLOCK_SIZE is as small as possible to reduce uneven workload. But too small will cause the grid size too large. So I choose 128.
  3. GROUP_SIZE is set accorrding the profiler. When dim is small, use 16 threads to deal with one key cooperatively is more effective(if use 8, will consume more registers); and when dim is large, use 32 threads to deal with one key, so that we can put larger value to shared memory(group num is smaller, means using less shared memory for double buffer).
    And the only difference between kernel v1 and v2 is the GROUP_SIZE.
  4. DIM_BUF is configurable, according to the shared memory size of SM(different on arch). I've already finished this, and will commit today.
  5. RESERVE is the reserved size for possible keys(digest = target digest).
    From the statistics of continues keys, 16 is enough for RESERVE, but I use 8 in lookup_kernel_with_io_pipeline_v2,for reduce shared memory usage. Resolving correctness by swaping space with time(latency).
    The frequency of the reserve size that is really needed is a power-law distribution.

Copy link
Collaborator Author

@jiashuy jiashuy Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a summry, BUCKET_SIZEBLOCK_SIZEGROUP_SIZERESERVE are fixed for specific kernel.
And BLOCK_SIZE is set accorrding the subjective experience;
RESERVE and GROUP_SIZE are set by summary from profiler and performance.
DIM_BUF is configurable and have been implemented.

Comment on lines 87 to 88
__pipeline_commit(); // padding
__pipeline_commit(); // padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiashuy Why do you need these paddings?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the __pipeline_wait_prior(x) waits for the (x+1)th __pipeline_commit() in front.
You can observe the __pipeline_wait_prior(3) at line 109.
So, in the first loop, we need to wait sm_probing_digests to be writen back at the stage of pipeline loading.
So I pad __pipeline_commit() to avoid to check in the loop again and again.

@jiashuy jiashuy force-pushed the master branch 2 times, most recently from 2b7e893 to 9055d9a Compare June 8, 2023 12:47
On pure HBM mode

1. Using digests(some bits of hashed keys) as a filter to reduce memory
traffic.
2. Using kernel level pipeline to overlap memory accesses to hide
latency.
3. Unit test of the look kernels using filter and pipeline.
4. Make dim which lookup kernel with pipeline support Configurable.
5. Put common kernels into the core_kernels folder, and modify the BUILD file used for bazel build.
6. Change the way addressing digests
7. When init hash table, check the bucket_max_size to make keys and scores meet cache line size.
@rhdong
Copy link
Member

rhdong commented Jun 11, 2023

/blossom-ci

@rhdong
Copy link
Member

rhdong commented Jun 11, 2023

/blossom-ci

@rhdong rhdong merged commit 921e9b8 into NVIDIA-Merlin:master Jun 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants