Added inline PTX for mma.sync

This commit is contained in:
Philippe Tillet
2019-06-07 19:39:33 -07:00
parent 6fce9f28ae
commit ec4c6aaaaa
3 changed files with 92 additions and 48 deletions

View File

@@ -45,12 +45,7 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
fp16* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp16 a[TM, TK] = *pa;
fp16 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
for(int32 k = K; k > 0; k = k - TK){
c = dot(a, trans(b), c);
pa = pa + TK*lda;
pb = pb + TK*ldb;
@@ -59,15 +54,6 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp16* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp16* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp16 a[TM, 1] = checka ? *pa : 0;
fp16 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];

View File

@@ -14,6 +14,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
namespace triton{
namespace codegen{
@@ -470,24 +471,27 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
// offset_i = tid & 2 + tid & 8
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAnd(u_thread_id, _8));
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
// idx_i
std::vector<Value*> idx_i;
for(unsigned i = 0; i < 2; i++)
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
std::vector<Value*> idx_j;
for(unsigned i = 0; i < 2; i++){
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
}
// idx_j
std::vector<Value*> idx_j;
for(unsigned j = 0; j < 2; j++)
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
std::vector<Value*> idx_i;
for(unsigned j = 0; j < 2; j++){
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
}
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
}
}
@@ -822,29 +826,80 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
std::cout << NK << std::endl;
if(NK != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
{
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
}
result->set_value(idx, res);
});
}
else
{
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
Value *_4 = builder.getInt32(4);
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, builder, 0);
// offset_a_i = (tid & 3)
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
Value *offset_a_i = builder.CreateAnd(tid, _3);
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
_2),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
// offset_b_i = (tid & 3)
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
Value *offset_b_i = builder.CreateAnd(tid, _3);
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
std::vector<Value *> fc;
result->for_each([&](indices_t idx){
fc.push_back(result->get_value(idx));
});
Type *void_ty = builder.getVoidTy();
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8);
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2);
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\
{$8, $9}, \n\
{$10, $11}, \n\
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false);
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1});
}
}
else
{

View File

@@ -119,7 +119,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "-ptx60", opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
@@ -243,14 +243,17 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer, "", Assembly);
return std::string(buffer.begin(), buffer.end());
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_75", layout, buffer, "", Assembly);
std::string result(buffer.begin(), buffer.end());
std::string to_replace = ".version 6.3";
result.replace(result.find(to_replace), to_replace.size(), ".version 6.4");
return result;
}
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};