Files
triton/python/triton/kernel.py
Philippe Tillet 6fb4800f57 Improvements w/ Auto-Tuning and standard benchmarks (#57)
[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
2021-07-27 12:38:48 -07:00

97 lines
3.0 KiB
Python

import triton._C.libtriton as libtriton
import os
import time
from struct import pack
import torch
codes = {
libtriton.arg_type.int1: 'B',
libtriton.arg_type.int8: 'B',
libtriton.arg_type.int32: 'I',
libtriton.arg_type.int64: 'Q',
libtriton.arg_type.half: 'H',
libtriton.arg_type.float: 'f',
libtriton.arg_type.double: 'd',
libtriton.arg_type.buffer: 'P'
}
def th_to_triton(obj):
tys = {
torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'
}
if isinstance(obj, torch.dtype):
return tys[obj]
return str(obj)
def cdiv(a, b):
return libtriton.cdiv(a, b)
def synchronize(device):
dev_id = device.index
dev_id = -1 if dev_id is None else dev_id
libtriton.synchronize(dev_id)
def read(path, kernel_names=[]):
with open(path, 'r') as f:
source = f.read()
source = libtriton.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = libtriton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu':
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = libtriton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types])
def __call__(self, *args, grid):
# debug mode (initialize)
if self.is_debug:
_args = args
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
args[i] = libtriton.cuda_empty_like(args[i])
args[i].copy_(_args[i])
# initialize cuda device if necessary
libtriton.cuda_set_device(self.device)
# pack parameters into a byte buffer
params = pack(self.tys, *args)
# auto-tune if necessary
opt = libtriton.autotune(self.op_id, self.device, params, grid)
# run kernel
grid = grid(opt)
grid_0 = grid[0]
grid_1 = 1 if len(grid) < 2 else grid[1]
grid_2 = 1 if len(grid) < 3 else grid[2]
libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
# debug mode (finalize)
if self.is_debug:
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i].clone())
args = _args