[GENERAL] Various bugfixes
This commit is contained in:
committed by
Philippe Tillet
parent
50587bbf4b
commit
8f8d36c7a4
@@ -3,16 +3,16 @@ import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
@@ -95,11 +95,12 @@ class _dot(torch.autograd.Function):
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
'TK' : [8, 16],
|
||||
'TM' : [128],
|
||||
'TN' : [128],
|
||||
'TK' : [16],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
@@ -120,7 +121,7 @@ dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 2048, 2048, 2048
|
||||
M, N, K = 4096, 4096, 4096
|
||||
a = torch.rand((M, K)).cuda().half()
|
||||
b = torch.rand((K, N)).cuda().half()
|
||||
|
||||
@@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half()
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
|
||||
print(torch.allclose(zc, zc_))
|
||||
|
@@ -51,11 +51,6 @@ std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt)
|
||||
return id_fn_map[key]->ptx(&stream, opt);
|
||||
}
|
||||
|
||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||
pybind11::buffer_info info = data.request();
|
||||
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
@@ -134,7 +129,6 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("register_cst", ®ister_cst);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("cleanup", &cleanup);
|
||||
|
@@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||
for(size_t n = 0; n < constant_names.size(); n++){
|
||||
const torch::Tensor& x = constant_vals[n];
|
||||
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -63,9 +63,6 @@ class kernel:
|
||||
size = sum([sizes[x] for x in arg_types])
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
|
||||
def set_constant(self, device, name, value):
|
||||
libtriton.register_cst((self.op_id, device), name, value)
|
||||
|
||||
def ptx(self, device, **kwargs):
|
||||
dev_id = device.index
|
||||
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||
@@ -103,5 +100,7 @@ class kernel:
|
||||
if 'autotune_buf' in kwargs:
|
||||
pass
|
||||
# launch
|
||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
|
||||
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
Reference in New Issue
Block a user