[STYLE] add isort and autopep8 config files and check on CI (#423)

Also a fix a few more style issues from the "aggressive" mode of autopep8.
This commit is contained in:
Madeleine Thompson
2022-01-07 13:11:34 -08:00
committed by GitHub
parent 9801aa7b56
commit a70acfec77
11 changed files with 102 additions and 77 deletions

View File

@@ -30,6 +30,12 @@ jobs:
cd python cd python
pip3 install -e '.[tests]' pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Unit tests - name: Unit tests
run: | run: |
cd python/test/unit cd python/test/unit

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,2 +1,5 @@
[metadata] [metadata]
description_file = README.md description_file = README.md
[pycodestyle]
ignore = E501,E701,E731

View File

@@ -94,7 +94,7 @@ class CMakeBuild(build_ext):
"-DBUILD_PYTHON_MODULE=ON", "-DBUILD_PYTHON_MODULE=ON",
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable, # '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
@@ -148,6 +148,8 @@ setup(
], ],
extras_require={ extras_require={
"tests": [ "tests": [
"autopep8",
"isort",
"numpy", "numpy",
"pytest", "pytest",
"scipy>=1.7.1", "scipy>=1.7.1",

View File

@@ -1,4 +1,3 @@
import triton.language as tl
import subprocess import subprocess
import sys import sys
@@ -7,6 +6,7 @@ import torch
from numpy import record from numpy import record
import triton import triton
import triton.language as tl
####################### #######################
# Utilities # Utilities

View File

@@ -1,4 +1,4 @@
# version """isort:skip_file"""
__version__ = '2.0.0' __version__ = '2.0.0'
# TODO: torch needs to be imported first # TODO: torch needs to be imported first

View File

@@ -852,7 +852,7 @@ class Autotuner:
else: else:
config = self.configs[0] config = self.configs[0]
self.best_config = config self.best_config = config
if config.pre_hook != None: if config.pre_hook is not None:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

View File

@@ -293,7 +293,7 @@ class block:
dst_shape = [] dst_shape = []
curr = 0 curr = 0
for sl in slices: for sl in slices:
if sl == None: if sl is None:
dst_shape.append(1) dst_shape.append(1)
elif sl == slice(None, None, None): elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr]) dst_shape.append(src_shape[curr])

View File

@@ -26,9 +26,9 @@ def _sdd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr BLOCK: tl.constexpr, EVEN_K: tl.constexpr
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
block_id = tl.program_id(1) + grid_offset block_id = tl.program_id(1) + grid_offset
lut += block_id * 3 lut += block_id * 3
# offsets # offsets
@@ -39,21 +39,23 @@ def _sdd_kernel(
start_am = tl.load(lut + 1) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za a_ptrs = A \
+ off_h * stride_ha + off_z * stride_za \
+ offs_am[:, None] * stride_ma + off_h * stride_ha \
+ offs_ak[None, :] * stride_ak) + offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(lut + 2) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb b_ptrs = B \
+ off_h * stride_hb + off_z * stride_zb \
+ offs_bn[None, :] * stride_nb + off_h * stride_hb \
+ offs_bk[:, None] * stride_bk) + offs_bn[None, :] * stride_nb \
## ---------------- ## + offs_bk[:, None] * stride_bk
## Inner Loop ## # ---------------- #
## ---------------- ## # Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K): for k in range(K, 0, -TILE_K):
if EVEN_K: if EVEN_K:
@@ -66,15 +68,16 @@ def _sdd_kernel(
a_ptrs += TILE_K * stride_ak a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
## ---------------- ## # ---------------- #
## Epilogue ## # Epilogue #
## ---------------- ## # ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc pc = C \
+ block_id * stride_hc + off_z * stride_zc \
+ offs_cm[:, None] * stride_mc + block_id * stride_hc \
+ offs_cn[None, :] * stride_nc) + offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
@@ -134,9 +137,9 @@ def _dsd_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
@@ -168,9 +171,9 @@ def _dsd_kernel(
+ off_h * stride_hb \ + off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## # ---------------- #
## Inner Loop ## # Inner Loop #
## ---------------- ## # ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2 pinc += 2
inc_a = tl.load(pinc + 1) inc_a = tl.load(pinc + 1)
@@ -192,7 +195,8 @@ def _dsd_kernel(
# initialize pointers to C # initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M) offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \ pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \ + pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \ + offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
@@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
TILE_N = 128 TILE_N = 128
# compute output # compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
# fmt: off
_dsd_kernel[grid]( _dsd_kernel[grid](
a, b, c, a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut, BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4, num_warps=4, GROUP_SIZE_M=4,
) )
# exit() # exit()
return c return c
def dsd_lut(layout, block, step, trans, device): def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1) sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True) head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten() sizes = sizes.flatten()
segments = sizes*step segments = sizes * step
# pointer increments # pointer increments
if trans: if trans:
nnz = layout.nonzero(as_tuple=False) nnz = layout.nonzero(as_tuple=False)
@@ -301,13 +305,13 @@ def dsd_lut(layout, block, step, trans, device):
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1) A_incs = A_incs.view(-1)
# create header # create header
width = col_id.size(0) width = col_id.size(0)
offsets = offsets*2*div + 4*width offsets = offsets * 2 * div + 4 * width
segments = segments*div segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments # create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut # create lut
lut = torch.cat((header, incs)) lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device) lut = lut.type(torch.int32).to(device)
@@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device):
# ----------------------------- # -----------------------------
# Dense = Dense x Sparse (DDS) # Dense = Dense x Sparse (DDS)
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dds_kernel( def _dds_kernel(
A, B, C, A, B, C,
@@ -327,9 +333,9 @@ def _dds_kernel(
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
): ):
#------------# # ------------ #
#- Prologue -# # - Prologue - #
#------------# # ------------ #
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
@@ -343,31 +349,31 @@ def _dds_kernel(
off_h = tl.load(header + 3) off_h = tl.load(header + 3)
pinc = lut + offset pinc = lut + offset
# initialize pointers to A (dense) # initialize pointers to A (dense)
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc) start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8) start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K) offs_ak = start_ak + tl.arange(0, TILE_K)
ptrs_a = A + pid_z * stride_za \ ptrs_a = A + pid_z * stride_za \
+ off_h * stride_ha \ + off_h * stride_ha \
+ offs_am[:, None] * stride_ma \ + offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka + offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse) # initialize pointers to B (sparse)
block_id = tl.load(pinc + 1) block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N) offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \ ptrs_b = B + pid_z * stride_zb \
+ block_id * stride_hb \ + block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## # ---------------- #
## Inner Loop ## # Inner Loop #
## ---------------- ## # ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K): for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0) a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask = True) b = tl.load(ptrs_b, mask=True)
acc += tl.dot(a, b) acc += tl.dot(a, b)
pinc += 2 pinc += 2
inc_a = tl.load(pinc) inc_a = tl.load(pinc)
@@ -377,21 +383,22 @@ def _dds_kernel(
inc_a = inc_a * stride_ka inc_a = inc_a * stride_ka
ptrs_a += inc_a ptrs_a += inc_a
ptrs_b += inc_b ptrs_b += inc_b
## ---------------- ## # ---------------- #
## Epilogue ## # Epilogue #
## ---------------- ## # ---------------- #
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense) # initialize pointers to C (dense)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N) offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \ ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \ + pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc + offs_cn[None, :] * stride_nc
# write back # write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:
@@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
c = out c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
# fmt: off
_dds_kernel[grid]( _dds_kernel[grid](
a, b, c, a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut, AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4, num_warps=4, GROUP_SIZE_M=4,
) )
return c return c
@@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
############## ##############
# MAIN API # # MAIN API #
############## ##############
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@@ -474,8 +482,9 @@ class _matmul(torch.autograd.Function):
) )
dout = dc if ctx.has_out else None dout = dc if ctx.has_out else None
return da, db, None, None, None,\ return da, db, None, None, None,\
None, None, None, None,\ None, None, None, None,\
None, None, None, None, None, dout None, None, None, None, None, dout
class matmul: class matmul:
@@ -499,11 +508,11 @@ class matmul:
self.da_lut, self.da_width = sdd_lut(layout, block, device) self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds': if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device) self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out = None): def __call__(self, a, b, out=None):
c = _matmul.apply( c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width, self.c_lut, self.c_width,

View File

@@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels):
asm = asm[:-2] + ";" asm = asm[:-2] + ";"
ctrl = parseCtrl(sline) ctrl = parseCtrl(sline)
# BRA target address # BRA target address
if BRA_RE.match(asm) != None: if BRA_RE.match(asm) is not None:
target = int(BRA_RE.match(asm).group(2), 16) target = int(BRA_RE.match(asm).group(2), 16)
if target in labels: if target in labels:
pass pass
@@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels):
def extract(file_path, fun): def extract(file_path, fun):
if fun == None: if fun is None:
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
else: else:
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
@@ -77,7 +77,7 @@ def extract(file_path, fun):
# /*0x...*/ # /*0x...*/
fname_match = FNAME_RE.match(line) fname_match = FNAME_RE.match(line)
# Looking for new function header (function: <name>) # Looking for new function header (function: <name>)
while FNAME_RE.match(line) == None: while FNAME_RE.match(line) is None:
line_idx += 1 line_idx += 1
if line_idx < len(sass_lines): if line_idx < len(sass_lines):
line = sass_lines[line_idx].decode() line = sass_lines[line_idx].decode()
@@ -94,7 +94,7 @@ def extract(file_path, fun):
# store sass asm in buffer and them print them (for labels) # store sass asm in buffer and them print them (for labels)
# (ctrl, asm) # (ctrl, asm)
asm_buffer = [] asm_buffer = []
while FLINE_RE.match(line) != None: while FLINE_RE.match(line) is not None:
# First line (Offset ASM Encoding) # First line (Offset ASM Encoding)
fline = sass_lines[line_idx].decode() fline = sass_lines[line_idx].decode()
line_idx += 1 line_idx += 1

View File

@@ -16,10 +16,11 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation: # Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch import torch
import triton
import triton.language as tl
@torch.jit.script @torch.jit.script
def naive_softmax(x): def naive_softmax(x):