Improve ROCm support. (#780)
- updates to support ROCm 5.2 - workarounds in tests where NV tools were used unconditionally - implemented `get_num_blocks()` and `add_memfence()` for AMD GPU - backported from history some atomics - added bf16 support - minor warnings cleanup - added dockerfile to run on a ROCm enabled machine Co-authored-by: B1tway <andrew.shukshov@gmail.com> Co-authored-by: Andrey Shukshov <36711069+B1tway@users.noreply.github.com>
This commit is contained in:
@@ -69,18 +69,21 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
|
||||
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
|
||||
return (*builder_)->CreateGEP(gep->getPointerOperand(),
|
||||
(*builder_)->CreateAdd(cst1, cst2));
|
||||
return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(),
|
||||
gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2));
|
||||
}
|
||||
// ptr + (off + cst) -> (ptr + off) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(off))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
|
||||
ptr, bin->getOperand(0));
|
||||
return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(),
|
||||
gep, bin->getOperand(1));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateGEP(ptr, off, name);
|
||||
return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
|
||||
ptr, off, name);
|
||||
}
|
||||
|
||||
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
|
||||
@@ -91,6 +94,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getInt16Ty()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define f64_ty builder_->getDoubleTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
@@ -124,7 +128,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
@@ -293,8 +297,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
using tt = ir::binary_op_t;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
case tt::FAdd: return ll::FAdd;
|
||||
@@ -320,13 +324,47 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
|
||||
if (x->get_op() == tt::FAdd) {
|
||||
InlineAsm *bf16_add_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_add_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
|
||||
InlineAsm *bf16_sub_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_sub_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
|
||||
InlineAsm *bf16_mul_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"v_lshlrev_b32 $1, 16, $1 "
|
||||
"v_lshlrev_b32 $2, 16, $2 "
|
||||
"v_mul_f32 $0, $1, $2 "
|
||||
"v_lshrrev_b32 $0, 16, $0 ",
|
||||
"=v,v,v", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
|
||||
} else
|
||||
throw std::runtime_error("invalid bin op for bf16");
|
||||
}
|
||||
else {
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -979,6 +1017,35 @@ void generator::visit_log_inst(ir::log_inst* x){
|
||||
/**
|
||||
* \brief Code Generation for `atomic_cas`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
add_barrier();
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = extract_val(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#else
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
@@ -1013,12 +1080,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_rmw`
|
||||
*/
|
||||
#if defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
if (atom->get_op() == ir::atomic_rmw_op_t::Add ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Max ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Min ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMax ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::UMin ||
|
||||
atom->get_op() == ir::atomic_rmw_op_t::Xchg) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vals_[atom->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[atom->get_operand(1)][{}];
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
AtomicRMWInst::BinOp binop;
|
||||
switch (atom->get_op()) {
|
||||
case ir::atomic_rmw_op_t::Add:
|
||||
binop = AtomicRMWInst::Add;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Max:
|
||||
binop = AtomicRMWInst::Max;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Min:
|
||||
binop = AtomicRMWInst::Min;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMax:
|
||||
binop = AtomicRMWInst::UMax;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::UMin:
|
||||
binop = AtomicRMWInst::UMin;
|
||||
break;
|
||||
case ir::atomic_rmw_op_t::Xchg:
|
||||
binop = AtomicRMWInst::Xchg;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Not supported!");
|
||||
}
|
||||
#else // defined(USE_ROCM)
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value *ptr = atom->get_operand(0);
|
||||
ir::value* val = atom->get_operand(1);
|
||||
ir::value* msk = atom->get_operand(2);
|
||||
|
||||
@@ -1100,6 +1221,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // defined(USE_ROCM)
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `mma.884` (V100)
|
||||
|
@@ -41,7 +41,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt);
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)});
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +57,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
Function &F = *builder.GetInsertBlock()->getParent();
|
||||
Module *Mod = F.getParent();
|
||||
// We are indexing into this struct, and want to extract the grid_size_*
|
||||
// fields.
|
||||
//
|
||||
// typedef struct hsa_kernel_dispatch_packet_s {
|
||||
// uint16_t header;
|
||||
// uint16_t setup;
|
||||
// uint16_t workgroup_size_x ;
|
||||
// uint16_t workgroup_size_y;
|
||||
// uint16_t workgroup_size_z;
|
||||
// uint16_t reserved0;
|
||||
// uint32_t grid_size_x ;
|
||||
// uint32_t grid_size_y ;
|
||||
// uint32_t grid_size_z;
|
||||
// .....
|
||||
// } hsa_kernel_dispatch_packet_t
|
||||
//
|
||||
Function *DispatchPtrFn =
|
||||
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr);
|
||||
|
||||
CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {});
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
|
||||
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
|
||||
F.removeFnAttr("amdgpu-no-dispatch-ptr");
|
||||
|
||||
// Size of the dispatch packet struct.
|
||||
DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64);
|
||||
|
||||
Type *I32Ty = Type::getInt32Ty(Mod->getContext());
|
||||
// TODO: include AMDGPUAS:: declarations.
|
||||
Value *CastDispatchPtr = builder.CreateBitCast(
|
||||
DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/));
|
||||
|
||||
// grid_size_x offset is 3*32bit
|
||||
assert(ax < 3);
|
||||
Value *GEP =
|
||||
builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3);
|
||||
LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4));
|
||||
|
||||
MDNode *MD = MDNode::get(Mod->getContext(), None);
|
||||
Load->setMetadata(LLVMContext::MD_invariant_load, MD);
|
||||
|
||||
return Load; // throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
|
Reference in New Issue
Block a user