[ROCM] enable matmul(dot) and others (#391)
This commit is contained in:
@@ -7,6 +7,8 @@ if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}")
|
||||
set(TRITON_ROCM_DEBUG "$ENV{TRITON_ROCM_DEBUG}")
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
@@ -37,7 +39,11 @@ if(WIN32)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
if (TRITON_USE_ROCM)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
endif()
|
||||
|
||||
|
||||
##########
|
||||
@@ -110,6 +116,15 @@ libLLVMBinaryFormat.a
|
||||
libLLVMAMDGPUInfo.a
|
||||
libLLVMSupport.a
|
||||
libLLVMDemangle.a
|
||||
libLLVMPasses.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMipo.a
|
||||
libLLVMObjCARCOpts.a
|
||||
libLLVMCoroutines.a
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
endif()
|
||||
include_directories("${LLVM_INCLUDE_DIRS}")
|
||||
@@ -128,6 +143,13 @@ if(BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
if (TRITON_USE_ROCM)
|
||||
add_definitions(-DUSE_ROCM)
|
||||
endif()
|
||||
if (TRITON_ROCM_DEBUG)
|
||||
add_definitions(-DDEBUG_ROCM)
|
||||
endif()
|
||||
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
@@ -11,7 +11,6 @@
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
//// HIP backend
|
||||
//#define __HIP_PLATFORM_AMD__
|
||||
#include "triton/external/hip.h"
|
||||
|
||||
//Exceptions
|
||||
|
@@ -42,6 +42,7 @@ public:
|
||||
value *get_int64(int64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_float64(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
|
@@ -145,7 +145,7 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
|
@@ -26,6 +26,7 @@ namespace codegen {
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
|
@@ -14,7 +14,13 @@
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#ifdef USE_ROCM
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#else
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#endif
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
@@ -86,6 +92,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define f64_ty builder_->getDoubleTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
@@ -464,7 +471,7 @@ Value* generator::bf16_to_fp32(Value *in0){
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
@@ -584,6 +591,22 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
for(size_t i = 0; i <idxs.size(); i += 1){
|
||||
indices_t idx = idxs[i];
|
||||
// pointer value
|
||||
Value *ptr = vals_[op][idx];
|
||||
|
||||
// create load
|
||||
Value *_ret = builder_->CreateLoad(ty, ptr);
|
||||
|
||||
// upload to global vals map
|
||||
vals_[x][idx] = _ret;
|
||||
}
|
||||
#else
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
@@ -715,6 +738,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
@@ -733,6 +757,23 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
#ifdef USE_ROCM
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); i += 1)
|
||||
{
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
|
||||
// value
|
||||
Value *val = vals_.at(val_op)[idxs[i]];
|
||||
|
||||
// store value at pointer
|
||||
store(val, ptr);
|
||||
}
|
||||
#else
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
@@ -766,6 +807,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
else
|
||||
store(val, ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
@@ -858,7 +900,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
|
||||
#else
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
@@ -871,7 +918,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
|
||||
#else
|
||||
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -897,7 +948,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
void generator::visit_sin_inst(ir::sin_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
|
||||
#else
|
||||
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -910,7 +965,11 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
|
||||
#else
|
||||
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
vals_[x][idx] = fmul(lg2arg, rcplog2e);
|
||||
@@ -1701,10 +1760,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
#ifdef USE_ROCM
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||
#else
|
||||
bool is_mma = false;
|
||||
#endif
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK);
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
}
|
||||
@@ -1739,8 +1802,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
#ifdef USE_ROCM
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#else
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#endif
|
||||
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
||||
@@ -1902,8 +1971,14 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
if(arg->get_type()->get_tile_rank() == 1)
|
||||
if (arg->get_type()->get_tile_rank() == 1)
|
||||
{
|
||||
#ifdef USE_ROCM
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#else
|
||||
visit_reduce1d_inst(x, do_acc, neutral);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
}
|
||||
@@ -2286,12 +2361,14 @@ void generator::visit_function(ir::function* fn) {
|
||||
// set metadata
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
#ifndef USE_ROCM
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(i32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
#endif
|
||||
}
|
||||
// set arguments
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
@@ -2311,6 +2388,9 @@ void generator::visit_function(ir::function* fn) {
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
finalize_function(fn);
|
||||
|
||||
// verifyFunction
|
||||
llvm::verifyFunction(*ret);
|
||||
}
|
||||
|
||||
|
||||
@@ -2334,7 +2414,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
Value *_8 = i32(8);
|
||||
Value *_16 = i32(16);
|
||||
Value *_32 = i32(32);
|
||||
#ifdef USE_ROCM
|
||||
int cc = 1; // generate ir for older CUDA cards
|
||||
#else
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
#endif
|
||||
std::vector<Value*> idx_m;
|
||||
std::vector<Value*> idx_n;
|
||||
std::vector<Value*> idx_z;
|
||||
|
@@ -41,7 +41,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
throw std::runtime_error("not implemented on CPU");
|
||||
}
|
||||
|
||||
|
||||
|
@@ -25,6 +25,7 @@
|
||||
#endif
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include <iomanip>
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/error.h"
|
||||
@@ -56,6 +57,8 @@
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
// end AMD stuff
|
||||
|
||||
namespace triton{
|
||||
@@ -264,8 +267,13 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "amdgcn-amd-amdhsa";
|
||||
std::string layout = "";
|
||||
std::string features;
|
||||
std::string features="+sramecc,-xnack";
|
||||
std::string proc = "gfx908";
|
||||
// name kernel
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::stringstream cur_time;
|
||||
cur_time << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d--%I-%M-%S");
|
||||
std::string kernel_name = module->getModuleIdentifier() + "_" + cur_time.str();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
@@ -281,7 +289,7 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::Reloc::PIC_, llvm::None,
|
||||
llvm::CodeGenOpt::Aggressive);
|
||||
llvm::CodeGenOpt::None);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
@@ -294,11 +302,10 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
|
||||
// create dump files
|
||||
std::string module_name = module->getModuleIdentifier();
|
||||
std::error_code ec;
|
||||
|
||||
// Save GCN ISA binary.
|
||||
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
|
||||
std::string isabin_path = std::string("/tmp/") + kernel_name + std::string(".o");
|
||||
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
|
||||
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
|
||||
if (ec)
|
||||
@@ -309,15 +316,25 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||
pass.run(*module);
|
||||
|
||||
#ifdef DEBUG_ROCM
|
||||
std::cout << "Generating GCN ISA file" << std::endl;
|
||||
llvm::SmallVector<char, 0> debugBuffer;
|
||||
llvm::legacy::PassManager debugPass;
|
||||
llvm::raw_svector_ostream debugStream(debugBuffer);
|
||||
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
|
||||
debugPass.run(*module);
|
||||
|
||||
// Save GCN ISA.
|
||||
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::string amdgcn_path = std::string("/tmp/") + kernel_name + std::string(".gcn");
|
||||
std::string result(debugBuffer.begin(), debugBuffer.end());
|
||||
std::ofstream amdgcn(amdgcn_path);
|
||||
amdgcn << result;
|
||||
amdgcn.close();
|
||||
#endif
|
||||
|
||||
// generate HASCO file
|
||||
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
|
||||
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
|
||||
std::string error_message;
|
||||
int lld_result =
|
||||
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
|
||||
|
@@ -60,6 +60,9 @@ value *builder::get_float16(float val)
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float64(float val)
|
||||
{ return constant_fp::get(type::get_fp64_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
|
||||
|
@@ -559,7 +559,27 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::
|
||||
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||
}
|
||||
// cast to target data-type
|
||||
#ifdef USE_ROCM
|
||||
ir::type *src_ty = val->get_type();
|
||||
ir::type *dst_ty = elt_ty;
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, src_ty->get_block_shapes());
|
||||
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
// check if truncation is need
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
|
||||
if (truncate_fp && elt_ty->is_fp16_ty())
|
||||
{
|
||||
std::cout << "WARNING: "<<"casting down to fp16 is broken on ROCM" << std::endl;
|
||||
}
|
||||
|
||||
val = dispatch::cast(val, elt_ty, builder);
|
||||
#else
|
||||
val = dispatch::cast(val, elt_ty, builder);
|
||||
#endif
|
||||
if (!mask)
|
||||
return builder->create_store(ptr, val);
|
||||
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
|
||||
@@ -659,6 +679,8 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
lhs = dispatch::cast(lhs, builder->get_float_ty(), builder);
|
||||
rhs = dispatch::cast(rhs, builder->get_float_ty(), builder);
|
||||
ir::value *_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
|
@@ -15,10 +15,11 @@ from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import urllib.request
|
||||
import tarfile
|
||||
import torch
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-13.0', '-13', '-13-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
@@ -27,7 +28,7 @@ def get_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04'
|
||||
dir = '/tmp'
|
||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||
@@ -36,7 +37,7 @@ def get_llvm():
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
ftpstream = urllib.request.urlopen(url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||
@@ -80,7 +81,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
# self.debug = True
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
@@ -90,7 +91,10 @@ class CMakeBuild(build_ext):
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
os.makedirs(llvm_build_dir)
|
||||
# python directories
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
if torch.version.hip is not None:
|
||||
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
|
||||
else:
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
cmake_args = [
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DBUILD_TUTORIALS=OFF",
|
||||
@@ -117,6 +121,9 @@ class CMakeBuild(build_ext):
|
||||
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
env["TRITON_USE_ROCM"] = "ON"
|
||||
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
@@ -45,6 +45,35 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# ---------------
|
||||
# test load and store op
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype,size", [
|
||||
(dtype, size)
|
||||
for dtype in dtypes
|
||||
for size in [128, 256, 512, 1024, 2048, 4096]
|
||||
])
|
||||
def test_load_and_store_op(dtype, size, device='cuda'):
|
||||
SIZE = size
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
tl.store(Z + off, x)
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype], device=device)
|
||||
|
||||
# output tensors
|
||||
z_ref = x.clone() # reference result
|
||||
z_tri = torch.empty_like(x) # triton result
|
||||
|
||||
# run load and store kernel
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -340,18 +369,23 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
if torch.version.hip is not None:
|
||||
assert 'bfloat' not in dtype_x
|
||||
assert 'bfloat' not in dtype_z
|
||||
|
||||
SIZE = 1024
|
||||
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X)
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X+ off)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
tl.store(Z, z)
|
||||
tl.store(Z+ off, z)
|
||||
|
||||
# triton result
|
||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
||||
z_tri = torch.empty((SIZE, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, SIZE=SIZE, BITCAST=bitcast)
|
||||
# torch result
|
||||
if bitcast:
|
||||
import numpy as np
|
||||
@@ -359,7 +393,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z_ref = torch.from_numpy(z_ref).to(device)
|
||||
else:
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
@@ -448,17 +482,23 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
z_ref = x.permute(*perm).contiguous()
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
if torch.version.hip is None:
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype, epilogue", [(dtype, epilogue)\
|
||||
for dtype in ['float16','float32'] \
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols']])
|
||||
def test_dot(dtype, epilogue, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
|
||||
torch.manual_seed(0)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -486,10 +526,10 @@ def test_dot(epilogue, device='cuda'):
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
||||
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
||||
x = triton.testing.random((M, K), dtype=dtype, device=device)
|
||||
y = triton.testing.random((K, N), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
||||
z = triton.testing.random((M, N), dtype=dtype, device=device)
|
||||
z_tri = z.clone()
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
y, y.stride(0), y.stride(1),
|
||||
@@ -508,12 +548,14 @@ def test_dot(epilogue, device='cuda'):
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# print(ptx)
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
# make sure ld/st are vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
@@ -611,17 +653,18 @@ def test_load_cache_modifier(cache):
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
@@ -647,4 +690,4 @@ def test_noop(device='cuda'):
|
||||
def kernel(**meta):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
||||
kernel[(1, )](x)
|
||||
|
Reference in New Issue
Block a user