[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# --------------------------
import torch
import triton.language as tl
import triton
@@ -24,19 +25,19 @@ def _add(
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
pid = tl.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
x = tl.load(X + offsets, mask=mask)
y = tl.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
tl.store(Z + offsets, z)
# %%
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton'], # possible keys for `y_name`
y_lines=["Torch", "Triton"], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton'], # possible values for `line_arg`
line_names=["Torch", "Triton"], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name`