Files
triton/lib/codegen/selection/generator.cc
Keren Zhou db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00

4158 lines
171 KiB
C++

#include <numeric>
#include <sstream>
#include <iomanip>
#include <stdexcept>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/ir/utils.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
namespace triton{
namespace codegen{
using namespace llvm;
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
// (x + cst) + y -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
bin->getOperand(1));
}
// (x + (y + cst)) -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(y))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
bin->getOperand(1));
}
// default
return (*builder_)->CreateAdd(x, y, name);
}
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1)))
if(dyn_cast<Constant>(y)){
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
(*builder_)->CreateMul(bin->getOperand(1), y));
}
// default
return (*builder_)->CreateMul(x, y, name);
}
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(),
gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2));
}
// ptr + (off + cst) -> (ptr + off) + cst
if(auto* bin = dyn_cast<BinaryOperator>(off))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
ptr, bin->getOperand(0));
return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(),
gep, bin->getOperand(1));
}
// default
return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
ptr, off, name);
}
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
//}
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getInt16Ty()
#define f32_ty builder_->getFloatTy()
#define i1_ty builder_->getInt1Ty()
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty()
#define i64_ty builder_->getInt64Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
#define i16(...) builder_->getInt16(__VA_ARGS__)
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
#define br(...) builder_->CreateBr(__VA_ARGS__)
#define call(...) builder_->CreateCall(__VA_ARGS__)
#define cast(...) builder_->CreateCast(__VA_ARGS__)
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
#define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__)
#define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__)
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
/**
* \brief Convert Triton-IR Type to LLVM-IR Type
*/
Type *generator::cvt(ir::type *ty) {
// struct
if(ty->is_struct_ty()){
std::vector<Type*> tys;
for(size_t i = 0; i < ty->get_struct_numel(); i++)
tys.push_back(cvt(ty->get_struct_type(i)));
return StructType::get(builder_->getContext(), tys, true);
}
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *ret_ty = cvt(tt->get_return_ty());
std::vector<Type*> arg_tys(tt->get_num_params());
for(size_t i = 0; i < arg_tys.size(); i++)
arg_tys[i] = cvt(tt->get_param_ty(i));
return FunctionType::get(ret_ty, arg_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = cvt(ty->get_pointer_element_ty());
unsigned addr_space = ty->get_pointer_address_space();
return ptr_ty(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(*ctx_, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
/**
* \brief Convert Triton-IR Attribute to LLVM-IR Attribute
*/
llvm::Attribute generator::cvt(ir::attribute attr) {
switch(attr.get_kind()){
case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias);
case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly);
case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly);
case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value());
case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None);
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
}
}
/**
* \brief Constructor of LLVM code generator
*/
generator::generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
}
/**
* \brief Code Generation for `value`
*/
void generator::visit_value(ir::value* v) {
if(!seen_.insert(v).second)
return;
if(v->get_type()->is_block_ty()){
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
// offset
Value *offset = nullptr;
// base pointer
Value *ptr = shared_ptr_[layout];
if (n_buffer) {
// ptr = base (shared_ptr_[layout]) + smem_idx * size
// read_smem_idx
if (v == n_buffer->phi) {
ptr = shared_ptr_[layout];
}
// write_smem_idx
if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
int elements = write_smem_idx * layout->get_per_stage_elements();
ptr = gep(shared_pre_ptr_[layout], i32(elements));
} else if (v == n_buffer->latch) {
Value* write_smem_idx = write_smem_idx_[layout];
Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
ptr = gep(shared_pre_ptr_[layout], elements);
}
} else if (double_buffer) {
if(v == double_buffer->phi)
offset = shared_off_[layout];
if(v == double_buffer->latch)
ptr = shared_next_ptr_[layout];
else if(v == double_buffer->first)
ptr = shared_pre_ptr_[layout];
} // else do nothing
// what visit_dot & vist_cts & ... see
shmems_[v] = ptr;
// now only latches have offset (PHINode), only used by finalize_share_layout()
shoffs_[v] = offset;
}
}
// visit operands
BasicBlock *current = builder_->GetInsertBlock();
auto *inst = dynamic_cast<ir::instruction*>(v);
if(inst)
for(ir::value *op: inst->ops()){
if(dynamic_cast<ir::constant*>(op) || !dynamic_cast<ir::phi_node*>(v))
visit_value(op);
}
init_idx(v);
// change insert point for phi node
builder_->SetInsertPoint(current);
auto *phi = dynamic_cast<ir::phi_node*>(v);
if(phi && !current->empty() && current->getFirstNonPHI())
builder_->SetInsertPoint(&*current->getFirstNonPHI());
// visit user
if(auto *usr = dynamic_cast<ir::user*>(v)){
if(!dynamic_cast<ir::function*>(usr))
usr->accept(this);
}
// revert insert point
if(phi && !current->empty() && current->getFirstNonPHI())
builder_->SetInsertPoint(current);
}
/**
* \brief Code Generation for `phi`
*/
void generator::visit_phi_node(ir::phi_node* x) {
Type *ty = cvt(x->get_type()->get_scalar_ty());
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = phi(ty, x->get_num_operands());
}
/**
* \brief Code Generation for `call`
*/
void generator::visit_call_inst(ir::call_inst* call) {
throw std::runtime_error("call not supported! Triton should be inlining everything.");
}
void generator::visit_launch_inst(ir::launch_inst *launch) {
ir::function* fn = (ir::function*)launch->get_operand(0);
// forward-declare cudaGetParameterBufferV2
std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
ArrayType::get(builder_->getInt32Ty(), 3),
ArrayType::get(builder_->getInt32Ty(), 3),
builder_->getInt32Ty()};
FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
ConstantInt* _0 = builder_->getInt32(0);
ConstantInt* _1 = builder_->getInt32(1);
ConstantInt* _2 = builder_->getInt32(2);
// create basic block
BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
builder_->SetInsertPoint(launch_bb);
//
builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
Function* called_fn = fns_[fn];
Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
// forwrd-declare cudaLaunchDeviceV2
std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
// TODO: add branch
Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
builder_->getInt64(0));
BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
builder_->SetInsertPoint(launch2_bb);
unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
unsigned off = 0;
unsigned last_size = 0;
for(ir::value* arg: launch->get_values()){
Value* curr_arg = vals_[arg][{}];
Type* curr_arg_ty = curr_arg->getType();
// handle struct alignment
off += last_size;
unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
off = (off + size - 1) / size * size;
// get pointer to current arg
Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
// store arg
builder_->CreateStore(curr_arg, curr_arg_ptr);
last_size = size;
}
builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
builder_->CreateBr(launch_done_bb);
// done
builder_->SetInsertPoint(launch_done_bb);
}
/**
* \brief Code Generation for `binary_operator`
*/
void generator::visit_binary_operator(ir::binary_operator*x) {
using ll = llvm::Instruction::BinaryOps;
using tt = ir::binary_op_t;
auto cvt = [](ir::binary_op_t op){
switch(op) {
case tt::Add: return ll::Add;
case tt::FAdd: return ll::FAdd;
case tt::Sub: return ll::Sub;
case tt::FSub: return ll::FSub;
case tt::Mul: return ll::Mul;
case tt::FMul: return ll::FMul;
case tt::UDiv: return ll::UDiv;
case tt::SDiv: return ll::SDiv;
case tt::FDiv: return ll::FDiv;
case tt::URem: return ll::URem;
case tt::SRem: return ll::SRem;
case tt::FRem: return ll::FRem;
case tt::Shl: return ll::Shl;
case tt::LShr: return ll::LShr;
case tt::AShr: return ll::AShr;
case tt::And: return ll::And;
case tt::Or: return ll::Or;
case tt::Xor: return ll::Xor;
default: throw std::runtime_error("unreachable switch");
}
};
// x->print(std::cout);
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
// manually select bf16 bin op
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
InlineAsm *bf16_add_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"{ .reg .b16 c; \n\t"
" mov.b16 c, 0x3f80U; \n\t" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
InlineAsm *bf16_sub_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0xbf80U; \n\t" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
InlineAsm *bf16_mul_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0x8000U; \n\t" // 0.0
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
} else
throw std::runtime_error("invalid bin op for bf16");
} else { // not bf16
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
}
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
}
}
/**
* \brief Code Generation for `getelementptr`
*/
void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
for(indices_t idx: idxs_.at(x)){
Value *ptr = vals_[x->get_pointer_operand()][idx];
std::vector<Value*> vals;
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
vals.push_back(vals_[*it][idx]);
assert(vals.size() == 1);
vals_[x][idx] = gep(ptr, vals[0]);
}
}
/**
* \brief Code Generation for `icmp`
*/
void generator::visit_icmp_inst(ir::icmp_inst* x) {
auto cvt = [](ir::cmp_pred_t pred) {
using ll = llvm::CmpInst::Predicate;
using tt = ir::cmp_pred_t;
switch(pred){
case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE;
case tt::ICMP_EQ: return ll::ICMP_EQ;
case tt::ICMP_NE: return ll::ICMP_NE;
case tt::ICMP_UGT: return ll::ICMP_UGT;
case tt::ICMP_UGE: return ll::ICMP_UGE;
case tt::ICMP_ULT: return ll::ICMP_ULT;
case tt::ICMP_ULE: return ll::ICMP_ULE;
case tt::ICMP_SGT: return ll::ICMP_SGT;
case tt::ICMP_SGE: return ll::ICMP_SGE;
case tt::ICMP_SLT: return ll::ICMP_SLT;
case tt::ICMP_SLE: return ll::ICMP_SLE;
case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE;
default: throw std::runtime_error("unreachable switch");
}
};
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs);
}
}
/**
* \brief Code Generation for `fcmp`
*/
void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
auto cvt = [](ir::cmp_pred_t pred) {
using ll = llvm::CmpInst::Predicate;
using tt = ir::cmp_pred_t;
switch(pred){
case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE;
case tt::FCMP_FALSE: return ll::FCMP_FALSE;
case tt::FCMP_OEQ: return ll::FCMP_OEQ;
case tt::FCMP_OGT: return ll::FCMP_OGT;
case tt::FCMP_OGE: return ll::FCMP_OGE;
case tt::FCMP_OLT: return ll::FCMP_OLT;
case tt::FCMP_OLE: return ll::FCMP_OLE;
case tt::FCMP_ONE: return ll::FCMP_ONE;
case tt::FCMP_ORD: return ll::FCMP_ORD;
case tt::FCMP_UNO: return ll::FCMP_UNO;
case tt::FCMP_UEQ: return ll::FCMP_UEQ;
case tt::FCMP_UGT: return ll::FCMP_UGT;
case tt::FCMP_UGE: return ll::FCMP_UGE;
case tt::FCMP_ULT: return ll::FCMP_ULT;
case tt::FCMP_ULE: return ll::FCMP_ULE;
case tt::FCMP_UNE: return ll::FCMP_UNE;
case tt::FCMP_TRUE: return ll::FCMP_TRUE;
case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE;
default: throw std::runtime_error("unreachable switch");
}
};
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs);
}
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty);
in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty);
in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty);
in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty);
Value *ret0, *ret1, *ret2, *ret3;
std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
Value *ret0, *ret1, *ret2, *ret3;
std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
"{"
".reg .b32 a<2>, b<2>; \n\t"
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
"shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 position)
"shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position)
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign)
"}", "=r,=r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
Value *in = bit_cast(packed_in, i32_ty);
Value *ret = call(ptx, {in});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
/* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa)
* fp8 bit representation is seeeemmm
* The 4 fp8 exponent bits are the low order 4 exponent bits in fp16.
* The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16.
* Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous.
* We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's
* one more than the number of mantissa bits in fp8).
* fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff)
*
* We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the
* other. To avoid this we zero out the most significant exponent bit. If that bit is set then
* the value isn't representable in float8 anyway so we assume it's never set (and give garbage
* output if it is). If we were willing to assume the most significant exponent was never set
* we could save the first two lop3.b32 instructions below.
*/
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
"{"
".reg .b32 a<2>, b<2>; \n\t"
"shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1
"shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff)
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff)
"add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1
"prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02
"}", "=r,r,r", false);
Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2));
Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2));
packed_in0 = insert_elt(packed_in0, in0, (int)0);
packed_in0 = insert_elt(packed_in0, in1, (int)1);
packed_in1 = insert_elt(packed_in1, in2, (int)0);
packed_in1 = insert_elt(packed_in1, in3, (int)1);
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
Value *ret = call(ptx, {in_arg0, in_arg1});
Value *ret0 = extract_elt(ret, (int)0);
Value *ret1 = extract_elt(ret, (int)1);
Value *ret2 = extract_elt(ret, (int)2);
Value *ret3 = extract_elt(ret, (int)3);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) {
// current exp offset: 15
// Add 112 (127-15) to compensate the difference in exponent bias
// bf16 = (nosign >> (8-4) + 112 << 7) | sign;
// bf16 = (nosign >> 4 + 0x3800) | sign;
Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
"{"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t"
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0
"and.b32 sign0, a0, 0x80008000; \n\t"
"and.b32 sign1, a1, 0x80008000; \n\t"
"and.b32 nosign0, a0, 0x7fff7fff; \n\t"
"and.b32 nosign1, a1, 0x7fff7fff; \n\t"
"shr.b32 nosign0, nosign0, 4; \n\t"
"shr.b32 nosign1, nosign1, 4; \n\t"
"add.u32 nosign0, nosign0, 0x38003800; \n\t"
"add.u32 nosign1, nosign1, 0x38003800; \n\t"
"or.b32 $0, sign0, nosign0; \n\t"
"or.b32 $1, sign1, nosign1; \n\t"
"}", "=r,=r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
Value *in = bit_cast(packed_in, i32_ty);
Value *ret = call(ptx, {in});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
/* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127.
Max value in fp8: 0b01111111 (0x7f),
bf16: 3ff0
Min value in fp8: 0b00000000 (0x00)
bf16: 0x3c00
// @note: +0x8 is for "rounding to nearest zero"
fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4;
return fp8 | sign; // also permute bytes
*/
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
"{\n\t"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t"
"mov.u32 fp8_min, 0x38003800; \n\t"
"mov.u32 fp8_max, 0x3ff03ff0; \n\t"
"mov.u32 rn_, 0x80008; \n\t"
"mov.u32 zero, 0; \n\t"
"and.b32 sign0, $1, 0x80008000; \n\t"
"and.b32 sign1, $2, 0x80008000; \n\t"
"prmt.b32 sign, sign0, sign1, 0x7531; \n\t"
"and.b32 nosign0, $1, 0x7fff7fff; \n\t"
"and.b32 nosign1, $2, 0x7fff7fff; \n\t"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max)
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t"
"add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero
"add.u32 nosign1, nosign1, rn_; \n\t"
"sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset
"sub.u32 nosign1, nosign1, 0x38003800; \n\t"
"shr.u32 nosign0, nosign0, 4; \n\t"
"shr.u32 nosign1, nosign1, 4; \n\t"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t"
"or.b32 $0, nosign, sign; \n\t"
""
"}", "=r,r,r", false);
Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
packed_in0 = insert_elt(packed_in0, in0, (int)0);
packed_in0 = insert_elt(packed_in0, in1, (int)1);
packed_in1 = insert_elt(packed_in1, in2, (int)0);
packed_in1 = insert_elt(packed_in1, in3, (int)1);
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
Value *ret = call(ptx, {in_arg0, in_arg1});
Value *ret0 = extract_elt(ret, (int)0);
Value *ret1 = extract_elt(ret, (int)1);
Value *ret2 = extract_elt(ret, (int)2);
Value *ret3 = extract_elt(ret, (int)3);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
Value* generator::bf16_to_fp32(Value *in0){
if (tgt_->as_nvidia()->sm() >= 80) {
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
"cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
return call(ptx, {in0});
} else {
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
return bit_cast(ret, f32_ty);
}
}
Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0});
}
return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
}
/**
* \brief Code Generation for `cast`
*/
void generator::visit_cast_inst(ir::cast_inst* x) {
ir::value *op = x->get_operand(0);
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
auto x_idxs = idxs_.at(x);
auto op_idxs = idxs_.at(op);
// <> FP8
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
// ensure that conversions can be vectorized
int ld = layouts_->get(x)->get_order(0);
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
if(contiguous % 4 != 0)
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
// run the conversion
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty())
return fp32x4_to_fp8x4(a, b, c, d);
if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty())
return fp16x4_to_fp8x4(a, b, c, d);
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty())
return fp8x4_to_fp16x4(a, b, c, d);
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
return fp8x4_to_fp32x4(a, b, c, d);
// fp8 <> bf16
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty())
return fp8x4_to_bf16x4(a, b, c, d);
if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty())
return bf16x4_to_fp8x4(a, b, c, d);
throw std::runtime_error("unsupported conversion");
};
for(size_t i = 0; i < x_idxs.size(); i+=4){
std::tie(vals_[x][x_idxs[i+0]],
vals_[x][x_idxs[i+1]],
vals_[x][x_idxs[i+2]],
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
vals_[op][op_idxs[i+1]],
vals_[op][op_idxs[i+2]],
vals_[op][op_idxs[i+3]]);
}
return;
}
// <> BF16
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16
if(op_sca_ty->is_fp32_ty()){
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
return;
}
// BF16 -> FP32
if(ret_sca_ty->is_fp32_ty()){
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
return;
}
}
Type *ty = cvt(x->get_type()->get_scalar_ty());
auto cvt = [](ir::cast_op_t op){
using ll = llvm::Instruction::CastOps;
using tt = ir::cast_op_t;
switch(op){
case tt::Trunc: return ll::Trunc;
case tt::ZExt: return ll::ZExt;
case tt::SExt: return ll::SExt;
case tt::FPTrunc: return ll::FPTrunc;
case tt::FPExt: return ll::FPExt;
case tt::UIToFP: return ll::UIToFP;
case tt::SIToFP: return ll::SIToFP;
case tt::FPToUI: return ll::FPToUI;
case tt::FPToSI: return ll::FPToSI;
case tt::PtrToInt: return ll::PtrToInt;
case tt::IntToPtr: return ll::IntToPtr;
case tt::BitCast: return ll::BitCast;
case tt::AddrSpaceCast: return ll::AddrSpaceCast;
default: throw std::runtime_error("unreachable switch");
}
};
for(indices_t idx: idxs_.at(x)){
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = cast(cvt(x->get_op()), arg, ty);
}
}
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
){
/* unpacking 8 int2s packed into an int16 to 8 float16s
* the algorithm is similar to
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563
*/
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
"{"
".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits.
"and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h
"and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00
"mov.b32 a0, 0x78007800; \n\t" // a0 = 32768
"mov.b32 a1, 0x70007000; \n\t" // a1 = 8192
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192.
"mov.b32 a0, 0x68006800; \n\t" // a0 = 2048
"mov.b32 a1, 0x60006000; \n\t" // a1 = 512
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048.
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512.
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift.
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift.
"}", "=r,=r,=r,=r,r,r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2));
packed_in = insert_elt(packed_in, in0, (int)0);
packed_in = insert_elt(packed_in, in0, (int)1);
Value *in = bit_cast(packed_in, i32_ty);
Value *ret = call(ptx, {in, scale_x512, shift});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *packed_ret2 = extract_val(ret, {2});
Value *packed_ret3 = extract_val(ret, {3});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f
Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e
Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d
Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
}
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
){
/* unpacking 8 int4s packed into an int32 to 8 float16s
* the algorithm is similar to
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619
*/
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
"{"
".reg .b32 a<2>, b<4>; \n\t"
"and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h
"and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0
"mov.b32 a0, 0x78007800; \n\t"
"mov.b32 a1, 0x68006800; \n\t"
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048.
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768.
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048.
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift.
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift.
"}", "=r,=r,=r,=r,r,r,r", false);
Value *ret = call(ptx, {in0, scale_x512, shift});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *packed_ret2 = extract_val(ret, {2});
Value *packed_ret3 = extract_val(ret, {3});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e
Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d
Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
}
std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){
/* unpacking 4 int8s packed into an int32 to 4 fp16s
* the algorithm is similar to
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646
*/
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
"{"
".reg .b32 a, b<2>; \n\t"
"prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh
"prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd
"mov.b32 a, 0x78007800; \n\t"
"mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768.
"mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768.
"fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift.
"fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift.
"}", "=r,=r,r,r,r", false);
Value *ret = call(ptx, {in0, scale_x512, shift});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){
Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty));
Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2));
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0);
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1);
p_scale_x512 = bit_cast(p_scale_x512, i32_ty);
Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2));
p_shift = insert_elt(p_shift, shift, (int)0);
p_shift = insert_elt(p_shift, shift, (int)1);
p_shift = bit_cast(p_shift, i32_ty);
return std::make_tuple(p_scale_x512, p_shift);
}
/**
* \brief Code Generation for `dequantize`
*/
void generator::visit_dequantize_inst(ir::dequantize_inst* x) {
ir::value *op = x->get_operand(0);
auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits();
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto x_idxs = idxs_.at(x);
auto op_idxs = idxs_.at(op);
ir::value *scale = x->get_operand(1);
ir::value *shift = x->get_operand(2);
Value *p_scale_x512, *p_shift;
std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]);
int ld = layouts_->get(x)->get_order(0);
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
int op_ld = layouts_->get(op)->get_order(0);
int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld);
std::string err_msg;
err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): "
+ std::to_string(x_idxs.size()) + "; op_idxs.size(): "
+ std::to_string(op_idxs.size()) + "; contiguous: "
+ std::to_string(contiguous) + "; op_contiguous: "
+ std::to_string(op_contiguous) + ". if the condition "
"is not met, please try adjusting block_size, num_warps or "
"using tl.multiple_of to hint the input/output ptr address.";
if (ret_last_dim == 8 * op_last_dim) {
if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) {
throw std::runtime_error(err_msg);
}
auto cvt = [&](
Value* a, Value* scale, Value* shift
){
if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s
return int16_to_float16x8(a, scale, shift);
} else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s
return int32_to_float16x8(a, scale, shift);
} else {
throw std::runtime_error("unsupported conversion");
}
};
for(size_t j = 0; j < op_idxs.size(); j++){
size_t i = j * 8;
std::tie(vals_[x][x_idxs[i+0]],
vals_[x][x_idxs[i+1]],
vals_[x][x_idxs[i+2]],
vals_[x][x_idxs[i+3]],
vals_[x][x_idxs[i+4]],
vals_[x][x_idxs[i+5]],
vals_[x][x_idxs[i+6]],
vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
}
} else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s
if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) {
throw std::runtime_error(err_msg);
}
auto cvt = [&](Value* a, Value* scale, Value* shift){
return int32_to_float16x4(a, scale, shift);
};
for(size_t j = 0; j < op_idxs.size(); j++){
size_t i = j * 4;
std::tie(vals_[x][x_idxs[i+0]],
vals_[x][x_idxs[i+1]],
vals_[x][x_idxs[i+2]],
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
}
} else {
throw std::runtime_error("unsupported dequantization");
}
return;
}
/**
* \brief Code Generation for `return`
*/
void generator::visit_return_inst(ir::return_inst* rr) {
ir::value *ret_val = rr->get_return_value();
ret(ret_val ? vals_[ret_val][{}] : nullptr);
}
/**
* \brief Code Generation for `cond_branch`
*/
void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) {
BasicBlock *true_dest = bbs_.at(br->get_true_dest());
BasicBlock *false_dest = bbs_.at(br->get_false_dest());
Value *cond = vals_[br->get_cond()][{}];
cond_br(cond, true_dest, false_dest);
}
/**
* \brief Code Generation for `uncond_branch`
*/
void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
BasicBlock *dest = bbs_.at(br->get_dest());
br(dest);
}
/**
* \brief Code Generation for a (synchronous) `load`
*/
void generator::visit_load_inst(ir::load_inst* x){
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *lane = urem(tid, i32(32));
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
// compute vector width
size_t vec = 1;
bool is_mma_first_row = false;
if(op->get_type()->is_block_ty()){
auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]);
if(mx){
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
assert(layout);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
// code generation
auto idxs = idxs_.at(x);
for(size_t i = 0; i < idxs.size(); i += vec){
indices_t idx = idxs[i];
// pointer value
Value *ptr = vals_[op][idx];
// masked load
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
size_t in_off;
if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr = cst ? in_gep->getPointerOperand() : in_gep;
}
else{
in_off = 0;
}
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
// if(!op->get_type()->is_block_ty()){
// pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0)));
// }
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
has_l2_evict_policy = false;
// has_evict_policy = false; // currently disable until supported in `store`
// -----
// create inline asm string
// -----
std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate
asm_oss << " ld";
if(x->get_is_volatile())
asm_oss << ".volatile";
asm_oss << ".global";
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
asm_oss << " {";
for(int i = 0; i < n_words; i++){ // return values
if(i > 0) asm_oss << ",";
asm_oss << "$" << i;
}
asm_oss << "}";
asm_oss << ", [ $" << n_words + 1; // load
asm_oss << " + " << in_off << "]"; // constant offset
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
bool has_other = other && (other != UndefValue::get(other->getType()));
std::vector<Value *> others;
// handle `other` values for indices where the mask
// is false
if(has_other)
for(size_t ii = 0; ii < n_words; ii++){
size_t size = width / nbits;
Value *v = UndefValue::get(vec_ty(ty, size));
for(size_t s = 0; s < size; s++){
ir::value *false_val = mx->get_false_value_operand();
v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
}
v = bit_cast(v, IntegerType::get(*ctx_, width));
// PTX doesn't support mov.u8, so we need to use mov.u16
auto mov_width = width < 16 ? 16 : width;
asm_oss << "\n ";
asm_oss << "@!$" << n_words << " mov.u" << mov_width;
asm_oss << " $" << ii << ", ";
std::ios_base::fmtflags flags(asm_oss.flags());
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
asm_oss << "0x" << std::hex << cst->getSExtValue();
else{
asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
others.push_back(v);
}
asm_oss.flags(flags);
asm_oss << ";";
}
// ----
// create inline ASM signature
// ---
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
// ret_ty->print(llvm::outs());
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
// ---
// create inline ASM constraints
// ---
std::string asm_cstrt;
for(int ii = 0; ii < n_words; ii++){
if(ii > 0) asm_cstrt += ",";
asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
}
asm_cstrt += ",b,l";
for(size_t ii = 0; ii < others.size(); ii++){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
// ---
InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, ptr};
for(Value *v: others)
args.push_back(v);
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
Value *_ret = call(inlineAsm, args);
// if(!op->get_type()->is_block_ty()){
// Value* cond = icmp_eq(tid, i32(0));
// Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3));
// Instruction* bar = add_barrier();
// Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false);
// builder_->SetInsertPoint(term);
// store(_ret, shptr);
// builder_->SetInsertPoint(bar->getParent());
// _ret = load(shptr);
// add_barrier();
// }
// ---
// extract and store return values
// ---
std::vector<Value *> rets;
for(unsigned int ii = 0; ii < n_words; ii++){
Value *curr;
if(ret_ty->isStructTy())
curr = extract_val(_ret, {ii});
else
curr = _ret;
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
}
int tmp = (width / (dtsize * 8));
for(size_t ii = 0; ii < vec; ii++)
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
}
}
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
visit_load_inst(x);
}
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
visit_load_inst(x);
}
/**
* \brief Code Generation for a (synchronous) `store`
*/
void generator::visit_store_inst(ir::store_inst * x){
ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x);
// operands
ir::value *ptr_op = x->get_pointer_operand();
ir::value *val_op = x->get_value_operand();
ir::value *msk_op = nullptr;
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
msk_op = msk_st->get_mask_operand();
// vector size
size_t vec = 1;
if(val_op->get_type()->is_block_ty()){
auto ord = ords_.at(x->get_pointer_operand());
size_t aln = alignment_->get(ptr_op, ord[0]);
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
if(mx){
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
assert(layout);
// vec = std::min(nts, aln);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
has_l2_evict_policy = false;
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if(ty->isIntegerTy(1))
ty = builder_->getInt8Ty();
for(size_t i = 0; i < idxs.size(); i += vec){
indices_t idx = idxs[i];
// pointers
Value *ptr = vals_[ptr_op][idx];
size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
size_t in_off;
if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr = cst ? in_gep->getPointerOperand() : in_gep;
}
else{
in_off = 0;
}
// mask
Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// -----
// create inline asm string
// -----
std::ostringstream asm_oss;
asm_oss << "@$0"; // predicate
asm_oss << " st.global";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
asm_oss << " [ $1 + " << in_off << "]";
asm_oss << " , {";
for(int i = 0; i < n_words; i++){ // return values
if(i > 0) asm_oss << ",";
asm_oss << "$" << 2 + i;
}
asm_oss << "}";
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
// ----
// create inline ASM signature
// ---
Type* val_arg_ty = IntegerType::get(*ctx_, width);
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(int ii = 0; ii < n_words; ii++)
arg_tys.push_back(val_arg_ty);
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
// ---
// create inline ASM constraints
// ---
std::string asm_cstrt = "b,l";
for(int ii = 0; ii < n_words; ii++){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
// ---
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, ptr};
for(unsigned int ii = 0; ii < n_words; ii++){
size_t n_subw = width / nbits;
Value* curr = UndefValue::get(vec_ty(ty, n_subw));
for(unsigned int jj = 0; jj < n_subw; jj++){
Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
if(new_elt->getType()->isIntegerTy(1))
new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
new_elt = bit_cast(new_elt, ty);
curr = builder_->CreateInsertElement(curr, new_elt, jj);
}
args.push_back(bit_cast(curr, val_arg_ty));
}
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
call(_asm, args);
}
}
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
visit_store_inst(x);
}
void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
visit_store_inst(x);
}
// --
void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
auto idxs = idxs_.at(x);
ir::value* agg = x->get_operand(0);
unsigned insert_idx = x->get_idx();
for(size_t i = 0; i < idxs.size(); i++){
auto idx = idxs[i];
vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
}
}
void generator::visit_insert_value_inst(ir::insert_value_inst *x){
auto idxs = idxs_.at(x);
ir::value* agg = x->get_operand(0);
ir::value* val = x->get_operand(1);
unsigned insert_idx = x->get_idx();
for(size_t i = 0; i < idxs.size(); i++){
auto idx = idxs[i];
vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
}
}
// --
/**
* \brief Code Generation for `cat`
*/
void generator::visit_cat_inst(ir::cat_inst* x) {
auto idxs = idxs_.at(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
int i = 0;
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
}
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
}
}
/**
* \brief Code Generation for `reshape`
*/
void generator::visit_reshape_inst(ir::reshape_inst* x) {
auto idxs = idxs_.at(x);
for(size_t i = 0; i < idxs_.at(x).size(); i ++){
ir::value* op = x->get_operand(0);
vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
};
}
/**
* \brief Code Generation for `splat`
*/
void generator::visit_splat_inst(ir::splat_inst* x) {
for(auto idx: idxs_.at(x))
vals_[x][idx] = vals_[x->get_operand(0)][{}];
}
/**
* \brief Code Generation for `broadcast`
*/
void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
ir::value* op = x->get_operand(0);
const auto& shape = op->get_type()->get_block_shapes();
for(auto out_idx: idxs_.at(x)){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++)
in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
vals_[x][out_idx] = vals_[op][in_idx];
}
// for(size_t i = 0; i < idxs_.at(x).size(); i++)
// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
}
/**
* \brief Code Generation for `downcast`
*/
void generator::visit_downcast_inst(ir::downcast_inst* x) {
vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}];
}
/**
* \brief Code Generation for `get_program_id`
*/
void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
Module *module = builder_->GetInsertBlock()->getModule();
Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis());
vals_[pid][{}] = ret;
}
/**
* \brief Code Generation for `get_num_programs`
*/
void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
Module *module = builder_->GetInsertBlock()->getModule();
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
vals_[np][{}] = ret;
}
/**
* \brief Code Generation for `exp`
*/
void generator::visit_exp_inst(ir::exp_inst* x){
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
// Value *ex2arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
}
}
/**
* \brief Code Generation for `cos`
*/
void generator::visit_cos_inst(ir::cos_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
}
/**
* \brief Code Generation for `umulhi`
*/
void generator::visit_umulhi_inst(ir::umulhi_inst* x){
std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
for(auto idx: idxs_.at(x)){
Value* lhs = vals_[x->get_operand(0)][idx];
Value* rhs = vals_[x->get_operand(1)][idx];
vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
}
}
/**
* \brief Code Generation for `sin`
*/
void generator::visit_sin_inst(ir::sin_inst* x){
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
}
/**
* \brief Code Generation for `log`
*/
void generator::visit_log_inst(ir::log_inst* x){
Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
for(auto idx: idxs_.at(x)){
Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
vals_[x][idx] = fmul(lg2arg, rcplog2e);
}
}
/**
* \brief Code Generation for `atomic_cas`
*/
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = icmp_eq(tid, i32(0));
// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
add_barrier();
tgt_->add_memfence(module, *builder_);
Value *atom_ptr;
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3));
// cond_br(pred, tid_0_bb, tid_0_done_bb);
// builder_->SetInsertPoint(tid_0_bb);
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
Value *cas_val = vals_[cas->get_operand(2)][{}];
std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;";
FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false);
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true);
add_barrier();
Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val});
add_barrier();
std::string asm2_str = "@$0 st.shared.b32 [$1], $2;";
FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false);
InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true);
add_barrier();
call(iasm2, {pred, atom_ptr, old});
tgt_->add_memfence(module, *builder_);
add_barrier();
vals_[cas][{}] = load(atom_ptr);
add_barrier();
}
/**
* \brief Code Generation for `atomic_rmw`
*/
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
ir::value* ptr = atom->get_operand(0);
ir::value* val = atom->get_operand(1);
ir::value* msk = atom->get_operand(2);
// vector size
int vec = 1;
Value *mask = builder_->getInt1(true);
if(atom->get_type()->is_block_ty()){
auto shape = atom->get_type()->get_block_shapes();
int ld = ords_.at(ptr)[0];
unsigned alignment = alignment_->get(ptr, ld);
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
// mask out inactive threads
analysis::data_layout* layout = layouts_->get(val);
auto curr_axes = a_axes_->get(val);
auto layt_axes = layout->get_axes();
for(unsigned k = 0; k < layt_axes.size(); k++){
unsigned ax = layt_axes.at(k);
distributed_axis dax = axes_.at(ax);
// axis is part of the original layout: thread id should be 0
// but not the current layout
if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end())
mask = and_(mask, icmp_eq(dax.thread_id, i32(0)));
}
// last axis may spillover
Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0);
int per_thread = 1;
for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; }
int numel = 1;
for(int s: layout->get_shape()) { numel *= s; }
mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel)));
}
for(int i = 0; i < idxs_.at(val).size(); i += vec){
auto idx = idxs_[val][i];
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
for(int ii = 0; ii < vec; ii++)
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
Value *rmw_ptr = vals_[ptr][idx];
Value *rmw_msk = vals_[msk][idx];
rmw_msk = and_(rmw_msk, mask);
if(vec == 1)
rmw_val = extract_elt(rmw_val, i32(0));
Type* ty = rmw_val->getType();
size_t nbits = ty->getScalarSizeInBits();
// extract pointer offset
std::string offset = "";
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
if(gep->getNumIndices() == 1)
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
rmw_ptr = gep->getPointerOperand();
}
rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
// asm argument type
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
// asm function type
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
// asm string
std::string s_nbits = std::to_string(nbits);
std::string name;
std::string s_ty;
using tt = ir::atomic_rmw_op_t;
switch(atom->get_op()){
case tt::Or: name = "or"; s_ty = "b"; break;
case tt::And: name = "and"; s_ty = "b"; break;
case tt::Xor: name = "xor", s_ty = "b"; break;
case tt::Add: name = "add" , s_ty = "s"; break;
case tt::Min: name = "min", s_ty = "s"; break;
case tt::Max: name = "max", s_ty = "s"; break;
case tt::UMin: name = "min", s_ty = "u"; break;
case tt::UMax: name = "max", s_ty = "u"; break;
case tt::FAdd: name = "add", s_ty = "f"; break;
case tt::Xchg: name = "exch", s_ty = "b"; break;
}
std::string s_vec = vec == 2 ? "x2" : "";
std::string mod = nbits == 16 ? ".noftz" : "";
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
// call asm
if(atom->get_type()->is_block_ty())
vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
else{
Module *mod = builder_->GetInsertBlock()->getModule();
tgt_->add_memfence(mod, *builder_);
add_barrier();
Value *tid = tgt_->get_local_id(mod, *builder_, 0);
rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0)));
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
Value *atom_ptr;
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "");
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
store(old, atom_ptr);
add_barrier();
vals_[atom][idx] = load(atom_ptr);
add_barrier();
}
}
}
/**
* \brief Code Generation for `mma.884` (V100)
*/
//TODO: clean-up
void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
// shapes
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
// order
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
bool is_a_trans = C->is_trans_a();
// is_a_trans = false;
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
std::swap(offset_a_m_, offset_a_k_);
}
// std::cout << "visiting" << std::endl;
// if(C->is_trans_b()){
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
// layouts
analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
analysis::shared_layout* layout_b = layouts_->get(B)->to_shared();
// vectorization
int vec_a = swizzle_->get_vec(layout_a);
int vec_b = swizzle_->get_vec(layout_b);
// strides
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
int stride_am = is_a_row ? shape_a[1] : 1;
int stride_ak = is_a_row ? 1 : shape_a[0];
int stride_a0 = is_a_row ? stride_ak : stride_am;
int stride_a1 = is_a_row ? stride_am : stride_ak;
int stride_bn = is_b_row ? 1 : shape_b[0];
int stride_bk = is_b_row ? shape_b[1] : 1;
int stride_b0 = is_b_row ? stride_bn : stride_bk;
int stride_b1 = is_b_row ? stride_bk : stride_bn;
int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8;
int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8;
int stride_rep_k = 1;
// swizzling
int per_phase_a = swizzle_->get_per_phase(layout_a);
int max_phase_a = swizzle_->get_max_phase(layout_a);
int step_a0 = is_a_row ? stride_rep_k : stride_rep_m;
int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1);
int per_phase_b = swizzle_->get_per_phase(layout_b);
int max_phase_b = swizzle_->get_max_phase(layout_b);
int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
// max_phase_a = 4;
// vec_a = 8;
// std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
// std::cout << vec_a << " " << vec_b << std::endl;
/* --------------------------------- */
/* --- pre-compute pointer lanes --- */
/* --------------------------------- */
BasicBlock* curr_bb = builder_->GetInsertBlock();
BasicBlock* entry = &curr_bb->getParent()->getEntryBlock();
if(entry != curr_bb)
builder_->SetInsertPoint(entry->getTerminator());
Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c];
Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c];
Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a));
std::vector<Value*> off_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++){
Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m)));
off_a0i = exact_udiv(off_a0i, i32(vec_a));
off_a0i = xor_(off_a0i, phase_a);
off_a0i = mul(off_a0i, i32(vec_a));
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
}
Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c];
Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c];
Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b));
std::vector<Value*> off_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++){
Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4)));
off_b0i = udiv(off_b0i, i32(vec_b));
off_b0i = xor_(off_b0i, phase_b);
off_b0i = mul(off_b0i, i32(vec_b));
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
}
builder_->SetInsertPoint(curr_bb);
/* --------------------------------- */
/* --- MMA intrinsic --- */
/* --------------------------------- */
Type *f16x2_ty = vec_ty(f16_ty, 2);
Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty});
std::vector<Type*> arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty,
f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty};
InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false),
" mma.sync.aligned.m8n8k4."
+ std::string(is_a_row ? "row" : "col")
+ "."
+ std::string(is_b_row ? "row" : "col")
+ ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
std::vector<Value*> ptr_a(num_ptr_a);
std::vector<Value*> ptr_b(num_ptr_b);
std::map<std::pair<int, int>, std::pair<Value*, Value*>> has, hbs;
for(int i = 0; i < num_ptr_a; i++)
ptr_a[i] = gep(shmems_[A], off_a[i]);
for(int i = 0; i < num_ptr_b; i++)
ptr_b[i] = gep(shmems_[B], off_b[i]);
// initialize accumulators
std::vector<Value*> acc;
for(indices_t idx: idxs_.at(C))
acc.push_back(vals_[D][idx]);
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1);
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
};
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
// Cache lds value. If values are prefetched, create phi node
// @param inc: incoming block (0 = header, 1 = loop)
auto register_lds =
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K == 0 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
Value* ptra;
if(K==0 && is_prefetch){
if(inc == 0)
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
else
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
else
ptra = ptr_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
register_lds(has, m, K, inc, ha00, ha01, is_prefetch);
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch);
else
register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch);
}
};
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
int offidx = (is_b_row? n : K/4) % num_ptr_b;
Value* ptrb;
if(K==0 && is_prefetch){
if(inc == 0)
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
else
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
} else
ptrb = ptr_b[offidx];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch);
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch);
else
register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch);
}
};
// update accumulators
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
has[{m, 0}].first = phi(f16x2_ty, 2);
has[{m, 0}].second = phi(f16x2_ty, 2);
if (!is_a_row && vec_a>4) {
has[{m+1, 0}].first = phi(f16x2_ty, 2);
has[{m+1, 0}].second = phi(f16x2_ty, 2);
}
}
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
hbs[{n, 0}].first = phi(f16x2_ty, 2);
hbs[{n, 0}].second = phi(f16x2_ty, 2);
if (is_b_row && vec_b>4) {
hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
}
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
load_a(m, 0, 0, true);
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(curr_bb);
for (unsigned K = 0; K < NK; K += 4) {
int NEXTK = (K + 4) % NK;
// prefetch A
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
load_a(m, NEXTK, 1, true);
// prefetch B
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
load_b(n, NEXTK, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++){
call_mma(m, n, K);
}
}
} else { // not prefetched
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end())
load_a(m, K, /*inc*/0, /*is_prefetch*/false);
if(hbs.find({n, K}) == hbs.end())
load_b(n, K, /*inc*/0, /*is_prefetch*/false);
call_mma(m, n, K);
}
}
// write back accumulators
for(size_t i = 0; i < idxs_.at(C).size(); i++)
vals_[C][idxs_[C][i]] = acc[i];
}
namespace {
class mma16816_smem_loader {
public:
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
std::vector<unsigned> tile_shape,
std::vector<int> instr_shape, std::vector<int> mat_shape,
int per_phase, int max_phase, int dtsize, Builder *builder,
adder add, multiplier mul, geper gep)
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
instr_shape_(instr_shape), mat_shape_(mat_shape),
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
add(add), mul(mul), gep(gep) {
// compute compile-time constant variables & types
c_mat_shape_ = mat_shape[order[0]];
s_mat_shape_ = mat_shape[order[1]];
c_stride_ = tile_shape[order[1]];
s_stride_ = tile_shape[order[0]];
// rule: k must be the fast-changing axis
need_trans_ = k_order_ != order_[0];
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
// we need more pointers at the fast-changing axis,
if (can_use_ldmatrix_)
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
else // warning: this only works for tf32 & need transpose
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
num_ptr_ = std::max<int>(num_ptr_, 2);
// special rule for i8/u8, 4 ptrs for each matrix
if (!can_use_ldmatrix_ && dtsize_ == 1)
num_ptr_ *= 4;
// load_v4 stride (in num of mats)
int load_stride_in_mat[2];
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
// stride in mat, used by load_v4
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
}
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
// TODO: this needs to be moved to constructor (and extracted to arr_order)
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
// start matrix logic offset (rename it as base_mat_off?)
Value *mat_off[2] = {nullptr, nullptr};
if (can_use_ldmatrix_) {
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
// s: group idx (0,1,2,3) inside a warp
Value *c = urem(lane, i32(8));
Value *s = udiv(lane, i32(8));
// We can decompose s => s_0, s_1...
Value *s0 = urem(s, i32(2));
Value *s1 = udiv(s, i32(2));
// We use different orders for a & b for better performance.
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
mat_off[k_order_] = k_mat_arr;
// physical offset (before swizzling)
Value *c_mat_off = mat_off[order_[0]];
Value *s_mat_off = mat_off[order_[1]];
// offset inside a matrix
Value *s_off_in_mat = c;
std::vector<Value*> offs(num_ptr_);
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// pre-compute strided offset
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
for (int i=0; i < num_ptr_; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
}
return offs;
} else if (dtsize_ == 4 && need_trans_) {
// load tf32 matrices with lds32
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
Value *s_off_in_mat = urem(lane, i32(4)); //
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
// FIXME: (k_order_ == 1?) is really dirty hack
for (int i = 0; i < num_ptr_/2; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
// TODO: move this out of the loop
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
return offs;
// throw std::runtime_error("not implemented");
} else if (dtsize_ == 1 && need_trans_) {
// load i8/u8 matrices with lds8
Value *c_off_in_mat = udiv(lane, i32(4)); //
Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
for (int elem_off = 0; elem_off < 4; ++elem_off) {
int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off));
// disable swizzling ...
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_)));
// To prevent out-of-bound access when the tile is too small
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
}
return offs;
} else
throw std::runtime_error("invalid smem load config");
}
std::tuple<Value*, Value*, Value*, Value*>
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
int mat_idx[2] = {mat0, mat1};
int k = mat_idx[k_order_];
int ptr_idx = -1;
if (can_use_ldmatrix_)
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
else if (dtsize_ == 4 && need_trans_) // tf32 & trans
ptr_idx = mat_idx[order_[0]];
else // i8 & trans
ptr_idx = mat_idx[order_[0]] * 4;
auto get_ptr = [&](int idx) -> Value* {
Value *ptr = nullptr;
if (k == 0 && is_prefetch) {
if (inc == 0)
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
else
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
} else
ptr = ptrs.at(idx);
return ptr;
};
Value *ptr = get_ptr(ptr_idx);
Value *res_v4 = nullptr;
if (can_use_ldmatrix_) {
std::string trans = need_trans_ ? ".trans" : "";
// the offset (in byte) on the strided axis is a constant
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
"{$0, $1, $2, $3}, "
"[$4 + " + std::to_string(s_offset) + "];",
"=r,=r,=r,=r,r", true);
assert(ptr);
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
if (k == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
return {extract_val(res_v4, std::vector<unsigned>{0}),
extract_val(res_v4, std::vector<unsigned>{1}),
extract_val(res_v4, std::vector<unsigned>{2}),
extract_val(res_v4, std::vector<unsigned>{3})};
} else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices
Value *ptr2 = get_ptr(ptr_idx+1);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *elem0, *elem1, *elem2, *elem3;
if (k_order_ == 1) {
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem1 = load(gep(ptr2, i32(s_offset_elem)));
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
} else { // for b (k first)
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem2 = load(gep(ptr2, i32(s_offset_elem)));
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
}
if (k == 0 && inc == 1 && is_prefetch) {
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
}
return {elem0, elem1, elem2, elem3};
} else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices
Value *ptr00 = get_ptr(ptr_idx);
Value *ptr01 = get_ptr(ptr_idx+1);
Value *ptr02 = get_ptr(ptr_idx+2);
Value *ptr03 = get_ptr(ptr_idx+3);
Value *ptr10 = get_ptr(ptr_idx+4);
Value *ptr11 = get_ptr(ptr_idx+5);
Value *ptr12 = get_ptr(ptr_idx+6);
Value *ptr13 = get_ptr(ptr_idx+7);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *i8v4_elems[4];
Value *i32_elems[4];
for (int i=0; i<4; ++i)
i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
Value *elem00, *elem01, *elem02, *elem03;
Value *elem10, *elem11, *elem12, *elem13;
Value *elem20, *elem21, *elem22, *elem23;
Value *elem30, *elem31, *elem32, *elem33;
Value *i8_elems[4*4];
if (k_order_ == 1) { //
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
for (int m=0; m<4; ++m) {
for (int e=0; e<4; ++e)
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
}
} else { // for b (k first)
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
for (int m=0; m<4; ++m) {
for (int e=0; e<4; ++e)
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
}
}
if (k == 0 && inc == 1 && is_prefetch) {
for (int m = 0; m < 4; ++m)
for (int e = 0; e < 4; ++e)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]);
}
return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]};
} else
throw std::runtime_error("invalid smem load");
}
int get_num_ptr() const { return num_ptr_; }
private:
int wpt_;
std::vector<int> order_;
int k_order_;
std::vector<unsigned> tile_shape_;
std::vector<int> instr_shape_;
std::vector<int> mat_shape_;
int per_phase_, max_phase_;
int dtsize_;
// generated
int c_mat_shape_, s_mat_shape_;
int c_stride_, s_stride_;
// p_: on the pointer axis
int p_load_stride_in_mat_;
int s_mat_stride_;
// stride when moving to next not-k mat
int warp_off_stride_;
int mat_arr_stride_; // matrix arrangement (inside a load) stride
bool need_trans_, can_use_ldmatrix_;
int num_ptr_;
Builder *builder_;
adder add;
multiplier mul;
geper gep;
};
}
/**
* \brief Code Generation for `mma.16816` (A100)
*/
//TODO: clean-up
void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
for(indices_t idx: idxs_.at(C)){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(vals_[D][idx]);
};
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
}
auto ord_b = layouts_->get(B)->get_order();
if(C->is_trans_b()){
std::swap(ord_b[0], ord_b[1]);
std::swap(shape_b[0], shape_b[1]);
}
NK = shape_a[1];
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
const int mma_instr_m = mma_instr_shape[0];
const int mma_instr_n = mma_instr_shape[1];
const int mma_instr_k = mma_instr_shape[2];
std::vector<int> mat_shape = layout->get_mma_mat_shape();
const int mat_shape_m = mat_shape[0];
const int mat_shape_n = mat_shape[1];
const int mat_shape_k = mat_shape[2];
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
// floating point types
Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
// integer types
Type *i8x4_ty = vec_ty(i8_ty, 4);
Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty});
Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty});
FunctionType *ldmatrix_ty = nullptr;
FunctionType *mma_ty = nullptr;
Type *phi_ty = nullptr;
Type *smem_ptr_ty = nullptr;
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(bf16_ty, 3);
ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = bf16x2_ty;
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(fp32_ty, 3);
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp32_ty;
} else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) {
// FIXME: We should use i8 here (but nvptx will generate extra casts when using i8)
mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
smem_ptr_ty = ptr_ty(i8_ty, 3);
ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = i32_ty;
// mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
// smem_ptr_ty = ptr_ty(i8_ty, 3);
// ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = i8x4_ty;
} else
throw std::runtime_error("mma16816 data type not supported");
// left-hand-side values
std::map<std::pair<unsigned, unsigned>, Value*> ha;
std::map<std::pair<unsigned, unsigned>, Value*> hb;
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
// if true, this will move pointer declarations to the entry basic block
// not prefetched cases tend to be more limited in resource usage
// so we don't pre-compute ptrs to save registers
bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, i32(32));
Value *warp = udiv(thread, i32(32));
Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
Value *warp_m = urem(warp, i32(layout->wpt(0)));
Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
std::vector<Value *>& fc = fcs.begin()->second;
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{mn, k}] = val;
};
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
// v (s0_0(0), s1_0(2), | *num_rep_k
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
// -----------
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
std::function<void(int,int,int,bool)> load_a;
analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
bool is_a_shared = layout_a != nullptr;
if(is_a_shared) {
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
// pointers
std::vector<Value*> ptrs_a(num_ptr_a);
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
// loading function
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
}
else {
load_a = [&](int m, int k, int inc, bool is_prefetch) {
distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
int ldm = ax_n.values.size();
if(ldm != num_rep_k*4)
throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
// std::cout << m << " " << k << std::endl;
// std::cout << idxs_[A].size() << std::endl;
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
// int ldm = num_rep_k*4;
Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2
Value* ha1 = UndefValue::get(phi_ty);
Value* ha2 = UndefValue::get(phi_ty);
Value* ha3 = UndefValue::get(phi_ty);
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
ha[{m, k}] = ha0;
ha[{m+1, k}] = ha1;
ha[{m, k+1}] = ha2;
ha[{m+1, k+1}] = ha3;
};
}
// | -> n (col-major)
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
// -----------
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
int k_order_b = 0;
// if(C->is_trans_b()){
// std::swap(mma_instr_b[0], mma_instr_b[1]);
// std::swap(mat_shape_b[0], mat_shape_b[1]);
// k_order_b = k_order_b ^ 1;
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
mma_instr_b, mat_shape_b,
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
// pointers
int num_ptr_b = b_loader.get_num_ptr();
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
// loading function
std::function<void(int,int,int,bool)> load_b;
load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
" {$0, $1, $2, $3},"
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
unsigned cols_per_thread = num_rep_n * 2;
std::vector<size_t> idx = {
(m + 0)*cols_per_thread + (n*2 + 0),
(m + 0)*cols_per_thread + (n*2 + 1),
(m + 1)*cols_per_thread + (n*2 + 0),
(m + 1)*cols_per_thread + (n*2 + 1)
};
Value *nc = call(mma_ty, mma_fn,
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
hb[{n, k}], hb[{n, k+1}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
for(unsigned m = 0; m < num_rep_m; m++){
ha[{2*m, 0}] = phi(phi_ty, 2);
ha[{2*m+1, 0}] = phi(phi_ty, 2);
ha[{2*m, 1}] = phi(phi_ty, 2);
ha[{2*m+1, 1}] = phi(phi_ty, 2);
}
for(unsigned n = 0; n < num_rep_n; n+=2){
hb[{n, 0}] = phi(phi_ty, 2);
hb[{n+1, 0}] = phi(phi_ty, 2);
hb[{n, 1}] = phi(phi_ty, 2);
hb[{n+1, 1}] = phi(phi_ty, 2);
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 0, 0, true);
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(CurrBB);
for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
int next_k = (k + 1) % num_rep_k;
// prefetch A
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*next_k, 1, true);
// prefetch B
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*next_k, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_rep_m; m++)
for(unsigned n = 0; n < num_rep_n; n++){
call_mma(2*m, n, 2*k);
}
}
}
else{
for (unsigned k = 0; k < num_rep_k; k++) {
for (unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*k, 0, /*is_prefetch*/false);
for (unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*k, 0, /*is_prefetch*/false);
for (unsigned m = 0; m < num_rep_m; m++)
for (unsigned n = 0; n < num_rep_n; n++)
call_mma(2*m, n, 2*k);
}
}
// write back
unsigned i = 0;
for(indices_t idx: idxs_.at(C)){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
vals_[C][idx] = fcs.at(key)[i++];
};
}
/**
* \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
*/
void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::string a_trans = is_a_row ? "" : ".trans";
std::string b_trans = is_b_row ? ".trans" : "";
int stride_a_m = is_a_row ? shape_a[1] : 1;
int stride_a_k = is_a_row ? 1 : shape_a[0];
int stride_b_n = is_b_row ? 1 : shape_b[0];
int stride_b_k = is_b_row ? shape_b[1] : 1;
int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
int lda = is_a_row ? stride_a_m : stride_a_k;
int ldb = is_b_row ? stride_b_k : stride_b_n;
int per_phase_a = swizzle_->get_per_phase(layout_a);
int max_phase_a = swizzle_->get_max_phase(layout_a);
int per_phase_b = swizzle_->get_per_phase(layout_b);
int max_phase_b = swizzle_->get_max_phase(layout_b);
int num_ptr_a = 8;
int num_ptr_b = 8;
int vec_a = 2;
int vec_b = 4;
distributed_axis ax_m = axes_.at(a_axes_->get(C, 0));
distributed_axis ax_n = axes_.at(a_axes_->get(C, 1));
// Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous));
Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0);
std::vector<Value*> off_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++){
// Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a));
// off_a0i = exact_udiv(off_a0i, i32(vec_a));
// off_a0i = xor_(off_a0i, phase_a);
// off_a0i = mul(off_a0i, i32(vec_a));
off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
}
Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0);
Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous));
std::vector<Value*> off_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++){
// Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b));
// off_b0i = exact_udiv(off_b0i, i32(vec_b));
// off_b0i = xor_(off_b0i, phase_b);
// off_b0i = mul(off_b0i, i32(vec_b));
off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
}
std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = gep(shmems_[A], off_a[i]);
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = gep(shmems_[B], off_b[i]);
std::map<indices_t, Value*> ret = vals_[D];
std::map<std::pair<int, int>, Value*> has, hbs;
auto ord = layout_c->get_order();
for(unsigned k = 0; k < NK; k++){
int z = 0;
for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1]))
for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0]))
for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++)
for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){
unsigned m = (ord[0] == 1) ? i : j;
unsigned n = (ord[0] == 1) ? j : i;
unsigned mm = (ord[0] == 1) ? ii : jj;
unsigned nn = (ord[0] == 1) ? jj : ii;
if(has.find({m + mm, k}) == has.end()){
Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k));
Value* va = load(pa);
has[{m + mm, k}] = va;
}
if(hbs.find({n + nn, k}) == hbs.end()){
Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k));
Value* vb = load(pb);
hbs[{n + nn, k}] = vb;
}
ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]});
z++;
}
}
for(indices_t idx: idxs_.at(C)){
vals_[C][idx] = ret[idx];
}
}
/**
* \brief Code Generation for `dot`
* Dispatches to appropriate specialized function
*/
void generator::visit_dot_inst(ir::dot_inst* dot) {
Function *fn = builder_->GetInsertBlock()->getParent();
Module *module = fn->getParent();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
Type *c_ty = cvt(D->get_type()->get_scalar_ty());
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
auto A_shapes = A->get_type()->get_block_shapes();
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
bool is_mma = layouts_->get(dot)->to_mma();
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
throw std::runtime_error("dot has invalid operand type");
}
void generator::visit_trans_inst(ir::trans_inst* trans) {
throw std::runtime_error("not supported");
}
/**
* \brief Code Generation for `sqrt`
*/
void generator::visit_sqrt_inst(ir::sqrt_inst* x) {
for(indices_t idx: idxs_.at(x)){
Value *val = vals_[x->get_operand(0)][idx];
Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val});
vals_[x][idx] = ret;
}
}
Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx){
// strides
std::vector<Value*> strides(shapes.size(), builder_->getInt32(0));
strides[order[0]] = builder_->getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]]));
// result
Value *result = builder_->getInt32(0);
for(size_t i = 0; i < idx.size(); i++)
result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i]));
return result;
}
inline Value* generator::shfl_sync(Value* acc, int32_t i){
Type* ty = acc->getType();
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = bit_cast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return bit_cast(ret, ty);
}
/**
* \brief Code Generation for `reduce` (ND case)
*/
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
ir::value *arg = x->get_operand(0);
const auto with_index = x->with_index();
unsigned axis = x->get_axis();
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
const auto &shapes = layout->get_shape();
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
std::string n_bits_str = std::to_string(n_bits);
std::string cst = (n_bits == 64) ? "l" : "r";
FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false);
InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true);
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
Type *index_ty = IntegerType::get(*ctx_, 32);
FunctionType *st_shared_index_ty =
FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
InlineAsm *st_shared_index = InlineAsm::get(
st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
FunctionType *ld_shared_index_ty =
FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
InlineAsm *ld_shared_index = InlineAsm::get(
ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
Value* lane = urem(thread, i32(32));
unsigned shuffle_width = 0;
unsigned warps_per_inner = 0;
auto arg_vals = vals_.at(arg);
std::vector<indices_t> arg_idxs = idxs_.at(arg);
size_t n_elts = arg_idxs.size();
unsigned col_per_thread = 0;
Value* warp_j = nullptr;
if (analysis::scanline_layout *scanline = layout->to_scanline()) {
std::vector<int> order = layout->get_order();
unsigned mts = scanline->mts(order[0]);
shuffle_width = std::min<int>(mts, 32);
warps_per_inner = std::max<int>(mts / 32, 1);
col_per_thread = shapes[order[0]] / mts;
warp_j = urem(warp, i32(warps_per_inner));
} else if (layout->to_mma()) {
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
assert(warp_j != nullptr);
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
//
Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
cvt(x->get_type()->get_scalar_ty()));
Value *index_base =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: nullptr;
// preds
Value* is_lane0 = icmp_eq(lane, i32(0));
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
if(warps_per_inner > 1)
add_barrier();
// compute partial sum for each warp, and store to shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
std::pair<Value*, Value*> acc;
// reduce within thread
for(size_t j = 0; j < col_per_thread; j++){
auto arg_idx = arg_idxs[i*col_per_thread + j];
bool is_first = j == 0;
do_acc(
acc, [&]() -> Value * { return arg_vals[arg_idx]; },
[&]() -> Value * { return arg_idx[axis]; }, is_first);
}
// reduce within warp
for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
// single warp on the reduce dimension -- no need to use shmem
if(warps_per_inner==1){
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
else{
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
call(st_shared_index,
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
}
}
}
if(warps_per_inner==1)
return;
add_barrier();
// at this point, partial accumulator synchronized in shared memory
// Just need to reduce `warp_per_inner` numbers in shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
std::pair<Value*, Value*> acc;
acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
gep(index_base, ld_off)})
: nullptr;
for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
// add_barrier();
}
void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
ir::value *arg = x->get_operand(0);
unsigned axis = x->get_axis();
auto with_index = x->with_index();
// reduce within thread
// index-><current reduced value, current min/max index (optional)>
std::map<indices_t, std::pair<Value*, Value*>> accs;
for(indices_t idx: idxs_.at(arg)){
indices_t pidx = idx;
pidx[axis] = i32(0);
bool is_first = accs.find(pidx) == accs.end();
do_acc(
accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
[&]() -> Value * { return idx[axis]; }, is_first);
};
// reduce within blocks
auto *data_layout = layouts_->get(layouts_->tmp(x));
auto *data_ptr =
cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
auto *index_ptr =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: data_ptr;
auto shape = data_layout->get_shape();
auto order = data_layout->get_order();
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
for(auto& x: accs) {
// current element being computed
std::pair<Value *, Value *> acc = x.second;
indices_t write_idx = x.first;
write_idx[axis] = lane;
// shared memory write pointer
Value *write_off = shared_off(shape, order, write_idx);
Value *write_ptr = gep(data_ptr, write_off);
Value *index_write_ptr = gep(index_ptr, write_off);
// initialize shared memory
add_barrier();
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
// build result
indices_t idx(write_idx.size(), i32(0));
for(size_t i = shape[axis]/2; i > 0; i >>= 1){
idx[axis] = i32(i);
// read pointer
Value *read_msk = icmp_ult(lane, i32(i));
Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
Value *read_ptr = gep(write_ptr, read_off);
Value *index_read_ptr = gep(index_write_ptr, read_off);
add_barrier();
// update accumulator
do_acc(
acc, [&]() -> Value * { return load(read_ptr); },
[&]() -> Value * { return load(index_read_ptr); }, false);
add_barrier();
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
}
}
add_barrier();
// write back
for(indices_t idx: idxs_.at(x)){
indices_t read_idx = idx;
read_idx.insert(read_idx.begin() + axis, i32(0));
Value *read_off = shared_off(shape, order, read_idx);
Value *read_ptr =
with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
vals_[x][idx] = load(read_ptr);
};
}
/**
* \brief Code Generation for `reduce` (generic case)
*/
void generator::visit_reduce_inst(ir::reduce_inst* x) {
Type *ty = cvt(x->get_type()->get_scalar_ty());
// accumulation function
ir::reduce_inst::op_t op = x->get_op();
auto do_acc_op = [&](Value *x, Value *y) -> Value* {
switch(op){
case ir::reduce_inst::ADD: return add(x, y);
case ir::reduce_inst::SUB: return sub(x, y);
case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
case ir::reduce_inst::FADD: return fadd(x, y);
case ir::reduce_inst::FSUB: return fsub(x, y);
case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
case ir::reduce_inst::FMAX: return max_num(x, y);
case ir::reduce_inst::FMIN: return min_num(x, y);
case ir::reduce_inst::XOR: return xor_(x, y);
default: throw std::runtime_error("unreachable");
}
};
auto do_acc = [&](std::pair<Value *, Value *> &acc,
std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn,
bool is_first) -> void {
auto *val = load_value_fn();
if (x->with_index()) {
auto *index = load_index_fn();
if (is_first) {
acc.first = val;
acc.second = index;
} else {
Value *ret = do_acc_op(acc.first, val);
acc.first = select(ret, acc.first, val);
acc.second = select(ret, acc.second, index);
}
} else {
acc.first = is_first ? val : do_acc_op(acc.first, val);
}
};
// neutral element
Value *neutral;
switch(op) {
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
bool is_a100_mma = layouts_->is_a100_mma(x);
if (is_coalesced_scanline || is_a100_mma)
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
}
/**
* \brief Code Generation for `select`
*/
void generator::visit_select_inst(ir::select_inst* x) {
for(indices_t idx: idxs_.at(x)){
vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
vals_[x->get_operand(1)][idx],
vals_[x->get_operand(2)][idx]);
}
}
void generator::visit_layout_convert(ir::value *out, ir::value *in){
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
// pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty());
// Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
Value *base;
int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
// std::cout << off << std::endl;
base = gep(shmem_, i32(off));
base = bit_cast(base, ptr_ty(ty, 3));
std::vector<int> n_reps;
for(int i = 0; i < shape.size(); i++){
int in_per_cta = in_layout->shape_per_cta(i);
int out_per_cta = out_layout->shape_per_cta(i);
int max_per_cta = std::max(in_per_cta, out_per_cta);
n_reps.push_back(shape[i]/max_per_cta);
}
std::vector<std::vector<Value*>> in_ax;
std::vector<std::vector<Value*>> out_ax;
for(int d = 0; d < shape.size(); d++){
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
}
auto in_ord =
in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order();
auto out_ord =
out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order();
// out_ord[0] == 0 or in_order[0] == 0 means the first dimension is
// non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and
// and in_ord[0] are contiguous.
int in_vec = out_ord[0] == 0 ? 1
: in_ord[0] == 0 ? 1
: in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
Value *in_ld = i32(shape[in_ord[0]] + pad);
Value *out_ld = i32(shape[out_ord[0]] + pad);
for(int i = 0; i < n_reps[0]; i++)
for(int j = 0; j < n_reps[1]; j++){
int max_ii, max_jj;
add_barrier();
max_ii = in_ax[0].size()/n_reps[0];
max_jj = in_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj+=in_vec){
// shared mem pointer
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
// stash value to shared mem
Value* vals = UndefValue::get(vec_ty(ty, in_vec));
for(int jjj = 0; jjj < in_vec; jjj++){
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj + jjj]};
Value* val = bit_cast(vals_[in][idxs], ty);
vals = insert_elt(vals, val, jjj);
}
ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace()));
store(vals, ptr);
}
add_barrier();
max_ii = out_ax[0].size()/n_reps[0];
max_jj = out_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj+=out_vec){
// shared mem pointer
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace()));
// load value from shared rem
Value* vals = load(ptr);
for(int jjj = 0; jjj < out_vec; jjj++){
indices_t idxs = {out_ax[0][i*max_ii + ii],
out_ax[1][j*max_jj + jj + jjj]};
vals_[out][idxs] = extract_elt(vals, jjj);
}
}
}
}
void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) {
visit_layout_convert(rc, rc->get_operand(0));
}
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned in_vec = 1;
ir::value *arg = x->get_pointer_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
//
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_block_shapes();
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
std::vector<std::pair<Value*, int>> shared;
for(int i = 0; i < idxs_.at(arg).size(); i++){
unsigned id = i / min_vec;
// input ptr info
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(tmp.find(key) == tmp.end()){
if(CurrBB != FirstBB)
builder_->SetInsertPoint(FirstBB->getTerminator());
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
}
shared.push_back({tmp[key], off});
}
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
Value *ptr = vals_[arg][idx];
size_t in_off = 0;
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr= cst ? in_gep->getPointerOperand() : in_gep;
}
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*dtsize;
// asm
std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca";
// Value* false_value = vals_[x->get_false_value_operand()][idx];
// bool is_zero_false_value = false;
// if(Constant* cst = dyn_cast<Constant>(false_value))
// is_zero_false_value = cst->isZeroValue();
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false);
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
call(iasm, {out_base, ptr, src_size});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
}
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned in_vec = 1;
ir::value *arg = cts->get_operand(0);
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
in_vec = in_layout->contig_per_thread(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
//
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
if(in_layout->to_mma()){
mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
per_phase = 1;
max_phase = 8;
}
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
if(in_layout->to_mma()){
n_shared_0 = 8;
n_shared_1 = 1;
}
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_block_shapes();
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){
auto idx = idxs_[arg][i];
Value *in_value = vals_[arg][idx];
if(i % min_vec == 0)
current = UndefValue::get(vec_ty(in_value->getType(), min_vec));
current = insert_elt(current, in_value, i % min_vec);
if(i % min_vec == min_vec - 1){
unsigned id = i / min_vec;
// input ptr info
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
// std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(ptrs.find(key) == ptrs.end()){
if(FirstBB->getTerminator())
builder_->SetInsertPoint(FirstBB->getTerminator());
else
builder_->SetInsertPoint(FirstBB);
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
builder_->SetInsertPoint(CurrBB);
ptrs[key] = gep(shmems_.at(cts), {off});
}
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
if(in_layout->to_mma()){
off_0 = id_0/n_shared_0*n_shared_0*8;
off_1 = id_1/n_shared_1*n_shared_1*8;
}
int off = (off_1*shapes[in_order[0]] + off_0);
Value* ptr = gep(ptrs[key], {i32(off)});
ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
// asm
store(current, ptr);
}
};
}
void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) {
throw std::runtime_error("TODO");
}
Instruction* generator::add_barrier() {
Module *module = builder_->GetInsertBlock()->getModule();
return tgt_->add_barrier(module, *builder_);
}
void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_clock_inst(ir::clock_inst* clock){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
vals_[clock][{}] = call(iasm);
}
void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
vals_[timer][{}] = call(iasm);
}
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();
if (inc == 0) {
// If dot has not been visitied, do nothing.
} else {
// If dot has been visitied, insert prefetched lds
assert(inc == 1);
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
"dot hasn't be visited");
// sink lds & extract element
// move lds & all uses to current location
std::stack<Value*> work_stack;
for (Value *value : prefetch_latch_to_bb_[v])
work_stack.push(value);
std::vector<Instruction*> dead_instrs;
while (!work_stack.empty()) {
Value *m = work_stack.top();
work_stack.pop();
for (auto u : m->users())
work_stack.push(u);
assert(isa<Instruction>(m));
auto m_instr = static_cast<Instruction*>(m);
m_instr->removeFromParent();
m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
builder_->SetInsertPoint(m_instr->getParent());
}
}
}
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
}
/**
* \brief Code Generation for `extern_elementwise`
*/
void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
std::vector<Type *> operand_types;
for (size_t j = 0; j < i->get_num_operands(); j++) {
operand_types.push_back(
cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
}
Type *ret_type = cvt(i->get_type()->get_scalar_ty());
FunctionType *FT =
FunctionType::get(ret_type, std::move(operand_types), false);
Function *F = llvm::cast<llvm::Function>(
mod_->getOrInsertFunction(i->get_symbol_name(), FT).getCallee());
for (auto idx : idxs_.at(i)) {
std::vector<llvm::Value *> args;
for (size_t j = 0; j < i->get_num_operands(); j++) {
args.emplace_back(vals_[i->get_operand(j)][idx]);
}
vals_[i][idx] = call(F, std::move(args));
}
add_extern_lib(i->get_lib_name(), i->get_lib_path());
}
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0))
// vals_[x][idx] = idx[0];
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// vals_[x][idx] = bin_add->getOperand(0);
// }
// }
//}
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0)){
// vals_[x][idx] = idx[0];
// }
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// Value *cst = bin_add->getOperand(1);
// assert(isa<Constant>(cst));
// vals_[x][idx] = cst;
// }
// };
//}
void generator::visit_make_range(ir::make_range* x) {
for(indices_t idx: idxs_.at(x)){
Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value());
vals_[x][idx] = add(start, idx[0]);
}
}
void generator::visit_undef_value(ir::undef_value *x) {
ir::type* sca_ty = x->get_type()->get_scalar_ty();
Type* ty = cvt(sca_ty);
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = llvm::UndefValue::get(ty);
}
void generator::visit_constant_int(ir::constant_int *x){
Type *ty = cvt(x->get_type()->get_scalar_ty());
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = ConstantInt::get(ty, x->get_value());
}
void generator::visit_constant_fp(ir::constant_fp *x){
Type *ty = cvt(x->get_type()->get_scalar_ty());
for(indices_t idx: idxs_.at(x)) {
// manually select bf16 constant
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
// highest 16 bits of fp32
float fp32_value = x->get_value();
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
& 0xffff0000) >> 16;
std::stringstream const_str;
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
" mov.b16 $0, " + const_str.str() + ";",
"=h", false);
vals_[x][idx] = builder_->CreateCall(bf16_const, {});
} else
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
}
}
void generator::visit_alloc_const(ir::alloc_const *alloc) {
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty());
Type *array_ty = llvm::ArrayType::get(element_ty, size);
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4));
}
void generator::forward_declare(ir::function* fn){
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
if(!tgt_->is_gpu()){
Type *fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
fn_args_ty.push_back(fn_ty->getParamType(i));
fn_args_ty.push_back(i32_ty);
fn_args_ty.push_back(i32_ty);
fn_args_ty.push_back(i32_ty);
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
}
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
fns_[fn] = ret;
}
Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
Type *ty) {
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
return base;
}
void generator::visit_function(ir::function* fn) {
idxs_.clear();
vals_.clear();
seen_.clear();
LLVMContext &ctx = builder_->getContext();
Function* ret = fns_[fn];
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
if(attr.is_llvm_attr()){
llvm::Attribute llattr = cvt(attr);
if(llattr.getKindAsEnum() != llvm::Attribute::None)
ret->addAttribute(id, cvt(attr));
}
}
// set metadata
if(tgt_->is_gpu()){
tgt_->set_kernel(*builder_, ctx, mod_, ret);
Metadata *md_args[] = {
ValueAsMetadata::get(ret),
MDString::get(ctx, "maxntidx"),
ValueAsMetadata::get(i32(num_warps_*32))
};
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
}
// set arguments
for(unsigned i = 0; i < fn->args().size(); i++)
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
// create blocks
auto blocks = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: blocks) {
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
bbs_[block] = dst_block;
}
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
// create policies
if(tgt_->as_nvidia()->sm() >= 80)
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
policies_[evict] = call(iasm);
}
// initialize layouts
for(auto x: layouts_->get_all()){
visit_layout(x.second);
}
// generate LLVM-IR code
for(ir::basic_block *block: blocks)
visit_basic_block(block);
// finalize
finalize_function(fn);
}
void generator::visit_layout_mma(analysis::mma_layout* layout) {
ir::value *a = nullptr;
ir::value *b = nullptr;
for(ir::value* v: layout->get_values())
if(ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(v)){
a = dot->get_operand(0);
b = dot->get_operand(1);
}
analysis::data_layout* layout_a = layouts_->get(a);
analysis::data_layout* layout_b = layouts_->get(b);
const auto& shape = layout->get_shape();
Value *_1 = i32(1);
Value *_2 = i32(2);
Value *_3 = i32(3);
Value *_4 = i32(4);
Value *_8 = i32(8);
Value *_16 = i32(16);
Value *_32 = i32(32);
int cc = tgt_->as_nvidia()->sm();
std::vector<Value*> idx_m;
std::vector<Value*> idx_n;
std::vector<Value*> idx_z;
//
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, _32);
Value *warp = udiv(thread, _32);
/* lane offset */
if(cc < 80){
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
/* warp offset */
Value *warp_0 = urem(warp, i32(layout->wpt(0)));
Value *warp_12 = udiv(warp, i32(layout->wpt(0)));
Value *warp_1 = urem(warp_12, i32(layout->wpt(1)));
Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
// Quad offset
Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0)));
Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1)));
// Pair offset
Value *off_pair_m = udiv(urem(lane, _16), _4);
off_pair_m = urem(off_pair_m, i32(layout->fpw(0)));
off_pair_m = mul(off_pair_m, i32(4));
Value *off_pair_n = udiv(urem(lane, _16), _4);
off_pair_n = udiv(off_pair_n, i32(layout->fpw(0)));
off_pair_n = urem(off_pair_n, i32(layout->fpw(1)));
off_pair_n = mul(off_pair_n, i32(4));
// scale
off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2));
off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2));
off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2));
off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2));
// Quad pair offset
Value *off_lane_m = add(off_pair_m, off_quad_m);
Value *off_lane_n = add(off_pair_n, off_quad_n);
// a offset
offset_a_m_[layout] = add(off_warp_m, off_lane_m);
offset_a_k_[layout] = and_(lane, _3);
// b offsets
offset_b_n_[layout] = add(off_warp_n, off_lane_n);
offset_b_k_[layout] = and_(lane, _3);
// i indices
Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]);
for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0))
for(unsigned mm = 0; mm < layout->rep(0); mm++)
idx_m.push_back(add(offset_c_m, i32(m + mm*2)));
// j indices
Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n));
for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1))
for(unsigned nn = 0; nn < layout->rep(1); nn++){
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1))));
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1)));
}
if(is_a_row){
offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4)));
offset_a_k_[layout] = i32(0);
}
if(!is_b_row){
offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4)));
offset_b_k_[layout] = i32(0);
}
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1};
}
else{
/* warp offset */
Value *warp_0 = urem(warp, i32(layout->wpt(0)));
Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1)));
Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
Value *off_lane_m = urem(lane, _16);
Value *off_lane_n = urem(lane, _8);
/* offsets */
// a offset
offset_a_m_[layout] = add(off_warp_m, off_lane_m);
offset_a_k_[layout] = i32(0);
// b offsets
offset_b_n_[layout] = add(off_warp_n, off_lane_n);
offset_b_k_[layout] = i32(0);
// c offset
Value *off_c_m = add(udiv(lane, _4), off_warp_m);
Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n);
for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){
idx_m.push_back(add(off_c_m, i32(m)));
idx_m.push_back(add(off_c_m, i32(m + 8)));
}
for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){
idx_n.push_back(add(off_c_n, i32(n)));
idx_n.push_back(add(off_c_n, i32(n + 1)));
}
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1};
}
}
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0);
auto order = layout->get_order();
const auto& shape = layout->get_shape();
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_ids(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = i32(layout->mts(order[k]));
Value *rem = urem(thread_id, dim_k);
thread_id = udiv(thread_id, dim_k);
thread_ids[order[k]] = rem;
}
Constant *dim_k = i32(layout->mts(order[dim - 1]));
thread_ids[order[dim - 1]] = urem(thread_id, dim_k);
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = i32(nts);
Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k);
unsigned per_cta = layout->shape_per_cta(k);
unsigned per_thread = nts * shape[k] / per_cta;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_cta + n % nts;
idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]};
}
}
void generator::visit_layout_shared(analysis::shared_layout* layout) {
Type* ty = cvt(layout->get_type());
PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
if (layout->get_N_buffer()) {
// create pointers
shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout->get_N_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = bbs_.at(phi->get_parent());
if(parent->empty())
builder_->SetInsertPoint(parent);
else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
} else
builder_->SetInsertPoint(parent);
// create smem_idx
read_smem_idx_[layout] = phi(i32_ty, 2);
write_smem_idx_[layout] = phi(i32_ty, 2);
// create pointers
// ptr of the current iteration
shared_ptr_[layout] = phi(ptr_ty, 2);
// ptr of the next iteration
shared_next_ptr_[layout] = phi(ptr_ty, 2);
builder_->SetInsertPoint(current);
} else if(layout->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = bbs_.at(phi->get_parent());
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
shared_ptr_[layout] = phi(ptr_ty, 2);
shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType());
shared_off_[layout] = phi(i32_ty, 2);
shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
builder_->SetInsertPoint(current);
} else{
size_t offset = alloc_->offset(layout);
shared_ptr_[layout] = gep(shmem_, i32(offset));
shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
}
}
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
visit_value(i);
// std::cout << "done" << std::endl;
}
// Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock();
}
void generator::visit_argument(ir::argument* arg) {
}
void generator::init_idx(ir::value *v) {
idxs_[v].clear();
if(!v->get_type()->is_block_ty()){
idxs_[v].push_back({});
return;
}
if(layouts_->get(v)->to_shared())
return;
const auto &shapes = v->get_type()->get_block_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> ord(rank);
// compute axes
// std::cout << "axes" << std::endl;
for(size_t d = 0; d < shapes.size(); d++){
// std::cout << d << " " << shapes[d] << std::endl;
// std::cout << a_axes_->get(v, d) << std::endl;
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {i32(0)};
}
}
// std::cout << "axes ok" << std::endl;
// compute order
analysis::data_layout* layout = layouts_->get(v);
std::iota(ord.begin(), ord.end(), 0);
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
size_t posx = layout->find_axis(axx);
size_t posy = layout->find_axis(axy);
if(posx < rank && posy < rank)
return layout->get_order(posx) < layout->get_order(posy);
return false;
};
std::sort(ord.begin(), ord.end(), cmp);
ords_[v] = ord;
// indices
if(axes.size() == 1)
for(Value* x0: axes[ord[0]].values){
idxs_[v].push_back({x0});
}
if(axes.size() == 2)
for(Value* x1: axes[ord[1]].values)
for(Value* x0: axes[ord[0]].values){
indices_t idx(2);
idx[ord[0]] = x0;
idx[ord[1]] = x1;
idxs_[v].push_back(idx);
}
if(axes.size() == 3)
for(Value* x2: axes[ord[2]].values)
for(Value* x1: axes[ord[1]].values)
for(Value* x0: axes[ord[0]].values){
indices_t idx(3);
idx[ord[0]] = x0;
idx[ord[1]] = x1;
idx[ord[2]] = x2;
idxs_[v].push_back(idx);
}
}
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
if (auto n_buffer = shared->get_N_buffer()) {
// if (*_smem_idx == #stages-1) {
// *_smem_idx = 0;
// } else *_smem_idx++;
auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
// insert point
Value *idx = smem_idx[shared];
builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
PHINode *_ret = phi(i32_ty, 2);
Instruction *then_term = nullptr;
Instruction *else_term = nullptr;
Instruction *dummy = builder_->CreateRet(nullptr);
llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
dummy->removeFromParent();
builder_->SetInsertPoint(then_term);
Value *zero_smem_idx = i32(0);
builder_->SetInsertPoint(else_term);
Value *inc_smem_idx = add(idx, i32(1));
builder_->SetInsertPoint(_ret->getParent());
_ret->addIncoming(zero_smem_idx, then_term->getParent());
_ret->addIncoming(inc_smem_idx, else_term->getParent());
// update ir::bb -> llvm::bb mapping
bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
// idx = init_stage;
// loop: ...
if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
} else
throw std::runtime_error("Should be PHINode");
};
// read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
finalize_smem_idx(read_smem_idx_, 2);
finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
// finalize pointers
ir::phi_node *pn = n_buffer->phi;
BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
// %curr_ptr = phi %shared_pre_ptr, %next_ptr
// %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
} else
throw std::runtime_error("Should be PHINode");
BasicBlock *current = builder_->GetInsertBlock();
builder_->SetInsertPoint(header->getTerminator());
Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
builder_->SetInsertPoint(current->getTerminator());
assert(isa<PHINode>(shared_next_ptr_[shared]));
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
} else if(shared->get_double_buffer()) {
auto info = *shared->get_double_buffer();
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)shmems_[phi];
PHINode *offset = (PHINode*)shoffs_[phi];
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n);
BasicBlock *llvm_inc_block = bbs_.at(inc_block);
if(inc_val == info.latch){
builder_->SetInsertPoint(llvm_inc_block->getTerminator());
Value *next_offset = neg(offset);
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(shmems_[inc_val], llvm_inc_block);
}
}
}
void generator::finalize_function(ir::function *fn) {
// finalize double-buffering
for(const auto& x: layouts_->get_all())
if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
finalize_shared_layout(shared);
// finalize phi
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
finalize_phi_node(phi);
for(auto& x: lazy_phi_incs_)
std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
}
void generator::finalize_phi_node(ir::phi_node *x) {
if(shmems_.find(x) != shmems_.end())
return;
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::basic_block *_block = x->get_incoming_block(n);
BasicBlock *block = bbs_.at(_block);
for(indices_t idx: idxs_.at(x)){
PHINode *phi = (PHINode*)vals_[x][idx];
Value *inc = vals_[x->get_incoming_value(n)][idx];
// x->print(std::cout);
phi->addIncoming(inc, block);
}
}
}
void generator::packed_type(ir::value* i){
Type* dtype = cvt(i->get_type()->get_tile_element_ty());
auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
assert(layout);
}
void generator::visit(ir::module &src, llvm::Module &dst) {
mod_ = &dst;
ctx_ = &dst.getContext();
builder_ = new Builder(*ctx_);
// allocate shared memory
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty);
}
// instantiate device functions
// for(ir::function *fn: src.get_function_list())
// for(ir::basic_block *bb: fn->blocks())
// for(ir::instruction *i: bb->get_inst_list())
// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
// std::cout << "call??" << std::endl;
// }
// visit functions
for(ir::function *fn: src.get_function_list())
forward_declare(fn);
for(ir::function *fn: src.get_function_list())
visit_function(fn);
}
void generator::add_extern_lib(const std::string &lib_name,
const std::string &lib_path) {
if (extern_lib_map_.count(lib_name) == 0) {
extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
} else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
throw std::runtime_error("A library has multiple paths (1) " + lib_path +
" (2) " + extern_lib_map_.at(lib_name)->path());
}
}
} // namespace codegen
} // namespace triton