[RUNTIME] Auto-tuning now works as expected when the values of

autotune_key change
This commit is contained in:
Philippe Tillet
2021-01-31 14:17:27 -05:00
parent 52af8cda34
commit 3fde4b8f5b
7 changed files with 53 additions and 15 deletions

View File

@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt) {
const rt::options_space_t& opt,
const std::vector<std::string>& autotune_key) {
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
driver::stream* stream;
@@ -61,7 +62,7 @@ void register_fn(int op_id,
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;

View File

@@ -78,21 +78,30 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms, flops/time_ms*1e-9, ret
return time_ms
def perf_op(dtype=th.float16, warmup=10, rep=50):
AT, BT = False, False
import pandas as pd
AT, BT = False, False
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH'])
Ns = [128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192]
# Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192]
Ns = [8192]
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
for AT, BT, M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': TT_TFLOPS, 'TORCH': TH_TFLOPS}, ignore_index=True)
# benchmarks
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
# store result
num_flops = 2*M*N*K
torch_tflops = num_flops / torch_ms * 1e-9
triton_tflops = num_flops / triton_ms * 1e-9
#print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops)
#./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True)
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
print(df)

View File

@@ -49,7 +49,7 @@ def read(path, kernel_names=[]):
class kernel:
def __init__(self, src, device, defines = dict(), num_warps = [4]):
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
@@ -65,7 +65,7 @@ class kernel:
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.device, self.src, self.opt)
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature

View File

@@ -59,7 +59,7 @@ class _matmul(torch.autograd.Function):
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : int(is_tk_div_k)
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks: