Added inline PTX for mma.sync
This commit is contained in:
@@ -45,12 +45,7 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
|
||||
fp16* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
|
||||
fp16 a[TM, TK] = *pa;
|
||||
fp16 b[TN, TK] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
c = dot(a, trans(b), c);
|
||||
pa = pa + TK*lda;
|
||||
pb = pb + TK*ldb;
|
||||
@@ -59,15 +54,6 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp16* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
|
||||
fp16* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
|
||||
fp16 a[TM, 1] = checka ? *pa : 0;
|
||||
fp16 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
|
@@ -14,6 +14,7 @@
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -470,24 +471,27 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
// offset_i = tid & 2 + tid & 8
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAnd(u_thread_id, _8));
|
||||
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
|
||||
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
||||
// idx_i
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned i = 0; i < 2; i++)
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
|
||||
}
|
||||
|
||||
// idx_j
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned j = 0; j < 2; j++)
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
|
||||
}
|
||||
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -822,29 +826,80 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
bool BT = dot->is_b_trans();
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
std::cout << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
|
||||
{
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_3 = builder.getInt32(3);
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
// offset_a_i = (tid & 3)
|
||||
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
|
||||
Value *offset_a_i = builder.CreateAnd(tid, _3);
|
||||
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
|
||||
_2),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
// offset_b_i = (tid & 3)
|
||||
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
|
||||
Value *offset_b_i = builder.CreateAnd(tid, _3);
|
||||
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||
_4));
|
||||
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
|
||||
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
|
||||
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
|
||||
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
|
||||
std::vector<Value *> fc;
|
||||
result->for_each([&](indices_t idx){
|
||||
fc.push_back(result->get_value(idx));
|
||||
});
|
||||
|
||||
Type *void_ty = builder.getVoidTy();
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8);
|
||||
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2);
|
||||
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\
|
||||
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\
|
||||
{$8, $9}, \n\
|
||||
{$10, $11}, \n\
|
||||
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false);
|
||||
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@@ -119,7 +119,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "-ptx60", opt,
|
||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
@@ -243,14 +243,17 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer, "", Assembly);
|
||||
return std::string(buffer.begin(), buffer.end());
|
||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_75", layout, buffer, "", Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
std::string to_replace = ".version 6.3";
|
||||
result.replace(result.find(to_replace), to_replace.size(), ".version 6.4");
|
||||
return result;
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
Reference in New Issue
Block a user