Added inline PTX for mma.sync
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -470,24 +471,27 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
// offset_i = tid & 2 + tid & 8
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAnd(u_thread_id, _8));
|
||||
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
|
||||
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
||||
// idx_i
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned i = 0; i < 2; i++)
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
|
||||
}
|
||||
|
||||
// idx_j
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned j = 0; j < 2; j++)
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
|
||||
}
|
||||
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -822,29 +826,80 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
bool BT = dot->is_b_trans();
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
std::cout << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
|
||||
{
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_3 = builder.getInt32(3);
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
// offset_a_i = (tid & 3)
|
||||
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
|
||||
Value *offset_a_i = builder.CreateAnd(tid, _3);
|
||||
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
|
||||
_2),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
// offset_b_i = (tid & 3)
|
||||
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
|
||||
Value *offset_b_i = builder.CreateAnd(tid, _3);
|
||||
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
|
||||
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
|
||||
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
|
||||
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
|
||||
std::vector<Value *> fc;
|
||||
result->for_each([&](indices_t idx){
|
||||
fc.push_back(result->get_value(idx));
|
||||
});
|
||||
|
||||
Type *void_ty = builder.getVoidTy();
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8);
|
||||
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2);
|
||||
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\
|
||||
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\
|
||||
{$8, $9}, \n\
|
||||
{$10, $11}, \n\
|
||||
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false);
|
||||
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
Reference in New Issue
Block a user