[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
|
||||
|
||||
@@ -24,19 +25,19 @@ def _add(
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = triton.program_id(0)
|
||||
pid = tl.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = triton.load(X + offsets, mask=mask)
|
||||
y = triton.load(Y + offsets, mask=mask)
|
||||
x = tl.load(X + offsets, mask=mask)
|
||||
y = tl.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
triton.store(Z + offsets, z)
|
||||
tl.store(Z + offsets, z)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton"], # label name for the lines
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['torch', 'triton'], # possible values for `line_arg`
|
||||
line_names=["Torch", "Triton"], # label name for the lines
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={} # values for function arguments not in `x_names` and `y_name`
|
||||
|
Reference in New Issue
Block a user