[GENERAL] Various bugfixes

This commit is contained in:
Philippe Tillet
2020-11-11 14:44:56 -05:00
committed by Philippe Tillet
parent 50587bbf4b
commit 8f8d36c7a4
11 changed files with 103 additions and 59 deletions

View File

@@ -362,6 +362,30 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
} }
void generator::visit_masked_load_inst(ir::masked_load_inst* x) { void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
if(!x->get_type()->is_tile_ty()){
Value *ptr = vmap_.at(x->get_pointer_operand());
Value *mask = vmap_.at(x->get_mask_operand());
BasicBlock *current_bb = builder_->GetInsertBlock();
Function *parent = builder_->GetInsertBlock()->getParent();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb);
builder_->SetInsertPoint(mask_then_bb);
Value *result_then = builder_->CreateLoad(ptr);
builder_->CreateBr(mask_done_bb);
builder_->SetInsertPoint(mask_done_bb);
Value *result = nullptr;
if(x->get_false_value_operand()){
Value *result_false = vmap_.at(x->get_false_value_operand());
result = builder_->CreatePHI(result_then->getType(), 2);
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
((PHINode*)result)->addIncoming(result_false, current_bb);
}
else
result = result_then;
vmap_[x] = result;
return;
}
// find vector size // find vector size
ir::value *ptr = x->get_pointer_operand(); ir::value *ptr = x->get_pointer_operand();
auto order = layouts_->get(ptr)->get_order(); auto order = layouts_->get(ptr)->get_order();
@@ -677,6 +701,8 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
} }
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){ if(add->get_type()->is_tile_ty()){
ir::value* ptr = add->get_operand(0); ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1); ir::value* val = add->get_operand(1);
@@ -684,21 +710,36 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr); distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
distributed_tile* vals = (distributed_tile*)tmap_.at(val); distributed_tile* vals = (distributed_tile*)tmap_.at(val);
distributed_tile* msks = (distributed_tile*)tmap_.at(msk); distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
for_each(ptr, [&](indices_t idx){ for_each(ptr, [&](indices_t idx){
Value *rmw_ptr = ptrs->get_value(idx); Value *rmw_ptr = ptrs->get_value(idx);
Value *rmw_val = vals->get_value(idx); Value *rmw_val = vals->get_value(idx);
Value *rmw_msk = msks->get_value(idx); Value *rmw_msk = msks->get_value(idx);
BasicBlock *current_bb = builder_->GetInsertBlock(); // num bytes
Function *parent = builder_->GetInsertBlock()->getParent(); Type* ty = rmw_val->getType();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); size_t nbits = ty->getScalarSizeInBits();
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); // extract pointer offset
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb); std::string offset = "";
builder_->SetInsertPoint(mask_then_bb); if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val, if(gep->getNumIndices() == 1)
AtomicOrdering::Unordered, if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
SyncScope::System); offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
builder_->CreateBr(mask_done_bb); rmw_ptr = gep->getPointerOperand();
builder_->SetInsertPoint(mask_done_bb); }
rmw_ptr = builder_->CreateBitCast(rmw_ptr, ty->getPointerTo(1));
// asm argument type
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
// asm function type
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
// asm string
std::string mod = nbits == 32 ? "" : ".noftz";
std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? "f" : "h";
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
// call asm
builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val});
}); });
} }
else{ else{
@@ -803,6 +844,7 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i}; indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a); Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b); Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++) for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++)

View File

@@ -255,7 +255,6 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx(*context); cu_context::context_switcher ctx(*context);
// std::cout << source << std::endl;
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096; unsigned int errbufsize = 8096;
@@ -264,10 +263,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
try{ try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::base const &){ }catch(exception::cuda::base const &){
#ifdef TRITON_LOG_PTX_ERROR //#ifdef TRITON_LOG_PTX_ERROR
std::cerr << "Compilation Failed! Log: " << std::endl; std::cout << source << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
std::cerr << errbuf << std::endl; std::cerr << errbuf << std::endl;
#endif //#endif
throw; throw;
} }
} }

View File

@@ -231,7 +231,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
VisitExpr(condOp->exprFalse_); VisitExpr(condOp->exprFalse_);
ir::value* false_val = ret_; ir::value* false_val = ret_;
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) { if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
if(!false_val->get_type()->is_tile_ty()) if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
cond, cond,

View File

@@ -238,8 +238,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
if(allocation.allocated_size() > context->device()->max_shared_memory()) if(allocation.allocated_size() > context->device()->max_shared_memory())
throw std::runtime_error("using too much shared memory"); throw std::runtime_error("using too much shared memory");
barriers.run(module); barriers.run(module);
isel.visit(module, *llvm);
//ir::print(module, std::cout); //ir::print(module, std::cout);
isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm))); std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
return res; return res;
} }
@@ -364,6 +364,7 @@ std::string function::preheader() {
DECLARATION(float, 64, 64); DECLARATION(float, 64, 64);
DECLARATION(half , 64, 64); DECLARATION(half , 64, 64);
DECLARATION(half , 128, 128);
extern int atomic_cas(int*, int, int); extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int); extern int atomic_xchg(int*, int);

View File

@@ -9,7 +9,7 @@ class _dot(torch.autograd.Function):
float alpha, float alpha,
int M __retune, int M __retune,
int N __retune, int N __retune,
int K __retune, int K __retune __multipleof(16),
int lda __multipleof(8), int lda __multipleof(8),
int ldb __multipleof(8), int ldb __multipleof(8),
int ldc __multipleof(8)) { int ldc __multipleof(8)) {
@@ -95,11 +95,12 @@ class _dot(torch.autograd.Function):
if dtype not in _dot.kernel: if dtype not in _dot.kernel:
defines = { defines = {
'TYPE' : dtype, 'TYPE' : dtype,
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1', 'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [64, 128], 'TM' : [128],
'TN' : [64, 128], 'TN' : [128],
'TK' : [8, 16], 'TK' : [16],
'TZ' : [1] 'TZ' : [1]
} }
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
@@ -120,7 +121,7 @@ dot = _dot.apply
torch.manual_seed(0) torch.manual_seed(0)
M, N, K = 2048, 2048, 2048 M, N, K = 4096, 4096, 4096
a = torch.rand((M, K)).cuda().half() a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half() b = torch.rand((K, N)).cuda().half()
@@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half()
zc = torch.matmul(a,b) zc = torch.matmul(a,b)
zc_ = dot(a,b) zc_ = dot(a,b)
print(torch.allclose(zc, zc_)) print(torch.allclose(zc, zc_))

View File

@@ -51,11 +51,6 @@ std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt)
return id_fn_map[key]->ptx(&stream, opt); return id_fn_map[key]->ptx(&stream, opt);
} }
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() { void cleanup() {
id_grid_map.clear(); id_grid_map.clear();
id_fn_map.clear(); id_fn_map.clear();
@@ -134,7 +129,6 @@ PYBIND11_MODULE(libtriton, m) {
m.def("register_grid", &register_grid); m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid); m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn); m.def("register_fn", &register_fn);
m.def("register_cst", &register_cst);
m.def("delete_fn", &delete_fn); m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id); m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup); m.def("cleanup", &cleanup);

View File

@@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
} }
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){ if(dev_id == -1){
if(!host_stream){ if(!host_stream){
host_device.reset(new drv::host_device()); host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device)); host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context)); host_stream.reset(drv::stream::create(&*host_context));
} }
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
} }
else{ else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::context* ctx = stream.context(); triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
} }
} }

View File

@@ -63,9 +63,6 @@ class kernel:
size = sum([sizes[x] for x in arg_types]) size = sum([sizes[x] for x in arg_types])
self.tys = ''.join([codes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types])
def set_constant(self, device, name, value):
libtriton.register_cst((self.op_id, device), name, value)
def ptx(self, device, **kwargs): def ptx(self, device, **kwargs):
dev_id = device.index dev_id = device.index
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
@@ -104,4 +101,6 @@ class kernel:
pass pass
# launch # launch
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
torch.ops.triton.launch_kernel(self.op_id, device, params) names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)

View File

@@ -9,7 +9,7 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t; typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(auto ord: std::vector<std::vector<int>>{{0, 1}}) for(auto ord: std::vector<std::vector<int>>{{1, 0}})
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){ for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
std::vector<config_t> tmp = { std::vector<config_t> tmp = {
// config_t{ord, x[0], x[1], 128, 128, 128}, // config_t{ord, x[0], x[1], 128, 128, 128},
@@ -21,7 +21,7 @@ int main() {
// config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1280, 1280, 1280},
// config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 1536, 1536, 1536},
// config_t{ord, x[0], x[1], 2048, 2048, 2048}, // config_t{ord, x[0], x[1], 2048, 2048, 2048},
config_t{ord, x[0], x[1], 8192, 8192, 8192}, config_t{ord, x[0], x[1], 4096, 4096, 4096},
// config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 256, 16, 256},
// config_t{ord, x[0], x[1], 512, 16, 512}, // config_t{ord, x[0], x[1], 512, 16, 512},

View File

@@ -147,7 +147,7 @@ inline cublasGemmAlgo_t cublasGemmFastest(
M, N, K, M, N, K,
alpha, (const void*)A, cudt, lda, alpha, (const void*)A, cudt, lda,
(const void*)B, cudt, ldb, (const void*)B, cudt, ldb,
beta, (void*)C, cudt, ldc, cudt, beta, (void*)C, cudt, ldc, CUDA_R_32F,
a); }, stream); a); }, stream);
if(status != CUBLAS_STATUS_SUCCESS) if(status != CUBLAS_STATUS_SUCCESS)
nanosec = INFINITY; nanosec = INFINITY;
@@ -216,6 +216,6 @@ inline void cublasGemm(cublasDataType_t dtype,
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K, cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
alpha, (const void*)*A->cu(), dtype, lda, alpha, (const void*)*A->cu(), dtype, lda,
(const void*)*B->cu(), dtype, ldb, (const void*)*B->cu(), dtype, ldb,
beta, (void*)*C->cu(), dtype, ldc, dtype, algo); beta, (void*)*C->cu(), dtype, ldc, CUDA_R_32F, algo);
} }
} }

View File

@@ -152,16 +152,16 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
bench.push_back(tflops(triton_ns)); bench.push_back(tflops(triton_ns));
// cublas // cublas
// if(cublas::cublasinit()){ if(cublas::cublasinit()){
// T alpha(static_cast<double>(1)); T alpha(static_cast<double>(1));
// T beta(static_cast<double>(0)); T beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest; cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream); ldc, nullptr, fastest); }, stream);
// bench.push_back(tflops(cublas_ms)); bench.push_back(tflops(cublas_ms));
// } }
} }
// test triton // test triton