[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -46,28 +46,29 @@ def naive_softmax(x):
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
m = tl.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
x = tl.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
z = x - tl.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
tl.store(Y, y, mask=n < N)
# %%
@@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref))
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`