[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -10,9 +10,9 @@ square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
@@ -60,9 +60,9 @@ square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}

View File

@@ -5,9 +5,9 @@ confs = [
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
y_name = 'provider',
y_vals = ['triton', 'torch'],
y_lines = ['Triton', 'Torch'],
line_arg = 'provider',
line_vals = ['triton', 'torch'],
line_names = ['Triton', 'Torch'],
ylabel = 'GBPS',
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}

View File

@@ -16,9 +16,9 @@ square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider",
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
@@ -30,9 +30,9 @@ transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider",
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}

View File

@@ -1,5 +1,6 @@
import torch
import triton
import triton.language as tl
import copy
import pytest
import ast
@@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
@@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
y = triton.load(Y + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
@@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
@@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = triton.arange(0, SIZE)
n = triton.arange(0, SIZE)
x = triton.load(X_PTR_EXPR)
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z)
tl.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),

View File

@@ -2,9 +2,9 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import jit, autotune, heuristics, Config, Autotuner
from .core import *
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
from . import language
from . import code_gen
from . import testing
from . import ops

View File

@@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor):
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block):
if isinstance(ret, triton.language.block):
handle = self.module.get_value(name)
return triton.block(handle)
return triton.language.block(handle)
return ret
def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.block(value)
if isinstance(value, triton.block):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
return isinstance(value, triton.block)
return isinstance(value, triton.language.block)
def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
@@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor):
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'getattr': getattr,
}
def visit_Module(self, node):
self.module.add_new_scope()
@@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
@@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core:
sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws)
@@ -613,6 +620,11 @@ class JITFunction:
raise e
raise CompilationError(self.src, node, e)
def __setattr__(self, name, value):
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
@@ -659,4 +671,23 @@ def heuristics(values):
def jit(fn):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
return JITFunction(fn)
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -1,6 +1,6 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
import triton
from functools import wraps
@@ -25,7 +25,7 @@ def _patch(fn):
if x.type.is_void():
return None
return block(x)
return x
return tl
def wrapper(*args, **kwargs):
builder = args[-1]
@@ -547,7 +547,7 @@ def minimum(x, y):
:param other: the second input block
:type other: Block
"""
return triton.where(x < y, x, y)
return triton.language.where(x < y, x, y)
@triton.jit
@@ -560,7 +560,7 @@ def maximum(x, y):
:param other: the second input block
:type other: Block
"""
return triton.where(x > y, x, y)
return triton.language.where(x > y, x, y)
@triton.jit
@@ -571,7 +571,7 @@ def sigmoid(x):
:param x: the input block
:type x: Block
"""
return 1 / (1 + np.exp(-x))
return 1 / (1 + triton.language.exp(-x))
@triton.jit
@@ -582,9 +582,9 @@ def softmax(x):
:param x: the input block
:type x: Block
"""
z = x - triton.max(x, 0)
num = triton.exp(z)
den = triton.sum(num, 0)
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return num / den
@@ -596,8 +596,4 @@ def ravel(x):
:param x: the input block
:type x: Block
"""
return triton.reshape(x, [x.type.numel])
def cdiv(x, y):
return (x + y - 1) // y
return triton.language.reshape(x, [x.type.numel])

View File

@@ -1,4 +1,5 @@
import triton
import triton.language as tl
import triton._C.libtriton as libtriton
import torch
@@ -16,21 +17,21 @@ def _kernel(
#------------#
#- Prologue -#
#------------#
pid0 = triton.program_id(0)
pid1 = triton.program_id(1)
pidz = triton.program_id(2)
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pidz = tl.program_id(2)
if meta['SDD']:
pid1 = pid1 + SDD_off_width
blockidm = triton.arange(0, TM) // BLOCK
blockidn = triton.arange(0, TN) // BLOCK
blockidm = tl.arange(0, TM) // BLOCK
blockidn = tl.arange(0, TN) // BLOCK
offlutm = blockidm * (TN // BLOCK) * 4
offlutn = blockidn * 4
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
z = triton.load(header + 0)
i = triton.load(header + 1 + offlutm)
j = triton.load(header + 2 + offlutn)
z = tl.load(header + 0)
i = tl.load(header + 1 + offlutm)
j = tl.load(header + 2 + offlutn)
AS1 = SDD_K // TZ
lockid = triton.where(TZ > 1, 1, 0)
lockid = tl.where(TZ > 1, 1, 0)
offka = pid0 * AS1
offkb = pid0 * AS1
offmc = 0
@@ -41,16 +42,16 @@ def _kernel(
offhc = 0
offha = z
offhb = z
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK)
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK)
ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
else:
header = lut + pid0 * 6
offset = triton.load(header + 0)
AS1 = triton.load(header + 1)
column = triton.load(header + 2)
depth = triton.load(header + 3)
lockid = triton.load(header + 4)
maxid = triton.load(header + 5)
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
depth = tl.load(header + 3)
lockid = tl.load(header + 4)
maxid = tl.load(header + 5)
pinc = lut + offset
offhc = depth
if meta['DSD']:
@@ -60,14 +61,14 @@ def _kernel(
offpc = 0
# dense input offset
offnb = pid1 * TN
offkb = triton.load(pinc)
offkb = triton.multiple_of(offkb, 8) # compiler hint
offkb = tl.load(pinc)
offkb = tl.multiple_of(offkb, 8) # compiler hint
offpb = 0
# sparse input offset
offma = 0
offka = 0
offpa = triton.load(pinc + 1)
offpa = triton.multiple_of(offpa, 8) # compiler hint
offpa = tl.load(pinc + 1)
offpa = tl.multiple_of(offpa, 8) # compiler hint
offpa = offpa * BLOCK * BLOCK
offha = 0
offhb = depth
@@ -78,23 +79,23 @@ def _kernel(
offpc = 0
# dense input offset
offma = pid1 * TM
offka = triton.load(pinc)
offka = triton.multiple_of(offka, 8) # compiler hint
offka = tl.load(pinc)
offka = tl.multiple_of(offka, 8) # compiler hint
offpa = 0
# sparse input offset
offnb = 0
offkb = 0
offpb = triton.load(pinc + 1)
offpb = triton.multiple_of(offpb, 8) # compiler hint
offpb = tl.load(pinc + 1)
offpb = tl.multiple_of(offpb, 8) # compiler hint
offpb = offpb * BLOCK * BLOCK
offha = depth
offhb = 0
ram = offma + triton.arange(0, TM)
rbn = offnb + triton.arange(0, TN)
ram = offma + tl.arange(0, TM)
rbn = offnb + tl.arange(0, TN)
# initialize a, b pointers
rka = offka + triton.arange(0, TK)
rkb = offkb + triton.arange(0, TK)
rka = offka + tl.arange(0, TK)
rkb = offkb + tl.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']:
@@ -105,31 +106,31 @@ def _kernel(
checkbn = rbn[None, :] < DS0
else:
checkbn = AS1 > 0
a = triton.load(pa, mask=checkam, other=0.)
b = triton.load(pb, mask=checkbn, other=0.)
a = tl.load(pa, mask=checkam, other=0.)
b = tl.load(pb, mask=checkbn, other=0.)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = triton.zeros((TM, TN), dtype=triton.float32)
acc = tl.zeros((TM, TN), dtype=tl.float32)
for k in range(AS1, 0, -TK):
acc += triton.dot(a, b)
acc += tl.dot(a, b)
if meta['SDD']:
inc_a = TK * stride_ka
inc_b = TK * stride_kb
else:
pinc += 2
if meta['DSD']:
inc_b = triton.load(pinc)
inc_a = triton.load(pinc + 1)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_a = tl.load(pinc + 1)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = triton.load(pinc)
inc_b = triton.load(pinc + 1)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
pa += inc_a
pb += inc_b
@@ -138,24 +139,24 @@ def _kernel(
checkbk = k > TK
checka = checkam & checkak
checkb = checkbn & checkbk
a = triton.load(pa, mask=checka)
b = triton.load(pb, mask=checkb)
a = tl.load(pa, mask=checka)
b = tl.load(pb, mask=checkb)
c = acc.to(C.dtype.element_ty)
if meta['SDD']:
checkc = True
rr_blockidm = triton.arange(0, TM) // BLOCK
rr_blockidn = triton.arange(0, TN) // BLOCK
rr_blockidm = tl.arange(0, TM) // BLOCK
rr_blockidn = tl.arange(0, TN) // BLOCK
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
rr_offlutn = rr_blockidn * 4
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
bkid = triton.load(header + off_bkid)
bkid = tl.load(header + off_bkid)
offpc = bkid * BLOCK * BLOCK
rcm = triton.arange(0, TM) % BLOCK
rcn = triton.arange(0, TN) % BLOCK
rcm = tl.arange(0, TM) % BLOCK
rcn = tl.arange(0, TN) % BLOCK
else:
rcm = offmc + triton.arange(0, TM)
rcn = offnc + triton.arange(0, TN)
rcm = offmc + tl.arange(0, TM)
rcn = offnc + tl.arange(0, TN)
if meta['DSD']:
checkc = rcn[None, :] < DS0
if meta['DDS']:
@@ -164,21 +165,21 @@ def _kernel(
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
# write-back directly
if lockid == 0:
triton.store(pc, c, mask=checkc)
tl.store(pc, c, mask=checkc)
# accumulate partial results using spin-locks
else:
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks
while triton.atomic_cas(plock, 0, 1) == 1:
plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
while tl.atomic_cas(plock, 0, 1) == 1:
pass
count = triton.load(pcount)
count = tl.load(pcount)
if count == 0:
triton.store(pc, c, mask=checkc)
tl.store(pc, c, mask=checkc)
else:
d = triton.load(pc, mask=checkc)
triton.store(pc, d + c, mask=checkc)
triton.atomic_xchg(pcount, (count + 1) % maxid)
triton.atomic_xchg(plock, 0)
d = tl.load(pc, mask=checkc)
tl.store(pc, d + c, mask=checkc)
tl.atomic_xchg(pcount, (count + 1) % maxid)
tl.atomic_xchg(plock, 0)
##############

View File

@@ -1,3 +1,4 @@
import triton.language as tl
import triton
import torch
import os
@@ -31,86 +32,86 @@ def _forward(
):
TN = meta['TN']
BLOCK = meta['BLOCK']
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
# extract information from LUT
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
size = tl.load(header + 0)
offset = tl.load(header + 1)
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
rbmn = tl.where(check, rbn, size - 1)
# block id and column id
blockid = triton.load(LUT + offset + rbmn * 4 + 0)
columnid = triton.load(LUT + offset + rbmn * 4 + 1)
rowid = triton.load(LUT + offset + rbmn * 4 + 2)
headid = triton.load(LUT + offset + rbmn * 4 + 3)
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = triton.load(px, mask=check, other=-float('inf'))
x = x.to(triton.float32)
x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(tl.float32)
# apply scale
if meta['APPLY_SCALE']:
x = x * scale
# apply RPE
if meta['APPLY_RPE']:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = triton.load(prpe, mask=check, other=0)
rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if meta['APPLY_KP_MASK']:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = triton.load(pkp_m, mask=check, other=-float('inf'))
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']:
kp_m = triton.where(kp_m == 0, -float('inf'), 0.)
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if meta['APPLY_ATTN_MASK']:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = triton.load(pattn_m, mask=check, other=-float('inf'))
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']:
attn_m = triton.where(attn_m == 0, -float('inf'), 0.)
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# computation
x = triton.softmax(x)
triton.store(px, x, mask=check)
x = tl.softmax(x)
tl.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
TN = meta['TN']
BLOCK = meta['BLOCK']
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
# extract information from look-up table
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
size = tl.load(header + 0)
offset = tl.load(header + 1)
# bounds checking on lut
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
rbmn = tl.where(check, rbn, size - 1)
# initialize pointers to block-sparse input
blockid = triton.load(LUT + offset + rbmn * 4)
blockid = tl.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward
x = triton.load(X, mask=check, other=0)
dx = triton.load(DX, mask=check, other=0)
x = x.to(triton.float32)
dx = dx.to(triton.float32)
y = x * (dx - triton.sum(x * dx, 0)) * scale
triton.store(DX, y, mask=check)
x = tl.load(X, mask=check, other=0)
dx = tl.load(DX, mask=check, other=0)
x = x.to(tl.float32)
dx = dx.to(tl.float32)
y = x * (dx - tl.sum(x * dx, 0)) * scale
tl.store(DX, y, mask=check)
class _softmax(torch.autograd.Function):

View File

@@ -1,5 +1,6 @@
import os
import triton
import triton.language as tl
import torch
@@ -27,25 +28,25 @@ def num_warps(N):
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to logit and probs
LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(triton.float32)
logits = logits - triton.max(logits, 0)
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits
triton.store(WRIT_PROBS, probs, mask=cols < N)
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(tl.float32)
logits = logits - tl.max(logits, 0)
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
tl.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon.
triton.debug_barrier()
tl.debug_barrier()
# write-back loss
probs = triton.load(READ_PROBS)
triton.store(LOSS + row, probs)
probs = tl.load(READ_PROBS)
tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@@ -53,20 +54,20 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to probs
PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy
probs = -triton.load(PROBS, mask=cols < N, other=float('inf'))
probs = triton.exp(probs.to(triton.float32))
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
probs = tl.exp(probs.to(tl.float32))
delta = cols == idx
# write result in-place in PROBS
dout = triton.load(DPROBS + row)
dout = tl.load(DPROBS + row)
din = (probs - delta) * dout
triton.store(PROBS, din.to(triton.float16), mask=cols < N)
tl.store(PROBS, din.to(tl.float16), mask=cols < N)
class _cross_entropy(torch.autograd.Function):

View File

@@ -1,4 +1,5 @@
import torch
import triton.language as tl
import triton
@@ -27,8 +28,8 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K']
# matrix multiplication
pid = triton.program_id(0)
pid_z = triton.program_id(1)
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
@@ -38,46 +39,46 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
K = K // SPLIT_K
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
if META['EVEN_K']:
a = triton.load(A)
b = triton.load(B)
a = tl.load(A)
b = tl.load(B)
else:
a = triton.load(A, mask=rk[None, :] < k, other=0.)
b = triton.load(B, mask=rk[:, None] < k, other=0.)
acc += triton.dot(a, b)
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(triton.float16)
acc = acc.to(tl.float16)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
triton.store(C, acc, mask=mask)
tl.store(C, acc, mask=mask)
else:
LOCKS = LOCKS + triton.program_id(0)
COUNT = LOCKS + triton.num_programs(0)
while triton.atomic_cas(LOCKS, 0, 1) == 1:
LOCKS = LOCKS + tl.program_id(0)
COUNT = LOCKS + tl.num_programs(0)
while tl.atomic_cas(LOCKS, 0, 1) == 1:
pass
count = triton.load(COUNT)
count = tl.load(COUNT)
if count == 0:
triton.store(C, acc, mask=mask)
tl.store(C, acc, mask=mask)
else:
curr = triton.load(C, mask=mask, other=0.)
triton.store(C, acc + curr, mask=mask)
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
triton.atomic_xchg(LOCKS, 0)
curr = tl.load(C, mask=mask, other=0.)
tl.store(C, acc + curr, mask=mask)
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
tl.atomic_xchg(LOCKS, 0)
class _matmul(torch.autograd.Function):

View File

@@ -81,6 +81,22 @@ def random(shape, dtype, device):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
"""
# Estimate the runtime of the function
fn()
torch.cuda.synchronize()
@@ -125,13 +141,16 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
class Benchmark:
"""
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
"""
def __init__(
self,
x_names,
x_vals,
y_name,
y_vals,
y_lines,
line_arg,
line_vals,
line_names,
plot_name,
args,
xlabel='',
@@ -139,12 +158,38 @@ class Benchmark:
x_log=False,
y_log=False,
):
"""
Constructor
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:type x_names: List[str]
:param x_vals: List of values to use for the arguments in :code:`x_names`.
:type x_vals: List[Any]
:param line_arg: Argument name for which different values correspond to different lines in the plot.
:type line_arg: str
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
:type line_vals: List[str]
:param line_names: Label names for the different lines.
:type line_names: List[str]
:param plot_name: Name of the plot.
:type plot_name: str
:param args: List of arguments to remain fixed throughout the benchmark.
:type args: List[str]
:param xlabel: Label for the x axis of the plot.
:type xlabel: str, optional
:param ylabel: Label for the y axis of the plot.
:type ylabel: str, optional
:param x_log: Whether the x axis should be log scale.
:type x_log: bool, optional
:param y_log: Whether the y axis should be log scale.
:type y_log: bool, optional
"""
self.x_names = x_names
self.x_vals = x_vals
self.x_log = x_log
self.y_name = y_name
self.y_vals = y_vals
self.y_lines = y_lines
self.line_arg = line_arg
self.line_vals = line_vals
self.line_names = line_names
self.y_log = y_log
# plot info
self.xlabel = xlabel
@@ -162,15 +207,15 @@ class Mark:
import matplotlib.pyplot as plt
import pandas as pd
import os
y_mean = bench.y_lines
y_min = [f'{x}-min' for x in bench.y_lines]
y_max = [f'{x}-max' for x in bench.y_lines]
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row_mean, row_min, row_max = [], [], []
for y in bench.y_vals:
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
try:
y_mean, y_min, y_max = ret
except TypeError:
@@ -183,7 +228,7 @@ class Mark:
plt.figure()
ax = plt.subplot()
x = bench.x_names[0]
for y in bench.y_lines:
for y in bench.line_names:
y_min, y_max = df[y + '-min'], df[y + '-max']
ax.plot(df[x], df[y], label=y)
if y_min is not None and y_max is not None:
@@ -199,7 +244,7 @@ class Mark:
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.y_lines]
df = df[[bench.x_names[0]] + bench.line_names]
if print_data:
print(df)
if save_path:
@@ -220,5 +265,11 @@ class Mark:
def perf_report(benchmarks):
"""
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
:param benchmarks: Benchmarking configurations.
:type benchmarks: List of :class:`Benchmark`
"""
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper

View File

@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# --------------------------
import torch
import triton.language as tl
import triton
@@ -24,19 +25,19 @@ def _add(
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
pid = tl.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
x = tl.load(X + offsets, mask=mask)
y = tl.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
tl.store(Z + offsets, z)
# %%
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton'], # possible keys for `y_name`
y_lines=["Torch", "Triton"], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton'], # possible values for `line_arg`
line_names=["Torch", "Triton"], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name`

View File

@@ -46,28 +46,29 @@ def naive_softmax(x):
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
m = tl.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
x = tl.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
z = x - tl.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
tl.store(Y, y, mask=n < N)
# %%
@@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref))
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`

View File

@@ -115,6 +115,7 @@ You will specifically learn about:
import torch
import triton
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
@@ -124,9 +125,9 @@ import triton
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + triton.exp(-x))
ret_false = triton.exp(x) / (1 + triton.exp(x))
return triton.where(x >= 0, ret_true, ret_false)
ret_true = 1 / (1 + tl.exp(-x))
ret_false = tl.exp(x) / (1 + tl.exp(x))
return tl.where(x >= 0, ret_true, ret_false)
@triton.jit
@@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
pid = triton.program_id(0)
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
@@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
a = triton.load(A)
b = triton.load(B)
acc += triton.dot(a, b)
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
@@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
triton.store(C, acc, mask=mask)
tl.store(C, acc, mask=mask)
# %%
@@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1))
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
@@ -258,4 +259,4 @@ def benchmark(M, N, K, provider):
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(print_data=True)
benchmark.run(show_plots=True, print_data=True)