[RUNTIME] Auto-tuning now works as expected when the values of
autotune_key change
This commit is contained in:
@@ -82,6 +82,7 @@ public:
|
||||
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
||||
// getters
|
||||
const std::vector<arg_type>& get_sig() const { return sig_; }
|
||||
const std::vector<std::string>& get_arg_names() const { return arg_names_; }
|
||||
std::string get_asm(asm_mode_t mode);
|
||||
|
||||
private:
|
||||
@@ -96,6 +97,7 @@ private:
|
||||
driver::device* dev_;
|
||||
// signature
|
||||
std::vector<arg_type> sig_;
|
||||
std::vector<std::string> arg_names_;
|
||||
// triton context for parsing
|
||||
ir::context ctx_;
|
||||
// handles
|
||||
@@ -114,7 +116,7 @@ private:
|
||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f);
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device);
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
|
||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
||||
// auto-tuning
|
||||
@@ -129,6 +131,9 @@ private:
|
||||
private:
|
||||
std::vector<kernel_pair_t> kernels_;
|
||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||
std::vector<int> key_idxs_;
|
||||
std::vector<int> arg_size_;
|
||||
std::vector<int> arg_off_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -211,6 +211,8 @@ kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev
|
||||
init_ir(preheader() + src);
|
||||
init_ker();
|
||||
init_sig();
|
||||
for(auto arg: ir_->get_function_list()[0]->args())
|
||||
arg_names_.push_back(arg->get_name());
|
||||
}
|
||||
|
||||
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
||||
@@ -328,7 +330,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
||||
if(kernels_.size() == 1)
|
||||
return &*kernels_.begin()->second;
|
||||
// auto-tuning key
|
||||
std::vector<uint64_t> key;
|
||||
std::vector<uint64_t> key(key_idxs_.size());
|
||||
for(size_t i = 0; i < key.size(); i++){
|
||||
int idx = key_idxs_[i];
|
||||
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
|
||||
}
|
||||
auto it = cache_.find(key);
|
||||
if(it != cache_.end())
|
||||
return it->second;
|
||||
@@ -350,8 +356,26 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
||||
return it->second;
|
||||
}
|
||||
|
||||
function::function(const std::string& src, const options_space_t& opt, driver::device *device) {
|
||||
function::function(const std::string& src, const options_space_t& opt,
|
||||
driver::device *device, const std::vector<std::string>& autotune_key) {
|
||||
init_kernels(src, opt, device);
|
||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
||||
for(const std::string& name: autotune_key){
|
||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
||||
if(it == arg_names.end())
|
||||
throw std::runtime_error(name + " is not a valid argument name");
|
||||
key_idxs_.push_back(std::distance(arg_names.begin(), it));
|
||||
}
|
||||
// argument size and offset
|
||||
auto tys = kernels_.at(0).second->get_sig();
|
||||
size_t curr = 0;
|
||||
for(arg_type ty: tys){
|
||||
arg_size_.push_back(size_of(ty));
|
||||
arg_off_.push_back(curr);
|
||||
curr += arg_size_.back();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
|
@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
|
||||
void register_fn(int op_id,
|
||||
int dev_id,
|
||||
const std::string& src,
|
||||
const rt::options_space_t& opt) {
|
||||
const rt::options_space_t& opt,
|
||||
const std::vector<std::string>& autotune_key) {
|
||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||
driver::device* device;
|
||||
driver::stream* stream;
|
||||
@@ -61,7 +62,7 @@ void register_fn(int op_id,
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
|
||||
}
|
||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||
const rt::options_t* opt = &k.first;
|
||||
|
@@ -78,21 +78,30 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms, flops/time_ms*1e-9, ret
|
||||
return time_ms
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
AT, BT = False, False
|
||||
import pandas as pd
|
||||
AT, BT = False, False
|
||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH'])
|
||||
Ns = [128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192]
|
||||
# Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192]
|
||||
Ns = [8192]
|
||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||
for AT, BT, M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': TT_TFLOPS, 'TORCH': TH_TFLOPS}, ignore_index=True)
|
||||
# benchmarks
|
||||
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
|
||||
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
|
||||
# store result
|
||||
num_flops = 2*M*N*K
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
#print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops)
|
||||
#./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True)
|
||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
||||
print(df)
|
@@ -49,7 +49,7 @@ def read(path, kernel_names=[]):
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, device, defines = dict(), num_warps = [4]):
|
||||
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
@@ -65,7 +65,7 @@ class kernel:
|
||||
self.device = -1
|
||||
# C++ function wrapper
|
||||
self.op_id = libtriton.make_op_id()
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt)
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
# signature
|
||||
|
@@ -59,7 +59,7 @@ class _matmul(torch.autograd.Function):
|
||||
'TZ' : _matmul.TZ,
|
||||
'IS_TK_DIV_K' : int(is_tk_div_k)
|
||||
}
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
|
||||
kernel = _matmul._kernels[key]
|
||||
# # locks for split-k
|
||||
if device not in _matmul._locks:
|
||||
|
@@ -173,7 +173,6 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
||||
opts.defines.push_back({"TK", {"32"}});
|
||||
opts.defines.push_back({"TZ", {"1"}});
|
||||
opts.num_warps = {4};
|
||||
|
||||
// arguments
|
||||
std::stringstream oss;
|
||||
rt::add_arg(oss, *da->cu());
|
||||
|
Reference in New Issue
Block a user