[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -1,7 +1,7 @@
Welcome to Triton's documentation! Welcome to Triton's documentation!
================================== ==================================
Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware. Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Getting Started Getting Started
--------------- ---------------
@@ -17,18 +17,22 @@ Getting Started
getting-started/installation getting-started/installation
getting-started/tutorials/index getting-started/tutorials/index
Language Reference Python API
------------------- -------------------
- Checkout the :doc:`Python API Documentation <language-reference/python-api/index>` - :doc:`triton <python-api/triton>`
- :doc:`triton.language <python-api/triton.language>`
- :doc:`triton.testing <python-api/triton.testing>`
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Language Reference :caption: Python API
:hidden: :hidden:
language-reference/python-api/index python-api/triton
python-api/triton.language
python-api/triton.testing
Going Further Going Further

View File

@@ -1,7 +1,7 @@
Python API triton.language
=========== ================
.. currentmodule:: triton .. currentmodule:: triton.language
Programming Model Programming Model

View File

@@ -0,0 +1,10 @@
triton
========
.. currentmodule:: triton
.. autosummary::
:toctree: generated
:nosignatures:
jit

View File

@@ -0,0 +1,12 @@
triton.testing
================
.. currentmodule:: triton.testing
.. autosummary::
:toctree: generated
:nosignatures:
do_bench
Benchmark
perf_report

View File

@@ -10,9 +10,9 @@ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block', line_arg = 'block',
y_vals = [16, 32, 64], line_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'], line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'TFLOPS', ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode, args = {'layout_mode': layout_mode, 'op_mode': op_mode,
@@ -60,9 +60,9 @@ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N'], x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block', line_arg = 'block',
y_vals = [16, 32, 64], line_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'], line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS', ylabel = 'GBPS',
plot_name = f'{layout_mode}-square', plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}

View File

@@ -5,9 +5,9 @@ confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['N'], x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
y_name = 'provider', line_arg = 'provider',
y_vals = ['triton', 'torch'], line_vals = ['triton', 'torch'],
y_lines = ['Triton', 'Torch'], line_names = ['Triton', 'Torch'],
ylabel = 'GBPS', ylabel = 'GBPS',
plot_name = f'{mode}-2048', plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}

View File

@@ -16,9 +16,9 @@ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["M", "N", "K"], x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128), x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider", line_arg="provider",
y_vals=["cublas", "triton", "cutlass"], line_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"], line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16}, args={"AT": AT, "BT": BT, "dtype": torch.float16},
@@ -30,9 +30,9 @@ transformer_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=[x], x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128), x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider", line_arg="provider",
y_vals=["cublas", "triton", "cutlass"], line_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"], line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}

View File

@@ -1,5 +1,6 @@
import torch import torch
import triton import triton
import triton.language as tl
import copy import copy
import pytest import pytest
import ast import ast
@@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, **meta): def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE']) off = tl.arange(0, meta['SIZE'])
x = triton.load(X + off) x = tl.load(X + off)
z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
triton.store(Z + off, z) tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs # inputs
@@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, Y, **meta): def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE']) off = tl.arange(0, meta['SIZE'])
x = triton.load(X + off) x = tl.load(X + off)
y = triton.load(Y + off) y = tl.load(Y + off)
z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
triton.store(Z + off, z) tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs # inputs
@@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
stride = 1 stride = 1
for i in reversed(range(rank)): for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}'] offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i] stride *= shape[i]
return f"{name} + {' + '.join(offsets)}" return f"{name} + {' + '.join(offsets)}"
@@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
@triton.jit @triton.jit
def kernel(Z, X, **meta): def kernel(Z, X, **meta):
SIZE = meta['SIZE'] SIZE = meta['SIZE']
m = triton.arange(0, SIZE) m = tl.arange(0, SIZE)
n = triton.arange(0, SIZE) n = tl.arange(0, SIZE)
x = triton.load(X_PTR_EXPR) x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z) tl.store(Z_PTR_EXPR, z)
to_replace = { to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x), 'X_PTR_EXPR': make_ptr_str('X', shape_x),

View File

@@ -2,9 +2,9 @@
# or pybind11 shows `munmap_chunk(): invalid pointer` # or pybind11 shows `munmap_chunk(): invalid pointer`
import torch import torch
# submodules # submodules
from .code_gen import jit, autotune, heuristics, Config, Autotuner from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
from .core import *
from . import language
from . import code_gen from . import code_gen
from . import testing from . import testing
from . import ops from . import ops

View File

@@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor):
ret = self.builtins[name] ret = self.builtins[name]
else: else:
raise ValueError(f'{name} is not defined') raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block): if isinstance(ret, triton.language.block):
handle = self.module.get_value(name) handle = self.module.get_value(name)
return triton.block(handle) return triton.language.block(handle)
return ret return ret
def set_value(self, name, value): def set_value(self, name, value):
if isinstance(value, _triton.ir.value): if isinstance(value, _triton.ir.value):
value = triton.block(value) value = triton.language.block(value)
if isinstance(value, triton.block): if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle) self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type) self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value self.lscope[name] = value
def is_triton_object(self, value): def is_triton_object(self, value):
return isinstance(value, triton.block) return isinstance(value, triton.language.block)
def visit_compound_statement(self, stmts, add_scope=False): def visit_compound_statement(self, stmts, add_scope=False):
if add_scope: if add_scope:
@@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor):
self.constants = constants self.constants = constants
self.kwargs = kwargs self.kwargs = kwargs
self.last_node = None self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr} self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'getattr': getattr,
}
def visit_Module(self, node): def visit_Module(self, node):
self.module.add_new_scope() self.module.add_new_scope()
@@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]]) pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)]) pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\ self.visit(pos_cond_node),\
self.visit(neg_cond_node),\ self.visit(neg_cond_node),\
builder=self.builder) builder=self.builder)
@@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction): if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws) return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core: sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws) return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws) return fn(*args, **kws)
@@ -613,6 +620,11 @@ class JITFunction:
raise e raise e
raise CompilationError(self.src, node, e) raise CompilationError(self.src, node, e)
def __setattr__(self, name, value):
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
def _init_kernel(self): def _init_kernel(self):
if self.kernel is None: if self.kernel is None:
self.kernel = Kernel(self) self.kernel = Kernel(self)
@@ -659,4 +671,23 @@ def heuristics(values):
def jit(fn): def jit(fn):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
return JITFunction(fn) return JITFunction(fn)
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -1,6 +1,6 @@
import triton
from triton._C.libtriton.triton import ir from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend from triton._C.libtriton.triton import frontend
import triton
from functools import wraps from functools import wraps
@@ -25,7 +25,7 @@ def _patch(fn):
if x.type.is_void(): if x.type.is_void():
return None return None
return block(x) return block(x)
return x return tl
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
builder = args[-1] builder = args[-1]
@@ -547,7 +547,7 @@ def minimum(x, y):
:param other: the second input block :param other: the second input block
:type other: Block :type other: Block
""" """
return triton.where(x < y, x, y) return triton.language.where(x < y, x, y)
@triton.jit @triton.jit
@@ -560,7 +560,7 @@ def maximum(x, y):
:param other: the second input block :param other: the second input block
:type other: Block :type other: Block
""" """
return triton.where(x > y, x, y) return triton.language.where(x > y, x, y)
@triton.jit @triton.jit
@@ -571,7 +571,7 @@ def sigmoid(x):
:param x: the input block :param x: the input block
:type x: Block :type x: Block
""" """
return 1 / (1 + np.exp(-x)) return 1 / (1 + triton.language.exp(-x))
@triton.jit @triton.jit
@@ -582,9 +582,9 @@ def softmax(x):
:param x: the input block :param x: the input block
:type x: Block :type x: Block
""" """
z = x - triton.max(x, 0) z = x - triton.language.max(x, 0)
num = triton.exp(z) num = triton.language.exp(z)
den = triton.sum(num, 0) den = triton.language.sum(num, 0)
return num / den return num / den
@@ -596,8 +596,4 @@ def ravel(x):
:param x: the input block :param x: the input block
:type x: Block :type x: Block
""" """
return triton.reshape(x, [x.type.numel]) return triton.language.reshape(x, [x.type.numel])
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -1,4 +1,5 @@
import triton import triton
import triton.language as tl
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
import torch import torch
@@ -16,21 +17,21 @@ def _kernel(
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
pid0 = triton.program_id(0) pid0 = tl.program_id(0)
pid1 = triton.program_id(1) pid1 = tl.program_id(1)
pidz = triton.program_id(2) pidz = tl.program_id(2)
if meta['SDD']: if meta['SDD']:
pid1 = pid1 + SDD_off_width pid1 = pid1 + SDD_off_width
blockidm = triton.arange(0, TM) // BLOCK blockidm = tl.arange(0, TM) // BLOCK
blockidn = triton.arange(0, TN) // BLOCK blockidn = tl.arange(0, TN) // BLOCK
offlutm = blockidm * (TN // BLOCK) * 4 offlutm = blockidm * (TN // BLOCK) * 4
offlutn = blockidn * 4 offlutn = blockidn * 4
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
z = triton.load(header + 0) z = tl.load(header + 0)
i = triton.load(header + 1 + offlutm) i = tl.load(header + 1 + offlutm)
j = triton.load(header + 2 + offlutn) j = tl.load(header + 2 + offlutn)
AS1 = SDD_K // TZ AS1 = SDD_K // TZ
lockid = triton.where(TZ > 1, 1, 0) lockid = tl.where(TZ > 1, 1, 0)
offka = pid0 * AS1 offka = pid0 * AS1
offkb = pid0 * AS1 offkb = pid0 * AS1
offmc = 0 offmc = 0
@@ -41,16 +42,16 @@ def _kernel(
offhc = 0 offhc = 0
offha = z offha = z
offhb = z offhb = z
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK) ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK) rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
else: else:
header = lut + pid0 * 6 header = lut + pid0 * 6
offset = triton.load(header + 0) offset = tl.load(header + 0)
AS1 = triton.load(header + 1) AS1 = tl.load(header + 1)
column = triton.load(header + 2) column = tl.load(header + 2)
depth = triton.load(header + 3) depth = tl.load(header + 3)
lockid = triton.load(header + 4) lockid = tl.load(header + 4)
maxid = triton.load(header + 5) maxid = tl.load(header + 5)
pinc = lut + offset pinc = lut + offset
offhc = depth offhc = depth
if meta['DSD']: if meta['DSD']:
@@ -60,14 +61,14 @@ def _kernel(
offpc = 0 offpc = 0
# dense input offset # dense input offset
offnb = pid1 * TN offnb = pid1 * TN
offkb = triton.load(pinc) offkb = tl.load(pinc)
offkb = triton.multiple_of(offkb, 8) # compiler hint offkb = tl.multiple_of(offkb, 8) # compiler hint
offpb = 0 offpb = 0
# sparse input offset # sparse input offset
offma = 0 offma = 0
offka = 0 offka = 0
offpa = triton.load(pinc + 1) offpa = tl.load(pinc + 1)
offpa = triton.multiple_of(offpa, 8) # compiler hint offpa = tl.multiple_of(offpa, 8) # compiler hint
offpa = offpa * BLOCK * BLOCK offpa = offpa * BLOCK * BLOCK
offha = 0 offha = 0
offhb = depth offhb = depth
@@ -78,23 +79,23 @@ def _kernel(
offpc = 0 offpc = 0
# dense input offset # dense input offset
offma = pid1 * TM offma = pid1 * TM
offka = triton.load(pinc) offka = tl.load(pinc)
offka = triton.multiple_of(offka, 8) # compiler hint offka = tl.multiple_of(offka, 8) # compiler hint
offpa = 0 offpa = 0
# sparse input offset # sparse input offset
offnb = 0 offnb = 0
offkb = 0 offkb = 0
offpb = triton.load(pinc + 1) offpb = tl.load(pinc + 1)
offpb = triton.multiple_of(offpb, 8) # compiler hint offpb = tl.multiple_of(offpb, 8) # compiler hint
offpb = offpb * BLOCK * BLOCK offpb = offpb * BLOCK * BLOCK
offha = depth offha = depth
offhb = 0 offhb = 0
ram = offma + triton.arange(0, TM) ram = offma + tl.arange(0, TM)
rbn = offnb + triton.arange(0, TN) rbn = offnb + tl.arange(0, TN)
# initialize a, b pointers # initialize a, b pointers
rka = offka + triton.arange(0, TK) rka = offka + tl.arange(0, TK)
rkb = offkb + triton.arange(0, TK) rkb = offkb + tl.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']: if meta['DDS']:
@@ -105,31 +106,31 @@ def _kernel(
checkbn = rbn[None, :] < DS0 checkbn = rbn[None, :] < DS0
else: else:
checkbn = AS1 > 0 checkbn = AS1 > 0
a = triton.load(pa, mask=checkam, other=0.) a = tl.load(pa, mask=checkam, other=0.)
b = triton.load(pb, mask=checkbn, other=0.) b = tl.load(pb, mask=checkbn, other=0.)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
acc = triton.zeros((TM, TN), dtype=triton.float32) acc = tl.zeros((TM, TN), dtype=tl.float32)
for k in range(AS1, 0, -TK): for k in range(AS1, 0, -TK):
acc += triton.dot(a, b) acc += tl.dot(a, b)
if meta['SDD']: if meta['SDD']:
inc_a = TK * stride_ka inc_a = TK * stride_ka
inc_b = TK * stride_kb inc_b = TK * stride_kb
else: else:
pinc += 2 pinc += 2
if meta['DSD']: if meta['DSD']:
inc_b = triton.load(pinc) inc_b = tl.load(pinc)
inc_a = triton.load(pinc + 1) inc_a = tl.load(pinc + 1)
inc_b = triton.multiple_of(inc_b, 8) inc_b = tl.multiple_of(inc_b, 8)
inc_a = triton.multiple_of(inc_a, 8) inc_a = tl.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb inc_b = inc_b * stride_kb
if meta['DDS']: if meta['DDS']:
inc_a = triton.load(pinc) inc_a = tl.load(pinc)
inc_b = triton.load(pinc + 1) inc_b = tl.load(pinc + 1)
inc_a = triton.multiple_of(inc_a, 8) inc_a = tl.multiple_of(inc_a, 8)
inc_b = triton.multiple_of(inc_b, 8) inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka inc_a = inc_a * stride_ka
pa += inc_a pa += inc_a
pb += inc_b pb += inc_b
@@ -138,24 +139,24 @@ def _kernel(
checkbk = k > TK checkbk = k > TK
checka = checkam & checkak checka = checkam & checkak
checkb = checkbn & checkbk checkb = checkbn & checkbk
a = triton.load(pa, mask=checka) a = tl.load(pa, mask=checka)
b = triton.load(pb, mask=checkb) b = tl.load(pb, mask=checkb)
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
if meta['SDD']: if meta['SDD']:
checkc = True checkc = True
rr_blockidm = triton.arange(0, TM) // BLOCK rr_blockidm = tl.arange(0, TM) // BLOCK
rr_blockidn = triton.arange(0, TN) // BLOCK rr_blockidn = tl.arange(0, TN) // BLOCK
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
rr_offlutn = rr_blockidn * 4 rr_offlutn = rr_blockidn * 4
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
bkid = triton.load(header + off_bkid) bkid = tl.load(header + off_bkid)
offpc = bkid * BLOCK * BLOCK offpc = bkid * BLOCK * BLOCK
rcm = triton.arange(0, TM) % BLOCK rcm = tl.arange(0, TM) % BLOCK
rcn = triton.arange(0, TN) % BLOCK rcn = tl.arange(0, TN) % BLOCK
else: else:
rcm = offmc + triton.arange(0, TM) rcm = offmc + tl.arange(0, TM)
rcn = offnc + triton.arange(0, TN) rcn = offnc + tl.arange(0, TN)
if meta['DSD']: if meta['DSD']:
checkc = rcn[None, :] < DS0 checkc = rcn[None, :] < DS0
if meta['DDS']: if meta['DDS']:
@@ -164,21 +165,21 @@ def _kernel(
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
# write-back directly # write-back directly
if lockid == 0: if lockid == 0:
triton.store(pc, c, mask=checkc) tl.store(pc, c, mask=checkc)
# accumulate partial results using spin-locks # accumulate partial results using spin-locks
else: else:
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1 plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
while triton.atomic_cas(plock, 0, 1) == 1: while tl.atomic_cas(plock, 0, 1) == 1:
pass pass
count = triton.load(pcount) count = tl.load(pcount)
if count == 0: if count == 0:
triton.store(pc, c, mask=checkc) tl.store(pc, c, mask=checkc)
else: else:
d = triton.load(pc, mask=checkc) d = tl.load(pc, mask=checkc)
triton.store(pc, d + c, mask=checkc) tl.store(pc, d + c, mask=checkc)
triton.atomic_xchg(pcount, (count + 1) % maxid) tl.atomic_xchg(pcount, (count + 1) % maxid)
triton.atomic_xchg(plock, 0) tl.atomic_xchg(plock, 0)
############## ##############

View File

@@ -1,3 +1,4 @@
import triton.language as tl
import triton import triton
import torch import torch
import os import os
@@ -31,86 +32,86 @@ def _forward(
): ):
TN = meta['TN'] TN = meta['TN']
BLOCK = meta['BLOCK'] BLOCK = meta['BLOCK']
pidhm = triton.program_id(0) pidhm = tl.program_id(0)
pidz = triton.program_id(1) pidz = tl.program_id(1)
# create index ranges # create index ranges
rxm = pidhm % BLOCK rxm = pidhm % BLOCK
rbm = pidhm // BLOCK rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK rxn = tl.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK rbn = tl.arange(0, TN) // BLOCK
# extract information from LUT # extract information from LUT
header = LUT + rbm * 2 header = LUT + rbm * 2
size = triton.load(header + 0) size = tl.load(header + 0)
offset = triton.load(header + 1) offset = tl.load(header + 1)
check = rbn < size check = rbn < size
rbmn = triton.where(check, rbn, size - 1) rbmn = tl.where(check, rbn, size - 1)
# block id and column id # block id and column id
blockid = triton.load(LUT + offset + rbmn * 4 + 0) blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = triton.load(LUT + offset + rbmn * 4 + 1) columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = triton.load(LUT + offset + rbmn * 4 + 2) rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = triton.load(LUT + offset + rbmn * 4 + 3) headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X # pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = triton.load(px, mask=check, other=-float('inf')) x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(triton.float32) x = x.to(tl.float32)
# apply scale # apply scale
if meta['APPLY_SCALE']: if meta['APPLY_SCALE']:
x = x * scale x = x * scale
# apply RPE # apply RPE
if meta['APPLY_RPE']: if meta['APPLY_RPE']:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = triton.load(prpe, mask=check, other=0) rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe x = x + rpe
# apply key-padding mask # apply key-padding mask
if meta['APPLY_KP_MASK']: if meta['APPLY_KP_MASK']:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = triton.load(pkp_m, mask=check, other=-float('inf')) kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']: if meta['KP_MASK_MUL']:
kp_m = triton.where(kp_m == 0, -float('inf'), 0.) kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m x = x + kp_m
# apply attention mask # apply attention mask
if meta['APPLY_ATTN_MASK']: if meta['APPLY_ATTN_MASK']:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = triton.load(pattn_m, mask=check, other=-float('inf')) attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']: if meta['ATTN_MASK_MUL']:
attn_m = triton.where(attn_m == 0, -float('inf'), 0.) attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m x = x + attn_m
# computation # computation
x = triton.softmax(x) x = tl.softmax(x)
triton.store(px, x, mask=check) tl.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']}) @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit @triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = triton.program_id(0) pidhm = tl.program_id(0)
pidz = triton.program_id(1) pidz = tl.program_id(1)
TN = meta['TN'] TN = meta['TN']
BLOCK = meta['BLOCK'] BLOCK = meta['BLOCK']
# create index ranges # create index ranges
rxm = pidhm % BLOCK rxm = pidhm % BLOCK
rbm = pidhm // BLOCK rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK rxn = tl.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK rbn = tl.arange(0, TN) // BLOCK
# extract information from look-up table # extract information from look-up table
header = LUT + rbm * 2 header = LUT + rbm * 2
size = triton.load(header + 0) size = tl.load(header + 0)
offset = triton.load(header + 1) offset = tl.load(header + 1)
# bounds checking on lut # bounds checking on lut
check = rbn < size check = rbn < size
rbmn = triton.where(check, rbn, size - 1) rbmn = tl.where(check, rbn, size - 1)
# initialize pointers to block-sparse input # initialize pointers to block-sparse input
blockid = triton.load(LUT + offset + rbmn * 4) blockid = tl.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward # compute fused softmax backward
x = triton.load(X, mask=check, other=0) x = tl.load(X, mask=check, other=0)
dx = triton.load(DX, mask=check, other=0) dx = tl.load(DX, mask=check, other=0)
x = x.to(triton.float32) x = x.to(tl.float32)
dx = dx.to(triton.float32) dx = dx.to(tl.float32)
y = x * (dx - triton.sum(x * dx, 0)) * scale y = x * (dx - tl.sum(x * dx, 0)) * scale
triton.store(DX, y, mask=check) tl.store(DX, y, mask=check)
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):

View File

@@ -1,5 +1,6 @@
import os import os
import triton import triton
import triton.language as tl
import torch import torch
@@ -27,25 +28,25 @@ def num_warps(N):
@triton.jit @triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
BLOCK = meta['BLOCK'] BLOCK = meta['BLOCK']
row = triton.program_id(0) row = tl.program_id(0)
cols = triton.arange(0, BLOCK) cols = tl.arange(0, BLOCK)
idx = triton.load(IDX + row) idx = tl.load(IDX + row)
# pointers to logit and probs # pointers to logit and probs
LOGITS = LOGITS + row * N + cols LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs # write-back negative log-probs
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf')) logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(triton.float32) logits = logits.to(tl.float32)
logits = logits - triton.max(logits, 0) logits = logits - tl.max(logits, 0)
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
triton.store(WRIT_PROBS, probs, mask=cols < N) tl.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here. # There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon. # We add it explicitly for now. Will be fixed soon.
triton.debug_barrier() tl.debug_barrier()
# write-back loss # write-back loss
probs = triton.load(READ_PROBS) probs = tl.load(READ_PROBS)
triton.store(LOSS + row, probs) tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@@ -53,20 +54,20 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
@triton.jit @triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta): def _backward(PROBS, IDX, DPROBS, N, **meta):
BLOCK = meta['BLOCK'] BLOCK = meta['BLOCK']
row = triton.program_id(0) row = tl.program_id(0)
cols = triton.arange(0, BLOCK) cols = tl.arange(0, BLOCK)
idx = triton.load(IDX + row) idx = tl.load(IDX + row)
# pointers to probs # pointers to probs
PROBS = PROBS + row * N + cols PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy # and we have -log(p[k]) stored in PROBS, so this is easy
probs = -triton.load(PROBS, mask=cols < N, other=float('inf')) probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
probs = triton.exp(probs.to(triton.float32)) probs = tl.exp(probs.to(tl.float32))
delta = cols == idx delta = cols == idx
# write result in-place in PROBS # write result in-place in PROBS
dout = triton.load(DPROBS + row) dout = tl.load(DPROBS + row)
din = (probs - delta) * dout din = (probs - delta) * dout
triton.store(PROBS, din.to(triton.float16), mask=cols < N) tl.store(PROBS, din.to(tl.float16), mask=cols < N)
class _cross_entropy(torch.autograd.Function): class _cross_entropy(torch.autograd.Function):

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton.language as tl
import triton import triton
@@ -27,8 +28,8 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
GROUP_M = META['GROUP_M'] GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K'] SPLIT_K = META['SPLIT_K']
# matrix multiplication # matrix multiplication
pid = triton.program_id(0) pid = tl.program_id(0)
pid_z = triton.program_id(1) pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance # re-order program ID for better L2 performance
@@ -38,46 +39,46 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
pid_m = group_id * GROUP_M + (pid % group_size) pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size) pid_n = (pid % width) // (group_size)
# do matrix multiplication # do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K) rk = tl.arange(0, BLOCK_K)
# pointers # pointers
K = K // SPLIT_K K = K // SPLIT_K
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K): for k in range(K, 0, -BLOCK_K):
if META['EVEN_K']: if META['EVEN_K']:
a = triton.load(A) a = tl.load(A)
b = triton.load(B) b = tl.load(B)
else: else:
a = triton.load(A, mask=rk[None, :] < k, other=0.) a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = triton.load(B, mask=rk[:, None] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += triton.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * stride_ak A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk B += BLOCK_K * stride_bk
acc = acc.to(triton.float16) acc = acc.to(tl.float16)
# rematerialize rm and rn to save registers # rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting # handles write-back with reduction-splitting
if SPLIT_K == 1: if SPLIT_K == 1:
triton.store(C, acc, mask=mask) tl.store(C, acc, mask=mask)
else: else:
LOCKS = LOCKS + triton.program_id(0) LOCKS = LOCKS + tl.program_id(0)
COUNT = LOCKS + triton.num_programs(0) COUNT = LOCKS + tl.num_programs(0)
while triton.atomic_cas(LOCKS, 0, 1) == 1: while tl.atomic_cas(LOCKS, 0, 1) == 1:
pass pass
count = triton.load(COUNT) count = tl.load(COUNT)
if count == 0: if count == 0:
triton.store(C, acc, mask=mask) tl.store(C, acc, mask=mask)
else: else:
curr = triton.load(C, mask=mask, other=0.) curr = tl.load(C, mask=mask, other=0.)
triton.store(C, acc + curr, mask=mask) tl.store(C, acc + curr, mask=mask)
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K) tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
triton.atomic_xchg(LOCKS, 0) tl.atomic_xchg(LOCKS, 0)
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):

View File

@@ -81,6 +81,22 @@ def random(shape, dtype, device):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]): def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
"""
# Estimate the runtime of the function # Estimate the runtime of the function
fn() fn()
torch.cuda.synchronize() torch.cuda.synchronize()
@@ -125,13 +141,16 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
class Benchmark: class Benchmark:
"""
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
"""
def __init__( def __init__(
self, self,
x_names, x_names,
x_vals, x_vals,
y_name, line_arg,
y_vals, line_vals,
y_lines, line_names,
plot_name, plot_name,
args, args,
xlabel='', xlabel='',
@@ -139,12 +158,38 @@ class Benchmark:
x_log=False, x_log=False,
y_log=False, y_log=False,
): ):
"""
Constructor
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:type x_names: List[str]
:param x_vals: List of values to use for the arguments in :code:`x_names`.
:type x_vals: List[Any]
:param line_arg: Argument name for which different values correspond to different lines in the plot.
:type line_arg: str
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
:type line_vals: List[str]
:param line_names: Label names for the different lines.
:type line_names: List[str]
:param plot_name: Name of the plot.
:type plot_name: str
:param args: List of arguments to remain fixed throughout the benchmark.
:type args: List[str]
:param xlabel: Label for the x axis of the plot.
:type xlabel: str, optional
:param ylabel: Label for the y axis of the plot.
:type ylabel: str, optional
:param x_log: Whether the x axis should be log scale.
:type x_log: bool, optional
:param y_log: Whether the y axis should be log scale.
:type y_log: bool, optional
"""
self.x_names = x_names self.x_names = x_names
self.x_vals = x_vals self.x_vals = x_vals
self.x_log = x_log self.x_log = x_log
self.y_name = y_name self.line_arg = line_arg
self.y_vals = y_vals self.line_vals = line_vals
self.y_lines = y_lines self.line_names = line_names
self.y_log = y_log self.y_log = y_log
# plot info # plot info
self.xlabel = xlabel self.xlabel = xlabel
@@ -162,15 +207,15 @@ class Mark:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os import os
y_mean = bench.y_lines y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.y_lines] y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.y_lines] y_max = [f'{x}-max' for x in bench.line_names]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
for x in bench.x_vals: for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names} x_args = {x_name: x for x_name in bench.x_names}
row_mean, row_min, row_max = [], [], [] row_mean, row_min, row_max = [], [], []
for y in bench.y_vals: for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args) ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
try: try:
y_mean, y_min, y_max = ret y_mean, y_min, y_max = ret
except TypeError: except TypeError:
@@ -183,7 +228,7 @@ class Mark:
plt.figure() plt.figure()
ax = plt.subplot() ax = plt.subplot()
x = bench.x_names[0] x = bench.x_names[0]
for y in bench.y_lines: for y in bench.line_names:
y_min, y_max = df[y + '-min'], df[y + '-max'] y_min, y_max = df[y + '-min'], df[y + '-max']
ax.plot(df[x], df[y], label=y) ax.plot(df[x], df[y], label=y)
if y_min is not None and y_max is not None: if y_min is not None and y_max is not None:
@@ -199,7 +244,7 @@ class Mark:
plt.show() plt.show()
if save_path: if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.y_lines] df = df[[bench.x_names[0]] + bench.line_names]
if print_data: if print_data:
print(df) print(df)
if save_path: if save_path:
@@ -220,5 +265,11 @@ class Mark:
def perf_report(benchmarks): def perf_report(benchmarks):
"""
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
:param benchmarks: Benchmarking configurations.
:type benchmarks: List of :class:`Benchmark`
"""
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper

View File

@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# -------------------------- # --------------------------
import torch import torch
import triton.language as tl
import triton import triton
@@ -24,19 +25,19 @@ def _add(
N, # Size of the vector N, # Size of the vector
**meta # Optional meta-parameters for the kernel **meta # Optional meta-parameters for the kernel
): ):
pid = triton.program_id(0) pid = tl.program_id(0)
# Create an offset for the blocks of pointers to be # Create an offset for the blocks of pointers to be
# processed by this program instance # processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK']) offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against # Create a mask to guard memory operations against
# out-of-bounds accesses # out-of-bounds accesses
mask = offsets < N mask = offsets < N
# Load x # Load x
x = triton.load(X + offsets, mask=mask) x = tl.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask) y = tl.load(Y + offsets, mask=mask)
# Write back x + y # Write back x + y
z = x + y z = x + y
triton.store(Z + offsets, z) tl.store(Z + offsets, z)
# %% # %%
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
x_names=['size'], # argument names to use as an x-axis for the plot x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic x_log=True, # x axis is logarithmic
y_name='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton'], # possible keys for `y_name` line_vals=['torch', 'triton'], # possible values for `line_arg`
y_lines=["Torch", "Triton"], # label name for the lines line_names=["Torch", "Triton"], # label name for the lines
ylabel="GB/s", # label name for the y-axis ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name` args={} # values for function arguments not in `x_names` and `y_name`

View File

@@ -46,28 +46,29 @@ def naive_softmax(x):
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes: # so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
import triton import triton
import triton.language as tl
@triton.jit @triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta): def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index # row index
m = triton.program_id(0) m = tl.program_id(0)
# col indices # col indices
n = triton.arange(0, meta['BLOCK']) n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements # the memory address of all the elements
# that we want to load can be computed as follows # that we want to load can be computed as follows
X = X + m * stride_xm + n X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf')) x = tl.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability # Substract maximum for numerical stability
z = x - triton.max(x, axis=0) z = x - tl.max(x, axis=0)
# Note that exponentials in Triton are fast # Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA) # but approximate (i.e., think __expf in CUDA)
num = triton.exp(z) num = tl.exp(z)
denom = triton.sum(num, axis=0) denom = tl.sum(num, axis=0)
y = num / denom y = num / denom
# Write back to Y # Write back to Y
Y = Y + m * stride_ym + n Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N) tl.store(Y, y, mask=n < N)
# %% # %%
@@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref))
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name` x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name` line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines line_names=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name` args={'M': 4096} # values for function arguments not in `x_names` and `y_name`

View File

@@ -115,6 +115,7 @@ You will specifically learn about:
import torch import torch
import triton import triton
import triton.language as tl
# % # %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
@@ -124,9 +125,9 @@ import triton
@triton.jit @triton.jit
def sigmoid(x): def sigmoid(x):
ret_true = 1 / (1 + triton.exp(-x)) ret_true = 1 / (1 + tl.exp(-x))
ret_false = triton.exp(x) / (1 + triton.exp(x)) ret_false = tl.exp(x) / (1 + tl.exp(x))
return triton.where(x >= 0, ret_true, ret_false) return tl.where(x >= 0, ret_true, ret_false)
@triton.jit @triton.jit
@@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
BLOCK_K = META['BLOCK_K'] BLOCK_K = META['BLOCK_K']
GROUP_M = 8 GROUP_M = 8
# matrix multiplication # matrix multiplication
pid = triton.program_id(0) pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance # re-order program ID for better L2 performance
@@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
pid_m = group_id * GROUP_M + (pid % group_size) pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size) pid_n = (pid % width) // (group_size)
# do matrix multiplication # do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K) rk = tl.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K): for k in range(K, 0, -BLOCK_K):
a = triton.load(A) a = tl.load(A)
b = triton.load(B) b = tl.load(B)
acc += triton.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * stride_ak A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function # triton can accept arbitrary activation function
@@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
if META['ACTIVATION']: if META['ACTIVATION']:
acc = META['ACTIVATION'](acc) acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers # rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N) mask = (rm[:, None] < M) & (rn[None, :] < N)
triton.store(C, acc, mask=mask) tl.store(C, acc, mask=mask)
# %% # %%
@@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1))
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name` x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton'], # possible keys for `y_name` line_vals=['cublas', 'triton'], # possible values for `line_arg``
y_lines=["cuBLAS", "Triton"], # label name for the lines line_names=["cuBLAS", "Triton"], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={} args={}
@@ -258,4 +259,4 @@ def benchmark(M, N, K, provider):
return perf(ms), perf(max_ms), perf(min_ms) return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(print_data=True) benchmark.run(show_plots=True, print_data=True)