|
|
|
@@ -556,14 +556,14 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|
|
|
|
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
|
|
|
|
|
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
|
|
|
|
|
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
|
|
|
|
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
|
|
|
|
|
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
|
|
|
|
|
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for_each(x, [&](indices_t idx){
|
|
|
|
|
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
|
|
|
|
|
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
|
|
|
|
|
set_value(x, idx, builder_->CreateCall(ex2, std::vector<llvm::Value*>{ex2arg}));
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -584,7 +584,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
|
|
|
|
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
|
|
|
|
|
AtomicOrdering::Monotonic,
|
|
|
|
|
AtomicOrdering::Monotonic);
|
|
|
|
|
old = builder_->CreateExtractValue(old, {0});
|
|
|
|
|
old = builder_->CreateExtractValue(old, std::vector<unsigned>{0});
|
|
|
|
|
Value *atom_ptr;
|
|
|
|
|
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
|
|
|
|
|
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
|
|
|
|
@@ -640,8 +640,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|
|
|
|
|
|
|
|
|
Type *fp32_ty = builder_->getFloatTy();
|
|
|
|
|
Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
|
|
|
|
|
Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
|
|
|
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
|
|
|
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
|
|
|
@@ -720,15 +720,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|
|
|
|
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
|
|
|
|
|
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc
|
|
|
|
|
};
|
|
|
|
|
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
|
|
|
|
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
|
|
|
|
|
fc[idx[1]] = builder_->CreateExtractValue(nc, {1});
|
|
|
|
|
fc[idx[2]] = builder_->CreateExtractValue(nc, {2});
|
|
|
|
|
fc[idx[3]] = builder_->CreateExtractValue(nc, {3});
|
|
|
|
|
fc[idx[4]] = builder_->CreateExtractValue(nc, {4});
|
|
|
|
|
fc[idx[5]] = builder_->CreateExtractValue(nc, {5});
|
|
|
|
|
fc[idx[6]] = builder_->CreateExtractValue(nc, {6});
|
|
|
|
|
fc[idx[7]] = builder_->CreateExtractValue(nc, {7});
|
|
|
|
|
Value *nc = builder_->CreateCall(mma_fn, std::vector<llvm::Value*>{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
|
|
|
|
fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{0});
|
|
|
|
|
fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{1});
|
|
|
|
|
fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{2});
|
|
|
|
|
fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{3});
|
|
|
|
|
fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{4});
|
|
|
|
|
fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{5});
|
|
|
|
|
fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{6});
|
|
|
|
|
fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{7});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@@ -770,7 +770,7 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_ti
|
|
|
|
|
a = builder_->CreateFPCast(a, c_ty);
|
|
|
|
|
if(b->getType() != c_ty)
|
|
|
|
|
b = builder_->CreateFPCast(b, c_ty);
|
|
|
|
|
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
|
|
|
|
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
|
|
|
|
|
}
|
|
|
|
|
set_value(dot, idx, res);
|
|
|
|
|
});
|
|
|
|
@@ -790,7 +790,7 @@ void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distrib
|
|
|
|
|
a = builder_->CreateFPCast(a, c_ty);
|
|
|
|
|
if(b->getType() != c_ty)
|
|
|
|
|
b = builder_->CreateFPCast(b, c_ty);
|
|
|
|
|
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
|
|
|
|
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
|
|
|
|
|
set_value(dot, idx, res);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
@@ -805,7 +805,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|
|
|
|
|
|
|
|
|
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
|
|
|
|
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_);
|
|
|
|
|
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
|
|
|
|
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
|
|
|
|
|
auto A_shapes = A->get_type()->get_tile_shapes();
|
|
|
|
|
size_t red_axis = 1;
|
|
|
|
|
unsigned NK = A_shapes[red_axis];
|
|
|
|
@@ -835,8 +835,8 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
|
|
|
|
|
for_each(sqt, [&](indices_t idx){
|
|
|
|
|
Value *val = get_value(sqt->get_operand(0), idx);
|
|
|
|
|
Module* module = builder_->GetInsertBlock()->getModule();
|
|
|
|
|
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()});
|
|
|
|
|
Value *ret = builder_->CreateCall(sqrt, {val});
|
|
|
|
|
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector<llvm::Type*>{val->getType()});
|
|
|
|
|
Value *ret = builder_->CreateCall(sqrt, std::vector<llvm::Value*>{val});
|
|
|
|
|
set_value(sqt, idx, ret);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
@@ -849,7 +849,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|
|
|
|
unsigned axis = x->get_axis();
|
|
|
|
|
|
|
|
|
|
Type *fp32_ty = builder_->getFloatTy();
|
|
|
|
|
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
|
|
|
|
|
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty}, false);
|
|
|
|
|
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
|
|
|
|
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
|
|
|
|
|
|
|
|
@@ -871,8 +871,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|
|
|
|
}
|
|
|
|
|
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
|
|
|
|
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
|
|
|
|
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
|
|
|
|
|
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
|
|
|
|
|
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector<llvm::Value*>{x, y});
|
|
|
|
|
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector<llvm::Value*>{x, y});
|
|
|
|
|
default: assert(false); return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@@ -910,11 +910,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|
|
|
|
thread_acc = accumulate(thread_acc, current);
|
|
|
|
|
});
|
|
|
|
|
// reduce within wrap
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
|
|
|
|
|
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector<llvm::Type*>{thread_acc->getType(), builder_->getInt32Ty()}, false);
|
|
|
|
|
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
|
|
|
|
Value *warp_acc = thread_acc;
|
|
|
|
|
for(int i = 16; i > 0; i >>= 1)
|
|
|
|
|
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
|
|
|
|
|
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{warp_acc, builder_->getInt32(i)}));
|
|
|
|
|
// shared memory pointer
|
|
|
|
|
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
|
|
|
|
Type *res_ty = arg_tile->get_ty();
|
|
|
|
@@ -935,7 +935,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|
|
|
|
builder_->SetInsertPoint(bb_final_acc);
|
|
|
|
|
Value* final_val = builder_->CreateLoad(load_ptr);
|
|
|
|
|
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
|
|
|
|
|
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
|
|
|
|
|
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{final_val, builder_->getInt32(i)}));
|
|
|
|
|
builder_->CreateStore(final_val, load_ptr);
|
|
|
|
|
builder_->CreateBr(bb_final_acc_done);
|
|
|
|
|
// // store first warp done
|
|
|
|
|