diff --git a/CMakeLists.txt b/CMakeLists.txt index f44c35aa7..94808a586 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,8 @@ if(NOT TRITON_LLVM_BUILD_DIR) set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR}) endif() +set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}") +set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}") project(triton) include(CTest) @@ -37,7 +39,11 @@ if(WIN32) add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32) endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") +if (TRITON_USE_ROCM) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") +endif() ########## @@ -110,6 +116,15 @@ libLLVMBinaryFormat.a libLLVMAMDGPUInfo.a libLLVMSupport.a libLLVMDemangle.a +libLLVMPasses.a +libLLVMAnalysis.a +libLLVMTransformUtils.a +libLLVMScalarOpts.a +libLLVMTransformUtils.a +libLLVMipo.a +libLLVMObjCARCOpts.a +libLLVMCoroutines.a +libLLVMAnalysis.a ) endif() include_directories("${LLVM_INCLUDE_DIRS}") @@ -128,6 +143,13 @@ if(BUILD_PYTHON_MODULE) endif() include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR}) link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR}) + if (TRITON_USE_ROCM) + add_definitions(-DUSE_ROCM) + endif() + if (TRITON_ROCM_DEBUG) + add_definitions(-DDEBUG_ROCM) + endif() + set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC}) endif() diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 5503bacaf..ccecf604a 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -11,7 +11,6 @@ #include "triton/external/CUDA/nvml.h" //// HIP backend -//#define __HIP_PLATFORM_AMD__ #include "triton/external/hip.h" //Exceptions diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index a80bc471f..0b9d254e8 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -42,6 +42,7 @@ public: value *get_int64(int64_t val); value *get_float16(float val); value *get_float32(float val); + value *get_float64(float val); value *get_range(int32_t lo, int32_t hi); // Types type *get_void_ty(); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 6ea0dd219..b2f1bc1c6 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -145,7 +145,7 @@ mma_layout::mma_layout(size_t num_warps, shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { /* fragments per warp */ // try to make things as square as possible to maximize data re-use - if(tgt->as_nvidia()->sm() < 80){ + if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){ fpw_ = {2, 2, 1}; auto ord_a = layout_a->get_order(); auto ord_b = layout_b->get_order(); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 845e2e36d..ad8a1e851 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -26,6 +26,7 @@ namespace codegen { // There should be a proper pass manager there! std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, int cc, int num_warps, int num_stages, int& shared_static) { + // generate llvm code std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 9253dd319..71a585fa2 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -14,7 +14,13 @@ #include "triton/ir/type.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IR/Type.h" +#ifdef USE_ROCM +#include "llvm/IR/IntrinsicsAMDGPU.h" +#else #include "llvm/IR/IntrinsicsNVPTX.h" +#endif #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/InlineAsm.h" @@ -86,6 +92,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() #define f32_ty builder_->getFloatTy() +#define f64_ty builder_->getDoubleTy() #define i8_ty builder_->getInt8Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) @@ -464,7 +471,7 @@ Value* generator::bf16_to_fp32(Value *in0){ } Value* generator::fp32_to_bf16(Value *in0){ - if(tgt_->as_nvidia()->sm() >= 80){ + if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){ InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false), "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); return call(ptx, {in0}); @@ -584,6 +591,22 @@ void generator::visit_load_inst(ir::load_inst* x){ ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); + +#ifdef USE_ROCM + // code generation + auto idxs = idxs_.at(x); + for(size_t i = 0; i CreateLoad(ty, ptr); + + // upload to global vals map + vals_[x][idx] = _ret; + } +#else // compute vector width size_t vec = 1; if(op->get_type()->is_block_ty()){ @@ -715,6 +738,7 @@ void generator::visit_load_inst(ir::load_inst* x){ for(size_t ii = 0; ii < vec; ii++) vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp); } +#endif } void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { @@ -733,6 +757,23 @@ void generator::visit_store_inst(ir::store_inst * x){ // operands ir::value *ptr_op = x->get_pointer_operand(); ir::value *val_op = x->get_value_operand(); +#ifdef USE_ROCM + auto idxs = idxs_.at(val_op); + Type *ty = cvt(val_op->get_type()->get_scalar_ty()); + + for (size_t i = 0; i < idxs.size(); i += 1) + { + auto idx = idxs[i]; + // pointer + Value *ptr = vals_[ptr_op][idx]; + + // value + Value *val = vals_.at(val_op)[idxs[i]]; + + // store value at pointer + store(val, ptr); + } +#else // vector size size_t vec = 1; if(val_op->get_type()->is_block_ty()){ @@ -766,6 +807,7 @@ void generator::visit_store_inst(ir::store_inst * x){ else store(val, ptr); } +#endif } void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { visit_store_inst(x); @@ -858,7 +900,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){ Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); +#ifdef USE_ROCM + llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys); +#else InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); +#endif + for(auto idx: idxs_.at(x)){ Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); vals_[x][idx] = call(ex2, std::vector{ex2arg}); @@ -871,7 +918,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){ void generator::visit_cos_inst(ir::cos_inst* x){ std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); +#ifdef USE_ROCM + llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys); +#else InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false); +#endif for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); } @@ -897,7 +948,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){ void generator::visit_sin_inst(ir::sin_inst* x){ std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); +#ifdef USE_ROCM + llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys); +#else InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false); +#endif for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(sin, std::vector{vals_[x->get_operand(0)][idx]}); } @@ -910,7 +965,11 @@ void generator::visit_log_inst(ir::log_inst* x){ Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453); std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); +#ifdef USE_ROCM + llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys); +#else InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); +#endif for(auto idx: idxs_.at(x)){ Value *lg2arg = call(lg2, std::vector{vals_[x->get_operand(0)][idx]}); vals_[x][idx] = fmul(lg2arg, rcplog2e); @@ -1701,10 +1760,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; bool is_outer = NK == 1; +#ifdef USE_ROCM bool is_mma = layouts_->get(dot)->to_mma(); - if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) +#else + bool is_mma = false; +#endif + if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) return visit_mma884(dot, A, B, D, NK); - if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) + if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); } @@ -1739,8 +1802,14 @@ Value* generator::shared_off(const std::vector& shapes, const std::vec inline Value* generator::shfl_sync(Value* acc, int32_t i){ Type* ty = acc->getType(); +#ifdef USE_ROCM std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); +#else + std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; + InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); +#endif + if(ty->getPrimitiveSizeInBits() <= 32) return call(shfl, {acc, i32(i)}); acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); @@ -1902,8 +1971,14 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - if(arg->get_type()->get_tile_rank() == 1) + if (arg->get_type()->get_tile_rank() == 1) + { +#ifdef USE_ROCM + visit_reducend_inst(x, do_acc, neutral); +#else visit_reduce1d_inst(x, do_acc, neutral); +#endif + } else visit_reducend_inst(x, do_acc, neutral); } @@ -2286,12 +2361,14 @@ void generator::visit_function(ir::function* fn) { // set metadata if(tgt_->is_gpu()){ tgt_->set_kernel(*builder_, ctx, mod_, ret); + #ifndef USE_ROCM Metadata *md_args[] = { ValueAsMetadata::get(ret), MDString::get(ctx, "maxntidx"), ValueAsMetadata::get(i32(num_warps_*32)) }; mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); + #endif } // set arguments for(unsigned i = 0; i < fn->args().size(); i++) @@ -2311,6 +2388,9 @@ void generator::visit_function(ir::function* fn) { visit_basic_block(block); // finalize finalize_function(fn); + + // verifyFunction + llvm::verifyFunction(*ret); } @@ -2334,7 +2414,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { Value *_8 = i32(8); Value *_16 = i32(16); Value *_32 = i32(32); +#ifdef USE_ROCM + int cc = 1; // generate ir for older CUDA cards +#else int cc = tgt_->as_nvidia()->sm(); +#endif std::vector idx_m; std::vector idx_n; std::vector idx_z; diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc index 82ebbe649..caa8d72f9 100644 --- a/lib/codegen/target.cc +++ b/lib/codegen/target.cc @@ -41,7 +41,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un } Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) { - throw std::runtime_error("not implemented"); + throw std::runtime_error("not implemented on AMD"); } @@ -156,7 +156,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi } Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { - throw std::runtime_error("not implemented"); + throw std::runtime_error("not implemented on CPU"); } diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 0a355d106..3288bf9f1 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -25,6 +25,7 @@ #endif #include #include +#include #include "triton/driver/llvm.h" #include "triton/driver/dispatch.h" #include "triton/driver/error.h" @@ -56,6 +57,8 @@ #include "llvm/Support/ToolOutputFile.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/Intrinsics.h" // end AMD stuff namespace triton{ @@ -264,8 +267,13 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { llvm::SmallVector buffer; std::string triple = "amdgcn-amd-amdhsa"; std::string layout = ""; - std::string features; + std::string features="+sramecc,-xnack"; std::string proc = "gfx908"; + // name kernel + auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + std::stringstream cur_time; + cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S"); + std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str(); // verify and store llvm llvm::legacy::PassManager pm; pm.add(llvm::createVerifierPass()); @@ -281,7 +289,7 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { opt.NoNaNsFPMath = true; llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, llvm::None, - llvm::CodeGenOpt::Aggressive); + llvm::CodeGenOpt::None); // set data layout if(layout.empty()) module->setDataLayout(machine->createDataLayout()); @@ -294,11 +302,10 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { llvm::raw_svector_ostream stream(buffer); // create dump files - std::string module_name = module->getModuleIdentifier(); std::error_code ec; // Save GCN ISA binary. - std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o"); + std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o"); std::unique_ptr isabin_fs( new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text)); if (ec) @@ -309,15 +316,25 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { // emit machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile); pass.run(*module); + +#ifdef DEBUG_ROCM + std::cout << "Generating GCN ISA file" << std::endl; + llvm::SmallVector debugBuffer; + llvm::legacy::PassManager debugPass; + llvm::raw_svector_ostream debugStream(debugBuffer); + machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug + debugPass.run(*module); + // Save GCN ISA. - std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn"); - std::string result(buffer.begin(), buffer.end()); + std::string amdgcn_path = std::string("/tmp/") + kernel_name + std::string(".gcn"); + std::string result(debugBuffer.begin(), debugBuffer.end()); std::ofstream amdgcn(amdgcn_path); amdgcn << result; amdgcn.close(); +#endif // generate HASCO file - std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco"); + std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco"); std::string error_message; int lld_result = llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld", diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index cc1d354ee..7d6dfdfc8 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -60,6 +60,9 @@ value *builder::get_float16(float val) value *builder::get_float32(float val) { return constant_fp::get(type::get_fp32_ty(ctx_), val); } +value *builder::get_float64(float val) +{ return constant_fp::get(type::get_fp64_ty(ctx_), val); } + value *builder::get_range(int32_t _lo, int32_t _hi) { constant_int* lo = static_cast(get_int32(_lo)); constant_int* hi = static_cast(get_int32(_hi)); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 811e5c819..e1d51856e 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -559,7 +559,27 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir:: ptr = dispatch::cast(ptr, ptr_ty, builder); } // cast to target data-type +#ifdef USE_ROCM + ir::type *src_ty = val->get_type(); + ir::type *dst_ty = elt_ty; + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, src_ty->get_block_shapes()); + + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // check if truncation is need + bool truncate_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); + if (truncate_fp && elt_ty->is_fp16_ty()) + { + std::cout << "WARNING: "<<"casting down to fp16 is broken on ROCM" << std::endl; + } + val = dispatch::cast(val, elt_ty, builder); +#else + val = dispatch::cast(val, elt_ty, builder); +#endif if (!mask) return builder->create_store(ptr, val); if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) @@ -659,6 +679,8 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask //===----------------------------------------------------------------------===// ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + lhs = dispatch::cast(lhs, builder->get_float_ty(), builder); + rhs = dispatch::cast(rhs, builder->get_float_ty(), builder); ir::value *_0 = builder->get_float32(0); unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned N = rhs->get_type()->get_block_shapes()[1]; diff --git a/python/setup.py b/python/setup.py index d57ea96ed..9488a5734 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,10 +15,11 @@ from setuptools.command.test import test as TestCommand import distutils.spawn import urllib.request import tarfile +import torch def get_llvm(): # tries to find system LLVM - versions = ['-11.0', '-11', '-11-64'] + versions = ['-13.0', '-13', '-13-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] @@ -27,7 +28,7 @@ def get_llvm(): if platform.system() == "Windows": return '', '' # download if nothing is installed - name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' + name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04' dir = '/tmp' llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name) llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name) @@ -36,7 +37,7 @@ def get_llvm(): shutil.rmtree(os.path.join(dir, name)) except: pass - url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) + url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name) print('downloading and extracting ' + url + '...') ftpstream = urllib.request.urlopen(url) file = tarfile.open(fileobj=ftpstream, mode="r|xz") @@ -80,7 +81,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - # self.debug = True + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' @@ -90,7 +91,10 @@ class CMakeBuild(build_ext): if not os.path.exists(llvm_build_dir): os.makedirs(llvm_build_dir) # python directories - python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] + if torch.version.hip is not None: + python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include'] + else: + python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] cmake_args = [ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DBUILD_TUTORIALS=OFF", @@ -117,6 +121,9 @@ class CMakeBuild(build_ext): build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())] env = os.environ.copy() + + if torch.version.hip is not None: + env["TRITON_USE_ROCM"] = "ON" subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 25054f0dc..f99a632a9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -45,6 +45,35 @@ def test_empty_kernel(dtype_x, device='cuda'): x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) kernel[(1, )](x, SIZE=SIZE, num_warps=4) +# --------------- +# test load and store op +# --------------- +@pytest.mark.parametrize("dtype,size", [ + (dtype, size) + for dtype in dtypes + for size in [128, 256, 512, 1024, 2048, 4096] +]) +def test_load_and_store_op(dtype, size, device='cuda'): + SIZE = size + # define the kernel / launch-grid + @triton.jit + def kernel(Z, X, **meta): + off = tl.arange(0, meta['SIZE']) + x = tl.load(X + off) + tl.store(Z + off, x) + # inputs + x = triton.testing.random(SIZE, dtype=cvt[dtype], device=device) + + # output tensors + z_ref = x.clone() # reference result + z_tri = torch.empty_like(x) # triton result + + # run load and store kernel + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + + # compare + triton.testing.assert_almost_equal(z_ref, z_tri) + # generic test functions def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): SIZE = 128 @@ -340,18 +369,23 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): ('float32', 'int32', True) ]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) + if torch.version.hip is not None: + assert 'bfloat' not in dtype_x + assert 'bfloat' not in dtype_z + SIZE = 1024 + x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device) # triton kernel @triton.jit def kernel(X, Z, **meta): - x = tl.load(X) + off = tl.arange(0, meta['SIZE']) + x = tl.load(X+ off) z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST']) - tl.store(Z, z) + tl.store(Z+ off, z) # triton result - z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device) - kernel[(1, )](x, z_tri, BITCAST=bitcast) + z_tri = torch.empty((SIZE, ), dtype=cvt[dtype_z], device=device) + kernel[(1, )](x, z_tri, SIZE=SIZE, BITCAST=bitcast) # torch result if bitcast: import numpy as np @@ -359,7 +393,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z_ref = torch.from_numpy(z_ref).to(device) else: z_ref = x.to(z_tri.dtype) - assert z_tri == z_ref + triton.testing.assert_almost_equal(z_ref, z_tri) # --------------- # test reduce @@ -448,17 +482,23 @@ def test_permute(dtype, shape, perm, device='cuda'): z_ref = x.permute(*perm).contiguous() # compare triton.testing.assert_almost_equal(z_tri, z_ref) - # parse ptx to make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx + + if torch.version.hip is None: + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx # --------------- # test dot # --------------- -@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, device='cuda'): +@pytest.mark.parametrize("dtype, epilogue", [(dtype, epilogue)\ + for dtype in ['float16','float32'] \ + for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols']]) +def test_dot(dtype, epilogue, device='cuda'): + dtype = cvt[dtype] + torch.manual_seed(0) # triton kernel @triton.jit @@ -486,10 +526,10 @@ def test_dot(epilogue, device='cuda'): tl.store(Zs, z) # input M, N, K = 64, 64, 32 - x = triton.testing.random((M, K), dtype=torch.float16, device=device) - y = triton.testing.random((K, N), dtype=torch.float16, device=device) + x = triton.testing.random((M, K), dtype=dtype, device=device) + y = triton.testing.random((K, N), dtype=dtype, device=device) # triton result - z = triton.testing.random((M, N), dtype=torch.float16, device=device) + z = triton.testing.random((M, N), dtype=dtype, device=device) z_tri = z.clone() pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), y, y.stride(0), y.stride(1), @@ -508,12 +548,14 @@ def test_dot(epilogue, device='cuda'): z_ref += z[0,:][None, :] z_ref = z_ref.to(torch.float16) # compare - ptx = pgm.asm['ptx'] - # print(ptx) triton.testing.assert_almost_equal(z_tri, z_ref) - # make sure ld/st are vectorized - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx + + # print(ptx) + if torch.version.hip is None: + ptx = pgm.asm['ptx'] + # make sure ld/st are vectorized + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx def test_dot_without_load(): @triton.jit @@ -611,17 +653,18 @@ def test_load_cache_modifier(cache): tl.store(dst+offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) - ptx = pgm.asm['ptx'] + if torch.version.hip is None: + ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx # --------------- # test store @@ -647,4 +690,4 @@ def test_noop(device='cuda'): def kernel(**meta): pass x = triton.testing.random((1,), dtype=torch.int32, device=device) - kernel[(1, )](x) \ No newline at end of file + kernel[(1, )](x)