[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)
This commit is contained in:
@@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||
//}
|
||||
|
||||
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getBFloatTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i16_ty builder_->getInt16Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
@@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
||||
|
||||
|
||||
/**
|
||||
* \brief Convert Triton-IR Type to LLVM-IR Type
|
||||
*/
|
||||
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
@@ -457,19 +457,25 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
||||
}
|
||||
|
||||
Value* generator::bf16_to_fp32(Value *in0){
|
||||
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
|
||||
ret = insert_elt(ret, in0, (uint64_t)1);
|
||||
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
|
||||
return bit_cast(ret, builder_->getFloatTy());
|
||||
if (tgt_->as_nvidia()->sm() >= 80) {
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
||||
"cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
|
||||
return call(ptx, {in0});
|
||||
} else {
|
||||
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
|
||||
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
|
||||
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
|
||||
return bit_cast(ret, f32_ty);
|
||||
}
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
}
|
||||
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
|
||||
return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||
// FP32 -> BF16
|
||||
if(op_sca_ty->is_fp32_ty())
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||
// for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||
for (indices_t idx: idxs_.at(x)) {
|
||||
Value *arg = vals_[x->get_operand(0)][idx];
|
||||
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
|
||||
}
|
||||
// BF16 -> FP32
|
||||
if(ret_sca_ty->is_fp32_ty())
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
// ---
|
||||
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
|
||||
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
|
||||
// ret_ty->print(llvm::outs());
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(Value *v: others)
|
||||
arg_tys.push_back(v->getType());
|
||||
@@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
}
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1));
|
||||
// vectorize
|
||||
Type *v_ty = vec_ty(ty, vec);
|
||||
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
|
||||
// value
|
||||
Value* val = UndefValue::get(vec_ty(ty, vec));
|
||||
Value* val = UndefValue::get(v_ty);
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii);
|
||||
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
|
||||
if(mx){
|
||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||
@@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
vals_[C][idxs_[C][i]] = acc[i];
|
||||
}
|
||||
|
||||
namespace {
|
||||
class mma16816_smem_loader {
|
||||
public:
|
||||
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
|
||||
std::vector<unsigned> tile_shape,
|
||||
std::vector<int> instr_shape, std::vector<int> mat_shape,
|
||||
int per_phase, int max_phase, int dtsize, Builder *builder,
|
||||
adder add, multiplier mul, geper gep)
|
||||
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
|
||||
instr_shape_(instr_shape), mat_shape_(mat_shape),
|
||||
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
|
||||
add(add), mul(mul), gep(gep) {
|
||||
// compute compile-time constant variables & types
|
||||
c_mat_shape_ = mat_shape[order[0]];
|
||||
s_mat_shape_ = mat_shape[order[1]];
|
||||
|
||||
c_stride_ = tile_shape[order[1]];
|
||||
s_stride_ = tile_shape[order[0]];
|
||||
|
||||
// rule: k must be the fast-changing axis
|
||||
need_trans_ = k_order_ != order_[0];
|
||||
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
||||
|
||||
// std::cout << can_use_ldmatrix_ << std::endl;
|
||||
// std::cout << need_trans_ << std::endl;
|
||||
|
||||
// we need more pointers at the fast-changing axis,
|
||||
if (can_use_ldmatrix_)
|
||||
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
||||
else // warning: this only works for tf32 & need transpose
|
||||
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
|
||||
num_ptr_ = std::max<int>(num_ptr_, 2);
|
||||
|
||||
|
||||
// load_v4 stride (in num of mats)
|
||||
int load_stride_in_mat[2];
|
||||
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
|
||||
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
|
||||
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
|
||||
// stride in mat, used by load_v4
|
||||
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
|
||||
}
|
||||
|
||||
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
|
||||
// TODO: this needs to be moved to constructor (and extracted to arr_order)
|
||||
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
|
||||
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
|
||||
// start matrix logic offset (rename it as base_mat_off?)
|
||||
Value *mat_off[2] = {nullptr, nullptr};
|
||||
|
||||
if (can_use_ldmatrix_) {
|
||||
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
|
||||
// s: group idx (0,1,2,3) inside a warp
|
||||
Value *c = urem(lane, i32(8));
|
||||
Value *s = udiv(lane, i32(8));
|
||||
// We can decompose s => s_0, s_1...
|
||||
Value *s0 = urem(s, i32(2));
|
||||
Value *s1 = udiv(s, i32(2));
|
||||
|
||||
// We use different orders for a & b for better performance.
|
||||
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
|
||||
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
|
||||
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
|
||||
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||
mat_off[k_order_] = k_mat_arr;
|
||||
// physical offset (before swizzling)
|
||||
Value *c_mat_off = mat_off[order_[0]];
|
||||
Value *s_mat_off = mat_off[order_[1]];
|
||||
// offset inside a matrix
|
||||
Value *s_off_in_mat = c;
|
||||
|
||||
std::vector<Value*> offs(num_ptr_);
|
||||
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||
// pre-compute strided offset
|
||||
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
|
||||
for (int i=0; i < num_ptr_; ++i) {
|
||||
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
|
||||
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
|
||||
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
|
||||
}
|
||||
return offs;
|
||||
} else if (dtsize_ == 4 && need_trans_) {
|
||||
// load tf32 matrices with lds32
|
||||
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
|
||||
Value *s_off_in_mat = urem(lane, i32(4)); //
|
||||
|
||||
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||
std::vector<Value*> offs(num_ptr_);
|
||||
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
|
||||
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
|
||||
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
|
||||
if (k_mat_arr_int > 0) // we don't need pointers for k
|
||||
continue;
|
||||
Value *k_mat_arr = i32(k_mat_arr_int);
|
||||
Value *nk_mat_arr = i32(nk_mat_arr_int);
|
||||
// physical offset (before swizzling)
|
||||
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
|
||||
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||
Value *s_mat_off = k_mat_arr; // always 0?
|
||||
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
|
||||
// FIXME: (k_order_ == 1?) is really dirty hack
|
||||
for (int i = 0; i < num_ptr_/2; ++i) {
|
||||
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
|
||||
c_mat_off_i = xor_(c_mat_off_i, phase);
|
||||
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
|
||||
// TODO: move this out of the loop
|
||||
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
|
||||
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
|
||||
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
|
||||
}
|
||||
}
|
||||
return offs;
|
||||
// throw std::runtime_error("not implemented");
|
||||
} else
|
||||
throw std::runtime_error("invalid smem load config");
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*>
|
||||
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
|
||||
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
|
||||
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
|
||||
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
|
||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
|
||||
int mat_idx[2] = {mat0, mat1};
|
||||
int k = mat_idx[k_order_];
|
||||
|
||||
int ptr_idx = -1;
|
||||
if (can_use_ldmatrix_)
|
||||
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
|
||||
else // tf32 & trans
|
||||
ptr_idx = mat_idx[order_[0]];
|
||||
|
||||
auto get_ptr = [&](int idx) -> Value* {
|
||||
Value *ptr = nullptr;
|
||||
if (k == 0 && is_prefetch) {
|
||||
if (inc == 0)
|
||||
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
|
||||
else
|
||||
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
|
||||
} else
|
||||
ptr = ptrs.at(idx);
|
||||
return ptr;
|
||||
};
|
||||
Value *ptr = get_ptr(ptr_idx);
|
||||
|
||||
Value *res_v4 = nullptr;
|
||||
if (can_use_ldmatrix_) {
|
||||
std::string trans = need_trans_ ? ".trans" : "";
|
||||
// the offset (in byte) on the strided axis is a constant
|
||||
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
|
||||
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
|
||||
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, "
|
||||
"[$4 + " + std::to_string(s_offset) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
assert(ptr);
|
||||
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
|
||||
if (k == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
|
||||
return {extract_val(res_v4, std::vector<unsigned>{0}),
|
||||
extract_val(res_v4, std::vector<unsigned>{1}),
|
||||
extract_val(res_v4, std::vector<unsigned>{2}),
|
||||
extract_val(res_v4, std::vector<unsigned>{3})};
|
||||
} else {
|
||||
// assert(false && "should not be here");
|
||||
assert(dtsize_ == 4 && need_trans_);
|
||||
Value *ptr2 = get_ptr(ptr_idx+1);
|
||||
assert(s_mat_stride_ == 1);
|
||||
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||
Value *elem0, *elem1, *elem2, *elem3;
|
||||
if (k_order_ == 1) {
|
||||
elem0 = load(gep(ptr, i32(s_offset_elem)));
|
||||
elem1 = load(gep(ptr2, i32(s_offset_elem)));
|
||||
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
} else { // for b (k first)
|
||||
elem0 = load(gep(ptr, i32(s_offset_elem)));
|
||||
elem2 = load(gep(ptr2, i32(s_offset_elem)));
|
||||
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
}
|
||||
if (k == 0 && inc == 1 && is_prefetch) {
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
|
||||
}
|
||||
return {elem0, elem1, elem2, elem3};
|
||||
}
|
||||
}
|
||||
|
||||
int get_num_ptr() const { return num_ptr_; }
|
||||
|
||||
private:
|
||||
int wpt_;
|
||||
std::vector<int> order_;
|
||||
int k_order_;
|
||||
std::vector<unsigned> tile_shape_;
|
||||
std::vector<int> instr_shape_;
|
||||
std::vector<int> mat_shape_;
|
||||
int per_phase_, max_phase_;
|
||||
int dtsize_;
|
||||
|
||||
// generated
|
||||
int c_mat_shape_, s_mat_shape_;
|
||||
int c_stride_, s_stride_;
|
||||
// p_: on the pointer axis
|
||||
int p_load_stride_in_mat_;
|
||||
int s_mat_stride_;
|
||||
// stride when moving to next not-k mat
|
||||
int warp_off_stride_;
|
||||
int mat_arr_stride_; // matrix arrangement (inside a load) stride
|
||||
bool need_trans_, can_use_ldmatrix_;
|
||||
int num_ptr_;
|
||||
|
||||
Builder *builder_;
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `mma.16816` (A100)
|
||||
*/
|
||||
@@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||
bool is_a_row = ord_a[0] == 1;
|
||||
bool is_b_row = ord_b[0] == 1;
|
||||
std::string a_trans = is_a_row ? "" : ".trans";
|
||||
std::string b_trans = is_b_row ? ".trans" : "";
|
||||
int stride_a_m = is_a_row ? shape_a[1] : 1;
|
||||
int stride_a_k = is_a_row ? 1 : shape_a[0];
|
||||
int stride_b_n = is_b_row ? 1 : shape_b[0];
|
||||
int stride_b_k = is_b_row ? shape_b[1] : 1;
|
||||
int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
|
||||
int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
|
||||
int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
|
||||
int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
|
||||
int lda = is_a_row ? stride_a_m : stride_a_k;
|
||||
int ldb = is_b_row ? stride_b_k : stride_b_n;
|
||||
int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
int num_ptr_a = 8;
|
||||
int num_ptr_b = 8;
|
||||
int vec_a = 8;
|
||||
int vec_b = 8;
|
||||
|
||||
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
|
||||
const int mma_instr_m = mma_instr_shape[0];
|
||||
const int mma_instr_n = mma_instr_shape[1];
|
||||
const int mma_instr_k = mma_instr_shape[2];
|
||||
|
||||
std::vector<int> mat_shape = layout->get_mma_mat_shape();
|
||||
const int mat_shape_m = mat_shape[0];
|
||||
const int mat_shape_n = mat_shape[1];
|
||||
const int mat_shape_k = mat_shape[2];
|
||||
|
||||
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
|
||||
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
|
||||
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
|
||||
|
||||
Type *fp32_ty = f32_ty;
|
||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
|
||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
|
||||
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{ptr_ty(f16_ty, 3)}, false);
|
||||
|
||||
FunctionType *ldmatrix_ty = nullptr;
|
||||
FunctionType *mma_ty = nullptr;
|
||||
Type *phi_ty = nullptr;
|
||||
Type *smem_ptr_ty = nullptr;
|
||||
|
||||
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
|
||||
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
|
||||
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
|
||||
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
smem_ptr_ty = ptr_ty(f16_ty, 3);
|
||||
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
phi_ty = fp16x2_ty;
|
||||
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
|
||||
// FIXME: We should use bf16 here.
|
||||
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
smem_ptr_ty = ptr_ty(f16_ty, 3);
|
||||
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
phi_ty = fp16x2_ty;
|
||||
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
|
||||
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
// phi_ty = bf16x2_ty;
|
||||
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
|
||||
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
smem_ptr_ty = ptr_ty(fp32_ty, 3);
|
||||
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
phi_ty = fp32_ty;
|
||||
} else
|
||||
throw std::runtime_error("mma16816 data type not supported");
|
||||
|
||||
// left-hand-side values
|
||||
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
|
||||
std::map<std::pair<unsigned, unsigned>, Value*> ha;
|
||||
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
@@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *lane = urem(thread, i32(32));
|
||||
Value *warp = udiv(thread, i32(32));
|
||||
Value *warp12 = udiv(warp, i32(layout->wpt(0)));
|
||||
Value *warp0 = urem(warp, i32(layout->wpt(0)));
|
||||
Value *warp1 = urem(warp12, i32(layout->wpt(1)));
|
||||
Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
|
||||
Value *warp_m = urem(warp, i32(layout->wpt(0)));
|
||||
Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
|
||||
std::vector<Value *>& fc = fcs.begin()->second;
|
||||
|
||||
Value *tidr8 = urem(lane, i32(8));
|
||||
Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a));
|
||||
Value* off_a0 = mul(tidr8, i32(lda));
|
||||
Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8));
|
||||
Value *off_ak = mul(udiv(lane, i32(16)), i32(8));
|
||||
off_am = urem(off_am, i32(shape_a[0]));
|
||||
off_ak = urem(off_ak, i32(shape_a[1]));
|
||||
off_a0 = add(off_a0, is_a_row ? off_ak : off_am);
|
||||
Value* off_a1 = is_a_row ? off_am : off_ak;
|
||||
std::vector<Value*> off_a(num_ptr_a);
|
||||
for(int i = 0; i < num_ptr_a; i++){
|
||||
Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0))));
|
||||
off_a0i = exact_udiv(off_a0i, i32(vec_a));
|
||||
off_a0i = xor_(off_a0i, phase_a);
|
||||
off_a0i = mul(off_a0i, i32(vec_a));
|
||||
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
|
||||
}
|
||||
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
|
||||
Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b));
|
||||
Value* off_b0 = mul(tidr8, i32(ldb));
|
||||
Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8));
|
||||
Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8));
|
||||
off_bn = urem(off_bn, i32(shape_b[1]));
|
||||
off_bk = urem(off_bk, i32(shape_b[0]));
|
||||
off_b0 = add(off_b0, is_b_row ? off_bn : off_bk);
|
||||
Value* off_b1 = is_b_row ? off_bk : off_bn;
|
||||
std::vector<Value*> off_b(num_ptr_b);
|
||||
for(int i = 0; i < num_ptr_b; i++){
|
||||
Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16)));
|
||||
off_b0i = exact_udiv(off_b0i, i32(vec_b));
|
||||
off_b0i = xor_(off_b0i, phase_b);
|
||||
off_b0i = mul(off_b0i, i32(vec_b));
|
||||
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
|
||||
}
|
||||
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
|
||||
// v (s0_0(0), s1_0(2), | *num_rep_k
|
||||
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
|
||||
// -----------
|
||||
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
|
||||
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
||||
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
||||
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
||||
int num_ptr_a = a_loader.get_num_ptr();
|
||||
|
||||
// | -> n (col-major)
|
||||
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
|
||||
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
|
||||
// -----------
|
||||
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
|
||||
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
|
||||
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
|
||||
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
|
||||
int num_ptr_b = b_loader.get_num_ptr();
|
||||
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
// A pointer
|
||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||
for(int i = 0; i < num_ptr_a; i++)
|
||||
ptrs_a[i] = gep(shmems_[A], {off_a[i]});
|
||||
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||
// B pointer
|
||||
std::vector<Value*> ptrs_b(num_ptr_b);
|
||||
for(int i = 0; i < num_ptr_b; i++)
|
||||
ptrs_b[i] = gep(shmems_[B], {off_b[i]});
|
||||
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
|
||||
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3}, "
|
||||
"{$4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11, $12, $13};",
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
|
||||
" {$0, $1, $2, $3},"
|
||||
" {$4, $5, $6, $7},"
|
||||
" {$8, $9},"
|
||||
" {$10, $11, $12, $13};",
|
||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
|
||||
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
|
||||
|
||||
// create mma & unpack result
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||
unsigned cols_per_thread = num_rep_0 * 2;
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned cols_per_thread = num_rep_m * 2;
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
||||
(m + 0) + (n*2 + 0)*cols_per_thread,
|
||||
(m + 0) + (n*2 + 1)*cols_per_thread,
|
||||
(m + 1) + (n*2 + 0)*cols_per_thread,
|
||||
(m + 1) + (n*2 + 1)*cols_per_thread
|
||||
};
|
||||
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
||||
hb[{n, K}], hb[{n, K+8}],
|
||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||
Value *nc = call(mma_ty, mma_fn,
|
||||
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
|
||||
hb[{n, k}], hb[{n, k+1}],
|
||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
||||
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||
@@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
auto register_lds =
|
||||
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
||||
if (K <= 8 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto register_lds2 =
|
||||
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
|
||||
if (K <= 8 && is_prefetch) {
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = val;
|
||||
vals[{n, k}] = val;
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K == 0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
else
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else
|
||||
ptra = ptrs_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
||||
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
|
||||
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
|
||||
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
|
||||
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
|
||||
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
|
||||
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
|
||||
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
|
||||
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K == 0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
else
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
else
|
||||
ptrb = ptrs_b[offidx];
|
||||
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
||||
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
|
||||
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
|
||||
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
|
||||
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
|
||||
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
|
||||
register_lds2(hb, n, K, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
|
||||
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
if (C->is_prefetched()) {
|
||||
// create phis
|
||||
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||
for(unsigned m = 0; m < num_rep_0; m++){
|
||||
ha[{m, 0}].first = phi(fp16x2_ty, 2);
|
||||
ha[{m, 0}].second = phi(fp16x2_ty, 2);
|
||||
ha[{m, 8}].first = phi(fp16x2_ty, 2);
|
||||
ha[{m, 8}].second = phi(fp16x2_ty, 2);
|
||||
for(unsigned m = 0; m < num_rep_m; m++){
|
||||
ha[{2*m, 0}] = phi(phi_ty, 2);
|
||||
ha[{2*m+1, 0}] = phi(phi_ty, 2);
|
||||
ha[{2*m, 1}] = phi(phi_ty, 2);
|
||||
ha[{2*m+1, 1}] = phi(phi_ty, 2);
|
||||
}
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2){
|
||||
hb[{n, 0}] = phi(fp16x2_ty, 2);
|
||||
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
|
||||
hb[{n, 8}] = phi(fp16x2_ty, 2);
|
||||
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
|
||||
for(unsigned n = 0; n < num_rep_n; n+=2){
|
||||
hb[{n, 0}] = phi(phi_ty, 2);
|
||||
hb[{n+1, 0}] = phi(phi_ty, 2);
|
||||
hb[{n, 1}] = phi(phi_ty, 2);
|
||||
hb[{n+1, 1}] = phi(phi_ty, 2);
|
||||
}
|
||||
// insert prefetched lds at the end of loop header
|
||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
load_a(m, 0, 0, true);
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||
for(unsigned m = 0; m < num_rep_m; m++)
|
||||
load_a(2*m, 0, 0, true);
|
||||
for(unsigned n = 0; n < num_rep_n; n+=2)
|
||||
load_b(n, 0, 0, true);
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
for(unsigned K = 0; K < NK; K += 16){
|
||||
int NEXTK = (K + 16) % NK;
|
||||
for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
|
||||
int next_k = (k + 1) % num_rep_k;
|
||||
// prefetch A
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
load_a(m, NEXTK, 1, true);
|
||||
for(unsigned m = 0; m < num_rep_m; m++)
|
||||
load_a(2*m, 2*next_k, 1, true);
|
||||
// prefetch B
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||
load_b(n, NEXTK, 1, true);
|
||||
for(unsigned n = 0; n < num_rep_n; n+=2)
|
||||
load_b(n, 2*next_k, 1, true);
|
||||
// tensor core ops
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
for(unsigned n = 0; n < num_rep_1; n++){
|
||||
call_mma(m, n, K);
|
||||
for(unsigned m = 0; m < num_rep_m; m++)
|
||||
for(unsigned n = 0; n < num_rep_n; n++){
|
||||
call_mma(2*m, n, 2*k);
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
for(unsigned K = 0; K < NK; K += 16)
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
for(unsigned n = 0; n < num_rep_1; n++){
|
||||
if(ha.find({m, K}) == ha.end())
|
||||
load_a(m, K, 0, false);
|
||||
if(hb.find({n, K})==hb.end())
|
||||
load_b(n, K, 0, false);
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
for (unsigned k = 0; k < num_rep_k; k++) {
|
||||
for (unsigned m = 0; m < num_rep_m; m++)
|
||||
load_a(2*m, 2*k, 0, /*is_prefetch*/false);
|
||||
for (unsigned n = 0; n < num_rep_n; n+=2)
|
||||
load_b(n, 2*k, 0, /*is_prefetch*/false);
|
||||
for (unsigned m = 0; m < num_rep_m; m++)
|
||||
for (unsigned n = 0; n < num_rep_n; n++)
|
||||
call_mma(2*m, n, 2*k);
|
||||
}
|
||||
}
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
@@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK);
|
||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
}
|
||||
|
||||
@@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
||||
acc = bit_cast(acc, vec_ty(f32_ty, 2));
|
||||
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
|
||||
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
|
||||
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
|
||||
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
|
||||
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
|
||||
return builder_->CreateBitCast(ret, ty);
|
||||
return bit_cast(ret, ty);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
||||
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
|
||||
// Orders
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
@@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
// stash value to shared mem
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj]};
|
||||
store(vals_[in][idxs], ptr);
|
||||
store(bit_cast(vals_[in][idxs], ty), ptr);
|
||||
}
|
||||
add_barrier();
|
||||
max_ii = out_ax[0].size()/n_reps[0];
|
||||
|
Reference in New Issue
Block a user