TDL [Parser]: Initial commit
This commit is contained in:
@@ -1,21 +1,7 @@
|
||||
cmake_minimum_required(VERSION 2.8.7)
|
||||
project(TDL)
|
||||
|
||||
# LLVM
|
||||
include(cmake/FindLLVM.cmake)
|
||||
|
||||
# Link directories
|
||||
link_directories(/home/philippe/Development/llvm-tlvm/build/lib)
|
||||
# Include directories
|
||||
include_directories(/home/philippe/Development/llvm-tlvm/include)
|
||||
include_directories(/home/philippe/Development/llvm-tlvm/build/include)
|
||||
|
||||
# Flags
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wextra -pedantic -Wno-strict-aliasing")
|
||||
|
||||
# Executables
|
||||
foreach(PROG gemm conv)
|
||||
add_executable(${PROG} ${PROG}.cpp)
|
||||
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
|
||||
target_link_libraries(${PROG} ${LLVM_LIBRARIES})
|
||||
endforeach()
|
||||
find_package(BISON)
|
||||
BISON_TARGET(Parser parser.y ${CMAKE_CURRENT_BINARY_DIR}/parser.cpp)
|
||||
find_package(FLEX)
|
||||
FLEX_TARGET(Lexer scanner.l ${CMAKE_CURRENT_BINARY_DIR}/scanner.cpp)
|
||||
get_filename_component(BISON_Parser_INCLUDE_DIRECTORIES ${BISON_Parser_OUTPUT_HEADER} DIRECTORY)
|
||||
include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
||||
add_executable(test main.cpp ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
|
||||
|
295
ast.h
Normal file
295
ast.h
Normal file
@@ -0,0 +1,295 @@
|
||||
#include "parser.hpp"
|
||||
#include <cassert>
|
||||
#include <list>
|
||||
#include <string>
|
||||
|
||||
typedef yytokentype token_type;
|
||||
|
||||
namespace ast{
|
||||
|
||||
class node { };
|
||||
|
||||
template<class T>
|
||||
class list: public node {
|
||||
public:
|
||||
list(const T& x): values_{x} {}
|
||||
node* append(const T& x) { values_.push_back(x); return this;}
|
||||
|
||||
private:
|
||||
std::list<T> values_;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
node* append_ptr_list(node *result, node *in){
|
||||
return static_cast<list<T*>*>(result)->append((T*)in);
|
||||
}
|
||||
|
||||
class binary_operator: public node{
|
||||
public:
|
||||
enum OP_T{
|
||||
MUL, DIV, REM,
|
||||
ADD, SUB,
|
||||
LEFT_SHIFT, RIGHT_SHIFT,
|
||||
LT, GT,
|
||||
LE, GE,
|
||||
EQ, NE,
|
||||
AND, XOR, OR,
|
||||
LAND, LOR
|
||||
};
|
||||
|
||||
static OP_T get_op(token_type token){
|
||||
switch(token){
|
||||
case LEFT_OP: return LEFT_SHIFT;
|
||||
case RIGHT_OP: return RIGHT_SHIFT;
|
||||
case LE_OP: return LE;
|
||||
case GE_OP: return GE;
|
||||
case EQ_OP: return EQ;
|
||||
case NE_OP: return NE;
|
||||
case AND_OP: return LAND;
|
||||
case OR_OP: return LOR;
|
||||
default: assert(false && "unreachable"); throw;
|
||||
}
|
||||
}
|
||||
|
||||
static OP_T get_op(char token){
|
||||
switch(token){
|
||||
case '*': return MUL;
|
||||
case '/': return DIV;
|
||||
case '%': return REM;
|
||||
case '+': return ADD;
|
||||
case '-': return SUB;
|
||||
case '<': return LT;
|
||||
case '>': return GT;
|
||||
case '&': return AND;
|
||||
case '^': return XOR;
|
||||
case '|': return OR;
|
||||
default: assert(false && "unreachable"); throw;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
binary_operator(token_type op, node *lhs, node *rhs)
|
||||
: op_(get_op(op)), lhs_(lhs), rhs_(rhs) { }
|
||||
binary_operator(char op, node *lhs, node *rhs)
|
||||
: op_(get_op(op)), lhs_(lhs), rhs_(rhs){ }
|
||||
|
||||
private:
|
||||
const OP_T op_;
|
||||
const node *lhs_;
|
||||
const node *rhs_;
|
||||
};
|
||||
|
||||
|
||||
class constant: public node{
|
||||
public:
|
||||
constant(int value): value_(value) { }
|
||||
|
||||
private:
|
||||
const int value_;
|
||||
};
|
||||
|
||||
class identifier: public node{
|
||||
public:
|
||||
identifier(char *&name): name_(name) { }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
class string_literal: public node{
|
||||
public:
|
||||
string_literal(char *&value): value_(value) { }
|
||||
|
||||
public:
|
||||
std::string value_;
|
||||
};
|
||||
|
||||
class unary_operator: public node{
|
||||
public:
|
||||
unary_operator(token_type token, node *arg): token_(token), arg_(arg) { }
|
||||
|
||||
private:
|
||||
const token_type token_;
|
||||
const node *arg_;
|
||||
};
|
||||
|
||||
class cast_operator: public node{
|
||||
public:
|
||||
cast_operator(token_type type, node *arg): type_(type), arg_(arg) { }
|
||||
|
||||
public:
|
||||
const token_type type_;
|
||||
const node *arg_;
|
||||
};
|
||||
|
||||
class conditional_expression: public node{
|
||||
public:
|
||||
conditional_expression(node *cond, node *true_value, node *false_value)
|
||||
: cond_(cond), true_value_(true_value), false_value_(false_value) { }
|
||||
|
||||
public:
|
||||
const node *cond_;
|
||||
const node *true_value_;
|
||||
const node *false_value_;
|
||||
};
|
||||
|
||||
class assignment_expression: public node{
|
||||
typedef binary_operator::OP_T op_t;
|
||||
|
||||
public:
|
||||
assignment_expression(node *lvalue, token_type op, node *rvalue)
|
||||
: lvalue_(lvalue), op_(binary_operator::get_op(op)), rvalue_(rvalue) { }
|
||||
|
||||
public:
|
||||
op_t op_;
|
||||
const node *lvalue_;
|
||||
const node *rvalue_;
|
||||
};
|
||||
|
||||
class compound_statement: public node{
|
||||
public:
|
||||
compound_statement() : statements_() {}
|
||||
compound_statement(node *stmt): statements_{stmt} {}
|
||||
compound_statement* append(node *stmt) { statements_.push_back(stmt); return this; }
|
||||
|
||||
private:
|
||||
std::list<node*> statements_;
|
||||
};
|
||||
|
||||
class selection_statement: public node{
|
||||
public:
|
||||
selection_statement(node *cond, node *if_value, node *else_value = nullptr)
|
||||
: cond_(cond), if_value_(if_value), else_value_(else_value) { }
|
||||
|
||||
public:
|
||||
const node *cond_;
|
||||
const node *if_value_;
|
||||
const node *else_value_;
|
||||
};
|
||||
|
||||
class iteration_statement: public node{
|
||||
public:
|
||||
iteration_statement(node *init, node *stop, node *exec, node *statements)
|
||||
: init_(init), stop_(stop), exec_(exec), statements_(statements) { }
|
||||
|
||||
private:
|
||||
const node *init_;
|
||||
const node *stop_;
|
||||
const node *exec_;
|
||||
const node *statements_;
|
||||
};
|
||||
|
||||
class no_op: public node { };
|
||||
|
||||
// Types
|
||||
class declarator: public node{
|
||||
|
||||
};
|
||||
|
||||
class pointer_declarator: public declarator{
|
||||
public:
|
||||
pointer_declarator(unsigned order)
|
||||
: order_(order) { }
|
||||
|
||||
pointer_declarator *inc(){
|
||||
order_ += 1;
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned order_;
|
||||
};
|
||||
|
||||
class tile_declarator: public declarator{
|
||||
public:
|
||||
tile_declarator(node *shapes)
|
||||
: shapes_((list<constant*>*)(shapes)) { }
|
||||
|
||||
public:
|
||||
const list<constant*>* shapes_;
|
||||
};
|
||||
|
||||
class parameter: public declarator {
|
||||
public:
|
||||
parameter(token_type type, node *decl)
|
||||
: type_(type), decl_(decl) { }
|
||||
|
||||
public:
|
||||
const token_type type_;
|
||||
const node *decl_;
|
||||
};
|
||||
|
||||
class function_declarator: public declarator{
|
||||
public:
|
||||
function_declarator(node *args)
|
||||
: args_((list<node*>)args) { }
|
||||
|
||||
public:
|
||||
const list<node*> args_;
|
||||
};
|
||||
|
||||
class compound_declarator: public declarator{
|
||||
public:
|
||||
compound_declarator(node *ptr, node *tile)
|
||||
: ptr_(ptr), tile_(tile) { }
|
||||
|
||||
public:
|
||||
const node *ptr_;
|
||||
const node *tile_;
|
||||
};
|
||||
|
||||
class init_declarator : public declarator{
|
||||
public:
|
||||
init_declarator(node *decl, node *initializer)
|
||||
: decl_(decl), initializer_(initializer){ }
|
||||
|
||||
public:
|
||||
const node *decl_;
|
||||
const node *initializer_;
|
||||
};
|
||||
|
||||
class declaration: public node{
|
||||
public:
|
||||
declaration(node *spec, node *init)
|
||||
: spec_(spec), init_(init) { }
|
||||
|
||||
public:
|
||||
const node *spec_;
|
||||
const node *init_;
|
||||
};
|
||||
|
||||
class type: public node{
|
||||
public:
|
||||
type(token_type spec, node * decl)
|
||||
: spec_(spec), decl_(decl) { }
|
||||
|
||||
public:
|
||||
const token_type spec_;
|
||||
const node *decl_;
|
||||
};
|
||||
|
||||
class translation_unit: public node{
|
||||
public:
|
||||
translation_unit(node *item)
|
||||
: decls_(item) { }
|
||||
|
||||
translation_unit *add(node *item) {
|
||||
decls_.append(item);
|
||||
return this;
|
||||
}
|
||||
|
||||
private:
|
||||
list<node*> decls_;
|
||||
};
|
||||
|
||||
class function_definition: public node{
|
||||
public:
|
||||
function_definition(node *header, node *body)
|
||||
: header_((declarator *)header), body_((compound_statement*)body) { }
|
||||
|
||||
public:
|
||||
const declarator *header_;
|
||||
const compound_statement *body_;
|
||||
};
|
||||
|
||||
}
|
@@ -1,88 +0,0 @@
|
||||
# - Find LLVM
|
||||
# This module can be used to find LLVM.
|
||||
# It requires that the llvm-config executable be available on the system path.
|
||||
# Once found, llvm-config is used for everything else.
|
||||
#
|
||||
# Typical usage could be:
|
||||
# find_package(LLVM QUIET REQUIRED COMPONENTS jit native interpreter)
|
||||
#
|
||||
# If the QUIET flag is not set, the specified components and LLVM version are
|
||||
# outputted.
|
||||
#
|
||||
# If the COMPONENTS are not set, the default set of "all" is used.
|
||||
#
|
||||
# The following variables are set:
|
||||
#
|
||||
# LLVM_FOUND - Set to YES if LLVM is found.
|
||||
# LLVM_VERSION - Set to the decimal version of the LLVM library.
|
||||
# LLVM_C_FLAGS - All flags that should be passed to a C compiler.
|
||||
# LLVM_CXX_FLAGS - All flags that should be passed to a C++ compiler.
|
||||
# LLVM_CPP_FLAGS - All flags that should be passed to the C pre-processor.
|
||||
# LLVM_LD_FLAGS - Additional flags to pass to the linker.
|
||||
# LLVM_LIBRARY_DIRS - A list of directories where the LLVM libraries are located.
|
||||
# LLVM_INCLUDE_DIRS - A list of directories where the LLVM headers are located.
|
||||
# LLVM_LIBRARIES - A list of libraries which should be linked against.
|
||||
|
||||
# A macro to run llvm config
|
||||
macro(_llvm_config _var_name)
|
||||
# Firstly, locate the LLVM config executable
|
||||
find_program(_llvm_config_exe
|
||||
NAMES llvm-config
|
||||
PATHS /home/philippe/Development/llvm-tlvm/build/bin/
|
||||
DOC "llvm-config executable location"
|
||||
)
|
||||
|
||||
# If no llvm-config executable was found, set the output variable to not
|
||||
# found.
|
||||
if(NOT _llvm_config_exe)
|
||||
set(${_var_name} "${_var_name}-NOTFOUND")
|
||||
else(NOT _llvm_config_exe)
|
||||
# Otherwise, run llvm-config
|
||||
execute_process(
|
||||
COMMAND ${_llvm_config_exe} ${ARGN}
|
||||
OUTPUT_VARIABLE ${_var_name}
|
||||
RESULT_VARIABLE _llvm_config_retval
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
if(RESULT_VARIABLE)
|
||||
message(SEND_ERROR
|
||||
"Error running llvm-config with arguments: ${ARGN}")
|
||||
endif(RESULT_VARIABLE)
|
||||
endif(NOT _llvm_config_exe)
|
||||
endmacro(_llvm_config)
|
||||
|
||||
# The default set of components
|
||||
set(_llvm_components all)
|
||||
|
||||
# If components have been specified via find_package, use them
|
||||
if(LLVM_FIND_COMPONENTS)
|
||||
set(_llvm_components ${LLVM_FIND_COMPONENTS})
|
||||
endif(LLVM_FIND_COMPONENTS)
|
||||
|
||||
if(NOT LLVM_FIND_QUIETLY)
|
||||
message(STATUS "Looking for LLVM components: ${_llvm_components}")
|
||||
endif(NOT LLVM_FIND_QUIETLY)
|
||||
|
||||
_llvm_config(LLVM_VERSION --version)
|
||||
_llvm_config(LLVM_C_FLAGS --cflags)
|
||||
_llvm_config(LLVM_CXX_FLAGS --cxxflags)
|
||||
_llvm_config(LLVM_CPP_FLAGS --cppflags)
|
||||
_llvm_config(LLVM_LD_FLAGS --ldflags)
|
||||
_llvm_config(LLVM_LIBRARY_DIRS --libdir)
|
||||
_llvm_config(LLVM_INCLUDE_DIRS --includedir)
|
||||
_llvm_config(LLVM_LIBRARIES --libs)
|
||||
|
||||
if(NOT LLVM_FIND_QUIETLY)
|
||||
message(STATUS "Found LLVM version: ${LLVM_VERSION}")
|
||||
endif(NOT LLVM_FIND_QUIETLY)
|
||||
|
||||
# handle the QUIETLY and REQUIRED arguments and set LLVM_FOUND to TRUE if
|
||||
# all listed variables are TRUE
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(LLVM
|
||||
DEFAULT_MSG
|
||||
LLVM_LIBRARIES
|
||||
LLVM_INCLUDE_DIRS
|
||||
LLVM_LIBRARY_DIRS)
|
||||
|
||||
# vim:sw=4:ts=4:autoindent
|
456
conv.cpp
456
conv.cpp
@@ -1,456 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Host.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
// Index computation
|
||||
inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
|
||||
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
|
||||
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpp_conv_nchw(int32_t C, int32_t N, int32_t K,
|
||||
int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t M, int32_t P, int32_t Q,
|
||||
std::vector<std::vector<OUT_DTYPE>>& O, IN_DTYPE* I, IN_DTYPE* F)
|
||||
{
|
||||
size_t num_outputs = O.size();
|
||||
static const int PACK_IN = 1;
|
||||
static const int PACK_OUT = 1;
|
||||
if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4");
|
||||
if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4");
|
||||
C /= PACK_IN;
|
||||
K /= PACK_OUT;
|
||||
int32_t Kout = K;
|
||||
IN_DTYPE accs[PACK_OUT];
|
||||
for(size_t o = 0; o < num_outputs; o++)
|
||||
for(int32_t m = 0 ; m < M; ++m)
|
||||
for(int32_t p = 0 ; p < P; ++p)
|
||||
for(int32_t q = 0; q < Q; ++q)
|
||||
for(int32_t n = 0; n < N; ++n)
|
||||
for(int32_t k = 0; k < Kout ; ++k)
|
||||
{
|
||||
for(int32_t i = 0 ; i < PACK_OUT; ++i)
|
||||
accs[i] = 0;
|
||||
int32_t mm = m*stride_d - pad_d;
|
||||
int32_t pp = p*stride_h - pad_h;
|
||||
int32_t qq = q*stride_w - pad_w;
|
||||
for(int32_t kk = 0; kk < PACK_OUT; ++kk)
|
||||
for(int32_t c = 0; c < C; ++c)
|
||||
for(int32_t t = 0; t < T; ++t)
|
||||
for(int32_t r = 0; r < R; ++r)
|
||||
for(int32_t s = 0; s < S; ++s){
|
||||
int32_t d = mm + t;
|
||||
int32_t h = pp + r;
|
||||
int32_t w = qq + s;
|
||||
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W);
|
||||
IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0;
|
||||
IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)];
|
||||
accs[kk] += i*f;
|
||||
}
|
||||
O[o][idx(n, k, m, p, q, N, K, M, P, Q)] = accs[0];
|
||||
}
|
||||
}
|
||||
|
||||
void autotune(llvm::TargetMachine *machine, llvm::Module &module){
|
||||
// Target parameters
|
||||
std::vector<unsigned> ranges = {
|
||||
// asm
|
||||
2, 16, 1, 64,
|
||||
// bsn
|
||||
2, 16, 1, 64,
|
||||
// pa
|
||||
1, 2, 4, 8,
|
||||
// pb
|
||||
1, 2, 4,
|
||||
// sm
|
||||
2, 1, 16, 2, 2, 2
|
||||
};
|
||||
|
||||
// Function
|
||||
llvm::Function *F = module.getFunction("kernel");
|
||||
|
||||
// Auto-tuning
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::TargetPassConfig *pass_config = static_cast<llvm::LLVMTargetMachine*>(machine)->createPassConfig(pass);
|
||||
llvm::FunctionPass *tuning_params = pass_config->createTargetTuningParameters();
|
||||
tuning_params->runOnFunction(*F);
|
||||
|
||||
|
||||
// Gather all parameters
|
||||
llvm::DenseSet<unsigned*> unique;
|
||||
llvm::SmallVector<unsigned*, 8> params;
|
||||
for(llvm::BasicBlock &bb: *F)
|
||||
for(llvm::Instruction &instr: bb){
|
||||
// Get tuning parameters for this particular instruction
|
||||
std::vector<llvm::TargetTuner::ParamType> tuning_params;
|
||||
machine->getTargetTuner().getParams(&instr, tuning_params);
|
||||
for(llvm::TargetTuner::ParamType ¶m: tuning_params){
|
||||
// This parameter has not been seen before
|
||||
if(unique.insert(param.Value).second){
|
||||
std::cout << "PARAM: " << instr.getName().data() << " " << param.Name << std::endl;
|
||||
params.push_back(param.Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather all constraints
|
||||
std::vector<std::function<bool()>> constraints;
|
||||
for(llvm::BasicBlock &bb: *F)
|
||||
for(llvm::Instruction &instr: bb)
|
||||
machine->getTargetTuner().getConstraints(&instr, constraints);
|
||||
|
||||
// Assign parameters
|
||||
std::cout << params.size() << " " << ranges.size() << std::endl;
|
||||
for(unsigned i = 0; i < params.size(); i++)
|
||||
*params[i] = ranges[i];
|
||||
|
||||
// Verify constraints
|
||||
bool valid = true;
|
||||
for(auto &constraint: constraints){
|
||||
valid = valid & constraint();
|
||||
}
|
||||
|
||||
if(!valid){
|
||||
printf("Invalid kernel parameters\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
int main(){
|
||||
std::string error;
|
||||
|
||||
llvm::InitializeAllTargetInfos();
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmParsers();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
|
||||
// Module
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::Module> module = llvm::make_unique<llvm::Module>("TLVM toy example", context);
|
||||
llvm::IRBuilder<> builder(context);
|
||||
|
||||
unsigned RR = 3, SS = 3;
|
||||
unsigned Nfilt = RR * SS;
|
||||
unsigned block = 8;
|
||||
unsigned nlut = (block + Nfilt - 1)/Nfilt * Nfilt;
|
||||
|
||||
// Globals
|
||||
llvm::Type* bool_t = llvm::Type::getInt1Ty(context);
|
||||
llvm::Type* mask_tile_t = llvm::TileType::get(bool_t, 2);
|
||||
llvm::Type* numeric_t = llvm::Type::getFloatTy(context);
|
||||
llvm::PointerType* numeric_ptr_t = llvm::PointerType::get(numeric_t, 0);
|
||||
llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context);
|
||||
llvm::PointerType* lut_ptr_t = llvm::PointerType::get(int32_t, 4);
|
||||
llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context);
|
||||
|
||||
llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1);
|
||||
llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2);
|
||||
llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1);
|
||||
llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2);
|
||||
|
||||
llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0);
|
||||
llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t});
|
||||
llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t});
|
||||
llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t});
|
||||
llvm::Function* gtp_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_1d, {int32_slice_t->getPointerTo(4), int32_t->getPointerTo(4), int32_slice_t});
|
||||
llvm::Function* stp_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_1d, {int32_slice_t->getPointerTo(4), int32_slice_t});
|
||||
|
||||
llvm::Function* gtp_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t});
|
||||
llvm::Intrinsic::ID mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* outer_and_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t});
|
||||
|
||||
// Hyperparameters
|
||||
llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0);
|
||||
llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1);
|
||||
llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2);
|
||||
|
||||
// Constants
|
||||
llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0);
|
||||
llvm::Constant *_f0 = llvm::ConstantFP::get(numeric_t, 0);
|
||||
llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn});
|
||||
|
||||
// LUT
|
||||
unsigned num_delta = nlut;
|
||||
unsigned num_inc_delta = nlut;
|
||||
unsigned num_masks = nlut;
|
||||
unsigned num_inc_masks = nlut;
|
||||
unsigned cst_size = num_delta + num_inc_delta + num_masks + num_inc_masks;
|
||||
llvm::GlobalVariable *lut_array =
|
||||
new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, cst_size), false, llvm::GlobalVariable::InternalLinkage,
|
||||
nullptr, "lut_array", nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
llvm::Value *cst_ptr = builder.CreateBitCast(lut_array, lut_ptr_t);
|
||||
|
||||
|
||||
// Function
|
||||
llvm::FunctionType* prototype = llvm::FunctionType::get(llvm::Type::getVoidTy(context), std::vector<llvm::Type*>{numeric_ptr_t, numeric_ptr_t, numeric_ptr_t, int32_t, int32_t, int32_t, int32_t, int32_t}, false);
|
||||
llvm::Function* F = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, "kernel", module.get());
|
||||
std::vector<llvm::Value*> args;
|
||||
F->addAttribute(1, llvm::Attribute::ReadOnly);
|
||||
F->addAttribute(1, llvm::Attribute::NoAlias);
|
||||
F->addAttribute(2, llvm::Attribute::ReadOnly);
|
||||
F->addAttribute(2, llvm::Attribute::NoAlias);
|
||||
std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(args), [&](llvm::Argument& x){ return &x;});
|
||||
llvm::Value *base_pc = args[0], *base_pa = args[1], *base_pb = args[2];
|
||||
llvm::Value *C = args[3], *H = args[4], *W = args[5], *N = args[6], *K = args[7];
|
||||
llvm::Value *R = builder.getInt32(RR), *S = builder.getInt32(SS);
|
||||
|
||||
// All basic blocks
|
||||
llvm::BasicBlock* PrologBB = llvm::BasicBlock::Create(context, "prologue", F);
|
||||
llvm::BasicBlock* LoopBB = llvm::BasicBlock::Create(context, "loop", F);
|
||||
llvm::BasicBlock* EarlyExitBB = llvm::BasicBlock::Create(context, "early_exit", F);
|
||||
llvm::BasicBlock* LastIterBB = llvm::BasicBlock::Create(context, "last_iter", F);
|
||||
llvm::BasicBlock* EpilogueBB = llvm::BasicBlock::Create(context, "epilogue", F);
|
||||
|
||||
|
||||
// First basic block
|
||||
builder.SetInsertPoint(PrologBB);
|
||||
llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "sa0");
|
||||
llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "sb0");
|
||||
llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sa1");
|
||||
llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sb1");
|
||||
|
||||
llvm::Value* lda_w = builder.getInt32(1);
|
||||
llvm::Value* lda_h = builder.CreateMul(lda_w, W);
|
||||
llvm::Value* lda_c = builder.CreateMul(lda_h, H);
|
||||
llvm::Value* lda_n = builder.CreateMul(lda_c, C);
|
||||
|
||||
llvm::Value* ldb_s = builder.getInt32(1);
|
||||
llvm::Value* ldb_r = builder.CreateMul(ldb_s, S);
|
||||
llvm::Value* ldb_c = builder.CreateMul(ldb_r, R);
|
||||
llvm::Value* ldb_k = builder.CreateMul(ldb_c, C);
|
||||
|
||||
llvm::Value* CRS = builder.CreateMul(C, builder.CreateMul(R, S));
|
||||
llvm::Value* PQN = builder.CreateMul(H, builder.CreateMul(W, N));
|
||||
|
||||
// Images HWN offset
|
||||
llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {bm, lda_n}));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {bm, lda_h})));
|
||||
offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {bm, lda_w})));
|
||||
// Images CRS offset
|
||||
llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {bk, S}));
|
||||
llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {bk, S}));
|
||||
llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
|
||||
llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {bk, R}));
|
||||
llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {bk, lda_c}));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {bk, lda_h})));
|
||||
offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w})));
|
||||
// Images pointer
|
||||
llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1});
|
||||
llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_pa, off_a}, "start_pa");
|
||||
llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_aa");
|
||||
llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_a");
|
||||
// Filters pointer
|
||||
llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K});
|
||||
llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_b");
|
||||
llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_pb, off_b}, "start_pb");
|
||||
llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_bb");
|
||||
llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_b");
|
||||
// Filters increment
|
||||
llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_b_0");
|
||||
llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_b_1");
|
||||
llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_b");
|
||||
// Pointers to constant memory
|
||||
llvm::Value* base_incdelta = builder.CreateGEP(cst_ptr, builder.getInt32(0));
|
||||
llvm::Value* base_delta = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_delta));
|
||||
llvm::Value* base_incmask = builder.CreateGEP(cst_ptr, builder.getInt32(num_delta));
|
||||
llvm::Value* base_mask = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_masks));
|
||||
// Delta pointers
|
||||
llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta");
|
||||
llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta");
|
||||
// Masks
|
||||
llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)});
|
||||
llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1);
|
||||
llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}, "start_pincmask");
|
||||
llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}, "start_pmask");
|
||||
// Enter loop
|
||||
builder.CreateBr(LoopBB);
|
||||
builder.SetInsertPoint(LoopBB);
|
||||
// PHI nodes
|
||||
llvm::PHINode* c = builder.CreatePHI(_0->getType(), 3, "c");
|
||||
llvm::PHINode* crs = builder.CreatePHI(int32_t, 3, "crs");
|
||||
llvm::PHINode* pa = builder.CreatePHI(start_pa->getType(), 3, "pa");
|
||||
llvm::PHINode* pb = builder.CreatePHI(start_pb->getType(), 3, "pb");
|
||||
llvm::PHINode *a = builder.CreatePHI(start_a->getType(), 3, "a");
|
||||
llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b");
|
||||
llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3);
|
||||
llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3);
|
||||
llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3);
|
||||
llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3);
|
||||
llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c");
|
||||
c->addIncoming(_0, PrologBB);
|
||||
c->addIncoming(next_c, LoopBB);
|
||||
// Induction variable
|
||||
llvm::Value *next_crs = builder.CreateSub(crs, bk);
|
||||
crs->addIncoming(CRS, PrologBB);
|
||||
crs->addIncoming(next_crs, LoopBB);
|
||||
// Update pointer
|
||||
llvm::Value *inc_delta = builder.CreateLoad(pincdelta);
|
||||
llvm::Value *inc_mask = builder.CreateLoad(pincmasks);
|
||||
llvm::Value *inc_a_1 = builder.CreateLoad(pdelta);
|
||||
llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)});
|
||||
llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1});
|
||||
llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_pa");
|
||||
llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_pb");
|
||||
llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta}, "next_pdelta");
|
||||
llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta}, "next_pincdelta");
|
||||
llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask}, "next_pmask");
|
||||
llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask}, "next_pincmask");
|
||||
pdelta->addIncoming(start_pdelta, PrologBB);
|
||||
pdelta->addIncoming(next_pdelta, LoopBB);
|
||||
pincdelta->addIncoming(start_pincdelta, PrologBB);
|
||||
pincdelta->addIncoming(next_pincdelta, LoopBB);
|
||||
pmasks->addIncoming(start_pmask, PrologBB);
|
||||
pmasks->addIncoming(next_pmask, LoopBB);
|
||||
pincmasks->addIncoming(start_pincmask, PrologBB);
|
||||
pincmasks->addIncoming(next_pincmask, LoopBB);
|
||||
pa->addIncoming(start_pa, PrologBB);
|
||||
pa->addIncoming(next_pa, LoopBB);
|
||||
pb->addIncoming(start_pb, PrologBB);
|
||||
pb->addIncoming(next_pb, LoopBB);
|
||||
// End condition
|
||||
llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0), "no_bounds_check");
|
||||
// Masks
|
||||
llvm::Value* mask_a_0 = builder.CreateLoad(pmasks, "mask_a_0");
|
||||
llvm::Value* mask_a_i32 = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1}, "mask_a_i32");
|
||||
llvm::Value* mask_a = builder.CreateICmpNE(mask_a_i32, llvm::ConstantTile::get(_s0, {bm, bk}), "mask_a");
|
||||
llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b");
|
||||
// Pre-fetch
|
||||
llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa");
|
||||
llvm::Value* next_bb = builder.CreateCall(masked_load, {next_pb, mask_b}, "next_bb");
|
||||
llvm::Value* next_a = builder.CreateCall(reshape, {next_aa, bm, bk}, "next_a");
|
||||
llvm::Value* next_b = builder.CreateCall(reshape, {next_bb, bn, bk}, "next_b");
|
||||
a->addIncoming(start_a, PrologBB);
|
||||
a->addIncoming(next_a, LoopBB);
|
||||
b->addIncoming(start_b, PrologBB);
|
||||
b->addIncoming(next_b, LoopBB);
|
||||
// End condition
|
||||
builder.CreateCondBr(no_bounds_check, LoopBB, EarlyExitBB);
|
||||
// Early exit
|
||||
builder.SetInsertPoint(EarlyExitBB);
|
||||
llvm::Value* exit = builder.CreateICmpSLE(next_crs, _s0);
|
||||
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
|
||||
// Last Iteration
|
||||
builder.SetInsertPoint(LastIterBB);
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {bn, K}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {bk, next_crs}));
|
||||
llvm::Value* last_maskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "last_maskb");
|
||||
llvm::Value* last_bb = builder.CreateCall(masked_load, {next_pb, last_maskb}, "last_bb");
|
||||
llvm::Value* last_b = builder.CreateCall(reshape, {last_bb, bn, bk}, "last_b");
|
||||
llvm::Value* loop = builder.CreateICmpSGT(next_crs, _s0);
|
||||
a->addIncoming(next_a, LastIterBB);
|
||||
b->addIncoming(last_b, LastIterBB);
|
||||
c->addIncoming(next_c, LastIterBB);
|
||||
crs->addIncoming(next_crs, LastIterBB);
|
||||
pa->addIncoming(next_pa, LastIterBB);
|
||||
pb->addIncoming(next_pb, LastIterBB);
|
||||
pdelta->addIncoming(next_pdelta, LastIterBB);
|
||||
pincdelta->addIncoming(next_pincdelta, LastIterBB);
|
||||
pmasks->addIncoming(next_pmask, LastIterBB);
|
||||
pincmasks->addIncoming(next_pincmask, LastIterBB);
|
||||
builder.CreateCondBr(loop, LoopBB, EpilogueBB);
|
||||
|
||||
// Epilogue
|
||||
builder.SetInsertPoint(EpilogueBB);
|
||||
llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "sc_pqn");
|
||||
llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "sc_k");
|
||||
// Output strides
|
||||
llvm::Value* ldc_q = builder.getInt32(1);
|
||||
llvm::Value* ldc_p = builder.CreateMul(lda_w, W);
|
||||
llvm::Value* ldc_k = builder.CreateMul(lda_h, H);
|
||||
llvm::Value* ldb_n = builder.CreateMul(lda_c, K);
|
||||
// Output PQN offset
|
||||
llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {bm, N}));
|
||||
llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {bm, W}));
|
||||
llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {bm, ldb_n}));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {bm, ldc_p})));
|
||||
offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {bm, ldc_q})));
|
||||
// Output K offset
|
||||
llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k}));
|
||||
// Output pointer
|
||||
llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1});
|
||||
llvm::Value* pc = builder.CreateCall(gtp_2d, {base_pc, offc});
|
||||
// Output masks
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K}));
|
||||
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1});
|
||||
builder.CreateCall(masked_store, {next_c, pc, maskc});
|
||||
builder.CreateRet(NULL);
|
||||
|
||||
|
||||
// Set metadata
|
||||
llvm::Metadata *md_args[] = {
|
||||
llvm::ValueAsMetadata::get(F),
|
||||
llvm::MDString::get(context, "kernel"),
|
||||
llvm::ValueAsMetadata::get(llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1))
|
||||
};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(context, md_args));
|
||||
|
||||
// Machine
|
||||
module->setTargetTriple("nvptx64-nvidia-cuda");
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), "sm_52", "",
|
||||
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||
llvm::CodeModel::Model(), llvm::CodeGenOpt::Aggressive);
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
|
||||
// Auto-tuning
|
||||
autotune(machine, *module);
|
||||
|
||||
// Emit
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
std::string src(buffer.begin(), buffer.end());
|
||||
|
||||
// Execute
|
||||
std::cout << src << std::endl;
|
||||
}
|
342
gemm.cpp
342
gemm.cpp
@@ -1,342 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Host.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
|
||||
|
||||
void autotune(llvm::TargetMachine *machine, llvm::Module &module){
|
||||
// Target parameters
|
||||
std::vector<unsigned> ranges = {
|
||||
// asm
|
||||
2, 16, 1, 64,
|
||||
// bsn
|
||||
2, 16, 1, 64,
|
||||
// pa
|
||||
1, 2, 4, 8,
|
||||
// pb
|
||||
1, 2, 4,
|
||||
// sm
|
||||
2, 1, 16, 2, 2, 2
|
||||
};
|
||||
|
||||
// Function
|
||||
llvm::Function *F = module.getFunction("kernel");
|
||||
|
||||
// Auto-tuning
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::TargetPassConfig *pass_config = static_cast<llvm::LLVMTargetMachine*>(machine)->createPassConfig(pass);
|
||||
llvm::FunctionPass *tuning_params = pass_config->createTargetTuningParameters();
|
||||
tuning_params->runOnFunction(*F);
|
||||
|
||||
|
||||
// Gather all parameters
|
||||
llvm::DenseSet<unsigned*> unique;
|
||||
llvm::SmallVector<unsigned*, 8> params;
|
||||
for(llvm::BasicBlock &bb: *F)
|
||||
for(llvm::Instruction &instr: bb){
|
||||
// Get tuning parameters for this particular instruction
|
||||
std::vector<llvm::TargetTuner::ParamType> tuning_params;
|
||||
machine->getTargetTuner().getParams(&instr, tuning_params);
|
||||
for(llvm::TargetTuner::ParamType ¶m: tuning_params){
|
||||
// This parameter has not been seen before
|
||||
if(unique.insert(param.Value).second){
|
||||
std::cout << instr.getName().data() << " " << param.Name << std::endl;
|
||||
params.push_back(param.Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather all constraints
|
||||
std::vector<std::function<bool()>> constraints;
|
||||
for(llvm::BasicBlock &bb: *F)
|
||||
for(llvm::Instruction &instr: bb)
|
||||
machine->getTargetTuner().getConstraints(&instr, constraints);
|
||||
|
||||
// Assign parameters
|
||||
std::cout << params.size() << " " << ranges.size() << std::endl;
|
||||
for(unsigned i = 0; i < params.size(); i++)
|
||||
*params[i] = ranges[i];
|
||||
|
||||
// Verify constraints
|
||||
bool valid = true;
|
||||
for(auto &constraint: constraints){
|
||||
valid = valid & constraint();
|
||||
}
|
||||
|
||||
if(!valid){
|
||||
printf("Invalid kernel parameters\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
int main(){
|
||||
// llvm::DebugFlag = true;
|
||||
|
||||
std::string error;
|
||||
|
||||
llvm::InitializeAllTargetInfos();
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmParsers();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
|
||||
// Module
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::Module> module = llvm::make_unique<llvm::Module>("TLVM toy example", context);
|
||||
llvm::IRBuilder<> builder(context);
|
||||
|
||||
// Globals
|
||||
llvm::Type* bool_t = llvm::Type::getInt1Ty(context);
|
||||
llvm::Type* mask_tile_t = llvm::TileType::get(bool_t, 2);
|
||||
llvm::Type* numeric_t = llvm::Type::getFloatTy(context);
|
||||
llvm::PointerType* numeric_ptr_t = llvm::PointerType::get(numeric_t, 0);
|
||||
llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context);
|
||||
llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context);
|
||||
|
||||
llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2);
|
||||
llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3);
|
||||
llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1);
|
||||
llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2);
|
||||
llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1);
|
||||
llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2);
|
||||
|
||||
llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0);
|
||||
llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t});
|
||||
llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t});
|
||||
llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t});
|
||||
llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t});
|
||||
llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t});
|
||||
llvm::Intrinsic::ID mma_id;
|
||||
if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn;
|
||||
if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt;
|
||||
if(AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_tn;
|
||||
if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt;
|
||||
llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t});
|
||||
llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t});
|
||||
llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t});
|
||||
llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t});
|
||||
llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t});
|
||||
llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t});
|
||||
llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t});
|
||||
|
||||
// Hyperparameters
|
||||
llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0);
|
||||
llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1);
|
||||
llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2);
|
||||
llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3);
|
||||
|
||||
// Constants
|
||||
llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0);
|
||||
llvm::Constant *_f0 = llvm::ConstantFP::get(numeric_t, 0);
|
||||
llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn});
|
||||
|
||||
// Function
|
||||
llvm::FunctionType* prototype = llvm::FunctionType::get(llvm::Type::getVoidTy(context), std::vector<llvm::Type*>{numeric_ptr_t, numeric_ptr_t, numeric_ptr_t, int32_t, int32_t, int32_t, int32_t}, false);
|
||||
llvm::Function* F = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, "kernel", module.get());
|
||||
std::vector<llvm::Value*> arguments;
|
||||
F->addAttribute(1, llvm::Attribute::ReadOnly);
|
||||
F->addAttribute(1, llvm::Attribute::NoAlias);
|
||||
F->addAttribute(2, llvm::Attribute::ReadOnly);
|
||||
F->addAttribute(2, llvm::Attribute::NoAlias);
|
||||
std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(arguments), [&](llvm::Argument& x){ return &x;});
|
||||
arguments[0]->setName("pa");
|
||||
arguments[1]->setName("pb");
|
||||
arguments[2]->setName("pc");
|
||||
arguments[3]->setName("M");
|
||||
arguments[4]->setName("N");
|
||||
arguments[5]->setName("K");
|
||||
arguments[6]->setName("bound");
|
||||
|
||||
// All basic blocks
|
||||
llvm::BasicBlock* PrologBB = llvm::BasicBlock::Create(context, "prologue", F);
|
||||
llvm::BasicBlock* LoopBB = llvm::BasicBlock::Create(context, "loop", F);
|
||||
llvm::BasicBlock* EarlyExitBB = llvm::BasicBlock::Create(context, "early_exit", F);
|
||||
llvm::BasicBlock* LastIterBB = llvm::BasicBlock::Create(context, "last_iter", F);
|
||||
llvm::BasicBlock* EpilogueBB = llvm::BasicBlock::Create(context, "epilogue", F);
|
||||
|
||||
|
||||
// First basic block
|
||||
builder.SetInsertPoint(PrologBB);
|
||||
|
||||
llvm::CallInst* aasm = builder.CreateCall(read_slice_x, {bm}, "asm");
|
||||
llvm::CallInst* bbsn = builder.CreateCall(read_slice_y, {bn}, "bsn");
|
||||
llvm::CallInst* ask = builder.CreateCall(range, {builder.getInt32(0), bk}, "ask");
|
||||
llvm::CallInst* bsk = builder.CreateCall(range, {builder.getInt32(0), bk}, "bsk");
|
||||
|
||||
llvm::Value *M = arguments[3], *N = arguments[4], *K = arguments[5];
|
||||
llvm::Value *bound = arguments[6];
|
||||
llvm::Value *AS0 = M, *AS1 = K;
|
||||
llvm::Value *sa0 = aasm, *sa1 = ask;
|
||||
llvm::Value *ba0 = bm, *ba1 = bk;
|
||||
llvm::Value *inca0 = _s0, *inca1 = bk;
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(sa0, sa1);
|
||||
std::swap(ba0, ba1);
|
||||
std::swap(inca0, inca1);
|
||||
}
|
||||
llvm::Value *BS0 = K, *BS1 = N;
|
||||
llvm::Value *sb0 = bsk, *sb1 = bbsn;
|
||||
llvm::Value *bb0 = bk, *bb1 = bn;
|
||||
llvm::Value *incb0 = bk, *incb1 = _s0;
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(sb0, sb1);
|
||||
std::swap(bb0, bb1);
|
||||
std::swap(incb0, incb1);
|
||||
}
|
||||
|
||||
llvm::CallInst* tlda = builder.CreateCall(splat_1d, {ba1, AS0}, "lda");
|
||||
llvm::CallInst* tldb = builder.CreateCall(splat_1d, {bb1, BS1}, "ldb");
|
||||
llvm::CallInst* offa = builder.CreateCall(outer_add, {sa0, builder.CreateMul(sa1, tlda)}, "offa");
|
||||
llvm::CallInst* offb = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb)}, "offb");
|
||||
llvm::CallInst* startpa = builder.CreateCall(gtp, {arguments[0], offa}, "startpa");
|
||||
llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb");
|
||||
llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa");
|
||||
llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb");
|
||||
llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta");
|
||||
llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb");
|
||||
llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0");
|
||||
llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)});
|
||||
llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)});
|
||||
llvm::Value* tincb1 = builder.CreateCall(splat_1d, {bb1, builder.CreateMul(incb1, BS1)});
|
||||
llvm::Value* inca = builder.CreateCall(outer_add, {tinca0, tinca1}, "inca");
|
||||
llvm::Value* incb = builder.CreateCall(outer_add, {tincb0, tincb1}, "incb");
|
||||
// Enter loop
|
||||
builder.CreateBr(LoopBB);
|
||||
builder.SetInsertPoint(LoopBB);
|
||||
// PHI nodes
|
||||
llvm::PHINode* c = builder.CreatePHI(_0->getType(), 2, "c");
|
||||
llvm::PHINode* k = builder.CreatePHI(int32_t, 2, "k");
|
||||
llvm::PHINode* pa = builder.CreatePHI(startpa->getType(), 2, "pa");
|
||||
llvm::PHINode* pb = builder.CreatePHI(startpb->getType(), 2, "pb");
|
||||
llvm::PHINode *a = builder.CreatePHI(starta->getType(), 2, "a");
|
||||
llvm::PHINode *b = builder.CreatePHI(startb->getType(), 2, "b");
|
||||
llvm::Value* nextc = builder.CreateCall(mma, {a, b, c}, "nextc");
|
||||
c->addIncoming(_0, PrologBB);
|
||||
c->addIncoming(nextc, LoopBB);
|
||||
// Induction variable
|
||||
llvm::Value *nextk = builder.CreateSub(k, bk);
|
||||
k->addIncoming(K, PrologBB);
|
||||
k->addIncoming(nextk, LoopBB);
|
||||
// Update pointer
|
||||
llvm::Value *nextpa = builder.CreateCall(stp, {pa, inca}, "nextpa");
|
||||
llvm::Value *nextpb = builder.CreateCall(stp, {pb, incb}, "nextpb");
|
||||
pa->addIncoming(startpa, PrologBB);
|
||||
pa->addIncoming(nextpa, LoopBB);
|
||||
pb->addIncoming(startpb, PrologBB);
|
||||
pb->addIncoming(nextpb, LoopBB);
|
||||
// End condition
|
||||
llvm::Value* no_bounds_check = builder.CreateICmpSGT(nextk, bound);
|
||||
// Masks
|
||||
llvm::Value* maska = builder.CreateCall(splat_2d, {ba0, ba1, no_bounds_check}, "maska");
|
||||
llvm::Value* maskb = builder.CreateCall(splat_2d, {bb0, bb1, no_bounds_check}, "maskb");
|
||||
// Pre-fetch
|
||||
llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa");
|
||||
llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb");
|
||||
llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta");
|
||||
llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb");
|
||||
a->addIncoming(starta, PrologBB);
|
||||
a->addIncoming(nexta, LoopBB);
|
||||
b->addIncoming(startb, PrologBB);
|
||||
b->addIncoming(nextb, LoopBB);
|
||||
// End condition
|
||||
builder.CreateCondBr(no_bounds_check, LoopBB, EarlyExitBB);
|
||||
// Early exit
|
||||
builder.SetInsertPoint(EarlyExitBB);
|
||||
llvm::Value* exit = builder.CreateICmpSLE(nextk, _s0);
|
||||
builder.CreateCondBr(exit, EpilogueBB, LastIterBB);
|
||||
// Last Iteration
|
||||
builder.SetInsertPoint(LastIterBB);
|
||||
llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {ba0, M}));
|
||||
llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ba1, bk}));
|
||||
llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bb0, N}));
|
||||
llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bb1, bk}));
|
||||
llvm::Value* lastmaska = builder.CreateCall(outer_and, {in_bounds_a0, in_bounds_a1}, "lastmaska");
|
||||
llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb");
|
||||
llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa");
|
||||
llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb");
|
||||
llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta");
|
||||
llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb");
|
||||
llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0);
|
||||
a->addIncoming(lasta, LastIterBB);
|
||||
b->addIncoming(lastb, LastIterBB);
|
||||
c->addIncoming(nextc, LastIterBB);
|
||||
k->addIncoming(nextk, LastIterBB);
|
||||
pa->addIncoming(nextpa, LastIterBB);
|
||||
pb->addIncoming(nextpb, LastIterBB);
|
||||
builder.CreateCondBr(loop, LoopBB, EpilogueBB);
|
||||
// Epilogue
|
||||
builder.SetInsertPoint(EpilogueBB);
|
||||
llvm::CallInst* sm = builder.CreateCall(read_slice_x, {bm}, "sm");
|
||||
llvm::CallInst* sn = builder.CreateCall(read_slice_y, {bn}, "sn");
|
||||
llvm::CallInst* ldc = builder.CreateCall(splat_1d, {bn, M}, "lda");
|
||||
llvm::CallInst* offc = builder.CreateCall(outer_add, {sm, builder.CreateMul(sn, ldc)}, "offc");
|
||||
llvm::CallInst* pc = builder.CreateCall(gtp, {arguments[2], offc}, "pc");
|
||||
llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {bm, M}));
|
||||
llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {bn, N}));
|
||||
llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}, "maskc");
|
||||
builder.CreateCall(masked_store, {nextc, pc, maskc});
|
||||
builder.CreateRet(NULL);
|
||||
|
||||
|
||||
// Set metadata
|
||||
llvm::Metadata *md_args[] = {
|
||||
llvm::ValueAsMetadata::get(F),
|
||||
llvm::MDString::get(context, "kernel"),
|
||||
llvm::ValueAsMetadata::get(llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1))
|
||||
};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(context, md_args));
|
||||
|
||||
// Machine
|
||||
module->setTargetTriple("nvptx64-nvidia-cuda");
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), "sm_52", "",
|
||||
llvm::TargetOptions(), llvm::Reloc::Model(),
|
||||
llvm::CodeModel::Model(), llvm::CodeGenOpt::Aggressive);
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
|
||||
// Auto-tuning
|
||||
autotune(machine, *module);
|
||||
|
||||
// Emit
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
std::string src(buffer.begin(), buffer.end());
|
||||
|
||||
// Execute
|
||||
std::cout << src << std::endl;
|
||||
}
|
14
main.cpp
Normal file
14
main.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
|
||||
int main() {
|
||||
char string[] = "void test(int);";
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(string);
|
||||
yy_delete_buffer(buffer);
|
||||
return 0;
|
||||
}
|
305
parser.y
Normal file
305
parser.y
Normal file
@@ -0,0 +1,305 @@
|
||||
%{
|
||||
namespace ast{
|
||||
class node;
|
||||
}
|
||||
using namespace ast;
|
||||
#define YYSTYPE node*
|
||||
#include "../ast.h"
|
||||
using namespace ast;
|
||||
|
||||
extern char* yytext;
|
||||
void yyerror(const char *s);
|
||||
int yylex(void);
|
||||
|
||||
translation_unit *ast_root;
|
||||
|
||||
%}
|
||||
|
||||
%token IDENTIFIER CONSTANT STRING_LITERAL
|
||||
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
|
||||
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
|
||||
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64
|
||||
%token IF ELSE FOR
|
||||
%token DEF
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
|
||||
|
||||
/* -------------------------- */
|
||||
/* Types */
|
||||
/* -------------------------- */
|
||||
|
||||
type_specifier
|
||||
: VOID
|
||||
| UINT8 | UINT16 | UINT32 | UINT64
|
||||
| INT8 | INT16 | INT32 | INT64
|
||||
| FP32 | FP64
|
||||
;
|
||||
|
||||
pointer
|
||||
: '*' { $$ = new pointer_declarator(1); }
|
||||
| '*' pointer { $$ = ((pointer_declarator*)$1)->inc(); }
|
||||
|
||||
abstract_declarator
|
||||
: pointer { $$ = $1; }
|
||||
| direct_abstract_declarator { $$ = $1; }
|
||||
| pointer direct_abstract_declarator { $$ = new compound_declarator($1, $2); }
|
||||
;
|
||||
|
||||
direct_abstract_declarator
|
||||
: '[' constant_list ']' { $$ = new tile_declarator($1); }
|
||||
|
||||
constant :
|
||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||
;
|
||||
|
||||
constant_list
|
||||
: constant { $$ = new list<constant*>((constant*)$1); }
|
||||
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $2); }
|
||||
;
|
||||
|
||||
type_name
|
||||
: type_specifier { $$ = new type((yytokentype)(size_t)$1, nullptr); }
|
||||
| type_specifier abstract_declarator { $$ = new type((yytokentype)(size_t)$1, $2); }
|
||||
;
|
||||
|
||||
/* -------------------------- */
|
||||
/* Expressions */
|
||||
/* -------------------------- */
|
||||
|
||||
identifier
|
||||
: IDENTIFIER { $$ = new identifier(yytext); }
|
||||
;
|
||||
|
||||
primary_expression
|
||||
: identifier { $$ = $1; }
|
||||
| constant { $$ = $1; }
|
||||
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
||||
| '(' unary_expression ')' { $$ = $1; }
|
||||
;
|
||||
|
||||
unary_expression
|
||||
: primary_expression { $$ = $1; }
|
||||
| INC_OP unary_expression { $$ = new unary_operator(INC_OP, $2); }
|
||||
| DEC_OP unary_expression { $$ = new unary_operator(DEC_OP, $2); }
|
||||
| unary_operator cast_expression { $$ = new unary_operator((yytokentype)(size_t)$1, $2); }
|
||||
;
|
||||
|
||||
unary_operator
|
||||
: '&'
|
||||
| '*'
|
||||
| '+'
|
||||
| '-'
|
||||
| '~'
|
||||
| '!'
|
||||
;
|
||||
|
||||
cast_expression
|
||||
: unary_expression { $$ = $1; }
|
||||
| '(' type_name ')' cast_expression { $$ = new cast_operator((yytokentype)(size_t)$1, $2); }
|
||||
;
|
||||
|
||||
multiplicative_expression
|
||||
: cast_expression { $$ = $1; }
|
||||
| multiplicative_expression '*' cast_expression { $$ = new binary_operator('*', $1, $3); }
|
||||
| multiplicative_expression '/' cast_expression { $$ = new binary_operator('/', $1, $3); }
|
||||
| multiplicative_expression '%' cast_expression { $$ = new binary_operator('%', $1, $3); }
|
||||
;
|
||||
|
||||
additive_expression
|
||||
: multiplicative_expression { $$ = $1; }
|
||||
| additive_expression '+' multiplicative_expression { $$ = new binary_operator('+', $1, $3); }
|
||||
| additive_expression '-' multiplicative_expression { $$ = new binary_operator('-', $1, $3); }
|
||||
;
|
||||
|
||||
shift_expression
|
||||
: additive_expression { $$ = $1; }
|
||||
| shift_expression LEFT_OP additive_expression { $$ = new binary_operator(LEFT_OP, $1, $3); }
|
||||
| shift_expression RIGHT_OP additive_expression { $$ = new binary_operator(RIGHT_OP, $1, $3); }
|
||||
;
|
||||
|
||||
relational_expression
|
||||
: shift_expression { $$ = $1; }
|
||||
| relational_expression '<' shift_expression { $$ = new binary_operator('<', $1, $3); }
|
||||
| relational_expression '>' shift_expression { $$ = new binary_operator('>', $1, $3); }
|
||||
| relational_expression LE_OP shift_expression { $$ = new binary_operator(LE_OP, $1, $3); }
|
||||
| relational_expression GE_OP shift_expression { $$ = new binary_operator(GE_OP, $1, $3); }
|
||||
;
|
||||
|
||||
equality_expression
|
||||
: relational_expression { $$ = $1; }
|
||||
| equality_expression EQ_OP relational_expression { $$ = new binary_operator(EQ_OP, $1, $3); }
|
||||
| equality_expression NE_OP relational_expression { $$ = new binary_operator(NE_OP, $1, $3); }
|
||||
;
|
||||
|
||||
and_expression
|
||||
: equality_expression { $$ = $1; }
|
||||
| and_expression '&' equality_expression { $$ = new binary_operator('&', $1, $3); }
|
||||
;
|
||||
|
||||
exclusive_or_expression
|
||||
: and_expression { $$ = $1; }
|
||||
| exclusive_or_expression '^' and_expression { $$ = new binary_operator('^', $1, $3); }
|
||||
;
|
||||
|
||||
inclusive_or_expression
|
||||
: exclusive_or_expression { $$ = $1; }
|
||||
| inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_operator('|', $1, $3); }
|
||||
;
|
||||
|
||||
logical_and_expression
|
||||
: inclusive_or_expression { $$ = $1; }
|
||||
| logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_operator(AND_OP, $1, $3); }
|
||||
;
|
||||
|
||||
logical_or_expression
|
||||
: logical_and_expression { $$ = $1; }
|
||||
| logical_or_expression OR_OP logical_and_expression { $$ = new binary_operator(OR_OP, $1, $3); }
|
||||
;
|
||||
|
||||
conditional_expression
|
||||
: logical_or_expression { $$ = $1; }
|
||||
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $2, $3); }
|
||||
;
|
||||
|
||||
assignment_operator
|
||||
: '='
|
||||
| MUL_ASSIGN
|
||||
| DIV_ASSIGN
|
||||
| MOD_ASSIGN
|
||||
| ADD_ASSIGN
|
||||
| SUB_ASSIGN
|
||||
| LEFT_ASSIGN
|
||||
| RIGHT_ASSIGN
|
||||
| AND_ASSIGN
|
||||
| XOR_ASSIGN
|
||||
| OR_ASSIGN
|
||||
;
|
||||
|
||||
|
||||
assignment_expression
|
||||
: conditional_expression { $$ = $1; }
|
||||
| unary_expression assignment_operator assignment_expression { $$ = new assignment_expression($1, (yytokentype)(size_t)$2, $3); }
|
||||
;
|
||||
|
||||
expression
|
||||
: assignment_expression { $$ = $1; }
|
||||
;
|
||||
|
||||
/* -------------------------- */
|
||||
/* Statements */
|
||||
/* -------------------------- */
|
||||
|
||||
statement
|
||||
: compound_statement { $$ = $1; }
|
||||
| expression_statement { $$ = $1; }
|
||||
| selection_statement { $$ = $1; }
|
||||
| iteration_statement { $$ = $1; }
|
||||
;
|
||||
|
||||
compound_statement
|
||||
: '{' '}' { $$ = new compound_statement(); }
|
||||
| '{' statement_list '}' { $$ = $1; }
|
||||
;
|
||||
|
||||
statement_list
|
||||
: statement { $$ = new compound_statement($1); }
|
||||
| statement_list statement { $$ = append_ptr_list<compound_statement>($1, $2); }
|
||||
;
|
||||
|
||||
expression_statement
|
||||
: ';' { $$ = new no_op(); }
|
||||
| expression ';' { $$ = $1; }
|
||||
;
|
||||
|
||||
selection_statement
|
||||
: IF '(' expression ')' statement { $$ = new selection_statement($1, $2); }
|
||||
| IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($1, $2, $3); }
|
||||
;
|
||||
|
||||
iteration_statement
|
||||
: FOR '(' expression_statement expression_statement ')' statement { $$ = new iteration_statement($1, $2, NULL, $3); }
|
||||
| FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($1, $2, $3, $3); }
|
||||
;
|
||||
|
||||
|
||||
/* -------------------------- */
|
||||
/* Declarator */
|
||||
/* -------------------------- */
|
||||
|
||||
|
||||
direct_declarator
|
||||
: identifier { $$ = $1; }
|
||||
| direct_declarator '[' constant_list ']' { $$ = new tile_declarator($2); }
|
||||
| direct_declarator '(' parameter_list ')' { $$ = new function_declarator($2); }
|
||||
| direct_declarator '(' identifier_list ')' { $$ = new function_declarator($2); }
|
||||
| direct_declarator '(' ')' { $$ = new function_declarator(nullptr); }
|
||||
;
|
||||
|
||||
identifier_list
|
||||
: identifier { $$ = new list<identifier*>((identifier*)$1); }
|
||||
| identifier_list ',' identifier { $$ = append_ptr_list<identifier>($1, $2); }
|
||||
;
|
||||
|
||||
parameter_list
|
||||
: parameter_declaration { $$ = new list<parameter*>((parameter*)$1); }
|
||||
| parameter_list ',' parameter_declaration { $$ = append_ptr_list<parameter>($1, $2); }
|
||||
;
|
||||
|
||||
parameter_declaration
|
||||
: declaration_specifiers declarator { $$ = new parameter((yytokentype)(size_t)$1, $2); }
|
||||
| declaration_specifiers abstract_declarator { $$ = new parameter((yytokentype)(size_t)$1, $2); }
|
||||
| declaration_specifiers { $$ = new parameter((yytokentype)(size_t)$1, nullptr); }
|
||||
;
|
||||
|
||||
|
||||
declaration_specifiers
|
||||
: type_specifier { $$ = $1; }
|
||||
;
|
||||
|
||||
init_declarator_list
|
||||
: init_declarator { $$ = new list<init_declarator*>((init_declarator*)$1); }
|
||||
| init_declarator_list ',' init_declarator { $$ = append_ptr_list<init_declarator>($1, $2); }
|
||||
;
|
||||
|
||||
declaration
|
||||
: declaration_specifiers ';' { $$ = new declaration($1, nullptr); }
|
||||
| declaration_specifiers init_declarator_list ';' { $$ = new declaration($1, $2); }
|
||||
;
|
||||
|
||||
declarator
|
||||
: pointer direct_declarator { $$ = new compound_declarator($1, $2); }
|
||||
| direct_declarator { $$ = $1; }
|
||||
;
|
||||
|
||||
initializer
|
||||
: assignment_expression { $$ = $1; }
|
||||
| '{' constant '}' { $$ = $1; }
|
||||
;
|
||||
|
||||
init_declarator
|
||||
: declarator { $$ = new init_declarator($1, nullptr); }
|
||||
| declarator '=' initializer { $$ = new init_declarator($1, $2); }
|
||||
;
|
||||
|
||||
/* -------------------------- */
|
||||
/* Translation Unit */
|
||||
/* -------------------------- */
|
||||
|
||||
translation_unit
|
||||
: external_declaration { $$ = new translation_unit($1); }
|
||||
| translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); }
|
||||
;
|
||||
|
||||
external_declaration
|
||||
: function_definition { $$ = $1; }
|
||||
| declaration { $$ = $1; }
|
||||
;
|
||||
|
||||
function_definition
|
||||
: declarator compound_statement { $$ = new function_definition($1, $2); }
|
||||
;
|
||||
|
128
scanner.l
Normal file
128
scanner.l
Normal file
@@ -0,0 +1,128 @@
|
||||
D [0-9]
|
||||
L [a-zA-Z_]
|
||||
H [a-fA-F0-9]
|
||||
E [Ee][+-]?{D}+
|
||||
FS (f|F|l|L)
|
||||
IS (u|U|l|L)*
|
||||
|
||||
%{
|
||||
#include <stdio.h>
|
||||
#include "parser.hpp"
|
||||
|
||||
void count();
|
||||
int check_type();
|
||||
int comment();
|
||||
|
||||
%}
|
||||
|
||||
%%
|
||||
"def" { count(); return(DEF); }
|
||||
"if" { count(); return(IF); }
|
||||
"else" { count(); return(ELSE); }
|
||||
"for" { count(); return(FOR); }
|
||||
"void" { count(); return(VOID); }
|
||||
"uint8" { count(); return(UINT8); }
|
||||
"uint16" { count(); return(UINT16); }
|
||||
"uint32" { count(); return(UINT32); }
|
||||
"uint64" { count(); return(UINT64); }
|
||||
"int8" { count(); return(INT8); }
|
||||
"int16" { count(); return(INT16); }
|
||||
"int32" { count(); return(INT32); }
|
||||
"int64" { count(); return(INT64); }
|
||||
"fp32" { count(); return(FP32); }
|
||||
"fp64" { count(); return(FP64); }
|
||||
|
||||
{L}({L}|{D})* { count(); return(check_type()); }
|
||||
|
||||
0[xX]{H}+{IS}? { count(); return(CONSTANT); }
|
||||
0{D}+{IS}? { count(); return(CONSTANT); }
|
||||
{D}+{IS}? { count(); return(CONSTANT); }
|
||||
L?'(\\.|[^\\'])+' { count(); return(CONSTANT); }
|
||||
|
||||
{D}+{E}{FS}? { count(); return(CONSTANT); }
|
||||
{D}*"."{D}+({E})?{FS}? { count(); return(CONSTANT); }
|
||||
{D}+"."{D}*({E})?{FS}? { count(); return(CONSTANT); }
|
||||
|
||||
L?\"(\\.|[^\\"])*\" { count(); return(STRING_LITERAL); }
|
||||
|
||||
">>=" { count(); return(RIGHT_ASSIGN); }
|
||||
"<<=" { count(); return(LEFT_ASSIGN); }
|
||||
"+=" { count(); return(ADD_ASSIGN); }
|
||||
"-=" { count(); return(SUB_ASSIGN); }
|
||||
"*=" { count(); return(MUL_ASSIGN); }
|
||||
"/=" { count(); return(DIV_ASSIGN); }
|
||||
"%=" { count(); return(MOD_ASSIGN); }
|
||||
"&=" { count(); return(AND_ASSIGN); }
|
||||
"^=" { count(); return(XOR_ASSIGN); }
|
||||
"|=" { count(); return(OR_ASSIGN); }
|
||||
">>" { count(); return(RIGHT_OP); }
|
||||
"<<" { count(); return(LEFT_OP); }
|
||||
"++" { count(); return(INC_OP); }
|
||||
"--" { count(); return(DEC_OP); }
|
||||
"->" { count(); return(PTR_OP); }
|
||||
"&&" { count(); return(AND_OP); }
|
||||
"||" { count(); return(OR_OP); }
|
||||
"<=" { count(); return(LE_OP); }
|
||||
">=" { count(); return(GE_OP); }
|
||||
"==" { count(); return(EQ_OP); }
|
||||
"!=" { count(); return(NE_OP); }
|
||||
";" { count(); return(';'); }
|
||||
("{"|"<%") { count(); return('{'); }
|
||||
("}"|"%>") { count(); return('}'); }
|
||||
"," { count(); return(','); }
|
||||
":" { count(); return(':'); }
|
||||
"=" { count(); return('='); }
|
||||
"(" { count(); return('('); }
|
||||
")" { count(); return(')'); }
|
||||
("["|"<:") { count(); return('['); }
|
||||
("]"|":>") { count(); return(']'); }
|
||||
"." { count(); return('.'); }
|
||||
"&" { count(); return('&'); }
|
||||
"!" { count(); return('!'); }
|
||||
"~" { count(); return('~'); }
|
||||
"-" { count(); return('-'); }
|
||||
"+" { count(); return('+'); }
|
||||
"*" { count(); return('*'); }
|
||||
"/" { count(); return('/'); }
|
||||
"%" { count(); return('%'); }
|
||||
"<" { count(); return('<'); }
|
||||
">" { count(); return('>'); }
|
||||
"^" { count(); return('^'); }
|
||||
"|" { count(); return('|'); }
|
||||
"?" { count(); return('?'); }
|
||||
|
||||
[ \t\v\n\f] { count(); }
|
||||
. { /* ignore bad characters */ }
|
||||
|
||||
%%
|
||||
|
||||
int yywrap()
|
||||
{ return(1); }
|
||||
|
||||
|
||||
int column = 0;
|
||||
|
||||
void count()
|
||||
{
|
||||
int i;
|
||||
|
||||
for (i = 0; yytext[i] != '\0'; i++)
|
||||
if (yytext[i] == '\n')
|
||||
column = 0;
|
||||
else if (yytext[i] == '\t')
|
||||
column += 8 - (column % 8);
|
||||
else
|
||||
column++;
|
||||
|
||||
ECHO;
|
||||
}
|
||||
|
||||
void yyerror (const char *s) /* Called by yyparse on error */
|
||||
{
|
||||
printf ("Error: %s\n", s);
|
||||
}
|
||||
|
||||
int check_type()
|
||||
{
|
||||
return(IDENTIFIER);
|
||||
}
|
Reference in New Issue
Block a user