[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# --------------------------
import torch
import triton.language as tl
import triton
@@ -24,19 +25,19 @@ def _add(
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
pid = tl.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
x = tl.load(X + offsets, mask=mask)
y = tl.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
tl.store(Z + offsets, z)
# %%
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton'], # possible keys for `y_name`
y_lines=["Torch", "Triton"], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton'], # possible values for `line_arg`
line_names=["Torch", "Triton"], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name`

View File

@@ -46,28 +46,29 @@ def naive_softmax(x):
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
m = tl.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
x = tl.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
z = x - tl.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
tl.store(Y, y, mask=n < N)
# %%
@@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref))
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`

View File

@@ -115,6 +115,7 @@ You will specifically learn about:
import torch
import triton
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
@@ -124,9 +125,9 @@ import triton
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + triton.exp(-x))
ret_false = triton.exp(x) / (1 + triton.exp(x))
return triton.where(x >= 0, ret_true, ret_false)
ret_true = 1 / (1 + tl.exp(-x))
ret_false = tl.exp(x) / (1 + tl.exp(x))
return tl.where(x >= 0, ret_true, ret_false)
@triton.jit
@@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
pid = triton.program_id(0)
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
@@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
a = triton.load(A)
b = triton.load(B)
acc += triton.dot(a, b)
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
@@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
triton.store(C, acc, mask=mask)
tl.store(C, acc, mask=mask)
# %%
@@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1))
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
@@ -258,4 +259,4 @@ def benchmark(M, N, K, provider):
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(print_data=True)
benchmark.run(show_plots=True, print_data=True)