[examples] added basic skeleton to generate matrix multiplication PTX

This commit is contained in:
Philippe Tillet
2019-02-07 22:42:54 -05:00
parent 1b9a7a8e97
commit dd35277858
4 changed files with 74 additions and 11 deletions

View File

@@ -12,9 +12,10 @@ include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
# LLVM
find_package(LLVM REQUIRED CONFIG)
message(STATUS ${LLVM_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(llvm_libs support core irreader)
llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all)
#Default build type
if(NOT CMAKE_BUILD_TYPE)
@@ -33,6 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
# TDL
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
message(STATUS ${llvm_libs})
target_link_libraries(tdl ${llvm_libs})
# Examples

View File

@@ -13,6 +13,12 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
@@ -36,7 +42,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
for(k = K; k >= 0; k = k - 8){\
fp32 a[32, 8] = *pa;\
fp32 b[32, 8] = *pb;\
C = C + 1;\
C = dot(a,b,C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
}\
@@ -44,6 +50,16 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
}\
";
static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
std::string Ret = "e";
if (!is64Bit)
Ret += "-p:32:32";
else if (UseShortPointers)
Ret += "-p3:32:32-p4:32:32-p5:32:32";
Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
return Ret;
}
int main() {
YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse();
@@ -86,12 +102,37 @@ int main() {
liveness.run(module);
allocation.run();
selection.run(module, llvm_module);
// std::vector<unsigned*> params = tune.get_params(module);
// std::cout << params.size() << std::endl;
// selection.run(module, llvm_module);
// print LLVM program
llvm::PrintModulePass print(llvm::outs());
llvm::AnalysisManager<llvm::Module> analysis;
print.run(llvm_module, analysis);
// // print LLVM program
// llvm::PrintModulePass print(llvm::outs());
// llvm::AnalysisManager<llvm::Module> analysis;
// print.run(llvm_module, analysis);
// create target machine
{
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
llvm_module.setTargetTriple("nvptx64-nvidia-cuda");
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(llvm_module.getTargetTriple(), error);
llvm::TargetMachine *machine = target->createTargetMachine(llvm_module.getTargetTriple(), "sm_52", "",
llvm::TargetOptions(), llvm::Reloc::Model(),
llvm::None, llvm::CodeGenOpt::Aggressive);
llvm_module.setDataLayout(computeDataLayout(true, true));
// emit machine code
llvm::legacy::PassManager pass;
llvm::SmallVector<char, 0> buffer;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
pass.run(llvm_module);
std::string src(buffer.begin(), buffer.end());
std::cout << src << std::endl;
}
return 0;
}

View File

@@ -441,6 +441,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(idx, in->get_value(idx));
});
}
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
result->for_each([&](indices_t idx){
Value *res = tmap_.at(C)->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1];
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = tmap_.at(A)->get_value(a_idx);
Value *b = tmap_.at(B)->get_value(b_idx);
res = builder.CreateAdd(res, builder.CreateMul(a, b));
}
result->set_value(idx, res);
});
}
// element-wise
else {
result->for_each([&](indices_t idx){

View File

@@ -17,10 +17,12 @@ void place_shared_copy::run(ir::module &mod) {
builder.set_insert_point(i);
ir::value *x = i->get_operand(0);
ir::value *y = i->get_operand(1);
ir::value *rx = builder.create_copy_to_shared(x);
ir::value *ry = builder.create_copy_to_shared(y);
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
ir::instruction *ry = (ir::instruction*)builder.create_copy_to_shared(y);
x->replace_all_uses_with(rx);
y->replace_all_uses_with(ry);
rx->set_operand(0, x);
ry->set_operand(0, y);
}
}