Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -54,19 +54,13 @@ enum asm_mode_t {
ASM_NV_SASS
};
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
size_t num_warps;
int num_warps;
};
@@ -111,12 +105,14 @@ public:
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
private:
static void do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f);
public:
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
function(const std::string& src, const options_t& opt, driver::device *device,
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
// auto-tuning
@@ -126,7 +122,7 @@ public:
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
private:
void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
private:
std::vector<kernel_pair_t> kernels_;

View File

@@ -224,7 +224,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
}
std::string kernel::get_asm(asm_mode_t mode) {
@@ -282,35 +282,35 @@ void function::do_loop_nest(std::vector<size_t> const & ranges,
return;
values[i--] = 0;
}
i = D - 1;
i = D - 1; options_t opt;
}
}
void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) {
// all ranges
std::vector<size_t> ranges;
ranges.push_back(opts.num_warps.size());
for(const auto& x: opts.defines)
ranges.push_back(x.second.size());
// functor for source with given option
void function::init_kernels(const std::string& src, const options_t& opt,
const autotune_vals_t& confs, driver::device *device) {
// list of all possible configs
// just augment `opt` with each define of `confs`
// and override warp count
size_t num_opts = std::max(confs.size(), (size_t)1);
std::vector<options_t> opts(num_opts, opt);
for(size_t i = 0; i < confs.size(); i++){
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
opts[i].num_warps = confs[i].second;
}
// compile all possible configs
// compilation errors (e.g., too much shared mem)
// will populate `err`
std::vector<std::pair<options_t, std::string>> err;
auto do_make = [&](std::vector<size_t> params) {
// compilation options
unsigned i = 0;
options_t opt;
opt.num_warps = opts.num_warps[params[i++]];
for(auto D: opts.defines)
opt.defines[D.first] = D.second[params[i++]];
// compile
for(const options_t& opt: opts) {
try{
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
}catch(const exception::base& e){
err.push_back({opt, e.what()});
}
};
// multi-threaded compilation
do_loop_nest(ranges, do_make);
}
// throw an exception if `err` is not empty
if(kernels_.empty()){
std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
@@ -357,9 +357,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
return it->second;
}
function::function(const std::string& src, const options_space_t& opt,
driver::device *device, const std::vector<std::string>& autotune_key) {
init_kernels(src, opt, device);
function::function(const std::string& src, const options_t &opt, driver::device *device,
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
// pre-compile all kernels
init_kernels(src, opt, autotune_vals, device);
// find indices of autotune keys
auto arg_names = kernels_.at(0).second->get_arg_names();
for(const std::string& name: autotune_key){
auto it = std::find(arg_names.begin(), arg_names.end(), name);

View File

@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt,
const rt::options_t& opt,
const rt::function::autotune_vals_t& autotune_vals,
const std::vector<std::string>& autotune_key) {
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
@@ -62,7 +63,7 @@ void register_fn(int op_id,
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
@@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) {
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def_readwrite("defines" , &rt::options_t::defines);
pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("defines" , &rt::options_space_t::defines);
.def_readwrite("defines" , &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("extract_kernels", &extract_kernels);

View File

@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
## -----------------------------------------------------------------------------
## Unit Tests
## -----------------------------------------------------------------------------
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
## -----------------------------------------------------------------------------
## Performance Tests
## -----------------------------------------------------------------------------
def do_bench(fn, warmup = 10, rep = 50):
import torch as th
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
# create layout
M, N, K = N, N, N
shape = {'sdd': (M, N),
'dsd': (K, M) if TRANS_A else (M, K),
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
# create op
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
# inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
triton_tflops = num_flops / ms * 1e3
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
layout = make_layout(H, N//BLOCK, N//BLOCK)
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK)
op = tt.ops.blocksparse.softmax(layout, BLOCK)
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
nbytes = 2 * a.numel() * a.element_size()
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
print(triton_gbyps)

View File

@@ -3,57 +3,58 @@ import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
tt.ops._matmul._kernels = dict()
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
if K is None: K = TK*TZ
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
return time_ms
def perf_op(dtype=th.float16, warmup=10, rep=50):
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
import pandas as pd
import matplotlib.pyplot as plt
import os
AT, BT = False, False
has_cutlass = 'CUTLASS_PROFILER' in os.environ
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
for AT, BT, M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
cutlass_tflops = max(df_c['GFLOPs'])/1e3
else:
cutlass_tflops = None
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
print(df)
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
# name
AT = {True: 'T', False: 'N'}[AT]
BT = {True: 'T', False: 'N'}[BT]
name = f'{AT}{BT}'
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
plt.savefig(f'matmul-{mode}-{name}.pdf')

View File

@@ -26,10 +26,8 @@ def th_to_triton(obj):
torch.float64: 'double'
}
if isinstance(obj, torch.dtype):
return [tys[obj]]
if isinstance(obj, list):
return [th_to_triton(x)[0] for x in obj]
return [str(obj)]
return tys[obj]
return str(obj)
def cdiv(a, b):
return libtriton.cdiv(a, b)
@@ -45,17 +43,15 @@ def read(path, kernel_names=[]):
source = libtriton.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
self.opt = libtriton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
@@ -65,7 +61,7 @@ class kernel:
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature

View File

@@ -81,7 +81,7 @@ class _matmul(torch.autograd.Function):
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 64 // block
start_width = 128 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
@@ -126,22 +126,18 @@ class _matmul(torch.autograd.Function):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
F32TK = [8, 16]
#F16TK = [16]
#F16TK += [32] if is_32_multiple else []
#F16TK += [64] if is_64_multiple else []
F16TK = [64]
TK = {torch.float32: F32TK,
torch.float16: F16TK}[dtype]
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
'TK': TK, 'TYPE': dtype,
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
# create output
@@ -270,9 +266,9 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': TM, 'TN': block, 'TK': TK,
defines = {'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
@@ -283,7 +279,7 @@ class _matmul(torch.autograd.Function):
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
@@ -315,9 +311,9 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': block, 'TN': TN, 'TK': TK,
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
@@ -328,7 +324,7 @@ class _matmul(torch.autograd.Function):
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0

View File

@@ -48,7 +48,7 @@ class _softmax(torch.autograd.Function):
# just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
if key not in cache:
defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block,
'INFINITY': {torch.float32: 'F32_INFINITY',
torch.float16: 'F16_INFINITY'}[dtype]}
if apply_scale:
@@ -63,7 +63,7 @@ class _softmax(torch.autograd.Function):
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps])
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel
return cache[key]

View File

@@ -29,10 +29,10 @@ class _conv(torch.autograd.Function):
TK = 16
defines = {
'TYPE' : dtype,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'TM' : 64,
'TN' : 64,
'TK' : TK,
'TZ' : 1,
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
@@ -40,7 +40,7 @@ class _conv(torch.autograd.Function):
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)

View File

@@ -83,8 +83,8 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)

View File

@@ -5,11 +5,21 @@ import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
TM = [128]
TN = [128]
TK = [32]
TZ = 1
num_warps = [4]
_DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
@@ -41,7 +51,7 @@ class _matmul(torch.autograd.Function):
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 32 == 0
is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels:
defines = {
@@ -53,13 +63,10 @@ class _matmul(torch.autograd.Function):
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'TM' : _matmul.TM,
'TN' : _matmul.TN,
'TK' : _matmul.TK,
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : int(is_tk_div_k)
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines,
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K'])
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
@@ -68,7 +75,7 @@ class _matmul(torch.autograd.Function):
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
kernel(*args, grid=grid)
return c

View File

@@ -158,21 +158,17 @@ float triton_dot(drv::context* context, drv::stream* stream,
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// macros
rt::options_space_t opts;
// A access patterns
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
// B access patterns
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
// data-type
opts.defines.push_back({"TYPE", {ty}});
// tile sizes
opts.defines.push_back({"TM", {"128"}});
opts.defines.push_back({"TN", {"128"}});
opts.defines.push_back({"TK", {"32"}});
opts.defines.push_back({"TZ", {"1"}});
opts.num_warps = {4};
rt::options_t opt;
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128";
opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ;
opt.defines["TZ"] = "1";
opt.num_warps = 4;
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
@@ -187,7 +183,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
rt::add_arg(oss, ldc);
rt::add_arg(oss, *dlocks->cu());
// function
rt::function function(src::dot, opts, device);
rt::function function(src::dot, opt, device);
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
// grid
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };