[ROCM] enable matmul(dot) and others (#391)
This commit is contained in:
@@ -14,7 +14,13 @@
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#ifdef USE_ROCM
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#else
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#endif
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
@@ -86,6 +92,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define f64_ty builder_->getDoubleTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
@@ -464,7 +471,7 @@ Value* generator::bf16_to_fp32(Value *in0){
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
@@ -584,6 +591,22 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
for(size_t i = 0; i <idxs.size(); i += 1){
|
||||
indices_t idx = idxs[i];
|
||||
// pointer value
|
||||
Value *ptr = vals_[op][idx];
|
||||
|
||||
// create load
|
||||
Value *_ret = builder_->CreateLoad(ty, ptr);
|
||||
|
||||
// upload to global vals map
|
||||
vals_[x][idx] = _ret;
|
||||
}
|
||||
#else
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
@@ -715,6 +738,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
@@ -733,6 +757,23 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
#ifdef USE_ROCM
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); i += 1)
|
||||
{
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
|
||||
// value
|
||||
Value *val = vals_.at(val_op)[idxs[i]];
|
||||
|
||||
// store value at pointer
|
||||
store(val, ptr);
|
||||
}
|
||||
#else
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
@@ -766,6 +807,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
else
|
||||
store(val, ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
@@ -858,7 +900,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
|
||||
#else
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
@@ -871,7 +918,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
|
||||
#else
|
||||
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -897,7 +948,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
void generator::visit_sin_inst(ir::sin_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
|
||||
#else
|
||||
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
@@ -910,7 +965,11 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
#ifdef USE_ROCM
|
||||
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
|
||||
#else
|
||||
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
||||
#endif
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
vals_[x][idx] = fmul(lg2arg, rcplog2e);
|
||||
@@ -1701,10 +1760,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
#ifdef USE_ROCM
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||
#else
|
||||
bool is_mma = false;
|
||||
#endif
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK);
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
}
|
||||
@@ -1739,8 +1802,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
#ifdef USE_ROCM
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#else
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
#endif
|
||||
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
||||
@@ -1902,8 +1971,14 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
if(arg->get_type()->get_tile_rank() == 1)
|
||||
if (arg->get_type()->get_tile_rank() == 1)
|
||||
{
|
||||
#ifdef USE_ROCM
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
#else
|
||||
visit_reduce1d_inst(x, do_acc, neutral);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
}
|
||||
@@ -2286,12 +2361,14 @@ void generator::visit_function(ir::function* fn) {
|
||||
// set metadata
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
#ifndef USE_ROCM
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(i32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
#endif
|
||||
}
|
||||
// set arguments
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
@@ -2311,6 +2388,9 @@ void generator::visit_function(ir::function* fn) {
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
finalize_function(fn);
|
||||
|
||||
// verifyFunction
|
||||
llvm::verifyFunction(*ret);
|
||||
}
|
||||
|
||||
|
||||
@@ -2334,7 +2414,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
Value *_8 = i32(8);
|
||||
Value *_16 = i32(16);
|
||||
Value *_32 = i32(32);
|
||||
#ifdef USE_ROCM
|
||||
int cc = 1; // generate ir for older CUDA cards
|
||||
#else
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
#endif
|
||||
std::vector<Value*> idx_m;
|
||||
std::vector<Value*> idx_n;
|
||||
std::vector<Value*> idx_z;
|
||||
|
Reference in New Issue
Block a user