[PYTHON][OPS] Added compiler hints to improve performance of

cross-entropy
This commit is contained in:
Philippe Tillet
2021-02-10 16:47:50 -05:00
parent b301c2d199
commit d190285d89
2 changed files with 18 additions and 15 deletions

View File

@@ -1,8 +1,8 @@
__global__ void forward(TYPE *logit,
TYPE *modified_logit,
__global__ void forward(TYPE *logit __aligned(16),
TYPE *modified_logit __aligned(16),
long *indices __readonly,
TYPE *result,
int n_cols) {
TYPE *result __aligned(16),
int n_cols __multipleof(N_COLS_MULT)) {
int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols);
@@ -19,10 +19,10 @@ __global__ void forward(TYPE *logit,
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
}
__global__ void backward(TYPE *neg_logprobs,
long *indices,
TYPE *dneg_logprobs,
int n_cols) {
__global__ void backward(TYPE *neg_logprobs __aligned(16),
long *indices __aligned(16),
TYPE *dneg_logprobs __aligned(16),
int n_cols __multipleof(N_COLS_MULT)) {
int row = get_program_id(0);
// pointer arithmetic

View File

@@ -12,9 +12,16 @@ def next_power_of_2(n):
n += 1
return n
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols)
key = (dtype, rounded)
div = largest_pow2_divisor(n_cols)
key = (dtype, rounded, div)
if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name])
@@ -22,12 +29,8 @@ def make_kernel(device, dtype, n_cols, cache, name):
torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY",
}
defines = {
"TILE": rounded,
"TYPE": dtype,
"INFINITY": infinities[dtype],
}
cache[key] = triton.kernel(src, device=device, defines=defines)
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
return cache[key]
# forward kernel