[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -0,0 +1,78 @@
#pragma once
#include <numeric>
#include <sstream>
#include <iomanip>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
namespace triton::codegen {
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
#define br(...) builder_->CreateBr(__VA_ARGS__)
#define call(...) builder_->CreateCall(__VA_ARGS__)
#define cast(...) builder_->CreateCast(__VA_ARGS__)
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
} // namespace triton::codegen

View File

@@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
//}
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
@@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
/**
* \brief Convert Triton-IR Type to LLVM-IR Type
*/
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
@@ -457,19 +457,25 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
}
Value* generator::bf16_to_fp32(Value *in0){
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
ret = insert_elt(ret, in0, (uint64_t)1);
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
return bit_cast(ret, builder_->getFloatTy());
if (tgt_->as_nvidia()->sm() >= 80) {
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
"cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
return call(ptx, {in0});
} else {
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
return bit_cast(ret, f32_ty);
}
}
Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0});
}
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
}
/**
@@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16
if(op_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
// for(size_t i = 0; i < x_idxs.size(); i++)
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
// BF16 -> FP32
if(ret_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
return;
}
@@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// ---
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
// ret_ty->print(llvm::outs());
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
@@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){
}
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
for(size_t i = 0; i < idxs.size(); i += vec){
auto idx = idxs[i];
// pointer
Value *ptr = vals_[ptr_op][idx];
ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1));
// vectorize
Type *v_ty = vec_ty(ty, vec);
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
// value
Value* val = UndefValue::get(vec_ty(ty, vec));
Value* val = UndefValue::get(v_ty);
for(size_t ii = 0; ii < vec; ii++)
val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii);
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
@@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
vals_[C][idxs_[C][i]] = acc[i];
}
namespace {
class mma16816_smem_loader {
public:
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
std::vector<unsigned> tile_shape,
std::vector<int> instr_shape, std::vector<int> mat_shape,
int per_phase, int max_phase, int dtsize, Builder *builder,
adder add, multiplier mul, geper gep)
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
instr_shape_(instr_shape), mat_shape_(mat_shape),
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
add(add), mul(mul), gep(gep) {
// compute compile-time constant variables & types
c_mat_shape_ = mat_shape[order[0]];
s_mat_shape_ = mat_shape[order[1]];
c_stride_ = tile_shape[order[1]];
s_stride_ = tile_shape[order[0]];
// rule: k must be the fast-changing axis
need_trans_ = k_order_ != order_[0];
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
// std::cout << can_use_ldmatrix_ << std::endl;
// std::cout << need_trans_ << std::endl;
// we need more pointers at the fast-changing axis,
if (can_use_ldmatrix_)
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
else // warning: this only works for tf32 & need transpose
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
num_ptr_ = std::max<int>(num_ptr_, 2);
// load_v4 stride (in num of mats)
int load_stride_in_mat[2];
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
// stride in mat, used by load_v4
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
}
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
// TODO: this needs to be moved to constructor (and extracted to arr_order)
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
// start matrix logic offset (rename it as base_mat_off?)
Value *mat_off[2] = {nullptr, nullptr};
if (can_use_ldmatrix_) {
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
// s: group idx (0,1,2,3) inside a warp
Value *c = urem(lane, i32(8));
Value *s = udiv(lane, i32(8));
// We can decompose s => s_0, s_1...
Value *s0 = urem(s, i32(2));
Value *s1 = udiv(s, i32(2));
// We use different orders for a & b for better performance.
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
mat_off[k_order_] = k_mat_arr;
// physical offset (before swizzling)
Value *c_mat_off = mat_off[order_[0]];
Value *s_mat_off = mat_off[order_[1]];
// offset inside a matrix
Value *s_off_in_mat = c;
std::vector<Value*> offs(num_ptr_);
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// pre-compute strided offset
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
for (int i=0; i < num_ptr_; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
}
return offs;
} else if (dtsize_ == 4 && need_trans_) {
// load tf32 matrices with lds32
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
Value *s_off_in_mat = urem(lane, i32(4)); //
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
// FIXME: (k_order_ == 1?) is really dirty hack
for (int i = 0; i < num_ptr_/2; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
// TODO: move this out of the loop
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
return offs;
// throw std::runtime_error("not implemented");
} else
throw std::runtime_error("invalid smem load config");
}
std::tuple<Value*, Value*, Value*, Value*>
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
int mat_idx[2] = {mat0, mat1};
int k = mat_idx[k_order_];
int ptr_idx = -1;
if (can_use_ldmatrix_)
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
else // tf32 & trans
ptr_idx = mat_idx[order_[0]];
auto get_ptr = [&](int idx) -> Value* {
Value *ptr = nullptr;
if (k == 0 && is_prefetch) {
if (inc == 0)
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
else
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
} else
ptr = ptrs.at(idx);
return ptr;
};
Value *ptr = get_ptr(ptr_idx);
Value *res_v4 = nullptr;
if (can_use_ldmatrix_) {
std::string trans = need_trans_ ? ".trans" : "";
// the offset (in byte) on the strided axis is a constant
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
"{$0, $1, $2, $3}, "
"[$4 + " + std::to_string(s_offset) + "];",
"=r,=r,=r,=r,r", true);
assert(ptr);
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
if (k == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
return {extract_val(res_v4, std::vector<unsigned>{0}),
extract_val(res_v4, std::vector<unsigned>{1}),
extract_val(res_v4, std::vector<unsigned>{2}),
extract_val(res_v4, std::vector<unsigned>{3})};
} else {
// assert(false && "should not be here");
assert(dtsize_ == 4 && need_trans_);
Value *ptr2 = get_ptr(ptr_idx+1);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *elem0, *elem1, *elem2, *elem3;
if (k_order_ == 1) {
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem1 = load(gep(ptr2, i32(s_offset_elem)));
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
} else { // for b (k first)
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem2 = load(gep(ptr2, i32(s_offset_elem)));
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
}
if (k == 0 && inc == 1 && is_prefetch) {
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
}
return {elem0, elem1, elem2, elem3};
}
}
int get_num_ptr() const { return num_ptr_; }
private:
int wpt_;
std::vector<int> order_;
int k_order_;
std::vector<unsigned> tile_shape_;
std::vector<int> instr_shape_;
std::vector<int> mat_shape_;
int per_phase_, max_phase_;
int dtsize_;
// generated
int c_mat_shape_, s_mat_shape_;
int c_stride_, s_stride_;
// p_: on the pointer axis
int p_load_stride_in_mat_;
int s_mat_stride_;
// stride when moving to next not-k mat
int warp_off_stride_;
int mat_arr_stride_; // matrix arrangement (inside a load) stride
bool need_trans_, can_use_ldmatrix_;
int num_ptr_;
Builder *builder_;
adder add;
multiplier mul;
geper gep;
};
}
/**
* \brief Code Generation for `mma.16816` (A100)
*/
@@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::string a_trans = is_a_row ? "" : ".trans";
std::string b_trans = is_b_row ? ".trans" : "";
int stride_a_m = is_a_row ? shape_a[1] : 1;
int stride_a_k = is_a_row ? 1 : shape_a[0];
int stride_b_n = is_b_row ? 1 : shape_b[0];
int stride_b_k = is_b_row ? shape_b[1] : 1;
int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
int lda = is_a_row ? stride_a_m : stride_a_k;
int ldb = is_b_row ? stride_b_k : stride_b_n;
int per_phase_a = swizzle_->get_per_phase(layout_a);
int max_phase_a = swizzle_->get_max_phase(layout_a);
int per_phase_b = swizzle_->get_per_phase(layout_b);
int max_phase_b = swizzle_->get_max_phase(layout_b);
int num_ptr_a = 8;
int num_ptr_b = 8;
int vec_a = 8;
int vec_b = 8;
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
const int mma_instr_m = mma_instr_shape[0];
const int mma_instr_n = mma_instr_shape[1];
const int mma_instr_k = mma_instr_shape[2];
std::vector<int> mat_shape = layout->get_mma_mat_shape();
const int mat_shape_m = mat_shape[0];
const int mat_shape_n = mat_shape[1];
const int mat_shape_k = mat_shape[2];
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{ptr_ty(f16_ty, 3)}, false);
FunctionType *ldmatrix_ty = nullptr;
FunctionType *mma_ty = nullptr;
Type *phi_ty = nullptr;
Type *smem_ptr_ty = nullptr;
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
// FIXME: We should use bf16 here.
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = bf16x2_ty;
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(fp32_ty, 3);
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp32_ty;
} else
throw std::runtime_error("mma16816 data type not supported");
// left-hand-side values
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
std::map<std::pair<unsigned, unsigned>, Value*> ha;
std::map<std::pair<unsigned, unsigned>, Value*> hb;
BasicBlock* CurrBB = builder_->GetInsertBlock();
@@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, i32(32));
Value *warp = udiv(thread, i32(32));
Value *warp12 = udiv(warp, i32(layout->wpt(0)));
Value *warp0 = urem(warp, i32(layout->wpt(0)));
Value *warp1 = urem(warp12, i32(layout->wpt(1)));
Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
Value *warp_m = urem(warp, i32(layout->wpt(0)));
Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
std::vector<Value *>& fc = fcs.begin()->second;
Value *tidr8 = urem(lane, i32(8));
Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a));
Value* off_a0 = mul(tidr8, i32(lda));
Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8));
Value *off_ak = mul(udiv(lane, i32(16)), i32(8));
off_am = urem(off_am, i32(shape_a[0]));
off_ak = urem(off_ak, i32(shape_a[1]));
off_a0 = add(off_a0, is_a_row ? off_ak : off_am);
Value* off_a1 = is_a_row ? off_am : off_ak;
std::vector<Value*> off_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++){
Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0))));
off_a0i = exact_udiv(off_a0i, i32(vec_a));
off_a0i = xor_(off_a0i, phase_a);
off_a0i = mul(off_a0i, i32(vec_a));
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
}
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b));
Value* off_b0 = mul(tidr8, i32(ldb));
Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8));
Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8));
off_bn = urem(off_bn, i32(shape_b[1]));
off_bk = urem(off_bk, i32(shape_b[0]));
off_b0 = add(off_b0, is_b_row ? off_bn : off_bk);
Value* off_b1 = is_b_row ? off_bk : off_bn;
std::vector<Value*> off_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++){
Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16)));
off_b0i = exact_udiv(off_b0i, i32(vec_b));
off_b0i = xor_(off_b0i, phase_b);
off_b0i = mul(off_b0i, i32(vec_b));
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
}
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
// v (s0_0(0), s1_0(2), | *num_rep_k
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
// -----------
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
// | -> n (col-major)
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
// -----------
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
int num_ptr_b = b_loader.get_num_ptr();
builder_->SetInsertPoint(CurrBB);
// A pointer
std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = gep(shmems_[A], {off_a[i]});
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
// B pointer
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = gep(shmems_[B], {off_b[i]});
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{$0, $1, $2, $3}, "
"{$4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11, $12, $13};",
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
" {$0, $1, $2, $3},"
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
unsigned cols_per_thread = num_rep_0 * 2;
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_m * 2;
std::vector<size_t> idx = {
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
(m*2 + 1) + (n*2 + 1)*cols_per_thread
(m + 0) + (n*2 + 0)*cols_per_thread,
(m + 0) + (n*2 + 1)*cols_per_thread,
(m + 1) + (n*2 + 0)*cols_per_thread,
(m + 1) + (n*2 + 1)*cols_per_thread
};
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
hb[{n, K}], hb[{n, K+8}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
Value *nc = call(mma_ty, mma_fn,
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
hb[{n, k}], hb[{n, k+1}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
@@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds =
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto register_lds2 =
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
} else
vals[{m, K}] = val;
vals[{n, k}] = val;
};
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
Value* ptra;
if(K == 0 && is_prefetch){
if(inc == 0)
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
else
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
else
ptra = ptrs_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
"=r,=r,=r,=r,r", true);
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
Value* ptrb;
if(K == 0 && is_prefetch){
if(inc == 0)
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
else
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
else
ptrb = ptrs_b[offidx];
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
"=r,=r,=r,=r,r", true);
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
register_lds2(hb, n, K, inc, hb0, is_prefetch);
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
for(unsigned m = 0; m < num_rep_0; m++){
ha[{m, 0}].first = phi(fp16x2_ty, 2);
ha[{m, 0}].second = phi(fp16x2_ty, 2);
ha[{m, 8}].first = phi(fp16x2_ty, 2);
ha[{m, 8}].second = phi(fp16x2_ty, 2);
for(unsigned m = 0; m < num_rep_m; m++){
ha[{2*m, 0}] = phi(phi_ty, 2);
ha[{2*m+1, 0}] = phi(phi_ty, 2);
ha[{2*m, 1}] = phi(phi_ty, 2);
ha[{2*m+1, 1}] = phi(phi_ty, 2);
}
for(unsigned n = 0; n < num_rep_1; n+=2){
hb[{n, 0}] = phi(fp16x2_ty, 2);
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
hb[{n, 8}] = phi(fp16x2_ty, 2);
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
for(unsigned n = 0; n < num_rep_n; n+=2){
hb[{n, 0}] = phi(phi_ty, 2);
hb[{n+1, 0}] = phi(phi_ty, 2);
hb[{n, 1}] = phi(phi_ty, 2);
hb[{n+1, 1}] = phi(phi_ty, 2);
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, 0, 0, true);
for(unsigned n = 0; n < num_rep_1; n+=2)
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 0, 0, true);
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(CurrBB);
for(unsigned K = 0; K < NK; K += 16){
int NEXTK = (K + 16) % NK;
for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
int next_k = (k + 1) % num_rep_k;
// prefetch A
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, NEXTK, 1, true);
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*next_k, 1, true);
// prefetch B
for(unsigned n = 0; n < num_rep_1; n+=2)
load_b(n, NEXTK, 1, true);
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*next_k, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
call_mma(m, n, K);
for(unsigned m = 0; m < num_rep_m; m++)
for(unsigned n = 0; n < num_rep_n; n++){
call_mma(2*m, n, 2*k);
}
}
}
else{
for(unsigned K = 0; K < NK; K += 16)
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
if(ha.find({m, K}) == ha.end())
load_a(m, K, 0, false);
if(hb.find({n, K})==hb.end())
load_b(n, K, 0, false);
call_mma(m, n, K);
}
for (unsigned k = 0; k < num_rep_k; k++) {
for (unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*k, 0, /*is_prefetch*/false);
for (unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*k, 0, /*is_prefetch*/false);
for (unsigned m = 0; m < num_rep_m; m++)
for (unsigned n = 0; n < num_rep_n; n++)
call_mma(2*m, n, 2*k);
}
}
// write back
unsigned i = 0;
@@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK);
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
}
@@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
acc = bit_cast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return builder_->CreateBitCast(ret, ty);
return bit_cast(ret, ty);
}
/**
@@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
// pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
// Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
@@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// stash value to shared mem
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj]};
store(vals_[in][idxs], ptr);
store(bit_cast(vals_[in][idxs], ty), ptr);
}
add_barrier();
max_ii = out_ax[0].size()/n_reps[0];