[PYTHON][OPS] Added compiler hints to improve performance of
cross-entropy
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
__global__ void forward(TYPE *logit,
|
||||
TYPE *modified_logit,
|
||||
__global__ void forward(TYPE *logit __aligned(16),
|
||||
TYPE *modified_logit __aligned(16),
|
||||
long *indices __readonly,
|
||||
TYPE *result,
|
||||
int n_cols) {
|
||||
TYPE *result __aligned(16),
|
||||
int n_cols __multipleof(N_COLS_MULT)) {
|
||||
int row = get_program_id(0);
|
||||
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
@@ -19,10 +19,10 @@ __global__ void forward(TYPE *logit,
|
||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *neg_logprobs,
|
||||
long *indices,
|
||||
TYPE *dneg_logprobs,
|
||||
int n_cols) {
|
||||
__global__ void backward(TYPE *neg_logprobs __aligned(16),
|
||||
long *indices __aligned(16),
|
||||
TYPE *dneg_logprobs __aligned(16),
|
||||
int n_cols __multipleof(N_COLS_MULT)) {
|
||||
|
||||
int row = get_program_id(0);
|
||||
// pointer arithmetic
|
||||
|
@@ -12,9 +12,16 @@ def next_power_of_2(n):
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def largest_pow2_divisor(N):
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
return 1
|
||||
|
||||
def make_kernel(device, dtype, n_cols, cache, name):
|
||||
rounded = next_power_of_2(n_cols)
|
||||
key = (dtype, rounded)
|
||||
div = largest_pow2_divisor(n_cols)
|
||||
key = (dtype, rounded, div)
|
||||
if key not in cache:
|
||||
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
||||
src = triton.read(fname, kernel_names=[name])
|
||||
@@ -22,12 +29,8 @@ def make_kernel(device, dtype, n_cols, cache, name):
|
||||
torch.float16: "F16_INFINITY",
|
||||
torch.float32: "F32_INFINITY",
|
||||
}
|
||||
defines = {
|
||||
"TILE": rounded,
|
||||
"TYPE": dtype,
|
||||
"INFINITY": infinities[dtype],
|
||||
}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
|
||||
return cache[key]
|
||||
|
||||
# forward kernel
|
||||
|
Reference in New Issue
Block a user