|
|
|
@@ -14,7 +14,13 @@
|
|
|
|
|
#include "triton/ir/type.h"
|
|
|
|
|
#include "llvm/IR/Module.h"
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
|
|
|
#include "llvm/IR/Verifier.h"
|
|
|
|
|
#include "llvm/IR/Type.h"
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
|
|
|
|
#else
|
|
|
|
|
#include "llvm/IR/IntrinsicsNVPTX.h"
|
|
|
|
|
#endif
|
|
|
|
|
#include "llvm/IR/BasicBlock.h"
|
|
|
|
|
#include "llvm/IR/Attributes.h"
|
|
|
|
|
#include "llvm/IR/InlineAsm.h"
|
|
|
|
@@ -86,6 +92,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|
|
|
|
#define void_ty builder_->getVoidTy()
|
|
|
|
|
#define f16_ty builder_->getHalfTy()
|
|
|
|
|
#define f32_ty builder_->getFloatTy()
|
|
|
|
|
#define f64_ty builder_->getDoubleTy()
|
|
|
|
|
#define i8_ty builder_->getInt8Ty()
|
|
|
|
|
#define i32_ty builder_->getInt32Ty()
|
|
|
|
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
|
|
|
@@ -464,7 +471,7 @@ Value* generator::bf16_to_fp32(Value *in0){
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value* generator::fp32_to_bf16(Value *in0){
|
|
|
|
|
if(tgt_->as_nvidia()->sm() >= 80){
|
|
|
|
|
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
|
|
|
|
|
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
|
|
|
|
return call(ptx, {in0});
|
|
|
|
@@ -584,6 +591,22 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
ir::value *op = x->get_pointer_operand();
|
|
|
|
|
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
|
|
|
|
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
|
|
|
|
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
// code generation
|
|
|
|
|
auto idxs = idxs_.at(x);
|
|
|
|
|
for(size_t i = 0; i <idxs.size(); i += 1){
|
|
|
|
|
indices_t idx = idxs[i];
|
|
|
|
|
// pointer value
|
|
|
|
|
Value *ptr = vals_[op][idx];
|
|
|
|
|
|
|
|
|
|
// create load
|
|
|
|
|
Value *_ret = builder_->CreateLoad(ty, ptr);
|
|
|
|
|
|
|
|
|
|
// upload to global vals map
|
|
|
|
|
vals_[x][idx] = _ret;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// compute vector width
|
|
|
|
|
size_t vec = 1;
|
|
|
|
|
if(op->get_type()->is_block_ty()){
|
|
|
|
@@ -715,6 +738,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
for(size_t ii = 0; ii < vec; ii++)
|
|
|
|
|
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|
|
|
@@ -733,6 +757,23 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|
|
|
|
// operands
|
|
|
|
|
ir::value *ptr_op = x->get_pointer_operand();
|
|
|
|
|
ir::value *val_op = x->get_value_operand();
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
auto idxs = idxs_.at(val_op);
|
|
|
|
|
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < idxs.size(); i += 1)
|
|
|
|
|
{
|
|
|
|
|
auto idx = idxs[i];
|
|
|
|
|
// pointer
|
|
|
|
|
Value *ptr = vals_[ptr_op][idx];
|
|
|
|
|
|
|
|
|
|
// value
|
|
|
|
|
Value *val = vals_.at(val_op)[idxs[i]];
|
|
|
|
|
|
|
|
|
|
// store value at pointer
|
|
|
|
|
store(val, ptr);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// vector size
|
|
|
|
|
size_t vec = 1;
|
|
|
|
|
if(val_op->get_type()->is_block_ty()){
|
|
|
|
@@ -766,6 +807,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|
|
|
|
else
|
|
|
|
|
store(val, ptr);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
|
|
|
|
visit_store_inst(x);
|
|
|
|
@@ -858,7 +900,12 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|
|
|
|
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
|
|
|
|
std::vector<llvm::Type*> tys = {f32_ty};
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys);
|
|
|
|
|
#else
|
|
|
|
|
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for(auto idx: idxs_.at(x)){
|
|
|
|
|
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
|
|
|
|
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
|
|
|
@@ -871,7 +918,11 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|
|
|
|
void generator::visit_cos_inst(ir::cos_inst* x){
|
|
|
|
|
std::vector<llvm::Type*> tys = {f32_ty};
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys);
|
|
|
|
|
#else
|
|
|
|
|
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
|
|
|
|
#endif
|
|
|
|
|
for(auto idx: idxs_.at(x)){
|
|
|
|
|
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
|
|
|
|
}
|
|
|
|
@@ -897,7 +948,11 @@ void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
|
|
|
|
void generator::visit_sin_inst(ir::sin_inst* x){
|
|
|
|
|
std::vector<llvm::Type*> tys = {f32_ty};
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys);
|
|
|
|
|
#else
|
|
|
|
|
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
|
|
|
|
#endif
|
|
|
|
|
for(auto idx: idxs_.at(x)){
|
|
|
|
|
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
|
|
|
|
}
|
|
|
|
@@ -910,7 +965,11 @@ void generator::visit_log_inst(ir::log_inst* x){
|
|
|
|
|
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
|
|
|
|
|
std::vector<llvm::Type*> tys = {f32_ty};
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys);
|
|
|
|
|
#else
|
|
|
|
|
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
|
|
|
|
#endif
|
|
|
|
|
for(auto idx: idxs_.at(x)){
|
|
|
|
|
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
|
|
|
|
vals_[x][idx] = fmul(lg2arg, rcplog2e);
|
|
|
|
@@ -1701,10 +1760,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|
|
|
|
size_t red_axis = 1;
|
|
|
|
|
unsigned NK = A_shapes[red_axis];
|
|
|
|
|
bool is_outer = NK == 1;
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
bool is_mma = layouts_->get(dot)->to_mma();
|
|
|
|
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
|
|
|
|
#else
|
|
|
|
|
bool is_mma = false;
|
|
|
|
|
#endif
|
|
|
|
|
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
|
|
|
|
return visit_mma884(dot, A, B, D, NK);
|
|
|
|
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
|
|
|
|
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
|
|
|
|
return visit_mma16816(dot, A, B, D, NK);
|
|
|
|
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
|
|
|
|
}
|
|
|
|
@@ -1739,8 +1802,14 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
|
|
|
|
|
|
|
|
|
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
|
|
|
|
Type* ty = acc->getType();
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
|
|
|
|
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
|
|
|
|
#else
|
|
|
|
|
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
|
|
|
|
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if(ty->getPrimitiveSizeInBits() <= 32)
|
|
|
|
|
return call(shfl, {acc, i32(i)});
|
|
|
|
|
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
|
|
|
@@ -1902,8 +1971,14 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|
|
|
|
default: throw std::runtime_error("unreachable");
|
|
|
|
|
}
|
|
|
|
|
ir::value *arg = x->get_operand(0);
|
|
|
|
|
if(arg->get_type()->get_tile_rank() == 1)
|
|
|
|
|
if (arg->get_type()->get_tile_rank() == 1)
|
|
|
|
|
{
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
visit_reducend_inst(x, do_acc, neutral);
|
|
|
|
|
#else
|
|
|
|
|
visit_reduce1d_inst(x, do_acc, neutral);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
visit_reducend_inst(x, do_acc, neutral);
|
|
|
|
|
}
|
|
|
|
@@ -2286,12 +2361,14 @@ void generator::visit_function(ir::function* fn) {
|
|
|
|
|
// set metadata
|
|
|
|
|
if(tgt_->is_gpu()){
|
|
|
|
|
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
|
Metadata *md_args[] = {
|
|
|
|
|
ValueAsMetadata::get(ret),
|
|
|
|
|
MDString::get(ctx, "maxntidx"),
|
|
|
|
|
ValueAsMetadata::get(i32(num_warps_*32))
|
|
|
|
|
};
|
|
|
|
|
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
// set arguments
|
|
|
|
|
for(unsigned i = 0; i < fn->args().size(); i++)
|
|
|
|
@@ -2311,6 +2388,9 @@ void generator::visit_function(ir::function* fn) {
|
|
|
|
|
visit_basic_block(block);
|
|
|
|
|
// finalize
|
|
|
|
|
finalize_function(fn);
|
|
|
|
|
|
|
|
|
|
// verifyFunction
|
|
|
|
|
llvm::verifyFunction(*ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -2334,7 +2414,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
|
|
|
|
Value *_8 = i32(8);
|
|
|
|
|
Value *_16 = i32(16);
|
|
|
|
|
Value *_32 = i32(32);
|
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
|
int cc = 1; // generate ir for older CUDA cards
|
|
|
|
|
#else
|
|
|
|
|
int cc = tgt_->as_nvidia()->sm();
|
|
|
|
|
#endif
|
|
|
|
|
std::vector<Value*> idx_m;
|
|
|
|
|
std::vector<Value*> idx_n;
|
|
|
|
|
std::vector<Value*> idx_z;
|
|
|
|
|