Improve ROCm support. (#780)

- updates to support ROCm 5.2
- workarounds in tests where NV tools were used unconditionally
- implemented `get_num_blocks()` and `add_memfence()` for AMD GPU
- backported from history some atomics
- added bf16 support
- minor warnings cleanup
- added dockerfile to run on a ROCm enabled machine

Co-authored-by: B1tway <andrew.shukshov@gmail.com>
Co-authored-by: Andrey Shukshov <36711069+B1tway@users.noreply.github.com>
This commit is contained in:
Daniil Fukalov
2022-10-14 21:33:42 +03:00
committed by GitHub
parent 94d5c2e8b5
commit 406d03bfaf
17 changed files with 435 additions and 155 deletions

View File

@@ -69,18 +69,21 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
return (*builder_)->CreateGEP(gep->getPointerOperand(),
(*builder_)->CreateAdd(cst1, cst2));
return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(),
gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2));
}
// ptr + (off + cst) -> (ptr + off) + cst
if(auto* bin = dyn_cast<BinaryOperator>(off))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
bin->getOperand(1));
Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
ptr, bin->getOperand(0));
return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(),
gep, bin->getOperand(1));
}
// default
return (*builder_)->CreateGEP(ptr, off, name);
return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
ptr, off, name);
}
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
@@ -91,6 +94,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getInt16Ty()
#define f32_ty builder_->getFloatTy()
#define f64_ty builder_->getDoubleTy()
#define i8_ty builder_->getInt8Ty()
@@ -124,7 +128,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
@@ -293,8 +297,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
*/
void generator::visit_binary_operator(ir::binary_operator*x) {
using ll = llvm::Instruction::BinaryOps;
using tt = ir::binary_op_t;
auto cvt = [](ir::binary_op_t op){
using tt = ir::binary_op_t;
switch(op) {
case tt::Add: return ll::Add;
case tt::FAdd: return ll::FAdd;
@@ -320,13 +324,47 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else
vals_[x][idx] = bin_op(op, lhs, rhs);
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
if (x->get_op() == tt::FAdd) {
InlineAsm *bf16_add_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_add_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
InlineAsm *bf16_sub_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_sub_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
InlineAsm *bf16_mul_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"v_lshlrev_b32 $1, 16, $1 "
"v_lshlrev_b32 $2, 16, $2 "
"v_mul_f32 $0, $1, $2 "
"v_lshrrev_b32 $0, 16, $0 ",
"=v,v,v", false);
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
} else
throw std::runtime_error("invalid bin op for bf16");
}
else {
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
}
}
@@ -979,6 +1017,35 @@ void generator::visit_log_inst(ir::log_inst* x){
/**
* \brief Code Generation for `atomic_cas`
*/
#if defined(USE_ROCM)
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = icmp_eq(tid, i32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
add_barrier();
tgt_->add_memfence(module, *builder_);
cond_br(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
Value *cas_val = vals_[cas->get_operand(2)][{}];
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
old = extract_val(old, std::vector<unsigned>{0});
Value *atom_ptr;
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
store(old, atom_ptr);
br(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
add_barrier();
vals_[cas][{}] = load(atom_ptr);
add_barrier();
}
#else
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
@@ -1013,12 +1080,66 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
vals_[cas][{}] = load(atom_ptr);
add_barrier();
}
#endif // defined(USE_ROCM)
/**
* \brief Code Generation for `atomic_rmw`
*/
#if defined(USE_ROCM)
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
ir::value* ptr = atom->get_operand(0);
if (atom->get_op() == ir::atomic_rmw_op_t::Add ||
atom->get_op() == ir::atomic_rmw_op_t::Max ||
atom->get_op() == ir::atomic_rmw_op_t::Min ||
atom->get_op() == ir::atomic_rmw_op_t::UMax ||
atom->get_op() == ir::atomic_rmw_op_t::UMin ||
atom->get_op() == ir::atomic_rmw_op_t::Xchg) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vals_[atom->get_operand(0)][{}];
Value *rmw_val = vals_[atom->get_operand(1)][{}];
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = icmp_eq(tid, i32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
tgt_->add_memfence(module, *builder_);
add_barrier();
cond_br(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
AtomicRMWInst::BinOp binop;
switch (atom->get_op()) {
case ir::atomic_rmw_op_t::Add:
binop = AtomicRMWInst::Add;
break;
case ir::atomic_rmw_op_t::Max:
binop = AtomicRMWInst::Max;
break;
case ir::atomic_rmw_op_t::Min:
binop = AtomicRMWInst::Min;
break;
case ir::atomic_rmw_op_t::UMax:
binop = AtomicRMWInst::UMax;
break;
case ir::atomic_rmw_op_t::UMin:
binop = AtomicRMWInst::UMin;
break;
case ir::atomic_rmw_op_t::Xchg:
binop = AtomicRMWInst::Xchg;
break;
default:
throw std::runtime_error("Not supported!");
}
atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic,
SyncScope::System);
br(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
return;
}
throw std::runtime_error("Not supported!");
}
#else // defined(USE_ROCM)
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
ir::value *ptr = atom->get_operand(0);
ir::value* val = atom->get_operand(1);
ir::value* msk = atom->get_operand(2);
@@ -1100,6 +1221,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
}
}
}
#endif // defined(USE_ROCM)
/**
* \brief Code Generation for `mma.884` (V100)

View File

@@ -41,7 +41,8 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented on AMD");
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_waitcnt);
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_waitcnt, {}, {builder.getInt32(0)});
}
@@ -56,7 +57,50 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented on AMD");
Function &F = *builder.GetInsertBlock()->getParent();
Module *Mod = F.getParent();
// We are indexing into this struct, and want to extract the grid_size_*
// fields.
//
// typedef struct hsa_kernel_dispatch_packet_s {
// uint16_t header;
// uint16_t setup;
// uint16_t workgroup_size_x ;
// uint16_t workgroup_size_y;
// uint16_t workgroup_size_z;
// uint16_t reserved0;
// uint32_t grid_size_x ;
// uint32_t grid_size_y ;
// uint32_t grid_size_z;
// .....
// } hsa_kernel_dispatch_packet_t
//
Function *DispatchPtrFn =
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_dispatch_ptr);
CallInst *DispatchPtr = builder.CreateCall(DispatchPtrFn, {});
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
DispatchPtr->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
F.removeFnAttr("amdgpu-no-dispatch-ptr");
// Size of the dispatch packet struct.
DispatchPtr->addDereferenceableAttr(AttributeList::ReturnIndex, 64);
Type *I32Ty = Type::getInt32Ty(Mod->getContext());
// TODO: include AMDGPUAS:: declarations.
Value *CastDispatchPtr = builder.CreateBitCast(
DispatchPtr, PointerType::get(I32Ty, 4 /*AMDGPUAS::CONSTANT_ADDRESS*/));
// grid_size_x offset is 3*32bit
assert(ax < 3);
Value *GEP =
builder.CreateConstInBoundsGEP1_64(I32Ty, CastDispatchPtr, ax + 3);
LoadInst *Load = builder.CreateAlignedLoad(I32Ty, GEP, Align(4));
MDNode *MD = MDNode::get(Mod->getContext(), None);
Load->setMetadata(LLVMContext::MD_invariant_load, MD);
return Load; // throw std::runtime_error("not implemented on AMD");
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {