[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -115,6 +115,7 @@ You will specifically learn about:
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
@@ -124,9 +125,9 @@ import triton
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + triton.exp(-x))
|
||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
||||
return triton.where(x >= 0, ret_true, ret_false)
|
||||
ret_true = 1 / (1 + tl.exp(-x))
|
||||
ret_false = tl.exp(x) / (1 + tl.exp(x))
|
||||
return tl.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
@@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
acc += triton.dot(a, b)
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
@@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
triton.store(C, acc, mask=mask)
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1))
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
@@ -258,4 +259,4 @@ def benchmark(M, N, K, provider):
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(print_data=True)
|
||||
benchmark.run(show_plots=True, print_data=True)
|
Reference in New Issue
Block a user