[PYTHON][OPS] Added compiler hints to improve performance of
cross-entropy
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
__global__ void forward(TYPE *logit,
|
__global__ void forward(TYPE *logit __aligned(16),
|
||||||
TYPE *modified_logit,
|
TYPE *modified_logit __aligned(16),
|
||||||
long *indices __readonly,
|
long *indices __readonly,
|
||||||
TYPE *result,
|
TYPE *result __aligned(16),
|
||||||
int n_cols) {
|
int n_cols __multipleof(N_COLS_MULT)) {
|
||||||
int row = get_program_id(0);
|
int row = get_program_id(0);
|
||||||
|
|
||||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||||
@@ -19,10 +19,10 @@ __global__ void forward(TYPE *logit,
|
|||||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void backward(TYPE *neg_logprobs,
|
__global__ void backward(TYPE *neg_logprobs __aligned(16),
|
||||||
long *indices,
|
long *indices __aligned(16),
|
||||||
TYPE *dneg_logprobs,
|
TYPE *dneg_logprobs __aligned(16),
|
||||||
int n_cols) {
|
int n_cols __multipleof(N_COLS_MULT)) {
|
||||||
|
|
||||||
int row = get_program_id(0);
|
int row = get_program_id(0);
|
||||||
// pointer arithmetic
|
// pointer arithmetic
|
||||||
|
@@ -12,9 +12,16 @@ def next_power_of_2(n):
|
|||||||
n += 1
|
n += 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def largest_pow2_divisor(N):
|
||||||
|
if N % 8 == 0: return 8
|
||||||
|
if N % 4 == 0: return 4
|
||||||
|
if N % 2 == 0: return 2
|
||||||
|
return 1
|
||||||
|
|
||||||
def make_kernel(device, dtype, n_cols, cache, name):
|
def make_kernel(device, dtype, n_cols, cache, name):
|
||||||
rounded = next_power_of_2(n_cols)
|
rounded = next_power_of_2(n_cols)
|
||||||
key = (dtype, rounded)
|
div = largest_pow2_divisor(n_cols)
|
||||||
|
key = (dtype, rounded, div)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
||||||
src = triton.read(fname, kernel_names=[name])
|
src = triton.read(fname, kernel_names=[name])
|
||||||
@@ -22,12 +29,8 @@ def make_kernel(device, dtype, n_cols, cache, name):
|
|||||||
torch.float16: "F16_INFINITY",
|
torch.float16: "F16_INFINITY",
|
||||||
torch.float32: "F32_INFINITY",
|
torch.float32: "F32_INFINITY",
|
||||||
}
|
}
|
||||||
defines = {
|
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
|
||||||
"TILE": rounded,
|
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
|
||||||
"TYPE": dtype,
|
|
||||||
"INFINITY": infinities[dtype],
|
|
||||||
}
|
|
||||||
cache[key] = triton.kernel(src, device=device, defines=defines)
|
|
||||||
return cache[key]
|
return cache[key]
|
||||||
|
|
||||||
# forward kernel
|
# forward kernel
|
||||||
|
Reference in New Issue
Block a user