[tensor cores] added basic codegen template for using wmma
This commit is contained in:
@@ -887,18 +887,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
});
|
});
|
||||||
|
|
||||||
Type *void_ty = builder.getVoidTy();
|
Type *void_ty = builder.getVoidTy();
|
||||||
|
Type *int32_ty = builder.getInt32Ty();
|
||||||
Type *fp32_ty = builder.getFloatTy();
|
Type *fp32_ty = builder.getFloatTy();
|
||||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||||
// Type *fp32_vec8_ty = VectorType::get(fp32_ty, 8);
|
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||||
// Type *fp16x2_vec2 = VectorType::get(fp16x2_ty, 2);
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
FunctionType *mma_ty = FunctionType::get(void_ty, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}, false);
|
|
||||||
|
|
||||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 \n\
|
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
||||||
{$0, $1, $2, $3, $4, $5, $6, $7}, \n\
|
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||||
{$8, $9}, \n\
|
"{$8, $9}, "
|
||||||
{$10, $11}, \n\
|
"{$10, $11}, "
|
||||||
{$0, $1, $2, $3, $4, $5, $6, $7};", "+f, +f, +f, +f, +f, +f, +f, +f, r, r, r, r", false);
|
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||||
builder.CreateCall(mma_fn, {fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7], ha0, ha1, hb0, hb1});
|
Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]});
|
||||||
|
std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl;
|
||||||
|
unsigned i = 0;
|
||||||
|
result->for_each([&](indices_t idx){
|
||||||
|
result->set_value(idx, builder.CreateExtractValue(nc, {i++}));
|
||||||
|
});
|
||||||
|
std::cout << "haha" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
|||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// debug
|
// debug
|
||||||
// llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
// pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
// pm.run(*module);
|
pm.run(*module);
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
|
Reference in New Issue
Block a user