From d190285d8917155fb603d758ad556f3399d6d0d8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 10 Feb 2021 16:47:50 -0500 Subject: [PATCH] [PYTHON][OPS] Added compiler hints to improve performance of cross-entropy --- python/triton/ops/cross_entropy.c | 16 ++++++++-------- python/triton/ops/cross_entropy.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c index 118f5d145..2767fae1a 100644 --- a/python/triton/ops/cross_entropy.c +++ b/python/triton/ops/cross_entropy.c @@ -1,8 +1,8 @@ -__global__ void forward(TYPE *logit, - TYPE *modified_logit, +__global__ void forward(TYPE *logit __aligned(16), + TYPE *modified_logit __aligned(16), long *indices __readonly, - TYPE *result, - int n_cols) { + TYPE *result __aligned(16), + int n_cols __multipleof(N_COLS_MULT)) { int row = get_program_id(0); bool check[TILE] = ((0 ... TILE) < n_cols); @@ -19,10 +19,10 @@ __global__ void forward(TYPE *logit, *(result + row) = *(modified_logit + (local_ind + n_cols * row)); } -__global__ void backward(TYPE *neg_logprobs, - long *indices, - TYPE *dneg_logprobs, - int n_cols) { +__global__ void backward(TYPE *neg_logprobs __aligned(16), + long *indices __aligned(16), + TYPE *dneg_logprobs __aligned(16), + int n_cols __multipleof(N_COLS_MULT)) { int row = get_program_id(0); // pointer arithmetic diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 05a512238..75ee03e59 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -12,9 +12,16 @@ def next_power_of_2(n): n += 1 return n +def largest_pow2_divisor(N): + if N % 8 == 0: return 8 + if N % 4 == 0: return 4 + if N % 2 == 0: return 2 + return 1 + def make_kernel(device, dtype, n_cols, cache, name): rounded = next_power_of_2(n_cols) - key = (dtype, rounded) + div = largest_pow2_divisor(n_cols) + key = (dtype, rounded, div) if key not in cache: fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") src = triton.read(fname, kernel_names=[name]) @@ -22,12 +29,8 @@ def make_kernel(device, dtype, n_cols, cache, name): torch.float16: "F16_INFINITY", torch.float32: "F32_INFINITY", } - defines = { - "TILE": rounded, - "TYPE": dtype, - "INFINITY": infinities[dtype], - } - cache[key] = triton.kernel(src, device=device, defines=defines) + defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div} + cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4) return cache[key] # forward kernel