[STYLE] add isort and autopep8 config files and check on CI (#423)

Also a fix a few more style issues from the "aggressive" mode of autopep8.
This commit is contained in:
Madeleine Thompson
2022-01-07 13:11:34 -08:00
committed by GitHub
parent 9801aa7b56
commit a70acfec77
11 changed files with 102 additions and 77 deletions

View File

@@ -30,6 +30,12 @@ jobs:
cd python
pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Unit tests
run: |
cd python/test/unit

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,2 +1,5 @@
[metadata]
description_file = README.md
[pycodestyle]
ignore = E501,E701,E731

View File

@@ -94,7 +94,7 @@ class CMakeBuild(build_ext):
"-DBUILD_PYTHON_MODULE=ON",
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
@@ -148,6 +148,8 @@ setup(
],
extras_require={
"tests": [
"autopep8",
"isort",
"numpy",
"pytest",
"scipy>=1.7.1",

View File

@@ -1,4 +1,3 @@
import triton.language as tl
import subprocess
import sys
@@ -7,6 +6,7 @@ import torch
from numpy import record
import triton
import triton.language as tl
#######################
# Utilities

View File

@@ -1,4 +1,4 @@
# version
"""isort:skip_file"""
__version__ = '2.0.0'
# TODO: torch needs to be imported first

View File

@@ -852,7 +852,7 @@ class Autotuner:
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook != None:
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

View File

@@ -293,7 +293,7 @@ class block:
dst_shape = []
curr = 0
for sl in slices:
if sl == None:
if sl is None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])

View File

@@ -26,9 +26,9 @@ def _sdd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
#------------#
#- Prologue -#
#------------#
# ------------ #
# - Prologue - #
# ------------ #
block_id = tl.program_id(1) + grid_offset
lut += block_id * 3
# offsets
@@ -39,21 +39,23 @@ def _sdd_kernel(
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha
+ offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak)
a_ptrs = A \
+ off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb
+ offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
b_ptrs = B \
+ off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if EVEN_K:
@@ -66,15 +68,16 @@ def _sdd_kernel(
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty)
## ---------------- ##
## Epilogue ##
## ---------------- ##
# ---------------- #
# Epilogue #
# ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc
+ block_id * stride_hc
+ offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc)
pc = C \
+ off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
@@ -134,9 +137,9 @@ def _dsd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
#------------#
#- Prologue -#
#------------#
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
@@ -168,9 +171,9 @@ def _dsd_kernel(
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
## ---------------- ##
## Inner Loop ##
## ---------------- ##
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
@@ -192,7 +195,8 @@ def _dsd_kernel(
# initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \
pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
@@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
# fmt: off
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes*step
segments = sizes * step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
@@ -301,13 +305,13 @@ def dsd_lut(layout, block, step, trans, device):
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets*2*div + 4*width
segments = segments*div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
width = col_id.size(0)
offsets = offsets * 2 * div + 4 * width
segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
@@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device):
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@triton.jit
def _dds_kernel(
A, B, C,
@@ -327,9 +333,9 @@ def _dds_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
):
#------------#
#- Prologue -#
#------------#
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
@@ -343,31 +349,31 @@ def _dds_kernel(
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (dense)
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M)
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K)
ptrs_a = A + pid_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \
+ block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
## ---------------- ##
## Inner Loop ##
## ---------------- ##
+ block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask = True)
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask=True)
acc += tl.dot(a, b)
pinc += 2
inc_a = tl.load(pinc)
@@ -377,21 +383,22 @@ def _dds_kernel(
inc_a = inc_a * stride_ka
ptrs_a += inc_a
ptrs_b += inc_b
## ---------------- ##
## Epilogue ##
## ---------------- ##
# ---------------- #
# Epilogue #
# ---------------- #
c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
+ pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
# write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
@@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
# fmt: off
_dds_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
return c
@@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@@ -474,8 +482,9 @@ class _matmul(torch.autograd.Function):
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, dout
None, None, None, None,\
None, None, None, None, None, dout
class matmul:
@@ -499,11 +508,11 @@ class matmul:
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out = None):
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,

View File

@@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels):
asm = asm[:-2] + ";"
ctrl = parseCtrl(sline)
# BRA target address
if BRA_RE.match(asm) != None:
if BRA_RE.match(asm) is not None:
target = int(BRA_RE.match(asm).group(2), 16)
if target in labels:
pass
@@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels):
def extract(file_path, fun):
if fun == None:
if fun is None:
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
else:
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
@@ -77,7 +77,7 @@ def extract(file_path, fun):
# /*0x...*/
fname_match = FNAME_RE.match(line)
# Looking for new function header (function: <name>)
while FNAME_RE.match(line) == None:
while FNAME_RE.match(line) is None:
line_idx += 1
if line_idx < len(sass_lines):
line = sass_lines[line_idx].decode()
@@ -94,7 +94,7 @@ def extract(file_path, fun):
# store sass asm in buffer and them print them (for labels)
# (ctrl, asm)
asm_buffer = []
while FLINE_RE.match(line) != None:
while FLINE_RE.match(line) is not None:
# First line (Offset ASM Encoding)
fline = sass_lines[line_idx].decode()
line_idx += 1

View File

@@ -16,10 +16,11 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):