[examples] added basic skeleton to generate matrix multiplication PTX
This commit is contained in:
@@ -12,9 +12,10 @@ include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
|||||||
|
|
||||||
# LLVM
|
# LLVM
|
||||||
find_package(LLVM REQUIRED CONFIG)
|
find_package(LLVM REQUIRED CONFIG)
|
||||||
|
message(STATUS ${LLVM_INCLUDE_DIRS})
|
||||||
include_directories(${LLVM_INCLUDE_DIRS})
|
include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
llvm_map_components_to_libnames(llvm_libs support core irreader)
|
llvm_map_components_to_libnames(llvm_libs support core irreader MC NVPTXCodeGen all)
|
||||||
|
|
||||||
#Default build type
|
#Default build type
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
@@ -33,6 +34,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
|
|||||||
# TDL
|
# TDL
|
||||||
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
|
file(GLOB_RECURSE LIBTDL_SRC lib/*.cpp)
|
||||||
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
|
add_library(tdl SHARED ${LIBTDL_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
|
||||||
|
message(STATUS ${llvm_libs})
|
||||||
target_link_libraries(tdl ${llvm_libs})
|
target_link_libraries(tdl ${llvm_libs})
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
@@ -13,6 +13,12 @@
|
|||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
#include "llvm/IR/PassManager.h"
|
#include "llvm/IR/PassManager.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "llvm/Support/TargetRegistry.h"
|
||||||
|
#include "llvm/Support/TargetSelect.h"
|
||||||
|
#include "llvm/Target/TargetMachine.h"
|
||||||
|
#include "llvm/Target/TargetOptions.h"
|
||||||
|
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||||
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
|
|
||||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||||
extern int yyparse();
|
extern int yyparse();
|
||||||
@@ -36,7 +42,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
|||||||
for(k = K; k >= 0; k = k - 8){\
|
for(k = K; k >= 0; k = k - 8){\
|
||||||
fp32 a[32, 8] = *pa;\
|
fp32 a[32, 8] = *pa;\
|
||||||
fp32 b[32, 8] = *pb;\
|
fp32 b[32, 8] = *pb;\
|
||||||
C = C + 1;\
|
C = dot(a,b,C);\
|
||||||
pa = pa + 8*M;\
|
pa = pa + 8*M;\
|
||||||
pb = pb + 8*K;\
|
pb = pb + 8*K;\
|
||||||
}\
|
}\
|
||||||
@@ -44,6 +50,16 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
|||||||
}\
|
}\
|
||||||
";
|
";
|
||||||
|
|
||||||
|
static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
|
||||||
|
std::string Ret = "e";
|
||||||
|
if (!is64Bit)
|
||||||
|
Ret += "-p:32:32";
|
||||||
|
else if (UseShortPointers)
|
||||||
|
Ret += "-p3:32:32-p4:32:32-p5:32:32";
|
||||||
|
Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||||
|
return Ret;
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||||
yyparse();
|
yyparse();
|
||||||
@@ -86,12 +102,37 @@ int main() {
|
|||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run();
|
allocation.run();
|
||||||
selection.run(module, llvm_module);
|
selection.run(module, llvm_module);
|
||||||
// std::vector<unsigned*> params = tune.get_params(module);
|
|
||||||
// std::cout << params.size() << std::endl;
|
// // print LLVM program
|
||||||
// selection.run(module, llvm_module);
|
// llvm::PrintModulePass print(llvm::outs());
|
||||||
// print LLVM program
|
// llvm::AnalysisManager<llvm::Module> analysis;
|
||||||
llvm::PrintModulePass print(llvm::outs());
|
// print.run(llvm_module, analysis);
|
||||||
llvm::AnalysisManager<llvm::Module> analysis;
|
|
||||||
print.run(llvm_module, analysis);
|
// create target machine
|
||||||
|
{
|
||||||
|
llvm::InitializeAllTargetInfos();
|
||||||
|
llvm::InitializeAllTargets();
|
||||||
|
llvm::InitializeAllTargetMCs();
|
||||||
|
llvm::InitializeAllAsmParsers();
|
||||||
|
llvm::InitializeAllAsmPrinters();
|
||||||
|
|
||||||
|
llvm_module.setTargetTriple("nvptx64-nvidia-cuda");
|
||||||
|
std::string error;
|
||||||
|
auto target = llvm::TargetRegistry::lookupTarget(llvm_module.getTargetTriple(), error);
|
||||||
|
llvm::TargetMachine *machine = target->createTargetMachine(llvm_module.getTargetTriple(), "sm_52", "",
|
||||||
|
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||||
|
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||||
|
llvm_module.setDataLayout(computeDataLayout(true, true));
|
||||||
|
|
||||||
|
// emit machine code
|
||||||
|
llvm::legacy::PassManager pass;
|
||||||
|
llvm::SmallVector<char, 0> buffer;
|
||||||
|
llvm::raw_svector_ostream stream(buffer);
|
||||||
|
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||||
|
pass.run(llvm_module);
|
||||||
|
std::string src(buffer.begin(), buffer.end());
|
||||||
|
std::cout << src << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@@ -441,6 +441,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
ti->set_value(idx, in->get_value(idx));
|
ti->set_value(idx, in->get_value(idx));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// matrix multiplication
|
||||||
|
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
|
||||||
|
ir::value *A = ins->get_operand(0);
|
||||||
|
ir::value *B = ins->get_operand(1);
|
||||||
|
ir::value *C = ins->get_operand(2);
|
||||||
|
result->for_each([&](indices_t idx){
|
||||||
|
Value *res = tmap_.at(C)->get_value(idx);
|
||||||
|
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||||
|
for(unsigned K = 0; K < NK; ++K){
|
||||||
|
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||||
|
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||||
|
Value *a = tmap_.at(A)->get_value(a_idx);
|
||||||
|
Value *b = tmap_.at(B)->get_value(b_idx);
|
||||||
|
res = builder.CreateAdd(res, builder.CreateMul(a, b));
|
||||||
|
}
|
||||||
|
result->set_value(idx, res);
|
||||||
|
});
|
||||||
|
}
|
||||||
// element-wise
|
// element-wise
|
||||||
else {
|
else {
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
|
@@ -17,10 +17,12 @@ void place_shared_copy::run(ir::module &mod) {
|
|||||||
builder.set_insert_point(i);
|
builder.set_insert_point(i);
|
||||||
ir::value *x = i->get_operand(0);
|
ir::value *x = i->get_operand(0);
|
||||||
ir::value *y = i->get_operand(1);
|
ir::value *y = i->get_operand(1);
|
||||||
ir::value *rx = builder.create_copy_to_shared(x);
|
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
|
||||||
ir::value *ry = builder.create_copy_to_shared(y);
|
ir::instruction *ry = (ir::instruction*)builder.create_copy_to_shared(y);
|
||||||
x->replace_all_uses_with(rx);
|
x->replace_all_uses_with(rx);
|
||||||
y->replace_all_uses_with(ry);
|
y->replace_all_uses_with(ry);
|
||||||
|
rx->set_operand(0, x);
|
||||||
|
ry->set_operand(0, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user