Added inline PTX for mma.sync

This commit is contained in:
Philippe Tillet
2019-06-07 19:39:33 -07:00
parent 6fce9f28ae
commit ec4c6aaaaa
3 changed files with 92 additions and 48 deletions

View File

@@ -14,6 +14,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
namespace triton{
namespace codegen{
@@ -470,24 +471,27 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
// offset_i = tid & 2 + tid & 8
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAnd(u_thread_id, _8));
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
// idx_i
std::vector<Value*> idx_i;
for(unsigned i = 0; i < 2; i++)
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
std::vector<Value*> idx_j;
for(unsigned i = 0; i < 2; i++){
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
}
// idx_j
std::vector<Value*> idx_j;
for(unsigned j = 0; j < 2; j++)
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
std::vector<Value*> idx_i;
for(unsigned j = 0; j < 2; j++){
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
}
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
}
}
@@ -822,29 +826,80 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
std::cout << NK << std::endl;
if(NK != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
{
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
}
result->set_value(idx, res);
});
}
else
{
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
Value *_4 = builder.getInt32(4);
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, builder, 0);
// offset_a_i = (tid & 3)
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
Value *offset_a_i = builder.CreateAnd(tid, _3);
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
_2),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
// offset_b_i = (tid & 3)
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
Value *offset_b_i = builder.CreateAnd(tid, _3);
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
std::vector<Value *> fc;
result->for_each([&](indices_t idx){
fc.push_back(result->get_value(idx));
});
Type *void_ty = builder.getVoidTy();
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8);
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2);
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\
{$8, $9}, \n\
{$10, $11}, \n\
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false);
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1});
}
}
else
{