Improvements w/ Auto-Tuning and standard benchmarks (#57)
[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
committed by
Philippe Tillet
parent
ad005d49ac
commit
6fb4800f57
@@ -54,19 +54,13 @@ enum asm_mode_t {
|
|||||||
ASM_NV_SASS
|
ASM_NV_SASS
|
||||||
};
|
};
|
||||||
|
|
||||||
struct options_space_t {
|
|
||||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
|
||||||
std::vector<define_t> defines;
|
|
||||||
std::vector<int> num_warps;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct options_t {
|
struct options_t {
|
||||||
template<class T>
|
template<class T>
|
||||||
T D(const std::string& name) const {
|
T D(const std::string& name) const {
|
||||||
return convert<T>(defines.at(name));
|
return convert<T>(defines.at(name));
|
||||||
}
|
}
|
||||||
std::unordered_map<std::string, std::string> defines;
|
std::unordered_map<std::string, std::string> defines;
|
||||||
size_t num_warps;
|
int num_warps;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -111,12 +105,14 @@ public:
|
|||||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
||||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||||
|
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
static void do_loop_nest(std::vector<size_t> const & ranges,
|
||||||
std::function<void(std::vector<size_t> const &)> const & f);
|
std::function<void(std::vector<size_t> const &)> const & f);
|
||||||
public:
|
public:
|
||||||
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
|
function(const std::string& src, const options_t& opt, driver::device *device,
|
||||||
|
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
|
||||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
||||||
// auto-tuning
|
// auto-tuning
|
||||||
@@ -126,7 +122,7 @@ public:
|
|||||||
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
|
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
|
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<kernel_pair_t> kernels_;
|
std::vector<kernel_pair_t> kernels_;
|
||||||
|
@@ -224,7 +224,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
|||||||
for(size_t i = 0; i < 3; i++)
|
for(size_t i = 0; i < 3; i++)
|
||||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||||
// enqueue
|
// enqueue
|
||||||
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
|
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string kernel::get_asm(asm_mode_t mode) {
|
std::string kernel::get_asm(asm_mode_t mode) {
|
||||||
@@ -282,35 +282,35 @@ void function::do_loop_nest(std::vector<size_t> const & ranges,
|
|||||||
return;
|
return;
|
||||||
values[i--] = 0;
|
values[i--] = 0;
|
||||||
}
|
}
|
||||||
i = D - 1;
|
i = D - 1; options_t opt;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) {
|
void function::init_kernels(const std::string& src, const options_t& opt,
|
||||||
// all ranges
|
const autotune_vals_t& confs, driver::device *device) {
|
||||||
std::vector<size_t> ranges;
|
// list of all possible configs
|
||||||
ranges.push_back(opts.num_warps.size());
|
// just augment `opt` with each define of `confs`
|
||||||
for(const auto& x: opts.defines)
|
// and override warp count
|
||||||
ranges.push_back(x.second.size());
|
size_t num_opts = std::max(confs.size(), (size_t)1);
|
||||||
// functor for source with given option
|
std::vector<options_t> opts(num_opts, opt);
|
||||||
|
for(size_t i = 0; i < confs.size(); i++){
|
||||||
|
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
|
||||||
|
opts[i].num_warps = confs[i].second;
|
||||||
|
}
|
||||||
|
// compile all possible configs
|
||||||
|
// compilation errors (e.g., too much shared mem)
|
||||||
|
// will populate `err`
|
||||||
std::vector<std::pair<options_t, std::string>> err;
|
std::vector<std::pair<options_t, std::string>> err;
|
||||||
auto do_make = [&](std::vector<size_t> params) {
|
for(const options_t& opt: opts) {
|
||||||
// compilation options
|
|
||||||
unsigned i = 0;
|
|
||||||
options_t opt;
|
|
||||||
opt.num_warps = opts.num_warps[params[i++]];
|
|
||||||
for(auto D: opts.defines)
|
|
||||||
opt.defines[D.first] = D.second[params[i++]];
|
|
||||||
// compile
|
|
||||||
try{
|
try{
|
||||||
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
|
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
|
||||||
}catch(const exception::base& e){
|
}catch(const exception::base& e){
|
||||||
err.push_back({opt, e.what()});
|
err.push_back({opt, e.what()});
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
// multi-threaded compilation
|
// throw an exception if `err` is not empty
|
||||||
do_loop_nest(ranges, do_make);
|
|
||||||
if(kernels_.empty()){
|
if(kernels_.empty()){
|
||||||
std::ostringstream dbg;
|
std::ostringstream dbg;
|
||||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||||
@@ -357,9 +357,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
function::function(const std::string& src, const options_space_t& opt,
|
function::function(const std::string& src, const options_t &opt, driver::device *device,
|
||||||
driver::device *device, const std::vector<std::string>& autotune_key) {
|
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
|
||||||
init_kernels(src, opt, device);
|
// pre-compile all kernels
|
||||||
|
init_kernels(src, opt, autotune_vals, device);
|
||||||
|
// find indices of autotune keys
|
||||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
auto arg_names = kernels_.at(0).second->get_arg_names();
|
||||||
for(const std::string& name: autotune_key){
|
for(const std::string& name: autotune_key){
|
||||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
||||||
|
@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
|
|||||||
void register_fn(int op_id,
|
void register_fn(int op_id,
|
||||||
int dev_id,
|
int dev_id,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
const rt::options_space_t& opt,
|
const rt::options_t& opt,
|
||||||
|
const rt::function::autotune_vals_t& autotune_vals,
|
||||||
const std::vector<std::string>& autotune_key) {
|
const std::vector<std::string>& autotune_key) {
|
||||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||||
driver::device* device;
|
driver::device* device;
|
||||||
@@ -62,7 +63,7 @@ void register_fn(int op_id,
|
|||||||
tt_streams[dev_id].reset(stream);
|
tt_streams[dev_id].reset(stream);
|
||||||
}
|
}
|
||||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
|
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
||||||
}
|
}
|
||||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||||
const rt::options_t* opt = &k.first;
|
const rt::options_t* opt = &k.first;
|
||||||
@@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
.value("sass", rt::ASM_NV_SASS);
|
.value("sass", rt::ASM_NV_SASS);
|
||||||
|
|
||||||
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
||||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
|
||||||
.def_readwrite("defines" , &rt::options_t::defines);
|
|
||||||
|
|
||||||
pybind11::class_<rt::options_space_t>(m, "options_space")
|
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
|
.def_readwrite("defines" , &rt::options_t::defines)
|
||||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
||||||
|
|
||||||
// hooks into triton constructs since frameworks may not use pybind11
|
// hooks into triton constructs since frameworks may not use pybind11
|
||||||
m.def("extract_kernels", &extract_kernels);
|
m.def("extract_kernels", &extract_kernels);
|
||||||
|
@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
|
|||||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## -----------------------------------------------------------------------------
|
||||||
|
## Unit Tests
|
||||||
|
## -----------------------------------------------------------------------------
|
||||||
|
|
||||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||||
[
|
[
|
||||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||||
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
|||||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||||
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
## -----------------------------------------------------------------------------
|
||||||
|
## Performance Tests
|
||||||
|
## -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def do_bench(fn, warmup = 10, rep = 50):
|
||||||
|
import torch as th
|
||||||
|
start_event = th.cuda.Event(enable_timing=True)
|
||||||
|
end_event = th.cuda.Event(enable_timing=True)
|
||||||
|
ret = fn()
|
||||||
|
for i in range(warmup):
|
||||||
|
fn()
|
||||||
|
th.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
for i in range(rep):
|
||||||
|
fn()
|
||||||
|
end_event.record()
|
||||||
|
th.cuda.synchronize()
|
||||||
|
time_ms = start_event.elapsed_time(end_event) / rep
|
||||||
|
return time_ms
|
||||||
|
|
||||||
|
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
|
||||||
|
Z, H = 1, 1
|
||||||
|
K = 512
|
||||||
|
make_layout = {
|
||||||
|
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||||
|
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||||
|
}[LAYOUT_MODE]
|
||||||
|
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||||
|
# create layout
|
||||||
|
M, N, K = N, N, N
|
||||||
|
shape = {'sdd': (M, N),
|
||||||
|
'dsd': (K, M) if TRANS_A else (M, K),
|
||||||
|
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
|
||||||
|
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
|
||||||
|
# create op
|
||||||
|
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||||
|
# inputs
|
||||||
|
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||||
|
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||||
|
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
|
||||||
|
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
|
||||||
|
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||||
|
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||||
|
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||||
|
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
|
||||||
|
triton_tflops = num_flops / ms * 1e3
|
||||||
|
|
||||||
|
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
|
||||||
|
Z, H = 1, 1
|
||||||
|
K = 512
|
||||||
|
make_layout = {
|
||||||
|
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||||
|
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||||
|
}[LAYOUT_MODE]
|
||||||
|
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||||
|
layout = make_layout(H, N//BLOCK, N//BLOCK)
|
||||||
|
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
|
||||||
|
a = sparsify_tensor(a, layout, BLOCK)
|
||||||
|
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||||
|
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||||
|
nbytes = 2 * a.numel() * a.element_size()
|
||||||
|
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
|
||||||
|
print(triton_gbyps)
|
||||||
|
@@ -3,57 +3,58 @@ import itertools
|
|||||||
import triton as tt
|
import triton as tt
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||||
[
|
[
|
||||||
# 1 warp
|
# 1 warp
|
||||||
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
# 2 warp
|
# 2 warp
|
||||||
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# 4 warp
|
# 4 warp
|
||||||
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
# 8 warp
|
# 8 warp
|
||||||
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||||
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||||
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||||
|
# split-k
|
||||||
|
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||||
# variable input
|
# variable input
|
||||||
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||||
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||||
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||||
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||||
]
|
]
|
||||||
for DTYPE in ['float16']
|
for DTYPE in ['float16']
|
||||||
for AT in [False, True]
|
for AT in [False, True]
|
||||||
for BT in [False, True]
|
for BT in [False, True]
|
||||||
]))
|
]))
|
||||||
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||||
th.manual_seed(0)
|
th.manual_seed(0)
|
||||||
tt.ops._matmul.kernel = dict()
|
tt.ops._matmul._kernels = dict()
|
||||||
tt.ops._matmul.TM = [TM]
|
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||||
tt.ops._matmul.TN = [TN]
|
|
||||||
tt.ops._matmul.TK = [TK]
|
|
||||||
tt.ops._matmul.num_warps = [NWARP]
|
|
||||||
if M is None: M = TM
|
if M is None: M = TM
|
||||||
if N is None: N = TN
|
if N is None: N = TN
|
||||||
if K is None: K = TK
|
if K is None: K = TK*TZ
|
||||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||||
a = a.t() if AT else a
|
a = a.t() if AT else a
|
||||||
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
|||||||
return time_ms
|
return time_ms
|
||||||
|
|
||||||
|
|
||||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import os
|
import os
|
||||||
AT, BT = False, False
|
|
||||||
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
||||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
|
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
|
||||||
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
|
||||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||||
for AT, BT, M, N, K in configs:
|
for AT, BT, M, N, K in configs:
|
||||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||||
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
|
|||||||
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||||
else:
|
else:
|
||||||
cutlass_tflops = None
|
cutlass_tflops = None
|
||||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
# name
|
||||||
print(df)
|
AT = {True: 'T', False: 'N'}[AT]
|
||||||
|
BT = {True: 'T', False: 'N'}[BT]
|
||||||
|
name = f'{AT}{BT}'
|
||||||
|
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
|
||||||
|
plt.savefig(f'matmul-{mode}-{name}.pdf')
|
@@ -26,10 +26,8 @@ def th_to_triton(obj):
|
|||||||
torch.float64: 'double'
|
torch.float64: 'double'
|
||||||
}
|
}
|
||||||
if isinstance(obj, torch.dtype):
|
if isinstance(obj, torch.dtype):
|
||||||
return [tys[obj]]
|
return tys[obj]
|
||||||
if isinstance(obj, list):
|
return str(obj)
|
||||||
return [th_to_triton(x)[0] for x in obj]
|
|
||||||
return [str(obj)]
|
|
||||||
|
|
||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return libtriton.cdiv(a, b)
|
return libtriton.cdiv(a, b)
|
||||||
@@ -45,17 +43,15 @@ def read(path, kernel_names=[]):
|
|||||||
source = libtriton.extract_kernels(source, kernel_names)
|
source = libtriton.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
|
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
|
||||||
# check if src is empty
|
# check if src is empty
|
||||||
if src == '':
|
if src == '':
|
||||||
raise ValueError('Kernel source code is empty')
|
raise ValueError('Kernel source code is empty')
|
||||||
self.src = src
|
self.src = src
|
||||||
self.opt = libtriton.options_space()
|
self.opt = libtriton.options()
|
||||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||||
self.opt.num_warps = num_warps
|
self.opt.num_warps = num_warps
|
||||||
# device
|
# device
|
||||||
assert device.type in ['cuda', 'cpu']
|
assert device.type in ['cuda', 'cpu']
|
||||||
@@ -65,7 +61,7 @@ class kernel:
|
|||||||
self.device = -1
|
self.device = -1
|
||||||
# C++ function wrapper
|
# C++ function wrapper
|
||||||
self.op_id = libtriton.make_op_id()
|
self.op_id = libtriton.make_op_id()
|
||||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
|
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||||
# debug mode
|
# debug mode
|
||||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||||
# signature
|
# signature
|
||||||
|
@@ -81,7 +81,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_sdd_lut(layout, block, dtype, device):
|
def make_sdd_lut(layout, block, dtype, device):
|
||||||
start_width = 64 // block
|
start_width = 128 // block
|
||||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||||
luts, widths, packs = [], [], []
|
luts, widths, packs = [], [], []
|
||||||
for size, nnz in superblocks:
|
for size, nnz in superblocks:
|
||||||
@@ -126,22 +126,18 @@ class _matmul(torch.autograd.Function):
|
|||||||
num_lock = 1
|
num_lock = 1
|
||||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||||
if key not in _matmul.sdd_cache:
|
if key not in _matmul.sdd_cache:
|
||||||
F32TK = [8, 16]
|
defines = {'TM': block*pack, 'TN': block*pack,
|
||||||
#F16TK = [16]
|
'TMN': block*block*pack*pack,
|
||||||
#F16TK += [32] if is_32_multiple else []
|
'BLOCK': block,
|
||||||
#F16TK += [64] if is_64_multiple else []
|
'TK': 32,
|
||||||
F16TK = [64]
|
'TYPE': dtype,
|
||||||
TK = {torch.float32: F32TK,
|
|
||||||
torch.float16: F16TK}[dtype]
|
|
||||||
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
|
|
||||||
'TK': TK, 'TYPE': dtype,
|
|
||||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||||
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
||||||
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
||||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
|
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||||
|
|
||||||
kernel = _matmul.sdd_cache[key]
|
kernel = _matmul.sdd_cache[key]
|
||||||
# create output
|
# create output
|
||||||
@@ -270,9 +266,9 @@ class _matmul(torch.autograd.Function):
|
|||||||
# kernel
|
# kernel
|
||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
if key not in _matmul.dds_cache:
|
if key not in _matmul.dds_cache:
|
||||||
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
|
defines = {'TM': 128,
|
||||||
TK = [8] if dtype == torch.float32 else [16]
|
'TN': block,
|
||||||
defines = {'TM': TM, 'TN': block, 'TK': TK,
|
'TK': 16,
|
||||||
'BLOCK': block,
|
'BLOCK': block,
|
||||||
'TYPE': dtype,
|
'TYPE': dtype,
|
||||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||||
@@ -283,7 +279,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||||
'NAME': 'dds_kernel',
|
'NAME': 'dds_kernel',
|
||||||
'DDS': True}
|
'DDS': True}
|
||||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
kernel = _matmul.dds_cache[key]
|
kernel = _matmul.dds_cache[key]
|
||||||
# output
|
# output
|
||||||
CS0 = AS0
|
CS0 = AS0
|
||||||
@@ -315,9 +311,9 @@ class _matmul(torch.autograd.Function):
|
|||||||
# kernel
|
# kernel
|
||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
if key not in _matmul.dsd_cache:
|
if key not in _matmul.dsd_cache:
|
||||||
TN = [64, 128] if dtype == torch.float32 else [64, 128]
|
defines = {'TM': block,
|
||||||
TK = [8] if dtype == torch.float32 else [16]
|
'TN': 128,
|
||||||
defines = {'TM': block, 'TN': TN, 'TK': TK,
|
'TK': 16,
|
||||||
'BLOCK': block,
|
'BLOCK': block,
|
||||||
'TYPE': dtype,
|
'TYPE': dtype,
|
||||||
'STRIDE_AM': 1 if trans_a else block,
|
'STRIDE_AM': 1 if trans_a else block,
|
||||||
@@ -328,7 +324,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||||
'NAME': 'dsd_kernel',
|
'NAME': 'dsd_kernel',
|
||||||
'DSD': True}
|
'DSD': True}
|
||||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
kernel = _matmul.dsd_cache[key]
|
kernel = _matmul.dsd_cache[key]
|
||||||
# output
|
# output
|
||||||
CS0 = BS0
|
CS0 = BS0
|
||||||
|
@@ -48,7 +48,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
# just-in-time compile kernel
|
# just-in-time compile kernel
|
||||||
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
|
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block,
|
||||||
'INFINITY': {torch.float32: 'F32_INFINITY',
|
'INFINITY': {torch.float32: 'F32_INFINITY',
|
||||||
torch.float16: 'F16_INFINITY'}[dtype]}
|
torch.float16: 'F16_INFINITY'}[dtype]}
|
||||||
if apply_scale:
|
if apply_scale:
|
||||||
@@ -63,7 +63,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
defines['APPLY_ATTN_MASK'] = True
|
defines['APPLY_ATTN_MASK'] = True
|
||||||
if attn_mask_mode == 'mul':
|
if attn_mask_mode == 'mul':
|
||||||
defines['ATTN_MASK_MUL'] = True
|
defines['ATTN_MASK_MUL'] = True
|
||||||
kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps])
|
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
|
||||||
cache[key] = kernel
|
cache[key] = kernel
|
||||||
return cache[key]
|
return cache[key]
|
||||||
|
|
||||||
|
@@ -29,10 +29,10 @@ class _conv(torch.autograd.Function):
|
|||||||
TK = 16
|
TK = 16
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE' : dtype,
|
'TYPE' : dtype,
|
||||||
'TM' : [32, 64, 128],
|
'TM' : 64,
|
||||||
'TN' : [32, 64, 128],
|
'TN' : 64,
|
||||||
'TK' : [TK],
|
'TK' : TK,
|
||||||
'TZ' : [1],
|
'TZ' : 1,
|
||||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||||
}
|
}
|
||||||
idx = torch.arange(CI*R*S)
|
idx = torch.arange(CI*R*S)
|
||||||
@@ -40,7 +40,7 @@ class _conv(torch.autograd.Function):
|
|||||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||||
delta = delta.type(torch.int32).cuda()
|
delta = delta.type(torch.int32).cuda()
|
||||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
|
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
|
||||||
delta, kernel = _conv.kernel[dtype]
|
delta, kernel = _conv.kernel[dtype]
|
||||||
# allocate output
|
# allocate output
|
||||||
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
||||||
|
@@ -83,8 +83,8 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
*?(checkc) pc = c;
|
*?(checkc) pc = c;
|
||||||
#else
|
#else
|
||||||
// accumulate partial result using spin-locks
|
// accumulate partial result using spin-locks
|
||||||
int *plock = locks + rid;
|
int *plock = locks + pid;
|
||||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
int *pcount = plock + get_num_programs(0);
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||||
int count = *pcount;
|
int count = *pcount;
|
||||||
if(count == 0)
|
if(count == 0)
|
||||||
|
@@ -5,11 +5,21 @@ import os
|
|||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||||
|
|
||||||
TM = [128]
|
_DEFAULT_CONFIGS = [
|
||||||
TN = [128]
|
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||||
TK = [32]
|
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||||
TZ = 1
|
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4),
|
||||||
num_warps = [4]
|
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4),
|
||||||
|
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
||||||
|
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4),
|
||||||
|
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2),
|
||||||
|
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2),
|
||||||
|
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||||
|
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||||
|
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||||
|
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||||
|
]
|
||||||
|
_CONFIGS = _DEFAULT_CONFIGS
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def largest_pow2_divisor(N):
|
def largest_pow2_divisor(N):
|
||||||
@@ -41,7 +51,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
||||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||||
is_tk_div_k = K % 32 == 0
|
is_tk_div_k = K % 64 == 0
|
||||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
||||||
if key not in _matmul._kernels:
|
if key not in _matmul._kernels:
|
||||||
defines = {
|
defines = {
|
||||||
@@ -53,13 +63,10 @@ class _matmul(torch.autograd.Function):
|
|||||||
'LDA_POW2_DIV': lda_pow2_div,
|
'LDA_POW2_DIV': lda_pow2_div,
|
||||||
'LDB_POW2_DIV': ldb_pow2_div,
|
'LDB_POW2_DIV': ldb_pow2_div,
|
||||||
'LDC_POW2_DIV': ldc_pow2_div,
|
'LDC_POW2_DIV': ldc_pow2_div,
|
||||||
'TM' : _matmul.TM,
|
|
||||||
'TN' : _matmul.TN,
|
|
||||||
'TK' : _matmul.TK,
|
|
||||||
'TZ' : _matmul.TZ,
|
|
||||||
'IS_TK_DIV_K' : int(is_tk_div_k)
|
'IS_TK_DIV_K' : int(is_tk_div_k)
|
||||||
}
|
}
|
||||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
|
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines,
|
||||||
|
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K'])
|
||||||
kernel = _matmul._kernels[key]
|
kernel = _matmul._kernels[key]
|
||||||
# # locks for split-k
|
# # locks for split-k
|
||||||
if device not in _matmul._locks:
|
if device not in _matmul._locks:
|
||||||
@@ -68,7 +75,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
# enqueue
|
# enqueue
|
||||||
alpha = 1.
|
alpha = 1.
|
||||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
|
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
|
||||||
kernel(*args, grid=grid)
|
kernel(*args, grid=grid)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
@@ -158,21 +158,17 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
|||||||
stream->write(&*da, true, 0, ha);
|
stream->write(&*da, true, 0, ha);
|
||||||
stream->write(&*db, true, 0, hb);
|
stream->write(&*db, true, 0, hb);
|
||||||
// macros
|
// macros
|
||||||
rt::options_space_t opts;
|
rt::options_t opt;
|
||||||
// A access patterns
|
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
|
||||||
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
|
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
|
||||||
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
|
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
|
||||||
// B access patterns
|
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
|
||||||
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
|
opt.defines["TYPE"] = ty;
|
||||||
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
|
opt.defines["TM"] = "128";
|
||||||
// data-type
|
opt.defines["TN"] = "128";
|
||||||
opts.defines.push_back({"TYPE", {ty}});
|
opt.defines["TK"] = "32" ;
|
||||||
// tile sizes
|
opt.defines["TZ"] = "1";
|
||||||
opts.defines.push_back({"TM", {"128"}});
|
opt.num_warps = 4;
|
||||||
opts.defines.push_back({"TN", {"128"}});
|
|
||||||
opts.defines.push_back({"TK", {"32"}});
|
|
||||||
opts.defines.push_back({"TZ", {"1"}});
|
|
||||||
opts.num_warps = {4};
|
|
||||||
// arguments
|
// arguments
|
||||||
std::stringstream oss;
|
std::stringstream oss;
|
||||||
rt::add_arg(oss, *da->cu());
|
rt::add_arg(oss, *da->cu());
|
||||||
@@ -187,7 +183,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
|||||||
rt::add_arg(oss, ldc);
|
rt::add_arg(oss, ldc);
|
||||||
rt::add_arg(oss, *dlocks->cu());
|
rt::add_arg(oss, *dlocks->cu());
|
||||||
// function
|
// function
|
||||||
rt::function function(src::dot, opts, device);
|
rt::function function(src::dot, opt, device);
|
||||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
||||||
// grid
|
// grid
|
||||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||||
|
Reference in New Issue
Block a user