[STYLE] add isort and autopep8 config files and check on CI (#423)

Also a fix a few more style issues from the "aggressive" mode of autopep8.
This commit is contained in:
Madeleine Thompson
2022-01-07 13:11:34 -08:00
committed by GitHub
parent 9801aa7b56
commit a70acfec77
11 changed files with 102 additions and 77 deletions

View File

@@ -30,6 +30,12 @@ jobs:
cd python cd python
pip3 install -e '.[tests]' pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Unit tests - name: Unit tests
run: | run: |
cd python/test/unit cd python/test/unit

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,2 +1,5 @@
[metadata] [metadata]
description_file = README.md description_file = README.md
[pycodestyle]
ignore = E501,E701,E731

View File

@@ -94,7 +94,7 @@ class CMakeBuild(build_ext):
"-DBUILD_PYTHON_MODULE=ON", "-DBUILD_PYTHON_MODULE=ON",
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable, # '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
@@ -148,6 +148,8 @@ setup(
], ],
extras_require={ extras_require={
"tests": [ "tests": [
"autopep8",
"isort",
"numpy", "numpy",
"pytest", "pytest",
"scipy>=1.7.1", "scipy>=1.7.1",

View File

@@ -1,4 +1,3 @@
import triton.language as tl
import subprocess import subprocess
import sys import sys
@@ -7,6 +6,7 @@ import torch
from numpy import record from numpy import record
import triton import triton
import triton.language as tl
####################### #######################
# Utilities # Utilities

View File

@@ -1,4 +1,4 @@
# version """isort:skip_file"""
__version__ = '2.0.0' __version__ = '2.0.0'
# TODO: torch needs to be imported first # TODO: torch needs to be imported first

View File

@@ -852,7 +852,7 @@ class Autotuner:
else: else:
config = self.configs[0] config = self.configs[0]
self.best_config = config self.best_config = config
if config.pre_hook != None: if config.pre_hook is not None:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

View File

@@ -293,7 +293,7 @@ class block:
dst_shape = [] dst_shape = []
curr = 0 curr = 0
for sl in slices: for sl in slices:
if sl == None: if sl is None:
dst_shape.append(1) dst_shape.append(1)
elif sl == slice(None, None, None): elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr]) dst_shape.append(src_shape[curr])

View File

@@ -26,9 +26,9 @@ def _sdd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr BLOCK: tl.constexpr, EVEN_K: tl.constexpr
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
block_id = tl.program_id(1) + grid_offset block_id = tl.program_id(1) + grid_offset
lut += block_id * 3 lut += block_id * 3
# offsets # offsets
@@ -39,21 +39,23 @@ def _sdd_kernel(
start_am = tl.load(lut + 1) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za a_ptrs = A \
+ off_h * stride_ha + off_z * stride_za \
+ offs_am[:, None] * stride_ma + off_h * stride_ha \
+ offs_ak[None, :] * stride_ak) + offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(lut + 2) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb b_ptrs = B \
+ off_h * stride_hb + off_z * stride_zb \
+ offs_bn[None, :] * stride_nb + off_h * stride_hb \
+ offs_bk[:, None] * stride_bk) + offs_bn[None, :] * stride_nb \
## ---------------- ## + offs_bk[:, None] * stride_bk
## Inner Loop ## # ---------------- #
## ---------------- ## # Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K): for k in range(K, 0, -TILE_K):
if EVEN_K: if EVEN_K:
@@ -66,15 +68,16 @@ def _sdd_kernel(
a_ptrs += TILE_K * stride_ak a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
## ---------------- ## # ---------------- #
## Epilogue ## # Epilogue #
## ---------------- ## # ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc pc = C \
+ block_id * stride_hc + off_z * stride_zc \
+ offs_cm[:, None] * stride_mc + block_id * stride_hc \
+ offs_cn[None, :] * stride_nc) + offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
@@ -134,9 +137,9 @@ def _dsd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
@@ -168,9 +171,9 @@ def _dsd_kernel(
+ off_h * stride_hb \ + off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## # ---------------- #
## Inner Loop ## # Inner Loop #
## ---------------- ## # ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2 pinc += 2
inc_a = tl.load(pinc + 1) inc_a = tl.load(pinc + 1)
@@ -192,7 +195,8 @@ def _dsd_kernel(
# initialize pointers to C # initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M) offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \ pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \ + pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \ + offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
@@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
TILE_N = 128 TILE_N = 128
# compute output # compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
# fmt: off
_dsd_kernel[grid]( _dsd_kernel[grid](
a, b, c, a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut, BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4, num_warps=4, GROUP_SIZE_M=4,
) )
# exit() # exit()
return c return c
def dsd_lut(layout, block, step, trans, device): def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1) sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True) head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten() sizes = sizes.flatten()
segments = sizes*step segments = sizes * step
# pointer increments # pointer increments
if trans: if trans:
nnz = layout.nonzero(as_tuple=False) nnz = layout.nonzero(as_tuple=False)
@@ -302,8 +306,8 @@ def dsd_lut(layout, block, step, trans, device):
A_incs = A_incs.view(-1) A_incs = A_incs.view(-1)
# create header # create header
width = col_id.size(0) width = col_id.size(0)
offsets = offsets*2*div + 4*width offsets = offsets * 2 * div + 4 * width
segments = segments*div segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments # create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
@@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device):
# ----------------------------- # -----------------------------
# Dense = Dense x Sparse (DDS) # Dense = Dense x Sparse (DDS)
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dds_kernel( def _dds_kernel(
A, B, C, A, B, C,
@@ -327,9 +333,9 @@ def _dds_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
@@ -343,7 +349,7 @@ def _dds_kernel(
off_h = tl.load(header + 3) off_h = tl.load(header + 3)
pinc = lut + offset pinc = lut + offset
# initialize pointers to A (dense) # initialize pointers to A (dense)
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc) start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8) start_ak = tl.multiple_of(start_ak, 8)
@@ -361,13 +367,13 @@ def _dds_kernel(
+ block_id * stride_hb \ + block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## # ---------------- #
## Inner Loop ## # Inner Loop #
## ---------------- ## # ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K): for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0) a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask = True) b = tl.load(ptrs_b, mask=True)
acc += tl.dot(a, b) acc += tl.dot(a, b)
pinc += 2 pinc += 2
inc_a = tl.load(pinc) inc_a = tl.load(pinc)
@@ -377,9 +383,9 @@ def _dds_kernel(
inc_a = inc_a * stride_ka inc_a = inc_a * stride_ka
ptrs_a += inc_a ptrs_a += inc_a
ptrs_b += inc_b ptrs_b += inc_b
## ---------------- ## # ---------------- #
## Epilogue ## # Epilogue #
## ---------------- ## # ---------------- #
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense) # initialize pointers to C (dense)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
@@ -389,9 +395,10 @@ def _dds_kernel(
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc + offs_cn[None, :] * stride_nc
# write back # write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:
@@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
c = out c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
# fmt: off
_dds_kernel[grid]( _dds_kernel[grid](
a, b, c, a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut, AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4, num_warps=4, GROUP_SIZE_M=4,
) )
return c return c
@@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
############## ##############
# MAIN API # # MAIN API #
############## ##############
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@@ -477,6 +485,7 @@ class _matmul(torch.autograd.Function):
None, None, None, None,\ None, None, None, None,\
None, None, None, None, None, dout None, None, None, None, None, dout
class matmul: class matmul:
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
@@ -503,7 +512,7 @@ class matmul:
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device) self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out = None): def __call__(self, a, b, out=None):
c = _matmul.apply( c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width, self.c_lut, self.c_width,

View File

@@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels):
asm = asm[:-2] + ";" asm = asm[:-2] + ";"
ctrl = parseCtrl(sline) ctrl = parseCtrl(sline)
# BRA target address # BRA target address
if BRA_RE.match(asm) != None: if BRA_RE.match(asm) is not None:
target = int(BRA_RE.match(asm).group(2), 16) target = int(BRA_RE.match(asm).group(2), 16)
if target in labels: if target in labels:
pass pass
@@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels):
def extract(file_path, fun): def extract(file_path, fun):
if fun == None: if fun is None:
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
else: else:
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
@@ -77,7 +77,7 @@ def extract(file_path, fun):
# /*0x...*/ # /*0x...*/
fname_match = FNAME_RE.match(line) fname_match = FNAME_RE.match(line)
# Looking for new function header (function: <name>) # Looking for new function header (function: <name>)
while FNAME_RE.match(line) == None: while FNAME_RE.match(line) is None:
line_idx += 1 line_idx += 1
if line_idx < len(sass_lines): if line_idx < len(sass_lines):
line = sass_lines[line_idx].decode() line = sass_lines[line_idx].decode()
@@ -94,7 +94,7 @@ def extract(file_path, fun):
# store sass asm in buffer and them print them (for labels) # store sass asm in buffer and them print them (for labels)
# (ctrl, asm) # (ctrl, asm)
asm_buffer = [] asm_buffer = []
while FLINE_RE.match(line) != None: while FLINE_RE.match(line) is not None:
# First line (Offset ASM Encoding) # First line (Offset ASM Encoding)
fline = sass_lines[line_idx].decode() fline = sass_lines[line_idx].decode()
line_idx += 1 line_idx += 1

View File

@@ -16,10 +16,11 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation: # Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch import torch
import triton
import triton.language as tl
@torch.jit.script @torch.jit.script
def naive_softmax(x): def naive_softmax(x):