[tensor cores] added basic codegen template for using wmma

This commit is contained in:
Philippe Tillet
2019-06-07 21:19:47 -07:00
parent ec4c6aaaaa
commit 5f3d48c1d0
2 changed files with 19 additions and 13 deletions

View File

@@ -887,18 +887,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}); });
Type *void_ty = builder.getVoidTy(); Type *void_ty = builder.getVoidTy();
Type *int32_ty = builder.getInt32Ty();
Type *fp32_ty = builder.getFloatTy(); Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2); Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8); Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2); FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\ InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\ "{$0, $1, $2, $3, $4, $5, $6, $7}, "
{$8, $9}, \n\ "{$8, $9}, "
{$10, $11}, \n\ "{$10, $11}, "
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false); "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1}); Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]});
std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl;
unsigned i = 0;
result->for_each([&](indices_t idx){
result->set_value(idx, builder.CreateExtractValue(nc, {i++}));
});
std::cout << "haha" << std::endl;
} }
} }
else else

View File

@@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
file_type_t ft) { file_type_t ft) {
init_llvm(); init_llvm();
// debug // debug
// llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs())); pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass()); pm.add(llvm::createVerifierPass());
// pm.run(*module); pm.run(*module);
// create machine // create machine
module->setTargetTriple(triple); module->setTargetTriple(triple);
std::string error; std::string error;