From a70acfec771d5e1ad9b4df8baa50166e99954e32 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 7 Jan 2022 13:11:34 -0800 Subject: [PATCH] [STYLE] add isort and autopep8 config files and check on CI (#423) Also a fix a few more style issues from the "aggressive" mode of autopep8. --- .github/workflows/integration-tests.yml | 6 + .isort.cfg | 4 + python/setup.cfg | 3 + python/setup.py | 4 +- python/test/regression/test_performance.py | 2 +- python/triton/__init__.py | 2 +- python/triton/code_gen.py | 2 +- python/triton/language/core.py | 2 +- python/triton/ops/blocksparse/matmul.py | 141 +++++++++++---------- python/triton/tools/disasm.py | 8 +- python/tutorials/02-fused-softmax.py | 5 +- 11 files changed, 102 insertions(+), 77 deletions(-) create mode 100644 .isort.cfg diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d99e95dc7..c01a16de1 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -30,6 +30,12 @@ jobs: cd python pip3 install -e '.[tests]' + - name: Check imports + run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )" + + - name: Check style + run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )" + - name: Unit tests run: | cd python/test/unit diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..833801cca --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,4 @@ +[settings] +known_local_folder=triton +line_length=88 +py_version=36 diff --git a/python/setup.cfg b/python/setup.cfg index 08aedd7e6..9d24c7de7 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -1,2 +1,5 @@ [metadata] description_file = README.md + +[pycodestyle] +ignore = E501,E701,E731 diff --git a/python/setup.py b/python/setup.py index 1171ad0a8..1cc2ea103 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,7 +94,7 @@ class CMakeBuild(build_ext): "-DBUILD_PYTHON_MODULE=ON", "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir, - #'-DPYTHON_EXECUTABLE=' + sys.executable, + # '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) @@ -148,6 +148,8 @@ setup( ], extras_require={ "tests": [ + "autopep8", + "isort", "numpy", "pytest", "scipy>=1.7.1", diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index f6e7ec237..84e829aa8 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,4 +1,3 @@ -import triton.language as tl import subprocess import sys @@ -7,6 +6,7 @@ import torch from numpy import record import triton +import triton.language as tl ####################### # Utilities diff --git a/python/triton/__init__.py b/python/triton/__init__.py index c079880e9..b4a92a8f8 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,4 +1,4 @@ -# version +"""isort:skip_file""" __version__ = '2.0.0' # TODO: torch needs to be imported first diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 439c1798e..af95bf280 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -852,7 +852,7 @@ class Autotuner: else: config = self.configs[0] self.best_config = config - if config.pre_hook != None: + if config.pre_hook is not None: config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6895c101c..4f63b33bc 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -293,7 +293,7 @@ class block: dst_shape = [] curr = 0 for sl in slices: - if sl == None: + if sl is None: dst_shape.append(1) elif sl == slice(None, None, None): dst_shape.append(src_shape[curr]) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 15e6c0523..9a04ded66 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -26,9 +26,9 @@ def _sdd_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # block_id = tl.program_id(1) + grid_offset lut += block_id * 3 # offsets @@ -39,21 +39,23 @@ def _sdd_kernel( start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + (off_z * stride_za - + off_h * stride_ha - + offs_am[:, None] * stride_ma - + offs_ak[None, :] * stride_ak) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak # initialize pointers to B start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + (off_z * stride_zb - + off_h * stride_hb - + offs_bn[None, :] * stride_nb - + offs_bk[:, None] * stride_bk) - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(K, 0, -TILE_K): if EVEN_K: @@ -66,15 +68,16 @@ def _sdd_kernel( a_ptrs += TILE_K * stride_ak b_ptrs += TILE_K * stride_bk c = acc.to(C.dtype.element_ty) - ## ---------------- ## - ## Epilogue ## - ## ---------------- ## + # ---------------- # + # Epilogue # + # ---------------- # offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + (off_z * stride_zc - + block_id * stride_hc - + offs_cm[:, None] * stride_mc - + offs_cn[None, :] * stride_nc) + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc tl.store(pc, c, mask=True) @@ -134,9 +137,9 @@ def _dsd_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # pid_m = tl.program_id(0) pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) @@ -168,9 +171,9 @@ def _dsd_kernel( + off_h * stride_hb \ + offs_bn[None, :] * stride_bn \ + offs_bk[:, None] * stride_bk - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) pinc += 2 inc_a = tl.load(pinc + 1) @@ -192,7 +195,8 @@ def _dsd_kernel( # initialize pointers to C offs_cm = column * TILE_M + tl.arange(0, TILE_M) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) - pc = C + off_h * stride_hc \ + pc = C \ + + off_h * stride_hc \ + pidz * stride_zc \ + offs_cm[:, None] * stride_cm \ + offs_cn[None, :] * stride_cn @@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N TILE_N = 128 # compute output grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] - # fmt: off _dsd_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), BS3, AS1, lut, - TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, num_warps=4, GROUP_SIZE_M=4, ) # exit() return c + def dsd_lut(layout, block, step, trans, device): sizes = torch.sum(layout, 2 if trans else 1) head_id, col_id = sizes.nonzero(as_tuple=True) sizes = sizes.flatten() - segments = sizes*step + segments = sizes * step # pointer increments if trans: nnz = layout.nonzero(as_tuple=False) @@ -301,13 +305,13 @@ def dsd_lut(layout, block, step, trans, device): A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] A_incs = A_incs.view(-1) # create header - width = col_id.size(0) - offsets = offsets*2*div + 4*width - segments = segments*div - header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() # create increments - incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) # create lut lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device) @@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- + + @triton.jit def _dds_kernel( A, B, C, @@ -327,9 +333,9 @@ def _dds_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # pid_m = tl.program_id(0) pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) @@ -343,31 +349,31 @@ def _dds_kernel( off_h = tl.load(header + 3) pinc = lut + offset # initialize pointers to A (dense) - offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) + offs_am = pid_m * TILE_M + tl.arange(0, TILE_M) offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) start_ak = tl.load(pinc) start_ak = tl.multiple_of(start_ak, 8) offs_ak = start_ak + tl.arange(0, TILE_K) ptrs_a = A + pid_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ka + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ka # initialize pointers to B (sparse) block_id = tl.load(pinc + 1) block_id = tl.multiple_of(block_id, 8) offs_bn = tl.arange(0, TILE_N) offs_bk = tl.arange(0, TILE_K) ptrs_b = B + pid_z * stride_zb \ - + block_id * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + + block_id * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(AS1, 0, -TILE_K): - a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0) - b = tl.load(ptrs_b, mask = True) + a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0) + b = tl.load(ptrs_b, mask=True) acc += tl.dot(a, b) pinc += 2 inc_a = tl.load(pinc) @@ -377,21 +383,22 @@ def _dds_kernel( inc_a = inc_a * stride_ka ptrs_a += inc_a ptrs_b += inc_b - ## ---------------- ## - ## Epilogue ## - ## ---------------- ## + # ---------------- # + # Epilogue # + # ---------------- # c = acc.to(C.dtype.element_ty) # initialize pointers to C (dense) offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cn = column * TILE_N + tl.arange(0, TILE_N) ptrs_c = C + off_h * stride_hc \ - + pid_z * stride_zc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc + + pid_z * stride_zc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc # write back - tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) + tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0) -def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = c = out TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] - # fmt: off _dds_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), AS2, BS2, lut, - TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4, num_warps=4, GROUP_SIZE_M=4, ) return c @@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = ############## # MAIN API # ############## + + class _matmul(torch.autograd.Function): fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} @@ -474,8 +482,9 @@ class _matmul(torch.autograd.Function): ) dout = dc if ctx.has_out else None return da, db, None, None, None,\ - None, None, None, None,\ - None, None, None, None, None, dout + None, None, None, None,\ + None, None, None, None, None, dout + class matmul: @@ -499,11 +508,11 @@ class matmul: self.da_lut, self.da_width = sdd_lut(layout, block, device) self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) if self.mode == 'dds': - self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) self.db_lut, self.db_width = sdd_lut(layout, block, device) - def __call__(self, a, b, out = None): + def __call__(self, a, b, out=None): c = _matmul.apply( a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, self.c_lut, self.c_width, diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 3b443c690..b030e72ec 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels): asm = asm[:-2] + ";" ctrl = parseCtrl(sline) # BRA target address - if BRA_RE.match(asm) != None: + if BRA_RE.match(asm) is not None: target = int(BRA_RE.match(asm).group(2), 16) if target in labels: pass @@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels): def extract(file_path, fun): - if fun == None: + if fun is None: sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) else: sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) @@ -77,7 +77,7 @@ def extract(file_path, fun): # /*0x...*/ fname_match = FNAME_RE.match(line) # Looking for new function header (function: ) - while FNAME_RE.match(line) == None: + while FNAME_RE.match(line) is None: line_idx += 1 if line_idx < len(sass_lines): line = sass_lines[line_idx].decode() @@ -94,7 +94,7 @@ def extract(file_path, fun): # store sass asm in buffer and them print them (for labels) # (ctrl, asm) asm_buffer = [] - while FLINE_RE.match(line) != None: + while FLINE_RE.match(line) is not None: # First line (Offset ASM Encoding) fline = sass_lines[line_idx].decode() line_idx += 1 diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index e5559ca7f..7af24e18d 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -16,10 +16,11 @@ You will learn about: # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Let us consider instead the case of a simple (numerically stabilized) softmax operation: -import triton.language as tl -import triton import torch +import triton +import triton.language as tl + @torch.jit.script def naive_softmax(x):