From dc755612b970042e7fe3a57281525b5f09d279ea Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 15 Dec 2018 22:29:36 -0500 Subject: [PATCH] TDL [Parser]: Initial commit --- CMakeLists.txt | 28 +-- README.md | 1 - ast.h | 295 ++++++++++++++++++++++++++++ cmake/FindLLVM.cmake | 88 --------- conv.cpp | 456 ------------------------------------------- gemm.cpp | 342 -------------------------------- main.cpp | 14 ++ parser.y | 305 +++++++++++++++++++++++++++++ scanner.l | 128 ++++++++++++ 9 files changed, 749 insertions(+), 908 deletions(-) delete mode 100644 README.md create mode 100644 ast.h delete mode 100644 cmake/FindLLVM.cmake delete mode 100644 conv.cpp delete mode 100644 gemm.cpp create mode 100644 main.cpp create mode 100644 parser.y create mode 100644 scanner.l diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f1650aca..308a86ad1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,21 +1,7 @@ -cmake_minimum_required(VERSION 2.8.7) -project(TDL) - -# LLVM -include(cmake/FindLLVM.cmake) - -# Link directories -link_directories(/home/philippe/Development/llvm-tlvm/build/lib) -# Include directories -include_directories(/home/philippe/Development/llvm-tlvm/include) -include_directories(/home/philippe/Development/llvm-tlvm/build/include) - -# Flags -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wextra -pedantic -Wno-strict-aliasing") - -# Executables -foreach(PROG gemm conv) - add_executable(${PROG} ${PROG}.cpp) - set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG}) - target_link_libraries(${PROG} ${LLVM_LIBRARIES}) -endforeach() +find_package(BISON) +BISON_TARGET(Parser parser.y ${CMAKE_CURRENT_BINARY_DIR}/parser.cpp) +find_package(FLEX) +FLEX_TARGET(Lexer scanner.l ${CMAKE_CURRENT_BINARY_DIR}/scanner.cpp) +get_filename_component(BISON_Parser_INCLUDE_DIRECTORIES ${BISON_Parser_OUTPUT_HEADER} DIRECTORY) +include_directories(${BISON_Parser_INCLUDE_DIRECTORIES}) +add_executable(test main.cpp ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) diff --git a/README.md b/README.md deleted file mode 100644 index 1a34b0d0e..000000000 --- a/README.md +++ /dev/null @@ -1 +0,0 @@ -# tdl-examples diff --git a/ast.h b/ast.h new file mode 100644 index 000000000..6cf22e89b --- /dev/null +++ b/ast.h @@ -0,0 +1,295 @@ +#include "parser.hpp" +#include +#include +#include + +typedef yytokentype token_type; + +namespace ast{ + +class node { }; + +template +class list: public node { +public: + list(const T& x): values_{x} {} + node* append(const T& x) { values_.push_back(x); return this;} + +private: + std::list values_; +}; + +template +node* append_ptr_list(node *result, node *in){ + return static_cast*>(result)->append((T*)in); +} + +class binary_operator: public node{ +public: + enum OP_T{ + MUL, DIV, REM, + ADD, SUB, + LEFT_SHIFT, RIGHT_SHIFT, + LT, GT, + LE, GE, + EQ, NE, + AND, XOR, OR, + LAND, LOR + }; + + static OP_T get_op(token_type token){ + switch(token){ + case LEFT_OP: return LEFT_SHIFT; + case RIGHT_OP: return RIGHT_SHIFT; + case LE_OP: return LE; + case GE_OP: return GE; + case EQ_OP: return EQ; + case NE_OP: return NE; + case AND_OP: return LAND; + case OR_OP: return LOR; + default: assert(false && "unreachable"); throw; + } + } + + static OP_T get_op(char token){ + switch(token){ + case '*': return MUL; + case '/': return DIV; + case '%': return REM; + case '+': return ADD; + case '-': return SUB; + case '<': return LT; + case '>': return GT; + case '&': return AND; + case '^': return XOR; + case '|': return OR; + default: assert(false && "unreachable"); throw; + } + } + +public: + binary_operator(token_type op, node *lhs, node *rhs) + : op_(get_op(op)), lhs_(lhs), rhs_(rhs) { } + binary_operator(char op, node *lhs, node *rhs) + : op_(get_op(op)), lhs_(lhs), rhs_(rhs){ } + +private: + const OP_T op_; + const node *lhs_; + const node *rhs_; +}; + + +class constant: public node{ +public: + constant(int value): value_(value) { } + +private: + const int value_; +}; + +class identifier: public node{ +public: + identifier(char *&name): name_(name) { } + +private: + std::string name_; +}; + +class string_literal: public node{ +public: + string_literal(char *&value): value_(value) { } + +public: + std::string value_; +}; + +class unary_operator: public node{ +public: + unary_operator(token_type token, node *arg): token_(token), arg_(arg) { } + +private: + const token_type token_; + const node *arg_; +}; + +class cast_operator: public node{ +public: + cast_operator(token_type type, node *arg): type_(type), arg_(arg) { } + +public: + const token_type type_; + const node *arg_; +}; + +class conditional_expression: public node{ +public: + conditional_expression(node *cond, node *true_value, node *false_value) + : cond_(cond), true_value_(true_value), false_value_(false_value) { } + +public: + const node *cond_; + const node *true_value_; + const node *false_value_; +}; + +class assignment_expression: public node{ + typedef binary_operator::OP_T op_t; + +public: + assignment_expression(node *lvalue, token_type op, node *rvalue) + : lvalue_(lvalue), op_(binary_operator::get_op(op)), rvalue_(rvalue) { } + +public: + op_t op_; + const node *lvalue_; + const node *rvalue_; +}; + +class compound_statement: public node{ +public: + compound_statement() : statements_() {} + compound_statement(node *stmt): statements_{stmt} {} + compound_statement* append(node *stmt) { statements_.push_back(stmt); return this; } + +private: + std::list statements_; +}; + +class selection_statement: public node{ +public: + selection_statement(node *cond, node *if_value, node *else_value = nullptr) + : cond_(cond), if_value_(if_value), else_value_(else_value) { } + +public: + const node *cond_; + const node *if_value_; + const node *else_value_; +}; + +class iteration_statement: public node{ +public: + iteration_statement(node *init, node *stop, node *exec, node *statements) + : init_(init), stop_(stop), exec_(exec), statements_(statements) { } + +private: + const node *init_; + const node *stop_; + const node *exec_; + const node *statements_; +}; + +class no_op: public node { }; + +// Types +class declarator: public node{ + +}; + +class pointer_declarator: public declarator{ +public: + pointer_declarator(unsigned order) + : order_(order) { } + + pointer_declarator *inc(){ + order_ += 1; + return this; + } + +private: + unsigned order_; +}; + +class tile_declarator: public declarator{ +public: + tile_declarator(node *shapes) + : shapes_((list*)(shapes)) { } + +public: + const list* shapes_; +}; + +class parameter: public declarator { +public: + parameter(token_type type, node *decl) + : type_(type), decl_(decl) { } + +public: + const token_type type_; + const node *decl_; +}; + +class function_declarator: public declarator{ +public: + function_declarator(node *args) + : args_((list)args) { } + +public: + const list args_; +}; + +class compound_declarator: public declarator{ +public: + compound_declarator(node *ptr, node *tile) + : ptr_(ptr), tile_(tile) { } + +public: + const node *ptr_; + const node *tile_; +}; + +class init_declarator : public declarator{ +public: + init_declarator(node *decl, node *initializer) + : decl_(decl), initializer_(initializer){ } + +public: + const node *decl_; + const node *initializer_; +}; + +class declaration: public node{ +public: + declaration(node *spec, node *init) + : spec_(spec), init_(init) { } + +public: + const node *spec_; + const node *init_; +}; + +class type: public node{ +public: + type(token_type spec, node * decl) + : spec_(spec), decl_(decl) { } + +public: + const token_type spec_; + const node *decl_; +}; + +class translation_unit: public node{ +public: + translation_unit(node *item) + : decls_(item) { } + + translation_unit *add(node *item) { + decls_.append(item); + return this; + } + +private: + list decls_; +}; + +class function_definition: public node{ +public: + function_definition(node *header, node *body) + : header_((declarator *)header), body_((compound_statement*)body) { } + +public: + const declarator *header_; + const compound_statement *body_; +}; + +} diff --git a/cmake/FindLLVM.cmake b/cmake/FindLLVM.cmake deleted file mode 100644 index b3196d444..000000000 --- a/cmake/FindLLVM.cmake +++ /dev/null @@ -1,88 +0,0 @@ -# - Find LLVM -# This module can be used to find LLVM. -# It requires that the llvm-config executable be available on the system path. -# Once found, llvm-config is used for everything else. -# -# Typical usage could be: -# find_package(LLVM QUIET REQUIRED COMPONENTS jit native interpreter) -# -# If the QUIET flag is not set, the specified components and LLVM version are -# outputted. -# -# If the COMPONENTS are not set, the default set of "all" is used. -# -# The following variables are set: -# -# LLVM_FOUND - Set to YES if LLVM is found. -# LLVM_VERSION - Set to the decimal version of the LLVM library. -# LLVM_C_FLAGS - All flags that should be passed to a C compiler. -# LLVM_CXX_FLAGS - All flags that should be passed to a C++ compiler. -# LLVM_CPP_FLAGS - All flags that should be passed to the C pre-processor. -# LLVM_LD_FLAGS - Additional flags to pass to the linker. -# LLVM_LIBRARY_DIRS - A list of directories where the LLVM libraries are located. -# LLVM_INCLUDE_DIRS - A list of directories where the LLVM headers are located. -# LLVM_LIBRARIES - A list of libraries which should be linked against. - -# A macro to run llvm config -macro(_llvm_config _var_name) - # Firstly, locate the LLVM config executable - find_program(_llvm_config_exe - NAMES llvm-config - PATHS /home/philippe/Development/llvm-tlvm/build/bin/ - DOC "llvm-config executable location" - ) - - # If no llvm-config executable was found, set the output variable to not - # found. - if(NOT _llvm_config_exe) - set(${_var_name} "${_var_name}-NOTFOUND") - else(NOT _llvm_config_exe) - # Otherwise, run llvm-config - execute_process( - COMMAND ${_llvm_config_exe} ${ARGN} - OUTPUT_VARIABLE ${_var_name} - RESULT_VARIABLE _llvm_config_retval - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - if(RESULT_VARIABLE) - message(SEND_ERROR - "Error running llvm-config with arguments: ${ARGN}") - endif(RESULT_VARIABLE) - endif(NOT _llvm_config_exe) -endmacro(_llvm_config) - -# The default set of components -set(_llvm_components all) - -# If components have been specified via find_package, use them -if(LLVM_FIND_COMPONENTS) - set(_llvm_components ${LLVM_FIND_COMPONENTS}) -endif(LLVM_FIND_COMPONENTS) - -if(NOT LLVM_FIND_QUIETLY) - message(STATUS "Looking for LLVM components: ${_llvm_components}") -endif(NOT LLVM_FIND_QUIETLY) - -_llvm_config(LLVM_VERSION --version) -_llvm_config(LLVM_C_FLAGS --cflags) -_llvm_config(LLVM_CXX_FLAGS --cxxflags) -_llvm_config(LLVM_CPP_FLAGS --cppflags) -_llvm_config(LLVM_LD_FLAGS --ldflags) -_llvm_config(LLVM_LIBRARY_DIRS --libdir) -_llvm_config(LLVM_INCLUDE_DIRS --includedir) -_llvm_config(LLVM_LIBRARIES --libs) - -if(NOT LLVM_FIND_QUIETLY) - message(STATUS "Found LLVM version: ${LLVM_VERSION}") -endif(NOT LLVM_FIND_QUIETLY) - -# handle the QUIETLY and REQUIRED arguments and set LLVM_FOUND to TRUE if -# all listed variables are TRUE -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(LLVM - DEFAULT_MSG - LLVM_LIBRARIES - LLVM_INCLUDE_DIRS - LLVM_LIBRARY_DIRS) - -# vim:sw=4:ts=4:autoindent diff --git a/conv.cpp b/conv.cpp deleted file mode 100644 index fa99d301e..000000000 --- a/conv.cpp +++ /dev/null @@ -1,456 +0,0 @@ -#include - -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/Host.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/TargetRegistry.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/CodeGen/TargetPassConfig.h" -#include "llvm/Support/Debug.h" -#include "llvm/Transforms/Utils/Cloning.h" - -// Index computation -inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, - int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) -{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; } - -template -void cpp_conv_nchw(int32_t C, int32_t N, int32_t K, - int32_t D, int32_t H, int32_t W, - int32_t T, int32_t R, int32_t S, - int32_t pad_d, int32_t pad_h, int32_t pad_w, - int32_t stride_d, int32_t stride_h, int32_t stride_w, - int32_t M, int32_t P, int32_t Q, - std::vector>& O, IN_DTYPE* I, IN_DTYPE* F) -{ - size_t num_outputs = O.size(); - static const int PACK_IN = 1; - static const int PACK_OUT = 1; - if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4"); - if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4"); - C /= PACK_IN; - K /= PACK_OUT; - int32_t Kout = K; - IN_DTYPE accs[PACK_OUT]; - for(size_t o = 0; o < num_outputs; o++) - for(int32_t m = 0 ; m < M; ++m) - for(int32_t p = 0 ; p < P; ++p) - for(int32_t q = 0; q < Q; ++q) - for(int32_t n = 0; n < N; ++n) - for(int32_t k = 0; k < Kout ; ++k) - { - for(int32_t i = 0 ; i < PACK_OUT; ++i) - accs[i] = 0; - int32_t mm = m*stride_d - pad_d; - int32_t pp = p*stride_h - pad_h; - int32_t qq = q*stride_w - pad_w; - for(int32_t kk = 0; kk < PACK_OUT; ++kk) - for(int32_t c = 0; c < C; ++c) - for(int32_t t = 0; t < T; ++t) - for(int32_t r = 0; r < R; ++r) - for(int32_t s = 0; s < S; ++s){ - int32_t d = mm + t; - int32_t h = pp + r; - int32_t w = qq + s; - bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W); - IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0; - IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)]; - accs[kk] += i*f; - } - O[o][idx(n, k, m, p, q, N, K, M, P, Q)] = accs[0]; - } -} - -void autotune(llvm::TargetMachine *machine, llvm::Module &module){ - // Target parameters - std::vector ranges = { - // asm - 2, 16, 1, 64, - // bsn - 2, 16, 1, 64, - // pa - 1, 2, 4, 8, - // pb - 1, 2, 4, - // sm - 2, 1, 16, 2, 2, 2 - }; - - // Function - llvm::Function *F = module.getFunction("kernel"); - - // Auto-tuning - llvm::legacy::PassManager pass; - llvm::TargetPassConfig *pass_config = static_cast(machine)->createPassConfig(pass); - llvm::FunctionPass *tuning_params = pass_config->createTargetTuningParameters(); - tuning_params->runOnFunction(*F); - - - // Gather all parameters - llvm::DenseSet unique; - llvm::SmallVector params; - for(llvm::BasicBlock &bb: *F) - for(llvm::Instruction &instr: bb){ - // Get tuning parameters for this particular instruction - std::vector tuning_params; - machine->getTargetTuner().getParams(&instr, tuning_params); - for(llvm::TargetTuner::ParamType ¶m: tuning_params){ - // This parameter has not been seen before - if(unique.insert(param.Value).second){ - std::cout << "PARAM: " << instr.getName().data() << " " << param.Name << std::endl; - params.push_back(param.Value); - } - } - } - - // Gather all constraints - std::vector> constraints; - for(llvm::BasicBlock &bb: *F) - for(llvm::Instruction &instr: bb) - machine->getTargetTuner().getConstraints(&instr, constraints); - - // Assign parameters - std::cout << params.size() << " " << ranges.size() << std::endl; - for(unsigned i = 0; i < params.size(); i++) - *params[i] = ranges[i]; - - // Verify constraints - bool valid = true; - for(auto &constraint: constraints){ - valid = valid & constraint(); - } - - if(!valid){ - printf("Invalid kernel parameters\n"); - exit(EXIT_FAILURE); - } -} - -int main(){ - std::string error; - - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - - // Module - llvm::LLVMContext context; - std::unique_ptr module = llvm::make_unique("TLVM toy example", context); - llvm::IRBuilder<> builder(context); - - unsigned RR = 3, SS = 3; - unsigned Nfilt = RR * SS; - unsigned block = 8; - unsigned nlut = (block + Nfilt - 1)/Nfilt * Nfilt; - - // Globals - llvm::Type* bool_t = llvm::Type::getInt1Ty(context); - llvm::Type* mask_tile_t = llvm::TileType::get(bool_t, 2); - llvm::Type* numeric_t = llvm::Type::getFloatTy(context); - llvm::PointerType* numeric_ptr_t = llvm::PointerType::get(numeric_t, 0); - llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context); - llvm::PointerType* lut_ptr_t = llvm::PointerType::get(int32_t, 4); - llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context); - - llvm::Type* tile_t = llvm::TileType::get(numeric_t, 2); - llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1); - llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2); - llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1); - llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2); - - llvm::PointerType* tile_ptr_t = llvm::PointerType::get(tile_t, 0); - llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t}); - llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t}); - llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t}); - llvm::Function* gtp_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_1d, {int32_slice_t->getPointerTo(4), int32_t->getPointerTo(4), int32_slice_t}); - llvm::Function* stp_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_1d, {int32_slice_t->getPointerTo(4), int32_slice_t}); - - llvm::Function* gtp_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile_ptr_t, numeric_ptr_t, int32_tile_t}); - llvm::Function* stp_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile_ptr_t, int32_tile_t}); - llvm::Intrinsic::ID mma_id = llvm::Intrinsic::tlvm_mma_nt; - llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t}); - llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t}); - llvm::Function* outer_and_int32 = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int32_slice_t, int32_slice_t}); - llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma_id, {tile_t}); - llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_2d, {tile_t}); - llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t}); - llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t}); - - llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile_t, tile_ptr_t, mask_tile_t}); - llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile_t, tile_ptr_t, mask_tile_t}); - - // Hyperparameters - llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0); - llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1); - llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2); - - // Constants - llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0); - llvm::Constant *_f0 = llvm::ConstantFP::get(numeric_t, 0); - llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn}); - - // LUT - unsigned num_delta = nlut; - unsigned num_inc_delta = nlut; - unsigned num_masks = nlut; - unsigned num_inc_masks = nlut; - unsigned cst_size = num_delta + num_inc_delta + num_masks + num_inc_masks; - llvm::GlobalVariable *lut_array = - new llvm::GlobalVariable(*module, llvm::ArrayType::get(int32_t, cst_size), false, llvm::GlobalVariable::InternalLinkage, - nullptr, "lut_array", nullptr, llvm::GlobalVariable::NotThreadLocal, 4); - llvm::Value *cst_ptr = builder.CreateBitCast(lut_array, lut_ptr_t); - - - // Function - llvm::FunctionType* prototype = llvm::FunctionType::get(llvm::Type::getVoidTy(context), std::vector{numeric_ptr_t, numeric_ptr_t, numeric_ptr_t, int32_t, int32_t, int32_t, int32_t, int32_t}, false); - llvm::Function* F = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, "kernel", module.get()); - std::vector args; - F->addAttribute(1, llvm::Attribute::ReadOnly); - F->addAttribute(1, llvm::Attribute::NoAlias); - F->addAttribute(2, llvm::Attribute::ReadOnly); - F->addAttribute(2, llvm::Attribute::NoAlias); - std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(args), [&](llvm::Argument& x){ return &x;}); - llvm::Value *base_pc = args[0], *base_pa = args[1], *base_pb = args[2]; - llvm::Value *C = args[3], *H = args[4], *W = args[5], *N = args[6], *K = args[7]; - llvm::Value *R = builder.getInt32(RR), *S = builder.getInt32(SS); - - // All basic blocks - llvm::BasicBlock* PrologBB = llvm::BasicBlock::Create(context, "prologue", F); - llvm::BasicBlock* LoopBB = llvm::BasicBlock::Create(context, "loop", F); - llvm::BasicBlock* EarlyExitBB = llvm::BasicBlock::Create(context, "early_exit", F); - llvm::BasicBlock* LastIterBB = llvm::BasicBlock::Create(context, "last_iter", F); - llvm::BasicBlock* EpilogueBB = llvm::BasicBlock::Create(context, "epilogue", F); - - - // First basic block - builder.SetInsertPoint(PrologBB); - llvm::Value* sa0 = builder.CreateCall(read_slice_x, {bm}, "sa0"); - llvm::Value* sb0 = builder.CreateCall(read_slice_y, {bn}, "sb0"); - llvm::Value* sa1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sa1"); - llvm::Value* sb1 = builder.CreateCall(range, {builder.getInt32(0), bk}, "sb1"); - - llvm::Value* lda_w = builder.getInt32(1); - llvm::Value* lda_h = builder.CreateMul(lda_w, W); - llvm::Value* lda_c = builder.CreateMul(lda_h, H); - llvm::Value* lda_n = builder.CreateMul(lda_c, C); - - llvm::Value* ldb_s = builder.getInt32(1); - llvm::Value* ldb_r = builder.CreateMul(ldb_s, S); - llvm::Value* ldb_c = builder.CreateMul(ldb_r, R); - llvm::Value* ldb_k = builder.CreateMul(ldb_c, C); - - llvm::Value* CRS = builder.CreateMul(C, builder.CreateMul(R, S)); - llvm::Value* PQN = builder.CreateMul(H, builder.CreateMul(W, N)); - - // Images HWN offset - llvm::Value* sa_hw = builder.CreateUDiv(sa0, builder.CreateCall(splat_1d, {bm, N})); - llvm::Value* sa_n = builder.CreateURem(sa0, builder.CreateCall(splat_1d, {bm, N})); - llvm::Value* sa_h = builder.CreateUDiv(sa_hw, builder.CreateCall(splat_1d, {bm, W})); - llvm::Value* sa_w = builder.CreateURem(sa_hw, builder.CreateCall(splat_1d, {bm, W})); - llvm::Value* offa_0 = builder.CreateMul(sa_n, builder.CreateCall(splat_1d, {bm, lda_n})); - offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_h, builder.CreateCall(splat_1d, {bm, lda_h}))); - offa_0 = builder.CreateAdd(offa_0, builder.CreateMul(sa_w, builder.CreateCall(splat_1d, {bm, lda_w}))); - // Images CRS offset - llvm::Value* sa_cr = builder.CreateUDiv(sa1, builder.CreateCall(splat_1d, {bk, S})); - llvm::Value* sa_s = builder.CreateURem(sa1, builder.CreateCall(splat_1d, {bk, S})); - llvm::Value* sa_c = builder.CreateUDiv(sa_cr, builder.CreateCall(splat_1d, {bk, R})); - llvm::Value* sa_r = builder.CreateURem(sa_cr, builder.CreateCall(splat_1d, {bk, R})); - llvm::Value* offa_1 = builder.CreateMul(sa_c, builder.CreateCall(splat_1d, {bk, lda_c})); - offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_r, builder.CreateCall(splat_1d, {bk, lda_h}))); - offa_1 = builder.CreateAdd(offa_1, builder.CreateMul(sa_s, builder.CreateCall(splat_1d, {bk, lda_w}))); - // Images pointer - llvm::Value* off_a = builder.CreateCall(outer_add, {offa_0, offa_1}); - llvm::Value* start_pa = builder.CreateCall(gtp_2d, {base_pa, off_a}, "start_pa"); - llvm::LoadInst* start_aa = builder.CreateLoad(start_pa, false, "start_aa"); - llvm::Value* start_a = builder.CreateCall(reshape, {start_aa, bm, bk}, "start_a"); - // Filters pointer - llvm::Value* tldb_s = builder.CreateCall(splat_1d, {bk, K}); - llvm::Value* off_b = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb_s)}, "off_b"); - llvm::Value* start_pb = builder.CreateCall(gtp_2d, {base_pb, off_b}, "start_pb"); - llvm::Value* start_bb = builder.CreateLoad(start_pb, false, "start_bb"); - llvm::Value* start_b = builder.CreateCall(reshape, {start_bb, bn, bk}, "start_b"); - // Filters increment - llvm::Value* inc_b_0 = builder.CreateCall(splat_1d, {bn, _s0}, "inc_b_0"); - llvm::Value* inc_b_1 = builder.CreateCall(splat_1d, {bk, builder.CreateMul(bk, ldb_k)}, "inc_b_1"); - llvm::Value* inc_b = builder.CreateCall(outer_add, {inc_b_0, inc_b_1}, "inc_b"); - // Pointers to constant memory - llvm::Value* base_incdelta = builder.CreateGEP(cst_ptr, builder.getInt32(0)); - llvm::Value* base_delta = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_delta)); - llvm::Value* base_incmask = builder.CreateGEP(cst_ptr, builder.getInt32(num_delta)); - llvm::Value* base_mask = builder.CreateGEP(cst_ptr, builder.getInt32(num_inc_masks)); - // Delta pointers - llvm::Value* start_pincdelta = builder.CreateCall(gtp_1d, {base_incdelta, sa1}, "start_pincdelta"); - llvm::Value* start_pdelta = builder.CreateCall(gtp_1d, {base_delta, builder.CreateCall(splat_1d, {bk, _s0})}, "start_pdelta"); - // Masks - llvm::Value* _1 = builder.CreateCall(splat_1d, {bk, builder.getInt32(1)}); - llvm::Value* mask_a_1 = builder.CreateShl(_1, sa1); - llvm::Value* start_pincmask = builder.CreateCall(gtp_1d, {base_incmask, sa0}, "start_pincmask"); - llvm::Value* start_pmask = builder.CreateCall(gtp_1d, {base_mask, sa0}, "start_pmask"); - // Enter loop - builder.CreateBr(LoopBB); - builder.SetInsertPoint(LoopBB); - // PHI nodes - llvm::PHINode* c = builder.CreatePHI(_0->getType(), 3, "c"); - llvm::PHINode* crs = builder.CreatePHI(int32_t, 3, "crs"); - llvm::PHINode* pa = builder.CreatePHI(start_pa->getType(), 3, "pa"); - llvm::PHINode* pb = builder.CreatePHI(start_pb->getType(), 3, "pb"); - llvm::PHINode *a = builder.CreatePHI(start_a->getType(), 3, "a"); - llvm::PHINode *b = builder.CreatePHI(start_b->getType(), 3, "b"); - llvm::PHINode *pdelta = builder.CreatePHI(start_pdelta->getType(), 3); - llvm::PHINode *pincdelta = builder.CreatePHI(start_pincdelta->getType(), 3); - llvm::PHINode *pmasks = builder.CreatePHI(start_pmask->getType(), 3); - llvm::PHINode *pincmasks = builder.CreatePHI(start_pincmask->getType(), 3); - llvm::Value* next_c = builder.CreateCall(mma, {a, b, c}, "next_c"); - c->addIncoming(_0, PrologBB); - c->addIncoming(next_c, LoopBB); - // Induction variable - llvm::Value *next_crs = builder.CreateSub(crs, bk); - crs->addIncoming(CRS, PrologBB); - crs->addIncoming(next_crs, LoopBB); - // Update pointer - llvm::Value *inc_delta = builder.CreateLoad(pincdelta); - llvm::Value *inc_mask = builder.CreateLoad(pincmasks); - llvm::Value *inc_a_1 = builder.CreateLoad(pdelta); - llvm::Value *inc_a_0 = builder.CreateCall(splat_1d, {bm, builder.getInt32(0)}); - llvm::Value *inc_a = builder.CreateCall(outer_add, {inc_a_0, inc_a_1}); - llvm::Value *next_pa = builder.CreateCall(stp_2d, {pa, inc_a}, "next_pa"); - llvm::Value *next_pb = builder.CreateCall(stp_2d, {pb, inc_b}, "next_pb"); - llvm::Value *next_pdelta = builder.CreateCall(stp_1d, {pdelta, inc_delta}, "next_pdelta"); - llvm::Value *next_pincdelta = builder.CreateCall(stp_1d, {pincdelta, inc_delta}, "next_pincdelta"); - llvm::Value *next_pmask = builder.CreateCall(stp_1d, {pmasks, inc_mask}, "next_pmask"); - llvm::Value *next_pincmask = builder.CreateCall(stp_1d, {pincmasks, inc_mask}, "next_pincmask"); - pdelta->addIncoming(start_pdelta, PrologBB); - pdelta->addIncoming(next_pdelta, LoopBB); - pincdelta->addIncoming(start_pincdelta, PrologBB); - pincdelta->addIncoming(next_pincdelta, LoopBB); - pmasks->addIncoming(start_pmask, PrologBB); - pmasks->addIncoming(next_pmask, LoopBB); - pincmasks->addIncoming(start_pincmask, PrologBB); - pincmasks->addIncoming(next_pincmask, LoopBB); - pa->addIncoming(start_pa, PrologBB); - pa->addIncoming(next_pa, LoopBB); - pb->addIncoming(start_pb, PrologBB); - pb->addIncoming(next_pb, LoopBB); - // End condition - llvm::Value* no_bounds_check = builder.CreateICmpSGT(next_crs, builder.getInt32(0), "no_bounds_check"); - // Masks - llvm::Value* mask_a_0 = builder.CreateLoad(pmasks, "mask_a_0"); - llvm::Value* mask_a_i32 = builder.CreateCall(outer_and_int32, {mask_a_0, mask_a_1}, "mask_a_i32"); - llvm::Value* mask_a = builder.CreateICmpNE(mask_a_i32, llvm::ConstantTile::get(_s0, {bm, bk}), "mask_a"); - llvm::Value* mask_b = builder.CreateCall(splat_2d, {bn, bk, no_bounds_check}, "mask_b"); - // Pre-fetch - llvm::Value* next_aa = builder.CreateCall(masked_load, {next_pa, mask_a}, "next_aa"); - llvm::Value* next_bb = builder.CreateCall(masked_load, {next_pb, mask_b}, "next_bb"); - llvm::Value* next_a = builder.CreateCall(reshape, {next_aa, bm, bk}, "next_a"); - llvm::Value* next_b = builder.CreateCall(reshape, {next_bb, bn, bk}, "next_b"); - a->addIncoming(start_a, PrologBB); - a->addIncoming(next_a, LoopBB); - b->addIncoming(start_b, PrologBB); - b->addIncoming(next_b, LoopBB); - // End condition - builder.CreateCondBr(no_bounds_check, LoopBB, EarlyExitBB); - // Early exit - builder.SetInsertPoint(EarlyExitBB); - llvm::Value* exit = builder.CreateICmpSLE(next_crs, _s0); - builder.CreateCondBr(exit, EpilogueBB, LastIterBB); - // Last Iteration - builder.SetInsertPoint(LastIterBB); - llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(sb0, builder.CreateCall(splat_1d, {bn, K})); - llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(sb1, builder.CreateCall(splat_1d, {bk, next_crs})); - llvm::Value* last_maskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "last_maskb"); - llvm::Value* last_bb = builder.CreateCall(masked_load, {next_pb, last_maskb}, "last_bb"); - llvm::Value* last_b = builder.CreateCall(reshape, {last_bb, bn, bk}, "last_b"); - llvm::Value* loop = builder.CreateICmpSGT(next_crs, _s0); - a->addIncoming(next_a, LastIterBB); - b->addIncoming(last_b, LastIterBB); - c->addIncoming(next_c, LastIterBB); - crs->addIncoming(next_crs, LastIterBB); - pa->addIncoming(next_pa, LastIterBB); - pb->addIncoming(next_pb, LastIterBB); - pdelta->addIncoming(next_pdelta, LastIterBB); - pincdelta->addIncoming(next_pincdelta, LastIterBB); - pmasks->addIncoming(next_pmask, LastIterBB); - pincmasks->addIncoming(next_pincmask, LastIterBB); - builder.CreateCondBr(loop, LoopBB, EpilogueBB); - - // Epilogue - builder.SetInsertPoint(EpilogueBB); - llvm::Value* sc_pqn = builder.CreateCall(read_slice_x, {bm}, "sc_pqn"); - llvm::Value* sc_k = builder.CreateCall(read_slice_y, {bn}, "sc_k"); - // Output strides - llvm::Value* ldc_q = builder.getInt32(1); - llvm::Value* ldc_p = builder.CreateMul(lda_w, W); - llvm::Value* ldc_k = builder.CreateMul(lda_h, H); - llvm::Value* ldb_n = builder.CreateMul(lda_c, K); - // Output PQN offset - llvm::Value* sc_pq = builder.CreateUDiv(sc_pqn, builder.CreateCall(splat_1d, {bm, N})); - llvm::Value* sc_n = builder.CreateURem(sc_pqn, builder.CreateCall(splat_1d, {bm, N})); - llvm::Value* sc_p = builder.CreateUDiv(sc_pq, builder.CreateCall(splat_1d, {bm, W})); - llvm::Value* sc_q = builder.CreateURem(sc_pq, builder.CreateCall(splat_1d, {bm, W})); - llvm::Value* offc0 = builder.CreateMul(sc_n, builder.CreateCall(splat_1d, {bm, ldb_n})); - offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_p, builder.CreateCall(splat_1d, {bm, ldc_p}))); - offc0 = builder.CreateAdd(offc0, builder.CreateMul(sc_q, builder.CreateCall(splat_1d, {bm, ldc_q}))); - // Output K offset - llvm::Value* offc1 = builder.CreateMul(sc_k, builder.CreateCall(splat_1d, {bn, ldc_k})); - // Output pointer - llvm::Value* offc = builder.CreateCall(outer_add, {offc0, offc1}); - llvm::Value* pc = builder.CreateCall(gtp_2d, {base_pc, offc}); - // Output masks - llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sc_pqn, builder.CreateCall(splat_1d, {bm, PQN})); - llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sc_k, builder.CreateCall(splat_1d, {bn, K})); - llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}); - builder.CreateCall(masked_store, {next_c, pc, maskc}); - builder.CreateRet(NULL); - - - // Set metadata - llvm::Metadata *md_args[] = { - llvm::ValueAsMetadata::get(F), - llvm::MDString::get(context, "kernel"), - llvm::ValueAsMetadata::get(llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)) - }; - module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(context, md_args)); - - // Machine - module->setTargetTriple("nvptx64-nvidia-cuda"); - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), "sm_52", "", - llvm::TargetOptions(), llvm::Reloc::Model(), - llvm::CodeModel::Model(), llvm::CodeGenOpt::Aggressive); - module->setDataLayout(machine->createDataLayout()); - - // Auto-tuning - autotune(machine, *module); - - // Emit - llvm::legacy::PassManager pass; - llvm::SmallVector buffer; - llvm::raw_svector_ostream stream(buffer); - machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); - pass.run(*module); - std::string src(buffer.begin(), buffer.end()); - - // Execute - std::cout << src << std::endl; -} diff --git a/gemm.cpp b/gemm.cpp deleted file mode 100644 index 5433fd8d9..000000000 --- a/gemm.cpp +++ /dev/null @@ -1,342 +0,0 @@ -#include - -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/Host.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/TargetRegistry.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/CodeGen/TargetPassConfig.h" -#include "llvm/Support/Debug.h" -#include "llvm/Transforms/Utils/Cloning.h" - - -bool AT = false; -bool BT = true; - - -void autotune(llvm::TargetMachine *machine, llvm::Module &module){ - // Target parameters - std::vector ranges = { - // asm - 2, 16, 1, 64, - // bsn - 2, 16, 1, 64, - // pa - 1, 2, 4, 8, - // pb - 1, 2, 4, - // sm - 2, 1, 16, 2, 2, 2 - }; - - // Function - llvm::Function *F = module.getFunction("kernel"); - - // Auto-tuning - llvm::legacy::PassManager pass; - llvm::TargetPassConfig *pass_config = static_cast(machine)->createPassConfig(pass); - llvm::FunctionPass *tuning_params = pass_config->createTargetTuningParameters(); - tuning_params->runOnFunction(*F); - - - // Gather all parameters - llvm::DenseSet unique; - llvm::SmallVector params; - for(llvm::BasicBlock &bb: *F) - for(llvm::Instruction &instr: bb){ - // Get tuning parameters for this particular instruction - std::vector tuning_params; - machine->getTargetTuner().getParams(&instr, tuning_params); - for(llvm::TargetTuner::ParamType ¶m: tuning_params){ - // This parameter has not been seen before - if(unique.insert(param.Value).second){ - std::cout << instr.getName().data() << " " << param.Name << std::endl; - params.push_back(param.Value); - } - } - } - - // Gather all constraints - std::vector> constraints; - for(llvm::BasicBlock &bb: *F) - for(llvm::Instruction &instr: bb) - machine->getTargetTuner().getConstraints(&instr, constraints); - - // Assign parameters - std::cout << params.size() << " " << ranges.size() << std::endl; - for(unsigned i = 0; i < params.size(); i++) - *params[i] = ranges[i]; - - // Verify constraints - bool valid = true; - for(auto &constraint: constraints){ - valid = valid & constraint(); - } - - if(!valid){ - printf("Invalid kernel parameters\n"); - exit(EXIT_FAILURE); - } -} - -int main(){ -// llvm::DebugFlag = true; - - std::string error; - - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - - // Module - llvm::LLVMContext context; - std::unique_ptr module = llvm::make_unique("TLVM toy example", context); - llvm::IRBuilder<> builder(context); - - // Globals - llvm::Type* bool_t = llvm::Type::getInt1Ty(context); - llvm::Type* mask_tile_t = llvm::TileType::get(bool_t, 2); - llvm::Type* numeric_t = llvm::Type::getFloatTy(context); - llvm::PointerType* numeric_ptr_t = llvm::PointerType::get(numeric_t, 0); - llvm::IntegerType* int32_t = llvm::Type::getInt32Ty(context); - llvm::IntegerType* int1_t = llvm::Type::getInt1Ty(context); - - llvm::Type* tile2d_t = llvm::TileType::get(numeric_t, 2); - llvm::Type* tile3d_t = llvm::TileType::get(numeric_t, 3); - llvm::Type* int32_slice_t = llvm::TileType::get(int32_t, 1); - llvm::Type* int32_tile_t = llvm::TileType::get(int32_t, 2); - llvm::Type* int1_slice_t = llvm::TileType::get(int1_t, 1); - llvm::Type* int1_tile_t = llvm::TileType::get(int1_t, 2); - - llvm::PointerType* tile2d_ptr_t = llvm::PointerType::get(tile2d_t, 0); - llvm::Function* read_slice_x = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_x, {int32_slice_t}); - llvm::Function* read_slice_y = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_read_slice_y, {int32_slice_t}); - llvm::Function* range = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_range, {int32_slice_t}); - llvm::Function* gtp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_gtp_2d, {tile2d_ptr_t, numeric_ptr_t, int32_tile_t}); - llvm::Function* stp = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_stp_2d, {tile2d_ptr_t, int32_tile_t}); - llvm::Intrinsic::ID mma_id; - if(!AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_nn; - if(!AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_nt; - if(AT && !BT) mma_id = llvm::Intrinsic::tlvm_mma_tn; - if(AT && BT) mma_id = llvm::Intrinsic::tlvm_mma_tt; - llvm::Function* outer_add = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_add, {int32_tile_t, int32_slice_t, int32_slice_t}); - llvm::Function* outer_and = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_outer_and, {int1_tile_t, int1_slice_t, int1_slice_t}); - llvm::Function* mma = llvm::Intrinsic::getDeclaration(module.get(), mma, {tile3d_t}); - llvm::Function* reshape = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_reshape_3d, {tile3d_t, tile2d_t}); - llvm::Function* splat_2d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_2d, {mask_tile_t, bool_t}); - llvm::Function* splat_1d = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_splat_1d, {int32_slice_t, int32_t}); - llvm::Function* masked_load = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_load, {tile2d_t, tile2d_ptr_t, mask_tile_t}); - llvm::Function* masked_store = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::tlvm_masked_store, {tile2d_t, tile2d_ptr_t, mask_tile_t}); - - // Hyperparameters - llvm::Hyperparameter *bm = llvm::Hyperparameter::get(int32_t, 0); - llvm::Hyperparameter *bn = llvm::Hyperparameter::get(int32_t, 1); - llvm::Hyperparameter *bk = llvm::Hyperparameter::get(int32_t, 2); - llvm::Hyperparameter *br = llvm::Hyperparameter::get(int32_t, 3); - - // Constants - llvm::Constant *_s0 = llvm::ConstantInt::get(int32_t, 0); - llvm::Constant *_f0 = llvm::ConstantFP::get(numeric_t, 0); - llvm::Constant *_0 = llvm::ConstantTile::get(_f0, {bm, bn}); - - // Function - llvm::FunctionType* prototype = llvm::FunctionType::get(llvm::Type::getVoidTy(context), std::vector{numeric_ptr_t, numeric_ptr_t, numeric_ptr_t, int32_t, int32_t, int32_t, int32_t}, false); - llvm::Function* F = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, "kernel", module.get()); - std::vector arguments; - F->addAttribute(1, llvm::Attribute::ReadOnly); - F->addAttribute(1, llvm::Attribute::NoAlias); - F->addAttribute(2, llvm::Attribute::ReadOnly); - F->addAttribute(2, llvm::Attribute::NoAlias); - std::transform(F->arg_begin(), F->arg_end(), std::back_inserter(arguments), [&](llvm::Argument& x){ return &x;}); - arguments[0]->setName("pa"); - arguments[1]->setName("pb"); - arguments[2]->setName("pc"); - arguments[3]->setName("M"); - arguments[4]->setName("N"); - arguments[5]->setName("K"); - arguments[6]->setName("bound"); - - // All basic blocks - llvm::BasicBlock* PrologBB = llvm::BasicBlock::Create(context, "prologue", F); - llvm::BasicBlock* LoopBB = llvm::BasicBlock::Create(context, "loop", F); - llvm::BasicBlock* EarlyExitBB = llvm::BasicBlock::Create(context, "early_exit", F); - llvm::BasicBlock* LastIterBB = llvm::BasicBlock::Create(context, "last_iter", F); - llvm::BasicBlock* EpilogueBB = llvm::BasicBlock::Create(context, "epilogue", F); - - - // First basic block - builder.SetInsertPoint(PrologBB); - - llvm::CallInst* aasm = builder.CreateCall(read_slice_x, {bm}, "asm"); - llvm::CallInst* bbsn = builder.CreateCall(read_slice_y, {bn}, "bsn"); - llvm::CallInst* ask = builder.CreateCall(range, {builder.getInt32(0), bk}, "ask"); - llvm::CallInst* bsk = builder.CreateCall(range, {builder.getInt32(0), bk}, "bsk"); - - llvm::Value *M = arguments[3], *N = arguments[4], *K = arguments[5]; - llvm::Value *bound = arguments[6]; - llvm::Value *AS0 = M, *AS1 = K; - llvm::Value *sa0 = aasm, *sa1 = ask; - llvm::Value *ba0 = bm, *ba1 = bk; - llvm::Value *inca0 = _s0, *inca1 = bk; - if(AT){ - std::swap(AS0, AS1); - std::swap(sa0, sa1); - std::swap(ba0, ba1); - std::swap(inca0, inca1); - } - llvm::Value *BS0 = K, *BS1 = N; - llvm::Value *sb0 = bsk, *sb1 = bbsn; - llvm::Value *bb0 = bk, *bb1 = bn; - llvm::Value *incb0 = bk, *incb1 = _s0; - if(BT){ - std::swap(BS0, BS1); - std::swap(sb0, sb1); - std::swap(bb0, bb1); - std::swap(incb0, incb1); - } - - llvm::CallInst* tlda = builder.CreateCall(splat_1d, {ba1, AS0}, "lda"); - llvm::CallInst* tldb = builder.CreateCall(splat_1d, {bb1, BS1}, "ldb"); - llvm::CallInst* offa = builder.CreateCall(outer_add, {sa0, builder.CreateMul(sa1, tlda)}, "offa"); - llvm::CallInst* offb = builder.CreateCall(outer_add, {sb0, builder.CreateMul(sb1, tldb)}, "offb"); - llvm::CallInst* startpa = builder.CreateCall(gtp, {arguments[0], offa}, "startpa"); - llvm::CallInst* startpb = builder.CreateCall(gtp, {arguments[1], offb}, "startpb"); - llvm::LoadInst* startfa = builder.CreateLoad(startpa, "startfa"); - llvm::LoadInst* startfb = builder.CreateLoad(startpb, "startfb"); - llvm::Value* starta = builder.CreateCall(reshape, {startfa, ba0, ba1, br}, "starta"); - llvm::Value* startb = builder.CreateCall(reshape, {startfb, bb0, bb1, br}, "startb"); - llvm::Value* tinca0 = builder.CreateCall(splat_1d, {ba0, builder.CreateMul(inca0, AS0)}, "tinca0"); - llvm::Value* tinca1 = builder.CreateCall(splat_1d, {ba1, builder.CreateMul(inca1, AS1)}); - llvm::Value* tincb0 = builder.CreateCall(splat_1d, {bb0, builder.CreateMul(incb0, BS0)}); - llvm::Value* tincb1 = builder.CreateCall(splat_1d, {bb1, builder.CreateMul(incb1, BS1)}); - llvm::Value* inca = builder.CreateCall(outer_add, {tinca0, tinca1}, "inca"); - llvm::Value* incb = builder.CreateCall(outer_add, {tincb0, tincb1}, "incb"); - // Enter loop - builder.CreateBr(LoopBB); - builder.SetInsertPoint(LoopBB); - // PHI nodes - llvm::PHINode* c = builder.CreatePHI(_0->getType(), 2, "c"); - llvm::PHINode* k = builder.CreatePHI(int32_t, 2, "k"); - llvm::PHINode* pa = builder.CreatePHI(startpa->getType(), 2, "pa"); - llvm::PHINode* pb = builder.CreatePHI(startpb->getType(), 2, "pb"); - llvm::PHINode *a = builder.CreatePHI(starta->getType(), 2, "a"); - llvm::PHINode *b = builder.CreatePHI(startb->getType(), 2, "b"); - llvm::Value* nextc = builder.CreateCall(mma, {a, b, c}, "nextc"); - c->addIncoming(_0, PrologBB); - c->addIncoming(nextc, LoopBB); - // Induction variable - llvm::Value *nextk = builder.CreateSub(k, bk); - k->addIncoming(K, PrologBB); - k->addIncoming(nextk, LoopBB); - // Update pointer - llvm::Value *nextpa = builder.CreateCall(stp, {pa, inca}, "nextpa"); - llvm::Value *nextpb = builder.CreateCall(stp, {pb, incb}, "nextpb"); - pa->addIncoming(startpa, PrologBB); - pa->addIncoming(nextpa, LoopBB); - pb->addIncoming(startpb, PrologBB); - pb->addIncoming(nextpb, LoopBB); - // End condition - llvm::Value* no_bounds_check = builder.CreateICmpSGT(nextk, bound); - // Masks - llvm::Value* maska = builder.CreateCall(splat_2d, {ba0, ba1, no_bounds_check}, "maska"); - llvm::Value* maskb = builder.CreateCall(splat_2d, {bb0, bb1, no_bounds_check}, "maskb"); - // Pre-fetch - llvm::Value* nextfa = builder.CreateCall(masked_load, {nextpa, maska}, "nextfa"); - llvm::Value* nextfb = builder.CreateCall(masked_load, {nextpb, maskb}, "nextfb"); - llvm::Value* nexta = builder.CreateCall(reshape, {nextfa, ba0, ba1, br}, "nexta"); - llvm::Value* nextb = builder.CreateCall(reshape, {nextfb, bb0, bb1, br}, "nextb"); - a->addIncoming(starta, PrologBB); - a->addIncoming(nexta, LoopBB); - b->addIncoming(startb, PrologBB); - b->addIncoming(nextb, LoopBB); - // End condition - builder.CreateCondBr(no_bounds_check, LoopBB, EarlyExitBB); - // Early exit - builder.SetInsertPoint(EarlyExitBB); - llvm::Value* exit = builder.CreateICmpSLE(nextk, _s0); - builder.CreateCondBr(exit, EpilogueBB, LastIterBB); - // Last Iteration - builder.SetInsertPoint(LastIterBB); - llvm::Value* in_bounds_a0 = builder.CreateICmpSLT(aasm, builder.CreateCall(splat_1d, {ba0, M})); - llvm::Value* in_bounds_a1 = builder.CreateICmpSLT(ask, builder.CreateCall(splat_1d, {ba1, bk})); - llvm::Value* in_bounds_b0 = builder.CreateICmpSLT(bbsn, builder.CreateCall(splat_1d, {bb0, N})); - llvm::Value* in_bounds_b1 = builder.CreateICmpSLT(bsk, builder.CreateCall(splat_1d, {bb1, bk})); - llvm::Value* lastmaska = builder.CreateCall(outer_and, {in_bounds_a0, in_bounds_a1}, "lastmaska"); - llvm::Value* lastmaskb = builder.CreateCall(outer_and, {in_bounds_b0, in_bounds_b1}, "lastmaskb"); - llvm::Value* lastfa = builder.CreateCall(masked_load, {nextpa, lastmaska}, "lastfa"); - llvm::Value* lastfb = builder.CreateCall(masked_load, {nextpb, lastmaskb}, "lastfb"); - llvm::Value* lasta = builder.CreateCall(reshape, {lastfa, ba0, ba1, br}, "lasta"); - llvm::Value* lastb = builder.CreateCall(reshape, {lastfb, bb0, bb1, br}, "lastb"); - llvm::Value* loop = builder.CreateICmpSGT(nextk, _s0); - a->addIncoming(lasta, LastIterBB); - b->addIncoming(lastb, LastIterBB); - c->addIncoming(nextc, LastIterBB); - k->addIncoming(nextk, LastIterBB); - pa->addIncoming(nextpa, LastIterBB); - pb->addIncoming(nextpb, LastIterBB); - builder.CreateCondBr(loop, LoopBB, EpilogueBB); - // Epilogue - builder.SetInsertPoint(EpilogueBB); - llvm::CallInst* sm = builder.CreateCall(read_slice_x, {bm}, "sm"); - llvm::CallInst* sn = builder.CreateCall(read_slice_y, {bn}, "sn"); - llvm::CallInst* ldc = builder.CreateCall(splat_1d, {bn, M}, "lda"); - llvm::CallInst* offc = builder.CreateCall(outer_add, {sm, builder.CreateMul(sn, ldc)}, "offc"); - llvm::CallInst* pc = builder.CreateCall(gtp, {arguments[2], offc}, "pc"); - llvm::Value* in_bounds_c0 = builder.CreateICmpSLT(sm, builder.CreateCall(splat_1d, {bm, M})); - llvm::Value* in_bounds_c1 = builder.CreateICmpSLT(sn, builder.CreateCall(splat_1d, {bn, N})); - llvm::Value* maskc = builder.CreateCall(outer_and, {in_bounds_c0, in_bounds_c1}, "maskc"); - builder.CreateCall(masked_store, {nextc, pc, maskc}); - builder.CreateRet(NULL); - - - // Set metadata - llvm::Metadata *md_args[] = { - llvm::ValueAsMetadata::get(F), - llvm::MDString::get(context, "kernel"), - llvm::ValueAsMetadata::get(llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)) - }; - module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(llvm::MDNode::get(context, md_args)); - - // Machine - module->setTargetTriple("nvptx64-nvidia-cuda"); - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), "sm_52", "", - llvm::TargetOptions(), llvm::Reloc::Model(), - llvm::CodeModel::Model(), llvm::CodeGenOpt::Aggressive); - module->setDataLayout(machine->createDataLayout()); - - // Auto-tuning - autotune(machine, *module); - - // Emit - llvm::legacy::PassManager pass; - llvm::SmallVector buffer; - llvm::raw_svector_ostream stream(buffer); - machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); - pass.run(*module); - std::string src(buffer.begin(), buffer.end()); - - // Execute - std::cout << src << std::endl; -} diff --git a/main.cpp b/main.cpp new file mode 100644 index 000000000..5a01e7c68 --- /dev/null +++ b/main.cpp @@ -0,0 +1,14 @@ +#include +#include + +typedef struct yy_buffer_state * YY_BUFFER_STATE; +extern int yyparse(); +extern YY_BUFFER_STATE yy_scan_string(const char * str); +extern void yy_delete_buffer(YY_BUFFER_STATE buffer); + +int main() { + char string[] = "void test(int);"; + YY_BUFFER_STATE buffer = yy_scan_string(string); + yy_delete_buffer(buffer); + return 0; +} diff --git a/parser.y b/parser.y new file mode 100644 index 000000000..3501fba0f --- /dev/null +++ b/parser.y @@ -0,0 +1,305 @@ +%{ +namespace ast{ +class node; +} +using namespace ast; +#define YYSTYPE node* +#include "../ast.h" +using namespace ast; + +extern char* yytext; +void yyerror(const char *s); +int yylex(void); + +translation_unit *ast_root; + +%} + +%token IDENTIFIER CONSTANT STRING_LITERAL +%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP +%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN +%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN +%token XOR_ASSIGN OR_ASSIGN TYPE_NAME +%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64 +%token IF ELSE FOR +%token DEF + +%start translation_unit +%% + + +/* -------------------------- */ +/* Types */ +/* -------------------------- */ + +type_specifier + : VOID + | UINT8 | UINT16 | UINT32 | UINT64 + | INT8 | INT16 | INT32 | INT64 + | FP32 | FP64 + ; + +pointer + : '*' { $$ = new pointer_declarator(1); } + | '*' pointer { $$ = ((pointer_declarator*)$1)->inc(); } + +abstract_declarator + : pointer { $$ = $1; } + | direct_abstract_declarator { $$ = $1; } + | pointer direct_abstract_declarator { $$ = new compound_declarator($1, $2); } + ; + +direct_abstract_declarator + : '[' constant_list ']' { $$ = new tile_declarator($1); } + +constant : + CONSTANT { $$ = new constant(atoi(yytext)); } + ; + +constant_list + : constant { $$ = new list((constant*)$1); } + | constant_list ',' constant { $$ = append_ptr_list($1, $2); } + ; + +type_name + : type_specifier { $$ = new type((yytokentype)(size_t)$1, nullptr); } + | type_specifier abstract_declarator { $$ = new type((yytokentype)(size_t)$1, $2); } + ; + +/* -------------------------- */ +/* Expressions */ +/* -------------------------- */ + +identifier + : IDENTIFIER { $$ = new identifier(yytext); } + ; + +primary_expression + : identifier { $$ = $1; } + | constant { $$ = $1; } + | STRING_LITERAL { $$ = new string_literal(yytext); } + | '(' unary_expression ')' { $$ = $1; } + ; + +unary_expression + : primary_expression { $$ = $1; } + | INC_OP unary_expression { $$ = new unary_operator(INC_OP, $2); } + | DEC_OP unary_expression { $$ = new unary_operator(DEC_OP, $2); } + | unary_operator cast_expression { $$ = new unary_operator((yytokentype)(size_t)$1, $2); } + ; + +unary_operator + : '&' + | '*' + | '+' + | '-' + | '~' + | '!' + ; + +cast_expression + : unary_expression { $$ = $1; } + | '(' type_name ')' cast_expression { $$ = new cast_operator((yytokentype)(size_t)$1, $2); } + ; + +multiplicative_expression + : cast_expression { $$ = $1; } + | multiplicative_expression '*' cast_expression { $$ = new binary_operator('*', $1, $3); } + | multiplicative_expression '/' cast_expression { $$ = new binary_operator('/', $1, $3); } + | multiplicative_expression '%' cast_expression { $$ = new binary_operator('%', $1, $3); } + ; + +additive_expression + : multiplicative_expression { $$ = $1; } + | additive_expression '+' multiplicative_expression { $$ = new binary_operator('+', $1, $3); } + | additive_expression '-' multiplicative_expression { $$ = new binary_operator('-', $1, $3); } + ; + +shift_expression + : additive_expression { $$ = $1; } + | shift_expression LEFT_OP additive_expression { $$ = new binary_operator(LEFT_OP, $1, $3); } + | shift_expression RIGHT_OP additive_expression { $$ = new binary_operator(RIGHT_OP, $1, $3); } + ; + +relational_expression + : shift_expression { $$ = $1; } + | relational_expression '<' shift_expression { $$ = new binary_operator('<', $1, $3); } + | relational_expression '>' shift_expression { $$ = new binary_operator('>', $1, $3); } + | relational_expression LE_OP shift_expression { $$ = new binary_operator(LE_OP, $1, $3); } + | relational_expression GE_OP shift_expression { $$ = new binary_operator(GE_OP, $1, $3); } + ; + +equality_expression + : relational_expression { $$ = $1; } + | equality_expression EQ_OP relational_expression { $$ = new binary_operator(EQ_OP, $1, $3); } + | equality_expression NE_OP relational_expression { $$ = new binary_operator(NE_OP, $1, $3); } + ; + +and_expression + : equality_expression { $$ = $1; } + | and_expression '&' equality_expression { $$ = new binary_operator('&', $1, $3); } + ; + +exclusive_or_expression + : and_expression { $$ = $1; } + | exclusive_or_expression '^' and_expression { $$ = new binary_operator('^', $1, $3); } + ; + +inclusive_or_expression + : exclusive_or_expression { $$ = $1; } + | inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_operator('|', $1, $3); } + ; + +logical_and_expression + : inclusive_or_expression { $$ = $1; } + | logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_operator(AND_OP, $1, $3); } + ; + +logical_or_expression + : logical_and_expression { $$ = $1; } + | logical_or_expression OR_OP logical_and_expression { $$ = new binary_operator(OR_OP, $1, $3); } + ; + +conditional_expression + : logical_or_expression { $$ = $1; } + | logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $2, $3); } + ; + +assignment_operator + : '=' + | MUL_ASSIGN + | DIV_ASSIGN + | MOD_ASSIGN + | ADD_ASSIGN + | SUB_ASSIGN + | LEFT_ASSIGN + | RIGHT_ASSIGN + | AND_ASSIGN + | XOR_ASSIGN + | OR_ASSIGN + ; + + +assignment_expression + : conditional_expression { $$ = $1; } + | unary_expression assignment_operator assignment_expression { $$ = new assignment_expression($1, (yytokentype)(size_t)$2, $3); } + ; + +expression + : assignment_expression { $$ = $1; } + ; + +/* -------------------------- */ +/* Statements */ +/* -------------------------- */ + +statement + : compound_statement { $$ = $1; } + | expression_statement { $$ = $1; } + | selection_statement { $$ = $1; } + | iteration_statement { $$ = $1; } + ; + +compound_statement + : '{' '}' { $$ = new compound_statement(); } + | '{' statement_list '}' { $$ = $1; } + ; + +statement_list + : statement { $$ = new compound_statement($1); } + | statement_list statement { $$ = append_ptr_list($1, $2); } + ; + +expression_statement + : ';' { $$ = new no_op(); } + | expression ';' { $$ = $1; } + ; + +selection_statement + : IF '(' expression ')' statement { $$ = new selection_statement($1, $2); } + | IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($1, $2, $3); } + ; + +iteration_statement + : FOR '(' expression_statement expression_statement ')' statement { $$ = new iteration_statement($1, $2, NULL, $3); } + | FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($1, $2, $3, $3); } + ; + + +/* -------------------------- */ +/* Declarator */ +/* -------------------------- */ + + +direct_declarator + : identifier { $$ = $1; } + | direct_declarator '[' constant_list ']' { $$ = new tile_declarator($2); } + | direct_declarator '(' parameter_list ')' { $$ = new function_declarator($2); } + | direct_declarator '(' identifier_list ')' { $$ = new function_declarator($2); } + | direct_declarator '(' ')' { $$ = new function_declarator(nullptr); } + ; + +identifier_list + : identifier { $$ = new list((identifier*)$1); } + | identifier_list ',' identifier { $$ = append_ptr_list($1, $2); } + ; + +parameter_list + : parameter_declaration { $$ = new list((parameter*)$1); } + | parameter_list ',' parameter_declaration { $$ = append_ptr_list($1, $2); } + ; + +parameter_declaration + : declaration_specifiers declarator { $$ = new parameter((yytokentype)(size_t)$1, $2); } + | declaration_specifiers abstract_declarator { $$ = new parameter((yytokentype)(size_t)$1, $2); } + | declaration_specifiers { $$ = new parameter((yytokentype)(size_t)$1, nullptr); } + ; + + +declaration_specifiers + : type_specifier { $$ = $1; } + ; + +init_declarator_list + : init_declarator { $$ = new list((init_declarator*)$1); } + | init_declarator_list ',' init_declarator { $$ = append_ptr_list($1, $2); } + ; + +declaration + : declaration_specifiers ';' { $$ = new declaration($1, nullptr); } + | declaration_specifiers init_declarator_list ';' { $$ = new declaration($1, $2); } + ; + +declarator + : pointer direct_declarator { $$ = new compound_declarator($1, $2); } + | direct_declarator { $$ = $1; } + ; + +initializer + : assignment_expression { $$ = $1; } + | '{' constant '}' { $$ = $1; } + ; + +init_declarator + : declarator { $$ = new init_declarator($1, nullptr); } + | declarator '=' initializer { $$ = new init_declarator($1, $2); } + ; + +/* -------------------------- */ +/* Translation Unit */ +/* -------------------------- */ + +translation_unit + : external_declaration { $$ = new translation_unit($1); } + | translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); } + ; + +external_declaration + : function_definition { $$ = $1; } + | declaration { $$ = $1; } + ; + +function_definition + : declarator compound_statement { $$ = new function_definition($1, $2); } + ; + diff --git a/scanner.l b/scanner.l new file mode 100644 index 000000000..394cca7c4 --- /dev/null +++ b/scanner.l @@ -0,0 +1,128 @@ +D [0-9] +L [a-zA-Z_] +H [a-fA-F0-9] +E [Ee][+-]?{D}+ +FS (f|F|l|L) +IS (u|U|l|L)* + +%{ +#include +#include "parser.hpp" + +void count(); +int check_type(); +int comment(); + +%} + +%% +"def" { count(); return(DEF); } +"if" { count(); return(IF); } +"else" { count(); return(ELSE); } +"for" { count(); return(FOR); } +"void" { count(); return(VOID); } +"uint8" { count(); return(UINT8); } +"uint16" { count(); return(UINT16); } +"uint32" { count(); return(UINT32); } +"uint64" { count(); return(UINT64); } +"int8" { count(); return(INT8); } +"int16" { count(); return(INT16); } +"int32" { count(); return(INT32); } +"int64" { count(); return(INT64); } +"fp32" { count(); return(FP32); } +"fp64" { count(); return(FP64); } + +{L}({L}|{D})* { count(); return(check_type()); } + +0[xX]{H}+{IS}? { count(); return(CONSTANT); } +0{D}+{IS}? { count(); return(CONSTANT); } +{D}+{IS}? { count(); return(CONSTANT); } +L?'(\\.|[^\\'])+' { count(); return(CONSTANT); } + +{D}+{E}{FS}? { count(); return(CONSTANT); } +{D}*"."{D}+({E})?{FS}? { count(); return(CONSTANT); } +{D}+"."{D}*({E})?{FS}? { count(); return(CONSTANT); } + +L?\"(\\.|[^\\"])*\" { count(); return(STRING_LITERAL); } + +">>=" { count(); return(RIGHT_ASSIGN); } +"<<=" { count(); return(LEFT_ASSIGN); } +"+=" { count(); return(ADD_ASSIGN); } +"-=" { count(); return(SUB_ASSIGN); } +"*=" { count(); return(MUL_ASSIGN); } +"/=" { count(); return(DIV_ASSIGN); } +"%=" { count(); return(MOD_ASSIGN); } +"&=" { count(); return(AND_ASSIGN); } +"^=" { count(); return(XOR_ASSIGN); } +"|=" { count(); return(OR_ASSIGN); } +">>" { count(); return(RIGHT_OP); } +"<<" { count(); return(LEFT_OP); } +"++" { count(); return(INC_OP); } +"--" { count(); return(DEC_OP); } +"->" { count(); return(PTR_OP); } +"&&" { count(); return(AND_OP); } +"||" { count(); return(OR_OP); } +"<=" { count(); return(LE_OP); } +">=" { count(); return(GE_OP); } +"==" { count(); return(EQ_OP); } +"!=" { count(); return(NE_OP); } +";" { count(); return(';'); } +("{"|"<%") { count(); return('{'); } +("}"|"%>") { count(); return('}'); } +"," { count(); return(','); } +":" { count(); return(':'); } +"=" { count(); return('='); } +"(" { count(); return('('); } +")" { count(); return(')'); } +("["|"<:") { count(); return('['); } +("]"|":>") { count(); return(']'); } +"." { count(); return('.'); } +"&" { count(); return('&'); } +"!" { count(); return('!'); } +"~" { count(); return('~'); } +"-" { count(); return('-'); } +"+" { count(); return('+'); } +"*" { count(); return('*'); } +"/" { count(); return('/'); } +"%" { count(); return('%'); } +"<" { count(); return('<'); } +">" { count(); return('>'); } +"^" { count(); return('^'); } +"|" { count(); return('|'); } +"?" { count(); return('?'); } + +[ \t\v\n\f] { count(); } +. { /* ignore bad characters */ } + +%% + +int yywrap() +{ return(1); } + + +int column = 0; + +void count() +{ + int i; + + for (i = 0; yytext[i] != '\0'; i++) + if (yytext[i] == '\n') + column = 0; + else if (yytext[i] == '\t') + column += 8 - (column % 8); + else + column++; + + ECHO; +} + +void yyerror (const char *s) /* Called by yyparse on error */ +{ + printf ("Error: %s\n", s); +} + +int check_type() +{ + return(IDENTIFIER); +}