[PYTHON][OPS] Added compiler hints to improve performance of

cross-entropy
This commit is contained in:
Philippe Tillet
2021-02-10 16:47:50 -05:00
parent b301c2d199
commit d190285d89
2 changed files with 18 additions and 15 deletions

View File

@@ -1,8 +1,8 @@
__global__ void forward(TYPE *logit, __global__ void forward(TYPE *logit __aligned(16),
TYPE *modified_logit, TYPE *modified_logit __aligned(16),
long *indices __readonly, long *indices __readonly,
TYPE *result, TYPE *result __aligned(16),
int n_cols) { int n_cols __multipleof(N_COLS_MULT)) {
int row = get_program_id(0); int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols); bool check[TILE] = ((0 ... TILE) < n_cols);
@@ -19,10 +19,10 @@ __global__ void forward(TYPE *logit,
*(result + row) = *(modified_logit + (local_ind + n_cols * row)); *(result + row) = *(modified_logit + (local_ind + n_cols * row));
} }
__global__ void backward(TYPE *neg_logprobs, __global__ void backward(TYPE *neg_logprobs __aligned(16),
long *indices, long *indices __aligned(16),
TYPE *dneg_logprobs, TYPE *dneg_logprobs __aligned(16),
int n_cols) { int n_cols __multipleof(N_COLS_MULT)) {
int row = get_program_id(0); int row = get_program_id(0);
// pointer arithmetic // pointer arithmetic

View File

@@ -12,9 +12,16 @@ def next_power_of_2(n):
n += 1 n += 1
return n return n
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def make_kernel(device, dtype, n_cols, cache, name): def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols) rounded = next_power_of_2(n_cols)
key = (dtype, rounded) div = largest_pow2_divisor(n_cols)
key = (dtype, rounded, div)
if key not in cache: if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name]) src = triton.read(fname, kernel_names=[name])
@@ -22,12 +29,8 @@ def make_kernel(device, dtype, n_cols, cache, name):
torch.float16: "F16_INFINITY", torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY", torch.float32: "F32_INFINITY",
} }
defines = { defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
"TILE": rounded, cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
"TYPE": dtype,
"INFINITY": infinities[dtype],
}
cache[key] = triton.kernel(src, device=device, defines=defines)
return cache[key] return cache[key]
# forward kernel # forward kernel