From 5f3d48c1d0b09adc00a822fc4482fbf3e81cfb4b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 7 Jun 2019 21:19:47 -0700 Subject: [PATCH] [tensor cores] added basic codegen template for using wmma --- lib/codegen/selection.cpp | 24 +++++++++++++++--------- lib/driver/module.cpp | 8 ++++---- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 8d5ba3d4a..e6a04a3b3 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -887,18 +887,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & }); Type *void_ty = builder.getVoidTy(); + Type *int32_ty = builder.getInt32Ty(); Type *fp32_ty = builder.getFloatTy(); Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2); -// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8); -// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2); - FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false); + Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); + FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\ - {$0, $1, $2, $3, $4, $5, $6, $7}, \n\ - {$8, $9}, \n\ - {$10, $11}, \n\ - {$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false); - builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1}); + InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 " + "{$0, $1, $2, $3, $4, $5, $6, $7}, " + "{$8, $9}, " + "{$10, $11}, " + "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); + Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]}); + std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl; + unsigned i = 0; + result->for_each([&](indices_t idx){ + result->set_value(idx, builder.CreateExtractValue(nc, {i++})); + }); + std::cout << "haha" << std::endl; } } else diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index ebc876559..a9d8ab549 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple file_type_t ft) { init_llvm(); // debug -// llvm::legacy::PassManager pm; -// pm.add(llvm::createPrintModulePass(llvm::outs())); -// pm.add(llvm::createVerifierPass()); -// pm.run(*module); + llvm::legacy::PassManager pm; + pm.add(llvm::createPrintModulePass(llvm::outs())); + pm.add(llvm::createVerifierPass()); + pm.run(*module); // create machine module->setTargetTriple(triple); std::string error;