diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index dd9f0592c..7e580251c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -362,6 +362,30 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { } void generator::visit_masked_load_inst(ir::masked_load_inst* x) { + if(!x->get_type()->is_tile_ty()){ + Value *ptr = vmap_.at(x->get_pointer_operand()); + Value *mask = vmap_.at(x->get_mask_operand()); + BasicBlock *current_bb = builder_->GetInsertBlock(); + Function *parent = builder_->GetInsertBlock()->getParent(); + BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); + BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); + builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb); + builder_->SetInsertPoint(mask_then_bb); + Value *result_then = builder_->CreateLoad(ptr); + builder_->CreateBr(mask_done_bb); + builder_->SetInsertPoint(mask_done_bb); + Value *result = nullptr; + if(x->get_false_value_operand()){ + Value *result_false = vmap_.at(x->get_false_value_operand()); + result = builder_->CreatePHI(result_then->getType(), 2); + ((PHINode*)result)->addIncoming(result_then, mask_then_bb); + ((PHINode*)result)->addIncoming(result_false, current_bb); + } + else + result = result_then; + vmap_[x] = result; + return; + } // find vector size ir::value *ptr = x->get_pointer_operand(); auto order = layouts_->get(ptr)->get_order(); @@ -677,6 +701,8 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { } void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { + + if(add->get_type()->is_tile_ty()){ ir::value* ptr = add->get_operand(0); ir::value* val = add->get_operand(1); @@ -684,21 +710,36 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr); distributed_tile* vals = (distributed_tile*)tmap_.at(val); distributed_tile* msks = (distributed_tile*)tmap_.at(msk); + for_each(ptr, [&](indices_t idx){ Value *rmw_ptr = ptrs->get_value(idx); Value *rmw_val = vals->get_value(idx); Value *rmw_msk = msks->get_value(idx); - BasicBlock *current_bb = builder_->GetInsertBlock(); - Function *parent = builder_->GetInsertBlock()->getParent(); - BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); - BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); - builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb); - builder_->SetInsertPoint(mask_then_bb); - builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val, - AtomicOrdering::Unordered, - SyncScope::System); - builder_->CreateBr(mask_done_bb); - builder_->SetInsertPoint(mask_done_bb); + // num bytes + Type* ty = rmw_val->getType(); + size_t nbits = ty->getScalarSizeInBits(); + // extract pointer offset + std::string offset = ""; + if(GetElementPtrInst *gep = dyn_cast(rmw_ptr)) + if(gep->getNumIndices() == 1) + if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ + offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); + rmw_ptr = gep->getPointerOperand(); + } + rmw_ptr = builder_->CreateBitCast(rmw_ptr, ty->getPointerTo(1)); + // asm argument type + std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; + // asm function type + FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); + // asm string + std::string mod = nbits == 32 ? "" : ".noftz"; + std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;"; + std::string ty_id = nbits == 32 ? "f" : "h"; + std::string constraint = "b,=" + ty_id + ",l," + ty_id; + // create inline asm + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); + // call asm + builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val}); }); } else{ @@ -803,6 +844,7 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i}; idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); + Value *ha = TA->get_value(idx_a); Value *hb = TB->get_value(idx_b); for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++) diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 526f93c8d..c3eeb5f27 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -255,7 +255,6 @@ cu_module::cu_module(driver::context * context, std::unique_ptr ll cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_context::context_switcher ctx(*context); - // std::cout << source << std::endl; // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; @@ -264,10 +263,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo try{ dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); }catch(exception::cuda::base const &){ -#ifdef TRITON_LOG_PTX_ERROR - std::cerr << "Compilation Failed! Log: " << std::endl; +//#ifdef TRITON_LOG_PTX_ERROR + std::cout << source << std::endl; + std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; std::cerr << errbuf << std::endl; -#endif +//#endif throw; } } diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 94be4e639..a05a7a123 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -231,7 +231,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) { VisitExpr(condOp->exprFalse_); ir::value* false_val = ret_; if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { - if(!false_val->get_type()->is_tile_ty()) + if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 2e6bcfc2c..66d8723dc 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -238,8 +238,8 @@ std::unique_ptr function::make_bin(ir::module &module, if(allocation.allocated_size() > context->device()->max_shared_memory()) throw std::runtime_error("using too much shared memory"); barriers.run(module); + //ir::print(module, std::cout); isel.visit(module, *llvm); - // ir::print(module, std::cout); std::unique_ptr res(driver::module::create(context, std::move(llvm))); return res; } @@ -364,6 +364,7 @@ std::string function::preheader() { DECLARATION(float, 64, 64); DECLARATION(half , 64, 64); +DECLARATION(half , 128, 128); extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index 6bd3a1495..4acbebb11 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -3,16 +3,16 @@ import triton class _dot(torch.autograd.Function): src = """ - __global__ void dot(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), - float alpha, - int M __retune, - int N __retune, - int K __retune, - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8)) { +__global__ void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __aligned(16), + float alpha, + int M __retune, + int N __retune, + int K __retune __multipleof(16), + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc __multipleof(8)) { // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); @@ -95,11 +95,12 @@ class _dot(torch.autograd.Function): if dtype not in _dot.kernel: defines = { 'TYPE' : dtype, + 'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN', 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', - 'TM' : [64, 128], - 'TN' : [64, 128], - 'TK' : [8, 16], + 'TM' : [128], + 'TN' : [128], + 'TK' : [16], 'TZ' : [1] } _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) @@ -120,7 +121,7 @@ dot = _dot.apply torch.manual_seed(0) -M, N, K = 2048, 2048, 2048 +M, N, K = 4096, 4096, 4096 a = torch.rand((M, K)).cuda().half() b = torch.rand((K, N)).cuda().half() @@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half() zc = torch.matmul(a,b) zc_ = dot(a,b) + print(torch.allclose(zc, zc_)) diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 0fcae9d31..040bcaa7e 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -51,11 +51,6 @@ std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) return id_fn_map[key]->ptx(&stream, opt); } -void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { - pybind11::buffer_info info = data.request(); - id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); -} - void cleanup() { id_grid_map.clear(); id_fn_map.clear(); @@ -134,7 +129,6 @@ PYBIND11_MODULE(libtriton, m) { m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); - m.def("register_cst", ®ister_cst); m.def("delete_fn", &delete_fn); m.def("make_op_id", &make_op_id); m.def("cleanup", &cleanup); diff --git a/python/src/launch.cc b/python/src/launch.cc index 995f9e1ea..3883aba46 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) { return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); } -void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ +void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, + const std::vector& constant_names, const std::vector& constant_vals){ + rt::function* fn = id_fn_map.at({op_id, dev_id}).get(); + for(size_t n = 0; n < constant_names.size(); n++){ + const torch::Tensor& x = constant_vals[n]; + fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size()); + } if(dev_id == -1){ if(!host_stream){ host_device.reset(new drv::host_device()); host_context.reset(drv::context::create(&*host_device)); host_stream.reset(drv::stream::create(&*host_context)); } - (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); + (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); } else{ triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); triton::driver::context* ctx = stream.context(); - (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); + (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); } } diff --git a/python/triton/kernel.py b/python/triton/kernel.py index e37f82340..e4f83aa41 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -63,9 +63,6 @@ class kernel: size = sum([sizes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types]) - def set_constant(self, device, name, value): - libtriton.register_cst((self.op_id, device), name, value) - def ptx(self, device, **kwargs): dev_id = device.index libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) @@ -103,5 +100,7 @@ class kernel: if 'autotune_buf' in kwargs: pass # launch - params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) - torch.ops.triton.launch_kernel(self.op_id, device, params) \ No newline at end of file + params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) + names = list(kwargs['constants'].keys()) if 'constants' in kwargs else [] + constants = list(kwargs['constants'].values()) if 'constants' in kwargs else [] + torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 7204e288b..6ec66ecff 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -9,7 +9,7 @@ int main() { // shapes to benchmark typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; - for(auto ord: std::vector>{{0, 1}}) + for(auto ord: std::vector>{{1, 0}}) for(auto x: std::vector>{{false, true}, {false, false}, {true, false}, {true, true}}){ std::vector tmp = { // config_t{ord, x[0], x[1], 128, 128, 128}, @@ -21,7 +21,7 @@ int main() { // config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 2048, 2048, 2048}, - config_t{ord, x[0], x[1], 8192, 8192, 8192}, + config_t{ord, x[0], x[1], 4096, 4096, 4096}, // config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 512, 16, 512}, diff --git a/tests/common/cuda/cublas.h b/tests/common/cuda/cublas.h index db1f2a360..1d403c413 100644 --- a/tests/common/cuda/cublas.h +++ b/tests/common/cuda/cublas.h @@ -147,7 +147,7 @@ inline cublasGemmAlgo_t cublasGemmFastest( M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, - beta, (void*)C, cudt, ldc, cudt, + beta, (void*)C, cudt, ldc, CUDA_R_32F, a); }, stream); if(status != CUBLAS_STATUS_SUCCESS) nanosec = INFINITY; @@ -216,6 +216,6 @@ inline void cublasGemm(cublasDataType_t dtype, cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K, alpha, (const void*)*A->cu(), dtype, lda, (const void*)*B->cu(), dtype, ldb, - beta, (void*)*C->cu(), dtype, ldc, dtype, algo); + beta, (void*)*C->cu(), dtype, ldc, CUDA_R_32F, algo); } } diff --git a/tests/common/dot.h b/tests/common/dot.h index d4bafa22b..6d46add14 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -152,16 +152,16 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, bench.push_back(tflops(triton_ns)); // cublas -// if(cublas::cublasinit()){ -// T alpha(static_cast(1)); -// T beta(static_cast(0)); -// cublasGemmAlgo_t fastest; -// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); -// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, -// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, -// ldc, nullptr, fastest); }, stream); -// bench.push_back(tflops(cublas_ms)); -// } + if(cublas::cublasinit()){ + T alpha(static_cast(1)); + T beta(static_cast(0)); + cublasGemmAlgo_t fastest; + cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); + double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, + &alpha, &*da, lda, &*db, ldb, &beta, &*dc, + ldc, nullptr, fastest); }, stream); + bench.push_back(tflops(cublas_ms)); + } } // test triton