Improvements w/ Auto-Tuning and standard benchmarks (#57)
[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
committed by
Philippe Tillet
parent
ad005d49ac
commit
6fb4800f57
@@ -54,19 +54,13 @@ enum asm_mode_t {
|
||||
ASM_NV_SASS
|
||||
};
|
||||
|
||||
struct options_space_t {
|
||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||
std::vector<define_t> defines;
|
||||
std::vector<int> num_warps;
|
||||
};
|
||||
|
||||
struct options_t {
|
||||
template<class T>
|
||||
T D(const std::string& name) const {
|
||||
return convert<T>(defines.at(name));
|
||||
}
|
||||
std::unordered_map<std::string, std::string> defines;
|
||||
size_t num_warps;
|
||||
int num_warps;
|
||||
};
|
||||
|
||||
|
||||
@@ -111,12 +105,14 @@ public:
|
||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
|
||||
|
||||
private:
|
||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f);
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
|
||||
function(const std::string& src, const options_t& opt, driver::device *device,
|
||||
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
|
||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
||||
// auto-tuning
|
||||
@@ -126,7 +122,7 @@ public:
|
||||
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
|
||||
|
||||
private:
|
||||
void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
|
||||
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
|
||||
|
||||
private:
|
||||
std::vector<kernel_pair_t> kernels_;
|
||||
|
@@ -224,7 +224,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
// enqueue
|
||||
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
|
||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
|
||||
}
|
||||
|
||||
std::string kernel::get_asm(asm_mode_t mode) {
|
||||
@@ -282,35 +282,35 @@ void function::do_loop_nest(std::vector<size_t> const & ranges,
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
i = D - 1; options_t opt;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) {
|
||||
// all ranges
|
||||
std::vector<size_t> ranges;
|
||||
ranges.push_back(opts.num_warps.size());
|
||||
for(const auto& x: opts.defines)
|
||||
ranges.push_back(x.second.size());
|
||||
// functor for source with given option
|
||||
void function::init_kernels(const std::string& src, const options_t& opt,
|
||||
const autotune_vals_t& confs, driver::device *device) {
|
||||
// list of all possible configs
|
||||
// just augment `opt` with each define of `confs`
|
||||
// and override warp count
|
||||
size_t num_opts = std::max(confs.size(), (size_t)1);
|
||||
std::vector<options_t> opts(num_opts, opt);
|
||||
for(size_t i = 0; i < confs.size(); i++){
|
||||
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
|
||||
opts[i].num_warps = confs[i].second;
|
||||
}
|
||||
// compile all possible configs
|
||||
// compilation errors (e.g., too much shared mem)
|
||||
// will populate `err`
|
||||
std::vector<std::pair<options_t, std::string>> err;
|
||||
auto do_make = [&](std::vector<size_t> params) {
|
||||
// compilation options
|
||||
unsigned i = 0;
|
||||
options_t opt;
|
||||
opt.num_warps = opts.num_warps[params[i++]];
|
||||
for(auto D: opts.defines)
|
||||
opt.defines[D.first] = D.second[params[i++]];
|
||||
// compile
|
||||
for(const options_t& opt: opts) {
|
||||
try{
|
||||
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
|
||||
}catch(const exception::base& e){
|
||||
err.push_back({opt, e.what()});
|
||||
}
|
||||
};
|
||||
// multi-threaded compilation
|
||||
do_loop_nest(ranges, do_make);
|
||||
}
|
||||
// throw an exception if `err` is not empty
|
||||
if(kernels_.empty()){
|
||||
std::ostringstream dbg;
|
||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||
@@ -357,9 +357,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
||||
return it->second;
|
||||
}
|
||||
|
||||
function::function(const std::string& src, const options_space_t& opt,
|
||||
driver::device *device, const std::vector<std::string>& autotune_key) {
|
||||
init_kernels(src, opt, device);
|
||||
function::function(const std::string& src, const options_t &opt, driver::device *device,
|
||||
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
|
||||
// pre-compile all kernels
|
||||
init_kernels(src, opt, autotune_vals, device);
|
||||
// find indices of autotune keys
|
||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
||||
for(const std::string& name: autotune_key){
|
||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
||||
|
@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
|
||||
void register_fn(int op_id,
|
||||
int dev_id,
|
||||
const std::string& src,
|
||||
const rt::options_space_t& opt,
|
||||
const rt::options_t& opt,
|
||||
const rt::function::autotune_vals_t& autotune_vals,
|
||||
const std::vector<std::string>& autotune_key) {
|
||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||
driver::device* device;
|
||||
@@ -62,7 +63,7 @@ void register_fn(int op_id,
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
||||
}
|
||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||
const rt::options_t* opt = &k.first;
|
||||
@@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_t::defines);
|
||||
|
||||
pybind11::class_<rt::options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
||||
.def_readwrite("defines" , &rt::options_t::defines)
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("extract_kernels", &extract_kernels);
|
||||
|
@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
|
||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Unit Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Performance Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
def do_bench(fn, warmup = 10, rep = 50):
|
||||
import torch as th
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
# create layout
|
||||
M, N, K = N, N, N
|
||||
shape = {'sdd': (M, N),
|
||||
'dsd': (K, M) if TRANS_A else (M, K),
|
||||
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
|
||||
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
|
||||
# create op
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
# inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
|
||||
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
|
||||
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
|
||||
triton_tflops = num_flops / ms * 1e3
|
||||
|
||||
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
layout = make_layout(H, N//BLOCK, N//BLOCK)
|
||||
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK)
|
||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
nbytes = 2 * a.numel() * a.element_size()
|
||||
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
|
||||
print(triton_gbyps)
|
||||
|
@@ -3,57 +3,58 @@ import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul.kernel = dict()
|
||||
tt.ops._matmul.TM = [TM]
|
||||
tt.ops._matmul.TN = [TN]
|
||||
tt.ops._matmul.TK = [TK]
|
||||
tt.ops._matmul.num_warps = [NWARP]
|
||||
tt.ops._matmul._kernels = dict()
|
||||
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK
|
||||
if K is None: K = TK*TZ
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
return time_ms
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
AT, BT = False, False
|
||||
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
|
||||
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
||||
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
|
||||
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
|
||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||
for AT, BT, M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||
else:
|
||||
cutlass_tflops = None
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
||||
print(df)
|
||||
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||
# name
|
||||
AT = {True: 'T', False: 'N'}[AT]
|
||||
BT = {True: 'T', False: 'N'}[BT]
|
||||
name = f'{AT}{BT}'
|
||||
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
|
||||
plt.savefig(f'matmul-{mode}-{name}.pdf')
|
@@ -26,10 +26,8 @@ def th_to_triton(obj):
|
||||
torch.float64: 'double'
|
||||
}
|
||||
if isinstance(obj, torch.dtype):
|
||||
return [tys[obj]]
|
||||
if isinstance(obj, list):
|
||||
return [th_to_triton(x)[0] for x in obj]
|
||||
return [str(obj)]
|
||||
return tys[obj]
|
||||
return str(obj)
|
||||
|
||||
def cdiv(a, b):
|
||||
return libtriton.cdiv(a, b)
|
||||
@@ -45,17 +43,15 @@ def read(path, kernel_names=[]):
|
||||
source = libtriton.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
|
||||
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
self.src = src
|
||||
self.opt = libtriton.options_space()
|
||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
||||
self.opt = libtriton.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# device
|
||||
assert device.type in ['cuda', 'cpu']
|
||||
@@ -65,7 +61,7 @@ class kernel:
|
||||
self.device = -1
|
||||
# C++ function wrapper
|
||||
self.op_id = libtriton.make_op_id()
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
# signature
|
||||
|
@@ -81,7 +81,7 @@ class _matmul(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
start_width = 64 // block
|
||||
start_width = 128 // block
|
||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
for size, nnz in superblocks:
|
||||
@@ -126,22 +126,18 @@ class _matmul(torch.autograd.Function):
|
||||
num_lock = 1
|
||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||
if key not in _matmul.sdd_cache:
|
||||
F32TK = [8, 16]
|
||||
#F16TK = [16]
|
||||
#F16TK += [32] if is_32_multiple else []
|
||||
#F16TK += [64] if is_64_multiple else []
|
||||
F16TK = [64]
|
||||
TK = {torch.float32: F32TK,
|
||||
torch.float16: F16TK}[dtype]
|
||||
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
|
||||
'TK': TK, 'TYPE': dtype,
|
||||
defines = {'TM': block*pack, 'TN': block*pack,
|
||||
'TMN': block*block*pack*pack,
|
||||
'BLOCK': block,
|
||||
'TK': 32,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
||||
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
|
||||
kernel = _matmul.sdd_cache[key]
|
||||
# create output
|
||||
@@ -270,9 +266,9 @@ class _matmul(torch.autograd.Function):
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dds_cache:
|
||||
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {'TM': TM, 'TN': block, 'TK': TK,
|
||||
defines = {'TM': 128,
|
||||
'TN': block,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
@@ -283,7 +279,7 @@ class _matmul(torch.autograd.Function):
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel',
|
||||
'DDS': True}
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dds_cache[key]
|
||||
# output
|
||||
CS0 = AS0
|
||||
@@ -315,9 +311,9 @@ class _matmul(torch.autograd.Function):
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dsd_cache:
|
||||
TN = [64, 128] if dtype == torch.float32 else [64, 128]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {'TM': block, 'TN': TN, 'TK': TK,
|
||||
defines = {'TM': block,
|
||||
'TN': 128,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else block,
|
||||
@@ -328,7 +324,7 @@ class _matmul(torch.autograd.Function):
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dsd_kernel',
|
||||
'DSD': True}
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dsd_cache[key]
|
||||
# output
|
||||
CS0 = BS0
|
||||
|
@@ -48,7 +48,7 @@ class _softmax(torch.autograd.Function):
|
||||
# just-in-time compile kernel
|
||||
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||
if key not in cache:
|
||||
defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
|
||||
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block,
|
||||
'INFINITY': {torch.float32: 'F32_INFINITY',
|
||||
torch.float16: 'F16_INFINITY'}[dtype]}
|
||||
if apply_scale:
|
||||
@@ -63,7 +63,7 @@ class _softmax(torch.autograd.Function):
|
||||
defines['APPLY_ATTN_MASK'] = True
|
||||
if attn_mask_mode == 'mul':
|
||||
defines['ATTN_MASK_MUL'] = True
|
||||
kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps])
|
||||
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
|
||||
cache[key] = kernel
|
||||
return cache[key]
|
||||
|
||||
|
@@ -29,10 +29,10 @@ class _conv(torch.autograd.Function):
|
||||
TK = 16
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'TM' : 64,
|
||||
'TN' : 64,
|
||||
'TK' : TK,
|
||||
'TZ' : 1,
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
@@ -40,7 +40,7 @@ class _conv(torch.autograd.Function):
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
||||
|
@@ -83,8 +83,8 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
|
@@ -5,11 +5,21 @@ import os
|
||||
class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
|
||||
TM = [128]
|
||||
TN = [128]
|
||||
TK = [32]
|
||||
TZ = 1
|
||||
num_warps = [4]
|
||||
_DEFAULT_CONFIGS = [
|
||||
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2),
|
||||
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2),
|
||||
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||
]
|
||||
_CONFIGS = _DEFAULT_CONFIGS
|
||||
|
||||
@staticmethod
|
||||
def largest_pow2_divisor(N):
|
||||
@@ -41,7 +51,7 @@ class _matmul(torch.autograd.Function):
|
||||
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||
is_tk_div_k = K % 32 == 0
|
||||
is_tk_div_k = K % 64 == 0
|
||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
||||
if key not in _matmul._kernels:
|
||||
defines = {
|
||||
@@ -53,13 +63,10 @@ class _matmul(torch.autograd.Function):
|
||||
'LDA_POW2_DIV': lda_pow2_div,
|
||||
'LDB_POW2_DIV': ldb_pow2_div,
|
||||
'LDC_POW2_DIV': ldc_pow2_div,
|
||||
'TM' : _matmul.TM,
|
||||
'TN' : _matmul.TN,
|
||||
'TK' : _matmul.TK,
|
||||
'TZ' : _matmul.TZ,
|
||||
'IS_TK_DIV_K' : int(is_tk_div_k)
|
||||
}
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines,
|
||||
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K'])
|
||||
kernel = _matmul._kernels[key]
|
||||
# # locks for split-k
|
||||
if device not in _matmul._locks:
|
||||
@@ -68,7 +75,7 @@ class _matmul(torch.autograd.Function):
|
||||
# enqueue
|
||||
alpha = 1.
|
||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
|
||||
kernel(*args, grid=grid)
|
||||
return c
|
||||
|
||||
|
@@ -158,21 +158,17 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
// macros
|
||||
rt::options_space_t opts;
|
||||
// A access patterns
|
||||
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
|
||||
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
|
||||
// B access patterns
|
||||
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
|
||||
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
|
||||
// data-type
|
||||
opts.defines.push_back({"TYPE", {ty}});
|
||||
// tile sizes
|
||||
opts.defines.push_back({"TM", {"128"}});
|
||||
opts.defines.push_back({"TN", {"128"}});
|
||||
opts.defines.push_back({"TK", {"32"}});
|
||||
opts.defines.push_back({"TZ", {"1"}});
|
||||
opts.num_warps = {4};
|
||||
rt::options_t opt;
|
||||
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
|
||||
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
|
||||
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
|
||||
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
|
||||
opt.defines["TYPE"] = ty;
|
||||
opt.defines["TM"] = "128";
|
||||
opt.defines["TN"] = "128";
|
||||
opt.defines["TK"] = "32" ;
|
||||
opt.defines["TZ"] = "1";
|
||||
opt.num_warps = 4;
|
||||
// arguments
|
||||
std::stringstream oss;
|
||||
rt::add_arg(oss, *da->cu());
|
||||
@@ -187,7 +183,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
||||
rt::add_arg(oss, ldc);
|
||||
rt::add_arg(oss, *dlocks->cu());
|
||||
// function
|
||||
rt::function function(src::dot, opts, device);
|
||||
rt::function function(src::dot, opt, device);
|
||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
||||
// grid
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
|
Reference in New Issue
Block a user