diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index a0a699711..fe9f5f21b 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -78,51 +78,51 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")"; std::string res = R"( - #define TM 128 - #define TN 128 - #define TK 32 +#define bool _Bool +#define true 1 +#define false 0 +#define __bool_true_false_are_defined 1 +extern int get_program_id(int); - #define bool _Bool - #define true 1 - #define false 0 - #define __bool_true_false_are_defined 1 +#define TN 128 +#define TK 32 - extern int get_program_id(int); +static const int TM = 128; - void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), - restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))), - restrict )" + c_ty + R"( * C __attribute__((aligned(16))), - int M, int N, int K, - int lda __attribute__((multiple_of(8))), - int ldb __attribute__((multiple_of(8))), - int ldc) { - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rxa[{TM}] = ridx * TM + 0 ... TM; - int ryb[{TN}] = ridy * TN + 0 ... TN; - int rka[{TK}] = 0 ... TK; - int rkb[{TK}] = 0 ... TK; - float xc[{)" + XCS + R"(}] = 0; - )" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; - )" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; - )" + a_ty + R"( a[{)" + AS + R"(}] = *pa; - )" + b_ty + R"( b[{)" + BS + R"(}] = *pb; - for(int k = K; k > 0; k = k - TK){ - xc = )" + usea + " @ " + useb + R"( + xc; - pa = pa + TK)" + lda0 + R"(; - pb = pb + TK)" + ldb0 + R"(; - a = *pa; - b = *pb; - } - int rxc[{TM}] = ridx * TM + (0 ... TM); - int ryc[{TN}] = ridy * TN + (0 ... TN); - )" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - )" + c_ty + R"( c[{TM, TN}] = xc; - bool checkc0[{TM}] = rxc < M; - bool checkc1[{TN}] = ryc < N; - bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :]; - *pc = c; +void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), + restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))), + restrict )" + c_ty + R"( * C __attribute__((aligned(16))), + int M, int N, int K, + int lda __attribute__((multiple_of(8))), + int ldb __attribute__((multiple_of(8))), + int ldc) { + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int rxa[{TM}] = ridx * TM + 0 ... TM; + int ryb[{TN}] = ridy * TN + 0 ... TN; + int rka[{TK}] = 0 ... TK; + int rkb[{TK}] = 0 ... TK; + float xc[{)" + XCS + R"(}] = 0; + )" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; + )" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; + )" + a_ty + R"( a[{)" + AS + R"(}] = *pa; + )" + b_ty + R"( b[{)" + BS + R"(}] = *pb; + for(int k = K; k > 0; k = k - TK){ + xc = )" + usea + " @ " + useb + R"( + xc; + pa = pa + TK)" + lda0 + R"(; + pb = pb + TK)" + ldb0 + R"(; + a = *pa; + b = *pb; } + int rxc[{TM}] = ridx * TM + (0 ... TM); + int ryc[{TN}] = ridy * TN + (0 ... TN); + )" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + )" + c_ty + R"( c[{TM, TN}] = xc; + bool checkc0[{TM}] = rxc < M; + bool checkc1[{TN}] = ryc < N; + bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :]; + *pc = c; +} )"; return res; diff --git a/include/triton/lang/wgtcc/parser.h b/include/triton/lang/wgtcc/parser.h index 92ed7c38e..eedaeb5e6 100644 --- a/include/triton/lang/wgtcc/parser.h +++ b/include/triton/lang/wgtcc/parser.h @@ -25,7 +25,7 @@ class Parser { friend class Generator; public: - explicit Parser(const TokenSequence& ts) + explicit Parser(TokenSequence& ts) : unit_(TranslationUnit::New()), ts_(ts), externalSymbols_(new Scope(nullptr, S_BLOCK)), @@ -224,7 +224,7 @@ private: // The root of the AST TranslationUnit* unit_; - TokenSequence ts_; + TokenSequence& ts_; // It is not the real scope, // It contains all external symbols(resolved and not resolved) diff --git a/include/triton/lang/wgtcc/type.h b/include/triton/lang/wgtcc/type.h index b43b74339..1cb10777b 100644 --- a/include/triton/lang/wgtcc/type.h +++ b/include/triton/lang/wgtcc/type.h @@ -223,7 +223,7 @@ public: virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); } virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; } virtual bool IsFloat() const { - return (tag_ & T_FLOAT) || (tag_ & T_DOUBLE); + return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE); } virtual bool IsBool() const { return tag_ & T_BOOL; } bool IsComplex() const { return tag_ & T_COMPLEX; } diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 2880a4e54..0a91984e2 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -20,6 +20,7 @@ #include "triton/codegen/transform/shmem/barriers.h" #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/vectorize.h" +#include "triton/lang/wgtcc/parser.h" namespace llvm { class Module; @@ -87,9 +88,9 @@ private: typedef std::pair cache_val_t; private: - triton::lang::translation_unit *make_ast(const char *src); - std::unique_ptr make_ir(triton::lang::translation_unit *program); - options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); + triton::lang::translation_unit *make_ast(const std::string &src); + std::unique_ptr make_ir(Parser &parser); + options autotune(Parser &parser, driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); std::unique_ptr make_bin(ir::module &function, driver::context *context, const options &opt); @@ -100,11 +101,12 @@ public: std::string make_tensorflow_src(const std::vector &outputs, const std::string ¯o); private: + TokenSequence ts_; + Parser parser_; // execution context ir::context ctx_; // program representations std::string src_; - lang::translation_unit *ast_; std::map cache_; }; diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 8e23959c0..7711b6d24 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -106,10 +106,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple file_type_t ft) { init_llvm(); // debug -// llvm::legacy::PassManager pm; -// pm.add(llvm::createPrintModulePass(llvm::outs())); -// pm.add(llvm::createVerifierPass()); -// pm.run(*module); + llvm::legacy::PassManager pm; + pm.add(llvm::createPrintModulePass(llvm::outs())); + pm.add(llvm::createVerifierPass()); + pm.run(*module); // create machine module->setTargetTriple(triple); std::string error; diff --git a/lib/lang/wgtcc/ast.cc b/lib/lang/wgtcc/ast.cc index 0a7327fa3..8cb029021 100644 --- a/lib/lang/wgtcc/ast.cc +++ b/lib/lang/wgtcc/ast.cc @@ -461,7 +461,10 @@ void BinaryOp::MatmulOpTypeChecking() { QualType retType = lhsType->Derived(); if(retType != rhsType->Derived()) Error(this, "matrix multiplication operands have incompatible data types"); - type_ = TileType::New(retShape, lhsType->Derived()); + ArithmType* ScalType = lhsType->ScalarType()->ToArithm(); + if(ScalType->Tag() & T_HALF) + ScalType = ArithmType::New(T_FLOAT); + type_ = TileType::New(retShape, ScalType); } void BinaryOp::ShiftOpTypeChecking() { diff --git a/lib/lang/wgtcc/code_gen.cc b/lib/lang/wgtcc/code_gen.cc index a234e8a47..d7188b2b1 100644 --- a/lib/lang/wgtcc/code_gen.cc +++ b/lib/lang/wgtcc/code_gen.cc @@ -29,7 +29,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { ir::value* lhs = ret_; // op info - auto type = binary->lhs_->Type(); + auto type = binary->lhs_->Type()->ScalarType(); auto flt = type->IsFloat(); auto sign = !type->IsUnsigned(); diff --git a/lib/lang/wgtcc/type.cc b/lib/lang/wgtcc/type.cc index c83ddd37d..25d5c56ce 100644 --- a/lib/lang/wgtcc/type.cc +++ b/lib/lang/wgtcc/type.cc @@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const { bool TileType::Compatible(const Type& other) const { // For two tile type to be compatible, // the element types must be compatible - // and they must have the same shape + // and they must have the same shapea auto otherTile = other.ToTile(); if(!otherTile) return false; diff --git a/lib/runtime/function.cpp b/lib/runtime/function.cpp index 55e0d5fc1..598ae146b 100644 --- a/lib/runtime/function.cpp +++ b/lib/runtime/function.cpp @@ -17,6 +17,7 @@ #include "triton/ir/function.h" #include "triton/tools/bench.hpp" #include "llvm/IR/Module.h" +#include "triton/ir/print.h" typedef struct yy_buffer_state * YY_BUFFER_STATE; @@ -117,50 +118,17 @@ void function::caller::operator ()(driver::stream *stream, const std::array function::make_ir(triton::lang::translation_unit *program) { +std::unique_ptr function::make_ir(Parser& parser) { // create Triton-IR from AST ir::module* module = new ir::module("", ctx_); - program->codegen(module); + Generator gen(&parser); + gen.Gen(module); return std::unique_ptr(module); } -options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector& args) { - std::unique_ptr ir = make_ir(ast); +options function::autotune(Parser& parser, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector& args) { + std::unique_ptr ir = make_ir(parser); // extract tunable values std::vector> values; for(auto it: ir->globals()) @@ -186,7 +154,7 @@ options function::autotune(lang::translation_unit *ast, driver::stream* stream, for(auto it: values) opt.params[it.first] = params[i++]; // make binary - auto ir = make_ir(ast); + auto ir = make_ir(parser); auto bin = make_bin(*ir, stream->context(), opt); // benchmark ir::function *tmp = ir->get_function_list()[0]; @@ -242,6 +210,8 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); + ir::print(module, std::cout); + exit(EXIT_FAILURE); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, llvm.get())); @@ -249,9 +219,11 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c } -function::function(const std::string &src): src_(src) { - // src -> ast - ast_ = make_ast(src_.c_str()); +function::function(const std::string &src): parser_(ts_), src_(src){ + Preprocessor cpp(&src_, true); + cpp.Process(ts_); + ts_.Print(); + parser_.Parse(); } void function::operator()(const std::vector& args, const grid_fn_ty& grid_fn, driver::stream *stream) { @@ -277,8 +249,8 @@ void function::operator()(const std::vector& args, const grid_fn_ty& grid_f } /* re-tune and re-compile */ - options opt = autotune(ast_, stream, grid_fn, args); - std::unique_ptr ir = make_ir(ast_); + options opt = autotune(parser_, stream, grid_fn, args); + std::unique_ptr ir = make_ir(parser_); std::unique_ptr bin = make_bin(*ir, stream->context(), opt); ir::function* fn = ir->get_function_list().front(); const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;