[STYLE] add isort and autopep8 config files and check on CI (#423)
Also a fix a few more style issues from the "aggressive" mode of autopep8.
This commit is contained in:
committed by
GitHub
parent
9801aa7b56
commit
a70acfec77
6
.github/workflows/integration-tests.yml
vendored
6
.github/workflows/integration-tests.yml
vendored
@@ -30,6 +30,12 @@ jobs:
|
|||||||
cd python
|
cd python
|
||||||
pip3 install -e '.[tests]'
|
pip3 install -e '.[tests]'
|
||||||
|
|
||||||
|
- name: Check imports
|
||||||
|
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
|
||||||
|
|
||||||
|
- name: Check style
|
||||||
|
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
|
||||||
|
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
run: |
|
run: |
|
||||||
cd python/test/unit
|
cd python/test/unit
|
||||||
|
4
.isort.cfg
Normal file
4
.isort.cfg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[settings]
|
||||||
|
known_local_folder=triton
|
||||||
|
line_length=88
|
||||||
|
py_version=36
|
@@ -1,2 +1,5 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
description_file = README.md
|
description_file = README.md
|
||||||
|
|
||||||
|
[pycodestyle]
|
||||||
|
ignore = E501,E701,E731
|
||||||
|
@@ -94,7 +94,7 @@ class CMakeBuild(build_ext):
|
|||||||
"-DBUILD_PYTHON_MODULE=ON",
|
"-DBUILD_PYTHON_MODULE=ON",
|
||||||
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
|
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
|
||||||
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
|
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
|
||||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||||
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
||||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
||||||
@@ -148,6 +148,8 @@ setup(
|
|||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
"tests": [
|
"tests": [
|
||||||
|
"autopep8",
|
||||||
|
"isort",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pytest",
|
"pytest",
|
||||||
"scipy>=1.7.1",
|
"scipy>=1.7.1",
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import triton.language as tl
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -7,6 +6,7 @@ import torch
|
|||||||
from numpy import record
|
from numpy import record
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
#######################
|
#######################
|
||||||
# Utilities
|
# Utilities
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# version
|
"""isort:skip_file"""
|
||||||
__version__ = '2.0.0'
|
__version__ = '2.0.0'
|
||||||
|
|
||||||
# TODO: torch needs to be imported first
|
# TODO: torch needs to be imported first
|
||||||
|
@@ -852,7 +852,7 @@ class Autotuner:
|
|||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
self.best_config = config
|
self.best_config = config
|
||||||
if config.pre_hook != None:
|
if config.pre_hook is not None:
|
||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
@@ -293,7 +293,7 @@ class block:
|
|||||||
dst_shape = []
|
dst_shape = []
|
||||||
curr = 0
|
curr = 0
|
||||||
for sl in slices:
|
for sl in slices:
|
||||||
if sl == None:
|
if sl is None:
|
||||||
dst_shape.append(1)
|
dst_shape.append(1)
|
||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
dst_shape.append(src_shape[curr])
|
dst_shape.append(src_shape[curr])
|
||||||
|
@@ -26,9 +26,9 @@ def _sdd_kernel(
|
|||||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||||
):
|
):
|
||||||
#------------#
|
# ------------ #
|
||||||
#- Prologue -#
|
# - Prologue - #
|
||||||
#------------#
|
# ------------ #
|
||||||
block_id = tl.program_id(1) + grid_offset
|
block_id = tl.program_id(1) + grid_offset
|
||||||
lut += block_id * 3
|
lut += block_id * 3
|
||||||
# offsets
|
# offsets
|
||||||
@@ -39,21 +39,23 @@ def _sdd_kernel(
|
|||||||
start_am = tl.load(lut + 1)
|
start_am = tl.load(lut + 1)
|
||||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||||
offs_ak = tl.arange(0, TILE_K)
|
offs_ak = tl.arange(0, TILE_K)
|
||||||
a_ptrs = A + (off_z * stride_za
|
a_ptrs = A \
|
||||||
+ off_h * stride_ha
|
+ off_z * stride_za \
|
||||||
+ offs_am[:, None] * stride_ma
|
+ off_h * stride_ha \
|
||||||
+ offs_ak[None, :] * stride_ak)
|
+ offs_am[:, None] * stride_ma \
|
||||||
|
+ offs_ak[None, :] * stride_ak
|
||||||
# initialize pointers to B
|
# initialize pointers to B
|
||||||
start_bn = tl.load(lut + 2)
|
start_bn = tl.load(lut + 2)
|
||||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||||
offs_bk = tl.arange(0, TILE_K)
|
offs_bk = tl.arange(0, TILE_K)
|
||||||
b_ptrs = B + (off_z * stride_zb
|
b_ptrs = B \
|
||||||
+ off_h * stride_hb
|
+ off_z * stride_zb \
|
||||||
+ offs_bn[None, :] * stride_nb
|
+ off_h * stride_hb \
|
||||||
+ offs_bk[:, None] * stride_bk)
|
+ offs_bn[None, :] * stride_nb \
|
||||||
## ---------------- ##
|
+ offs_bk[:, None] * stride_bk
|
||||||
## Inner Loop ##
|
# ---------------- #
|
||||||
## ---------------- ##
|
# Inner Loop #
|
||||||
|
# ---------------- #
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -TILE_K):
|
for k in range(K, 0, -TILE_K):
|
||||||
if EVEN_K:
|
if EVEN_K:
|
||||||
@@ -66,15 +68,16 @@ def _sdd_kernel(
|
|||||||
a_ptrs += TILE_K * stride_ak
|
a_ptrs += TILE_K * stride_ak
|
||||||
b_ptrs += TILE_K * stride_bk
|
b_ptrs += TILE_K * stride_bk
|
||||||
c = acc.to(C.dtype.element_ty)
|
c = acc.to(C.dtype.element_ty)
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
## Epilogue ##
|
# Epilogue #
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||||
pc = C + (off_z * stride_zc
|
pc = C \
|
||||||
+ block_id * stride_hc
|
+ off_z * stride_zc \
|
||||||
+ offs_cm[:, None] * stride_mc
|
+ block_id * stride_hc \
|
||||||
+ offs_cn[None, :] * stride_nc)
|
+ offs_cm[:, None] * stride_mc \
|
||||||
|
+ offs_cn[None, :] * stride_nc
|
||||||
tl.store(pc, c, mask=True)
|
tl.store(pc, c, mask=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -134,9 +137,9 @@ def _dsd_kernel(
|
|||||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||||
):
|
):
|
||||||
#------------#
|
# ------------ #
|
||||||
#- Prologue -#
|
# - Prologue - #
|
||||||
#------------#
|
# ------------ #
|
||||||
pid_m = tl.program_id(0)
|
pid_m = tl.program_id(0)
|
||||||
pid_n = tl.program_id(1)
|
pid_n = tl.program_id(1)
|
||||||
num_pid_m = tl.num_programs(0)
|
num_pid_m = tl.num_programs(0)
|
||||||
@@ -168,9 +171,9 @@ def _dsd_kernel(
|
|||||||
+ off_h * stride_hb \
|
+ off_h * stride_hb \
|
||||||
+ offs_bn[None, :] * stride_bn \
|
+ offs_bn[None, :] * stride_bn \
|
||||||
+ offs_bk[:, None] * stride_bk
|
+ offs_bk[:, None] * stride_bk
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
## Inner Loop ##
|
# Inner Loop #
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||||
pinc += 2
|
pinc += 2
|
||||||
inc_a = tl.load(pinc + 1)
|
inc_a = tl.load(pinc + 1)
|
||||||
@@ -192,7 +195,8 @@ def _dsd_kernel(
|
|||||||
# initialize pointers to C
|
# initialize pointers to C
|
||||||
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
||||||
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||||
pc = C + off_h * stride_hc \
|
pc = C \
|
||||||
|
+ off_h * stride_hc \
|
||||||
+ pidz * stride_zc \
|
+ pidz * stride_zc \
|
||||||
+ offs_cm[:, None] * stride_cm \
|
+ offs_cm[:, None] * stride_cm \
|
||||||
+ offs_cn[None, :] * stride_cn
|
+ offs_cn[None, :] * stride_cn
|
||||||
@@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
|||||||
TILE_N = 128
|
TILE_N = 128
|
||||||
# compute output
|
# compute output
|
||||||
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||||
# fmt: off
|
|
||||||
_dsd_kernel[grid](
|
_dsd_kernel[grid](
|
||||||
a, b, c,
|
a, b, c,
|
||||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||||
BS3, AS1, lut,
|
BS3, AS1, lut,
|
||||||
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
|
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||||
num_warps=4, GROUP_SIZE_M=4,
|
num_warps=4, GROUP_SIZE_M=4,
|
||||||
)
|
)
|
||||||
# exit()
|
# exit()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def dsd_lut(layout, block, step, trans, device):
|
def dsd_lut(layout, block, step, trans, device):
|
||||||
sizes = torch.sum(layout, 2 if trans else 1)
|
sizes = torch.sum(layout, 2 if trans else 1)
|
||||||
head_id, col_id = sizes.nonzero(as_tuple=True)
|
head_id, col_id = sizes.nonzero(as_tuple=True)
|
||||||
sizes = sizes.flatten()
|
sizes = sizes.flatten()
|
||||||
segments = sizes*step
|
segments = sizes * step
|
||||||
# pointer increments
|
# pointer increments
|
||||||
if trans:
|
if trans:
|
||||||
nnz = layout.nonzero(as_tuple=False)
|
nnz = layout.nonzero(as_tuple=False)
|
||||||
@@ -302,8 +306,8 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
A_incs = A_incs.view(-1)
|
A_incs = A_incs.view(-1)
|
||||||
# create header
|
# create header
|
||||||
width = col_id.size(0)
|
width = col_id.size(0)
|
||||||
offsets = offsets*2*div + 4*width
|
offsets = offsets * 2 * div + 4 * width
|
||||||
segments = segments*div
|
segments = segments * div
|
||||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||||
# create increments
|
# create increments
|
||||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||||
@@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Dense = Dense x Sparse (DDS)
|
# Dense = Dense x Sparse (DDS)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _dds_kernel(
|
def _dds_kernel(
|
||||||
A, B, C,
|
A, B, C,
|
||||||
@@ -327,9 +333,9 @@ def _dds_kernel(
|
|||||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
||||||
):
|
):
|
||||||
#------------#
|
# ------------ #
|
||||||
#- Prologue -#
|
# - Prologue - #
|
||||||
#------------#
|
# ------------ #
|
||||||
pid_m = tl.program_id(0)
|
pid_m = tl.program_id(0)
|
||||||
pid_n = tl.program_id(1)
|
pid_n = tl.program_id(1)
|
||||||
num_pid_m = tl.num_programs(0)
|
num_pid_m = tl.num_programs(0)
|
||||||
@@ -343,7 +349,7 @@ def _dds_kernel(
|
|||||||
off_h = tl.load(header + 3)
|
off_h = tl.load(header + 3)
|
||||||
pinc = lut + offset
|
pinc = lut + offset
|
||||||
# initialize pointers to A (dense)
|
# initialize pointers to A (dense)
|
||||||
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M)
|
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
||||||
start_ak = tl.load(pinc)
|
start_ak = tl.load(pinc)
|
||||||
start_ak = tl.multiple_of(start_ak, 8)
|
start_ak = tl.multiple_of(start_ak, 8)
|
||||||
@@ -361,13 +367,13 @@ def _dds_kernel(
|
|||||||
+ block_id * stride_hb \
|
+ block_id * stride_hb \
|
||||||
+ offs_bn[None, :] * stride_bn \
|
+ offs_bn[None, :] * stride_bn \
|
||||||
+ offs_bk[:, None] * stride_bk
|
+ offs_bk[:, None] * stride_bk
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
## Inner Loop ##
|
# Inner Loop #
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||||
for k in range(AS1, 0, -TILE_K):
|
for k in range(AS1, 0, -TILE_K):
|
||||||
a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0)
|
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
|
||||||
b = tl.load(ptrs_b, mask = True)
|
b = tl.load(ptrs_b, mask=True)
|
||||||
acc += tl.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
pinc += 2
|
pinc += 2
|
||||||
inc_a = tl.load(pinc)
|
inc_a = tl.load(pinc)
|
||||||
@@ -377,9 +383,9 @@ def _dds_kernel(
|
|||||||
inc_a = inc_a * stride_ka
|
inc_a = inc_a * stride_ka
|
||||||
ptrs_a += inc_a
|
ptrs_a += inc_a
|
||||||
ptrs_b += inc_b
|
ptrs_b += inc_b
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
## Epilogue ##
|
# Epilogue #
|
||||||
## ---------------- ##
|
# ---------------- #
|
||||||
c = acc.to(C.dtype.element_ty)
|
c = acc.to(C.dtype.element_ty)
|
||||||
# initialize pointers to C (dense)
|
# initialize pointers to C (dense)
|
||||||
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||||
@@ -389,9 +395,10 @@ def _dds_kernel(
|
|||||||
+ offs_cm[:, None] * stride_mc \
|
+ offs_cm[:, None] * stride_mc \
|
||||||
+ offs_cn[None, :] * stride_nc
|
+ offs_cn[None, :] * stride_nc
|
||||||
# write back
|
# write back
|
||||||
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
|
||||||
|
|
||||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
|
||||||
|
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||||
@@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
|
|||||||
c = out
|
c = out
|
||||||
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
||||||
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
||||||
# fmt: off
|
|
||||||
_dds_kernel[grid](
|
_dds_kernel[grid](
|
||||||
a, b, c,
|
a, b, c,
|
||||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||||
AS2, BS2, lut,
|
AS2, BS2, lut,
|
||||||
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
|
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||||
num_warps=4, GROUP_SIZE_M=4,
|
num_warps=4, GROUP_SIZE_M=4,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
@@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
|
|||||||
##############
|
##############
|
||||||
# MAIN API #
|
# MAIN API #
|
||||||
##############
|
##############
|
||||||
|
|
||||||
|
|
||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
|
|
||||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||||
@@ -477,6 +485,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
None, None, None, None,\
|
None, None, None, None,\
|
||||||
None, None, None, None, None, dout
|
None, None, None, None, None, dout
|
||||||
|
|
||||||
|
|
||||||
class matmul:
|
class matmul:
|
||||||
|
|
||||||
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
||||||
@@ -503,7 +512,7 @@ class matmul:
|
|||||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
||||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||||
|
|
||||||
def __call__(self, a, b, out = None):
|
def __call__(self, a, b, out=None):
|
||||||
c = _matmul.apply(
|
c = _matmul.apply(
|
||||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||||
self.c_lut, self.c_width,
|
self.c_lut, self.c_width,
|
||||||
|
@@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels):
|
|||||||
asm = asm[:-2] + ";"
|
asm = asm[:-2] + ";"
|
||||||
ctrl = parseCtrl(sline)
|
ctrl = parseCtrl(sline)
|
||||||
# BRA target address
|
# BRA target address
|
||||||
if BRA_RE.match(asm) != None:
|
if BRA_RE.match(asm) is not None:
|
||||||
target = int(BRA_RE.match(asm).group(2), 16)
|
target = int(BRA_RE.match(asm).group(2), 16)
|
||||||
if target in labels:
|
if target in labels:
|
||||||
pass
|
pass
|
||||||
@@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels):
|
|||||||
|
|
||||||
|
|
||||||
def extract(file_path, fun):
|
def extract(file_path, fun):
|
||||||
if fun == None:
|
if fun is None:
|
||||||
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
|
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
|
||||||
else:
|
else:
|
||||||
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
|
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
|
||||||
@@ -77,7 +77,7 @@ def extract(file_path, fun):
|
|||||||
# /*0x...*/
|
# /*0x...*/
|
||||||
fname_match = FNAME_RE.match(line)
|
fname_match = FNAME_RE.match(line)
|
||||||
# Looking for new function header (function: <name>)
|
# Looking for new function header (function: <name>)
|
||||||
while FNAME_RE.match(line) == None:
|
while FNAME_RE.match(line) is None:
|
||||||
line_idx += 1
|
line_idx += 1
|
||||||
if line_idx < len(sass_lines):
|
if line_idx < len(sass_lines):
|
||||||
line = sass_lines[line_idx].decode()
|
line = sass_lines[line_idx].decode()
|
||||||
@@ -94,7 +94,7 @@ def extract(file_path, fun):
|
|||||||
# store sass asm in buffer and them print them (for labels)
|
# store sass asm in buffer and them print them (for labels)
|
||||||
# (ctrl, asm)
|
# (ctrl, asm)
|
||||||
asm_buffer = []
|
asm_buffer = []
|
||||||
while FLINE_RE.match(line) != None:
|
while FLINE_RE.match(line) is not None:
|
||||||
# First line (Offset ASM Encoding)
|
# First line (Offset ASM Encoding)
|
||||||
fline = sass_lines[line_idx].decode()
|
fline = sass_lines[line_idx].decode()
|
||||||
line_idx += 1
|
line_idx += 1
|
||||||
|
@@ -16,10 +16,11 @@ You will learn about:
|
|||||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||||
|
|
||||||
import triton.language as tl
|
|
||||||
import triton
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def naive_softmax(x):
|
def naive_softmax(x):
|
||||||
|
Reference in New Issue
Block a user