[examples] added basic skeleton to generate matrix multiplication PTX
This commit is contained in:
@@ -12,9 +12,10 @@ include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
||||
|
||||
# LLVM
|
||||
find_package(LLVM REQUIRED CONFIG)
|
||||
message(STATUS ${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
llvm_map_components_to_libnames(llvm_libs support core irreader)
|
||||
llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all)
|
||||
|
||||
#Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
@@ -33,6 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
|
||||
# TDL
|
||||
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
|
||||
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
|
||||
message(STATUS ${llvm_libs})
|
||||
target_link_libraries(tdl ${llvm_libs})
|
||||
|
||||
# Examples
|
||||
|
@@ -13,6 +13,12 @@
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
@@ -36,7 +42,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
for(k = K; k >= 0; k = k - 8){\
|
||||
fp32 a[32, 8] = *pa;\
|
||||
fp32 b[32, 8] = *pb;\
|
||||
C = C + 1;\
|
||||
C = dot(a,b,C);\
|
||||
pa = pa + 8*M;\
|
||||
pb = pb + 8*K;\
|
||||
}\
|
||||
@@ -44,6 +50,16 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
}\
|
||||
";
|
||||
|
||||
static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
|
||||
std::string Ret = "e";
|
||||
if (!is64Bit)
|
||||
Ret += "-p:32:32";
|
||||
else if (UseShortPointers)
|
||||
Ret += "-p3:32:32-p4:32:32-p5:32:32";
|
||||
Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||
return Ret;
|
||||
}
|
||||
|
||||
int main() {
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
@@ -86,12 +102,37 @@ int main() {
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
selection.run(module, llvm_module);
|
||||
// std::vector<unsigned*> params = tune.get_params(module);
|
||||
// std::cout << params.size() << std::endl;
|
||||
// selection.run(module, llvm_module);
|
||||
// print LLVM program
|
||||
llvm::PrintModulePass print(llvm::outs());
|
||||
llvm::AnalysisManager<llvm::Module> analysis;
|
||||
print.run(llvm_module, analysis);
|
||||
|
||||
// // print LLVM program
|
||||
// llvm::PrintModulePass print(llvm::outs());
|
||||
// llvm::AnalysisManager<llvm::Module> analysis;
|
||||
// print.run(llvm_module, analysis);
|
||||
|
||||
// create target machine
|
||||
{
|
||||
llvm::InitializeAllTargetInfos();
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmParsers();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
|
||||
llvm_module.setTargetTriple("nvptx64-nvidia-cuda");
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(llvm_module.getTargetTriple(), error);
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(llvm_module.getTargetTriple(), "sm_52", "",
|
||||
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
llvm_module.setDataLayout(computeDataLayout(true, true));
|
||||
|
||||
// emit machine code
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(llvm_module);
|
||||
std::string src(buffer.begin(), buffer.end());
|
||||
std::cout << src << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@@ -441,6 +441,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ti->set_value(idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
// matrix multiplication
|
||||
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
|
||||
ir::value *A = ins->get_operand(0);
|
||||
ir::value *B = ins->get_operand(1);
|
||||
ir::value *C = ins->get_operand(2);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
Value *a = tmap_.at(A)->get_value(a_idx);
|
||||
Value *b = tmap_.at(B)->get_value(b_idx);
|
||||
res = builder.CreateAdd(res, builder.CreateMul(a, b));
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
// element-wise
|
||||
else {
|
||||
result->for_each([&](indices_t idx){
|
||||
|
@@ -17,10 +17,12 @@ void place_shared_copy::run(ir::module &mod) {
|
||||
builder.set_insert_point(i);
|
||||
ir::value *x = i->get_operand(0);
|
||||
ir::value *y = i->get_operand(1);
|
||||
ir::value *rx = builder.create_copy_to_shared(x);
|
||||
ir::value *ry = builder.create_copy_to_shared(y);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
|
||||
ir::instruction *ry = (ir::instruction*)builder.create_copy_to_shared(y);
|
||||
x->replace_all_uses_with(rx);
|
||||
y->replace_all_uses_with(ry);
|
||||
rx->set_operand(0, x);
|
||||
ry->set_operand(0, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user