[codegen] separated lower_dot_inst into lower_outer_dot ||

lower_hmma_dot || lower_scanline_dot
This commit is contained in:
Philippe Tillet
2019-08-12 21:48:30 -07:00
parent 4bc5758a22
commit b8cd63e0da
2 changed files with 198 additions and 184 deletions

View File

@@ -999,194 +999,148 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I
});
}
void selection::lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
}
void selection::lower_scalar_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
}
void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
const auto& shapes = dot->get_type()->get_tile_shapes();
distributed_tile* result = (distributed_tile*)tmap_.at(dot);
Module *module = fn->getParent();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *C = dot->get_operand(2);
bool AT = dot->is_a_trans();
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1;
unsigned NK = A_shapes[red_axis]->get_value();
// std::cout << red_axis << " " << NK << std::endl;
if(NK != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
// input indices
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
// add batching dimension
for(size_t i = 2; i < idx.size(); i++){
a_idx.insert(a_idx.end(), idx[i]);
b_idx.insert(b_idx.end(), idx[i]);
}
// load value
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
if(a->getType() != c_ty)
a = builder.CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
TA->set_vector_size(4*pack_size_0_);
TB->set_vector_size(4*pack_size_1_);
TA->set_return_mode(true);
TB->set_return_mode(true);
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
TC->for_each([&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(TD->get_value(idx));
});
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Value *offset_a_i = offset_a_i_;
Value *offset_a_k = offset_a_k_;
Value *offset_b_j = offset_b_j_;
Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
if(dot->is_a_trans()){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_a_k = builder.getInt32(0);
}
if(!dot->is_b_trans()){
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_b_k = builder.getInt32(0);
}
std::string op_a = dot->is_a_trans() ? "row" : "col";
std::string op_b = dot->is_b_trans() ? "row" : "col";
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
unsigned ld_fc = num_rep_i * 2;
for(auto& x: fcs){
std::vector<Value *>& fc = x.second;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
}
}
else {
TA->set_vector_size(4*pack_size_0_);
TB->set_vector_size(4*pack_size_1_);
TA->set_return_mode(true);
TB->set_return_mode(true);
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
result->for_each([&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(TC->get_value(idx));
});
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Value *offset_a_i = offset_a_i_;
Value *offset_a_k = offset_a_k_;
Value *offset_b_j = offset_b_j_;
Value *offset_b_k = offset_b_k_;
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
if(dot->is_a_trans()){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_a_k = builder.getInt32(0);
}
if(!dot->is_b_trans()){
offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4)));
offset_b_k = builder.getInt32(0);
}
std::string op_a = dot->is_a_trans() ? "row" : "col";
std::string op_b = dot->is_b_trans() ? "row" : "col";
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
unsigned ld_fc = num_rep_i * 2;
for(auto& x: fcs){
std::vector<Value *>& fc = x.second;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
}
}
}
}
// write back
unsigned i = 0;
result->for_each([&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
result->set_value(idx, fcs.at(key)[i++]);
});
TA->set_return_mode(false);
TB->set_return_mode(false);
}
}
else
{
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
indices_t a_idx = {idx[0], builder.getInt32(0)};
indices_t b_idx = {builder.getInt32(0), idx[1]};
if(AT)
// write back
unsigned i = 0;
TC->for_each([&](indices_t idx){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
TC->set_value(idx, fcs.at(key)[i++]);
});
TA->set_return_mode(false);
TB->set_return_mode(false);
}
void selection::lower_scanline_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
TC->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
// input indices
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(dot->is_a_trans())
std::swap(a_idx[0], a_idx[1]);
if(BT)
if(dot->is_b_trans())
std::swap(b_idx[0], b_idx[1]);
// add batching dimension
for(size_t i = 2; i < idx.size(); i++){
a_idx.insert(a_idx.end(), idx[i]);
b_idx.insert(b_idx.end(), idx[i]);
}
// load value
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
if(a->getType() != c_ty)
@@ -1194,8 +1148,59 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
res = builder.CreateCall(f_mul_add, {a, b, res});
result->set_value(idx, res);
});
}
TC->set_value(idx, res);
});
}
void selection::lower_outer_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD,
Type *c_ty, Function *f_mul_add) {
TC->for_each([&](indices_t idx){
Value *res = TD->get_value(idx);
indices_t a_idx = {idx[0], builder.getInt32(0)};
indices_t b_idx = {builder.getInt32(0), idx[1]};
if(dot->is_a_trans())
std::swap(a_idx[0], a_idx[1]);
if(dot->is_b_trans())
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
if(a->getType() != c_ty)
a = builder.CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
res = builder.CreateCall(f_mul_add, {a, b, res});
TC->set_value(idx, res);
});
}
void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* TC = (distributed_tile*)tmap_.at(dot);
Module *module = fn->getParent();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1;
unsigned NK = A_shapes[red_axis]->get_value();
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN)
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
else
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
}
else {
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
lower_outer_dot(dot, ctx, fn, builder, TC, TA, TB, TD, c_ty, f_mul_add);
}
}