diff --git a/CMakeLists.txt b/CMakeLists.txt index b6681ecdc..5f65560c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++17") diff --git a/docs/conf.py b/docs/conf.py index 606659a14..5c343cec0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,6 +35,13 @@ extensions = [] # Math Jax extensions += ['sphinx.ext.mathjax'] +# Auto Doc +import sys +import os +sys.path.insert(0, os.path.abspath('../python/')) +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] +autosummary_generate = True + # Sphinx gallery extensions += ['sphinx_gallery.gen_gallery'] from sphinx_gallery.sorting import FileNameSortKey diff --git a/docs/index.rst b/docs/index.rst index b5f4bda31..c722d85bd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,15 +17,27 @@ Getting Started getting-started/installation getting-started/tutorials/index -Programming Guide +Language Reference +------------------- + +- Checkout the :doc:`Python API Documentation ` + + +.. toctree:: + :maxdepth: 1 + :caption: Language Reference + :hidden: + + language-reference/python-api/index + + +Going Further ------------------ Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs: - Chapter 1: :doc:`Introduction ` - Chapter 2: :doc:`Related Work ` -- Chapter 3: :doc:`The Triton-C Language ` -- Chapter 4: :doc:`The Triton-IR Intermediate Representation ` .. toctree:: :maxdepth: 1 @@ -33,6 +45,4 @@ Check out the following documents to learn more about Triton and how it compares :hidden: programming-guide/chapter-1/introduction - programming-guide/chapter-2/related-work - programming-guide/chapter-3/triton-c - programming-guide/chapter-4/triton-ir \ No newline at end of file + programming-guide/chapter-2/related-work \ No newline at end of file diff --git a/docs/language-reference/python-api/index.rst b/docs/language-reference/python-api/index.rst new file mode 100644 index 000000000..152b7bd2b --- /dev/null +++ b/docs/language-reference/python-api/index.rst @@ -0,0 +1,117 @@ +Python API +=========== + +.. currentmodule:: triton + + +Programming Model +------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + program_id + num_programs + + +Creation Ops +------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + arange + zeros + + +Shape Manipulation Ops +----------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + broadcast_to + reshape + ravel + + + +Linear Algebra Ops +------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + dot + +Memory Ops +-------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + load + store + atomic_cas + atomic_xchg + + +Indexing Ops +-------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + where + + +Math Ops +---------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + exp + log + sigmoid + softmax + + +Reduction Ops +--------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + max + min + sum + + +Comparison ops +--------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + minimum + maximum + + +Compiler Hint Ops +------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + multiple_of diff --git a/docs/programming-guide/chapter-3/triton-c.rst b/docs/programming-guide/chapter-3/triton-c.rst deleted file mode 100644 index 2281680ba..000000000 --- a/docs/programming-guide/chapter-3/triton-c.rst +++ /dev/null @@ -1,84 +0,0 @@ -======================= -The Triton-C Language -======================= - -In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model. - -------------------- -Differences with C -------------------- - -The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories: - -+++++++++++ -Extensions -+++++++++++ - -**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`). - -**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`. - -**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`). - -**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`. - -**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below: - -.. code-block:: C - :force: - - // create mask - bool mask[16, 16] = ...; - // conditional addition - float x[16, 16] = mask ? a + b : 0; - // conditional load - float y[16] 16] = mask ? *ptr : 0; - // conditional store - *?(mask)ptr = y; - \end{lstlisting} - - -+++++++++++++ -Restrictions -+++++++++++++ - -The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported: - -**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments. - -**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes. - -**While loops**: We just haven't had time to implement those yet. - ----------------- -Semantics ----------------- - -The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics. - -+++++++++++++++++++++++ -Broadcasting Semantics -+++++++++++++++++++++++ - - -Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy `_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done. - -.. code-block:: C - - int a[16], b[32, 16], c[16, 1]; - // a is first reshaped to [1, 16] - // and then broadcast to [32, 16] - int x_1[32, 16] = a[newaxis, :] + b; - // Same as above but implicitly - int x_2[32, 16] = a + b; - // a is first reshaped to [1, 16] - // a is broadcast to [16, 16] - // c is broadcast to [16, 16] - int y[16, 16] = a + c; - ------------------- -Programming Model ------------------- - -As discussed in the `CUDA documentation `_, The execution of CUDA code on GPUs is supported by an `SPMD `_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials. - diff --git a/docs/programming-guide/chapter-4/broadcast-1.png b/docs/programming-guide/chapter-4/broadcast-1.png deleted file mode 100644 index 4e8071651..000000000 Binary files a/docs/programming-guide/chapter-4/broadcast-1.png and /dev/null differ diff --git a/docs/programming-guide/chapter-4/broadcast-2.png b/docs/programming-guide/chapter-4/broadcast-2.png deleted file mode 100644 index 9e8fc6ee1..000000000 Binary files a/docs/programming-guide/chapter-4/broadcast-2.png and /dev/null differ diff --git a/docs/programming-guide/chapter-4/triton-ir.rst b/docs/programming-guide/chapter-4/triton-ir.rst deleted file mode 100644 index e729b127f..000000000 --- a/docs/programming-guide/chapter-4/triton-ir.rst +++ /dev/null @@ -1,82 +0,0 @@ -========================================== -The Triton-IR Intermediate Representation -========================================== - -Triton-IR is an LLVM-based Intermediate Representation (IR) whose purpose is to provide an environment suitable for block-level program analysis, transformation and optimization. -In our implementation, Triton-IR programs are constructed directly from Triton-C after parsing, but they could also be formed directly by higher-level DSLs in the future. -Triton-IR and LLVM-IR programs share the same high-level structure, but the former also includes a number of extensions necessary for block-level data-flow analysis. -These extensions are crucial for carrying out the optimizations outlined in the next chapter of this document. - ---------------------------------- -Structure of a Triton-IR Program ---------------------------------- - -++++++++ -Modules -++++++++ - -At the highest level, Triton-IR programs consist of one or multiple basic units of compilation known as *modules*. These modules are compiled independently from one another, and eventually aggregated by a linker whose role is to resolve forward declarations and adequately merge global definitions. Each module itself is composed of functions, global variables, constants and other miscellaneous symbols such as metadata and attributes. - -++++++++++ -Functions -++++++++++ - -Triton-IR function definitions consist of a return type, a name and a potentially empty arguments list. Additional visibility, alignment and linkage specifiers can be added if desired. Function attributes (such as inlining hints) and parameter attributes (such as "readonly", aliasing hints) can also be specified, allowing compiler backends to perform more aggressive optimizations by, for instance, making better use of non-coherent caches found on NVIDIA GPUs. This header is followed by a body composed of a list of basic blocks whose interdependencies form the Control Flow Graph (CFG) of the function. - -+++++++++++++ -Basic Blocks -+++++++++++++ - -Basic blocks are straight-line code sequences that may only contain so-called *terminator* instructions (i.e., branching, return) at their end. To simplify program analysis, Triton-IR uses the Static Single Assignment (SSA) form, meaning that each variable in each basic block must be (1) assigned to only once and (2) defined before being used. In so doing, each basic block implicitly defines a Data-Flow Graph (DFG). In our case, the SSA form is created directly from Triton-C's Abstract Syntax Trees (ASTs) using an algorithm from the literature [BRAUN13]_. - ---------------------------------- -Block-Level Dataflow Analysis ---------------------------------- - -+++++++ -Types -+++++++ - -Multi-dimensional blocks are at the center of data-flow analysis in Triton-JIT. They can be declared using syntax similar to vector declarations in LLVM-IR. For example, :code:`i32<8, 8>` is the type corresponding to :math:`8 \times 8` blocks of 32-bit integers. Note that there is no preprocessor in Triton-IR, hence parametric shape values must be resolved before programs are generated. In our case, this is done by Triton-JIT's auto-tuner. - -+++++++++++++ -Instructions -+++++++++++++ - -Triton-IR introduces a set of *reblocking* instructions whose purpose is to support broadcasting semantics as described in the previous chapter. The :code:`reshape` instruction creates a block of the specified shape using the raw data from its input argument. This is particularly useful to re-interpret variables as higher-dimensional arrays by padding their input shapes with ones in preparation for broadcasting. The :code:`broadcast` instruction creates a block of the specified shapes by replicating its input argument as many times as necessary along dimensions of size 1 -- as shown below for the :code:`broadcast<3,3>` instruction. - -|pic1| and |pic2| - -.. |pic1| image:: broadcast-1.png - :width: 40% - -.. |pic2| image:: broadcast-2.png - :width: 40% - -Usual scalar instructions (:code:`cmp`, :code:`getelementptr`, :code:`add`, :code:`load`...) were preserved and extended to signify element-wise operations when applicable. Finally, Triton-IR also exposes specialized arithmetic instructions for reductions (:code:`reduce`) and matrix multiplications (:code:`dot`). - ----------------------------------- -Block-Level Control Flow Analysis ----------------------------------- - -In Triton-IR, operations on block variables are atomic: they execute either in full or not at all. As a result, traditional control flow structures (e.g., conditional, loops) are not applicable to individual block elements. This is problematic, since a program may need to e.g., partially guard blocked loads against memory access violations. - -This could be potentially solved through the use of the Predicated SSA (PSSA) [CARTER99]_ [STOUTCHININ01]_ form for Triton-IR. However, this would create a lot of unnecessary complexity for GPUs, where the benefits of PSSA are close to none as divergent program paths within warps are serialized anyway. Therefore, recent versions of Triton handle intra-block control flow in a much simpler way, using conditional instructions such as :code:`select`, :code:`masked_load` and :code:`masked_store`: - -.. code-block:: C - - // For all indices [idx], return cond[idx] ? true_value[idx] : false_value[idx]; - select TYPE cond, true_value, false_value; - // For all indices [idx], return cond[idx] ? *true_addr[idx] : false_value[idx]; - masked_load TYPE cond, true_addr, false_value; - // For all indices [idx], execute *true_addr[idx] = true_value[idx] if cond[idx] - masked_store TYPE cond, true_addr, true_value; - - ------------- -References ------------- - -.. [BRAUN13] M. Braun et al., "Simple and Efficient Construction of Static Single Assignment Form", CC 2013 -.. [CARTER99] L. Carter et al., "Predicated Static Single Assignment", PACT 1999 -.. [STOUTCHININ01] A. Stoutchinin et al., "Efficient Static Single Assignment Form for Predication", MICRO 2001 diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 129c02bc6..c1b67372c 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -1,30 +1,31 @@ #ifndef _TRITON_CODEGEN_PASS_H_ #define _TRITON_CODEGEN_PASS_H_ -#include + +#include namespace triton{ namespace ir{ class module; } +namespace driver{ + class device; + class module; + class kernel; +} +} +namespace triton{ namespace codegen{ -class pass { -public: - virtual void run(ir::module& m); -}; +// TODO: +// There should be a proper pass manager there! +void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, + driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); -class pass_manager { -public: - void add(pass* p); - void run(ir::module& m); - -private: - std::list passes; -}; - } } + +#endif diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 1524f53e4..60c3933ab 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -119,7 +119,7 @@ public: void visit_exp_inst(ir::exp_inst*); void visit_log_inst(ir::log_inst*); void visit_get_program_id_inst(ir::get_program_id_inst*); - void visit_get_num_program_inst(ir::get_num_program_inst*); + void visit_get_num_programs_inst(ir::get_num_programs_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*); void visit_atomic_exch_inst(ir::atomic_exch_inst*); void visit_atomic_add_inst(ir::atomic_add_inst*); diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index df98d5eb2..b31cf6f8a 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -59,7 +59,7 @@ public: // CUDA class cu_module: public module { - std::string compile_llvm_module(std::unique_ptr module, driver::device* device); + std::string compile_llvm_module(llvm::Module* module, driver::device* device); void init_from_ptx(const std::string& ptx); public: diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 99c639cc0..1498a8e57 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef _TRITON_IR_BUILDER_H_ #define _TRITON_IR_BUILDER_H_ @@ -27,6 +27,8 @@ class builder{ public: // Constructor builder(context &ctx); + // Getters + const context& get_context() { return ctx_; } // Setters void set_insert_point(iterator instr); void set_insert_point(instruction* i); @@ -38,6 +40,9 @@ public: value *get_int1(bool val); value *get_int32(int32_t val); value *get_int64(int64_t val); + value *get_float16(float val); + value *get_float32(float val); + value *get_range(int32_t lo, int32_t hi); // Types type *get_void_ty(); type *get_int1_ty(); @@ -50,11 +55,10 @@ public: type *get_double_ty(); // Insert template - InstTy* insert(InstTy *inst, const std::string &name = ""){ + InstTy* insert(InstTy *inst){ assert(block_); block_->get_inst_list().insert(insert_point_, inst); inst->set_parent(block_); - inst->set_name(name); // for(ir::value* op: inst->ops()) // op->add_use(inst); return inst; @@ -64,91 +68,87 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); // Cast instructions - value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = ""); - value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = ""); - value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = ""); - value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = ""); - value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = ""); - value* create_fp_to_ui(value *src, type *dst_ty, const std::string &name = ""); - value* create_fp_ext(value *src, type *dst_ty, const std::string &name = ""); - value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = ""); - value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = ""); - value *create_downcast(value *arg, const std::string &name = ""); + value *create_cast(cast_op_t op, value *v, type *dst_ty); + value* create_ptr_to_int(value *src, type *dst_ty); + value* create_si_to_fp(value *src, type *dst_ty); + value* create_ui_to_fp(value *src, type *dst_ty); + value* create_fp_to_si(value *src, type *dst_ty); + value* create_fp_to_ui(value *src, type *dst_ty); + value* create_fp_ext(value *src, type *dst_ty); + value* create_fp_trunc(value *src, type *dst_ty); + value* create_int_cast(value *src, type *dst_ty, bool is_signed); + value *create_downcast(value *arg); // Phi instruction - phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = ""); + phi_node* create_phi(type *ty, unsigned num_reserved); // Binary instructions - value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw); - value *create_fmul(value *lhs, value *rhs, const std::string &name = ""); - value *create_fdiv(value *lhs, value *rhs, const std::string &name = ""); - value *create_frem(value *lhs, value *rhs, const std::string &name = ""); - value *create_fadd(value *lhs, value *rhs, const std::string &name = ""); - value *create_fsub(value *lhs, value *rhs, const std::string &name = ""); - value *create_mul(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); - value *create_sdiv(value *lhs, value *rhs, const std::string &name = ""); - value *create_udiv(value *lhs, value *rhs, const std::string &name = ""); - value *create_srem(value *lhs, value *rhs, const std::string &name = ""); - value *create_urem(value *lhs, value *rhs, const std::string &name = ""); - value *create_add(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); - value *create_sub(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); - value *create_shl(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); - value *create_lshr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); - value *create_ashr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false); + value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw); + value *create_fmul(value *lhs, value *rhs); + value *create_fdiv(value *lhs, value *rhs); + value *create_frem(value *lhs, value *rhs); + value *create_fadd(value *lhs, value *rhs); + value *create_fsub(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); + value *create_sdiv(value *lhs, value *rhs); + value *create_udiv(value *lhs, value *rhs); + value *create_srem(value *lhs, value *rhs); + value *create_urem(value *lhs, value *rhs); + value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); + value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); + value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); + value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); + value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); // GEP - value *create_gep(value *ptr, const std::vector& idx_list, const std::string &name = ""); + value *create_gep(value *ptr, const std::vector& idx_list); // Comparison (int) - value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpSGT(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpULE(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpULT(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpUGE(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpUGT(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = ""); - value *create_icmpNE(value *lhs, value *rhs, const std::string &name = ""); + value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs); + value *create_icmpSLE(value *lhs, value *rhs); + value *create_icmpSLT(value *lhs, value *rhs); + value *create_icmpSGE(value *lhs, value *rhs); + value *create_icmpSGT(value *lhs, value *rhs); + value *create_icmpULE(value *lhs, value *rhs); + value *create_icmpULT(value *lhs, value *rhs); + value *create_icmpUGE(value *lhs, value *rhs); + value *create_icmpUGT(value *lhs, value *rhs); + value *create_icmpEQ(value *lhs, value *rhs); + value *create_icmpNE(value *lhs, value *rhs); // Comparison (float) - value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpOGE(value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpOEQ(value *lhs, value *rhs, const std::string &name = ""); - value *create_fcmpONE(value *lhs, value *rhs, const std::string &name = ""); + value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs); + value *create_fcmpOLT(value *lhs, value *rhs); + value *create_fcmpOGT(value *lhs, value *rhs); + value *create_fcmpOLE(value *lhs, value *rhs); + value *create_fcmpOGE(value *lhs, value *rhs); + value *create_fcmpOEQ(value *lhs, value *rhs); + value *create_fcmpONE(value *lhs, value *rhs); // Logical - value *create_and(value *lhs, value *rhs, const std::string &name = ""); - value *create_xor(value *lhs, value *rhs, const std::string &name = ""); - value *create_or(value *lhs, value *rhs, const std::string &name = ""); - // Unary -// value *create_fneg(value *arg, const std::string &name = ""); -// value *create_neg(value *arg, const std::string &name = ""); -// value *create_not(value *arg, const std::string &name = ""); + value *create_and(value *lhs, value *rhs); + value *create_xor(value *lhs, value *rhs); + value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg, const std::string &name = ""); - value *create_store(value *ptr, value *val, const std::string &name = ""); - value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = ""); - value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = ""); - // Tile instruction - value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); - value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); - value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); + value *create_load(value *arg); + value *create_store(value *ptr, value *val); + value *create_masked_load(value *arg, value *mask, value *false_value); + value *create_masked_store(value *ptr, value *val, value *mask); + // Block instruction + value *create_splat(value *arg, const type::block_shapes_t &shapes); + value *create_reshape(value *arg, const type::block_shapes_t &shapes); + value *create_broadcast(value *arg, const type::block_shapes_t &shapes); // Built-in instruction - value *create_get_program_id(unsigned axis, const std::string &name = ""); - value *create_get_num_program(unsigned axis, const std::string &name = ""); - value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); - value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); - value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); - value *create_exp(value* arg, const std::string &name = ""); - value *create_log(value* arg, const std::string &name = ""); - value *create_dot(value *A, value *B, value *C, const std::string &name = ""); - value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); - value *create_sqrt(value *A, const std::string &name = ""); - value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = ""); - value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); + value *create_get_program_id(unsigned axis); + value *create_get_num_programs(unsigned axis); + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_exch(value *ptr, value *val); + value *create_atomic_add(value *ptr, value *val, value *msk); + value *create_exp(value* arg); + value *create_log(value* arg); + value *create_dot(value *A, value *B, value *C); + value *create_trans(value *A, const std::vector &perm = {}); + value *create_sqrt(value *A); + value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); + value *create_select(value *pred, value *if_value, value *else_value); // Intrinsics - value *create_copy_to_shared(value *arg, const std::string &name = ""); - value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = ""); - value *create_copy_from_shared(value *arg, const std::string &name = ""); + value *create_copy_to_shared(value *arg); + value *create_masked_load_async(value *arg, value *mask, value *false_value); + value *create_copy_from_shared(value *arg); value *create_barrier(const std::string &name = ""); value *create_async_wait(int N); @@ -158,6 +158,7 @@ private: iterator insert_point_; }; + } } diff --git a/include/triton/ir/context.h b/include/triton/ir/context.h index 83627e869..55edf31cd 100644 --- a/include/triton/ir/context.h +++ b/include/triton/ir/context.h @@ -9,6 +9,7 @@ namespace triton{ namespace ir{ +class builder; class type; class context_impl; @@ -16,8 +17,11 @@ class context_impl; class context { public: context(); + context(const context&) = delete; + context& operator=(const context&) = delete; public: + ir::builder* builder = nullptr; std::shared_ptr p_impl; }; diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index a016d1add..3db225a37 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -28,13 +28,16 @@ public: integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; // Pointer types std::map, pointer_type*> ptr_tys; - std::map, tile_type*> tile_tys; + // Block types + std::map, block_type*> block_tys; + // Int constants std::map, constant_int*> int_constants_; // Float constants std::map, constant_fp*> fp_constants_; // undef values std::map uv_constants_; + }; } diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h new file mode 100644 index 000000000..c2e5ac320 --- /dev/null +++ b/include/triton/ir/dispatch.h @@ -0,0 +1,97 @@ +#pragma once + +#ifndef _TRITON_IR_DISPATCH_H_ +#define _TRITON_IR_DISPATCH_H_ + +#include "triton/ir/builder.h" + +namespace triton{ +namespace ir{ + + +/*---------------------------------------------- + higher level functions that follow the likely + semantics of most expected frontends + ----------------------------------------------*/ + +struct semantic_error: public std::runtime_error { + semantic_error(const std::string& msg): + std::runtime_error(msg) { } +}; + +struct dispatch{ + typedef ir::type::block_shapes_t shape_t; + + + // programming model + static ir::value *program_id(int axis, ir::builder *builder); + static ir::value *num_programs(int axis, ir::builder *builder); + + // binary operators + static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); + + // unary operators + static ir::value *plus(ir::value *input, ir::builder *builder); + static ir::value *minus(ir::value *input, ir::builder *builder); + static ir::value *invert(ir::value *input, ir::builder *builder); + + // comparison operators + static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); + + // block creation + static ir::value* arange(int start, int end, ir::builder *builder); + static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); + + + // casting ops + static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); + static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); + static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); + static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); + + // memory operators + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder); + static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); + static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); + static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder); + + // linear algebra + static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); + + // indexing + static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); + + // reduction + static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); + + // math + static ir::value *exp(ir::value *x, ir::builder *builder); + static ir::value *log(ir::value *x, ir::builder *builder); + static ir::value *sqrt(ir::value *x, ir::builder *builder); + + // internal (debug/optimization) + static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); + static ir::value *debug_barrier(ir::builder *builder); +}; + +} +} + +#endif diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 2df493d01..2a944fbb5 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -35,7 +35,7 @@ private: /* Attribute */ enum attribute_kind_t { - readonly, + readonly = 0, writeonly, noalias, aligned, @@ -71,7 +71,7 @@ public: case writeonly: return ".writeonly"; case noalias: return ".noalias"; case aligned: return ".aligned(" + std::to_string(value_) + ")"; - case multiple_of: return ".readonly"; + case multiple_of: return ".multipleof(" + std::to_string(value_) + ")"; case retune: return ".retunr"; default: break; } @@ -102,7 +102,7 @@ private: public: // accessors - const args_t &args() { return args_; } + const args_t &args() const { return args_; } function_type* get_fn_type() { return fn_ty_; } // factory methods diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 6971a751b..89598ef61 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -514,7 +514,7 @@ public: class retile_inst: public unary_inst { protected: - retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next); + retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next); }; // reshape @@ -525,7 +525,7 @@ private: std::string repr_impl() const { return "reshape"; } public: - static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, + static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(reshape_inst) _TRITON_DEFINE_ACCEPT(reshape_inst) @@ -539,7 +539,7 @@ private: std::string repr_impl() const { return "splat"; } public: - static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, + static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(splat_inst) _TRITON_DEFINE_ACCEPT(splat_inst) @@ -553,7 +553,7 @@ private: std::string repr_impl() const { return "broadcast"; } public: - static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, + static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(broadcast_inst) _TRITON_DEFINE_ACCEPT(broadcast_inst) @@ -597,16 +597,16 @@ private: unsigned axis_; }; -class get_num_program_inst: public builtin_inst { +class get_num_programs_inst: public builtin_inst { private: - get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next); - std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; } + get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next); + std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; } public: static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } - _TRITON_DEFINE_CLONE(get_num_program_inst) - _TRITON_DEFINE_ACCEPT(get_num_program_inst) + _TRITON_DEFINE_CLONE(get_num_programs_inst) + _TRITON_DEFINE_ACCEPT(get_num_programs_inst) private: unsigned axis_; diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 0d9c625f1..7e4a08209 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -36,6 +36,11 @@ class alloc_const; /* Module */ struct scope { +public: + const std::map& get_values() { return values; } + void set_type(const std::string& name, ir::type* ty) { types[name] = ty; } + ir::type* get_type(const std::string& name) { return types.at(name); } +private: std::map types; std::map values; }; @@ -61,8 +66,7 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name); - context& get_context(); + module(const std::string &name, builder& builder); builder& get_builder(); // Setters void set_value(const std::string& name, basic_block* block, value *x); @@ -95,8 +99,7 @@ public: private: std::string name_; - context context_; - builder builder_; + builder& builder_; std::map values_; std::map types_; std::set const_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 05fb795c4..27e6acef0 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -18,7 +18,7 @@ class constant_int; /* Type */ class type { public: - typedef std::vector tile_shapes_t; + typedef std::vector block_shapes_t; protected: typedef std::vector contained_tys_vec_t; @@ -43,7 +43,7 @@ public: FunctionTyID, ///< 11: Functions PointerTyID, ///< 12: Pointers StructTyID, ///< 13: Struct - TileTyID, ///< 14: Tile + BlockTyID, ///< 14: Tile }; public: @@ -62,7 +62,7 @@ public: unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; - const tile_shapes_t& get_tile_shapes() const; + block_shapes_t get_block_shapes() const; const size_t get_tile_rank() const; const size_t get_tile_ranks1() const; unsigned get_tile_num_elements() const; @@ -83,7 +83,7 @@ public: get_integer_bitwidth() == bitwidth;} bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } - bool is_tile_ty() const { return id_ == TileTyID; } + bool is_block_ty() const { return id_ == BlockTyID; } // Composite predicates bool is_int_or_tileint_ty(); @@ -110,7 +110,7 @@ public: // repr std::string tile_repr() const { std::string res = get_tile_element_ty()->repr(); - auto shapes = get_tile_shapes(); + auto shapes = get_block_shapes(); res += "<"; for(size_t i = 0; i < shapes.size(); i++){ if(i > 0) @@ -137,7 +137,7 @@ public: case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; - case TileTyID: return tile_repr(); + case BlockTyID: return tile_repr(); default: break; } assert(false); @@ -180,23 +180,23 @@ public: type* get_type_at_index(value *idx) const; }; -class tile_type: public composite_type { +class block_type: public composite_type { private: - tile_type(type *ty, const tile_shapes_t &shapes); + block_type(type *ty, const block_shapes_t &shapes); static bool is_valid_elt_ty(type *ty); public: // accessors - const tile_shapes_t& get_shapes() const { return shapes_; } + const block_shapes_t& get_shapes() const { return shapes_; } unsigned get_num_elements() const; unsigned get_bitwidth() const; // factory methods - static tile_type* get(type *ty, const tile_shapes_t &shapes); - static tile_type* get_same_shapes(type *ty, type *ref); + static block_type* get(type *ty, const block_shapes_t &shapes); + static block_type* get_same_shapes(type *ty, type *ref); private: - tile_shapes_t shapes_; + block_shapes_t shapes_; }; class pointer_type: public type { diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index a54e5edfc..7547c749f 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -52,7 +52,7 @@ class exp_inst; class log_inst; class get_program_id_inst; -class get_num_program_inst; +class get_num_programs_inst; class atomic_cas_inst; class atomic_exch_inst; class atomic_add_inst; @@ -128,7 +128,7 @@ public: virtual void visit_downcast_inst(downcast_inst*) = 0; virtual void visit_get_program_id_inst(get_program_id_inst*) = 0; - virtual void visit_get_num_program_inst(get_num_program_inst*) = 0; + virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0; virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0; virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0; virtual void visit_atomic_add_inst(atomic_add_inst*) = 0; diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h deleted file mode 100644 index 0f57d86cc..000000000 --- a/include/triton/lang/ast.h +++ /dev/null @@ -1,823 +0,0 @@ -#pragma once - -#ifndef _WGTCC_AST_H_ -#define _WGTCC_AST_H_ - -#include "error.h" -#include "token.h" -#include "type.h" - -#include -#include -#include -#include - - -class Visitor; -template class Evaluator; -class AddrEvaluator; -class Generator; - -class Scope; -class Parser; -class ASTNode; -class Token; -class TokenSequence; - -// Expressions -class Expr; -class BinaryOp; -class UnaryOp; -class ConditionalOp; -class FuncCall; -class TempVar; -class Constant; - -class Identifier; -class Object; -struct Initializer; -class Declaration; -class Enumerator; - -// Statements -class Stmt; -class IfStmt; -class ForStmt; -class JumpStmt; -class LabelStmt; -class EmptyStmt; -class CompoundStmt; -class FuncDef; -class TranslationUnit; - - -/* - * AST Node - */ - -class ASTNode { -public: - struct Attr{ - - enum KindT{ - MULTIPLEOF, - ALIGNED, - NOALIAS, - READONLY, - WRITEONLY, - RETUNE, - }; - - KindT kind; - std::vector vals; - }; - using AttrList = std::vector; - -public: - virtual ~ASTNode() {} - virtual void Accept(Visitor* v) = 0; - -protected: - ASTNode() {} - - MemPool* pool_ {nullptr}; -}; - -using ExtDecl = ASTNode; - - -/* - * Statements - */ - -class Stmt : public ASTNode { -public: - virtual ~Stmt() {} - -protected: - Stmt() {} -}; - - -class EmptyStmt : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static EmptyStmt* New(); - virtual ~EmptyStmt() {} - virtual void Accept(Visitor* v); - -protected: - EmptyStmt() {} -}; - - -class LabelStmt : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static LabelStmt* New(); - ~LabelStmt() {} - virtual void Accept(Visitor* v); - std::string Repr() const { return ".L" + std::to_string(tag_); } - -protected: - LabelStmt(): tag_(GenTag()) {} - -private: - static int GenTag() { - static int tag = 0; - return ++tag; - } - - int tag_; // 使用整型的tag值,而不直接用字符串 -}; - - -class IfStmt : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; -public: - static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr); - virtual ~IfStmt() {} - virtual void Accept(Visitor* v); - -protected: - IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr) - : cond_(cond), then_(then), else_(els) {} - -private: - Expr* cond_; - Stmt* then_; - Stmt* else_; -}; - -class ForStmt: public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; -public: - static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr); - virtual ~ForStmt() {} - virtual void Accept(Visitor* v); - -protected: - ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr) - : body_(body), init_(init), cond_(cond), step_(step) {} - -private: - Stmt* body_; - Stmt* init_; - Expr* cond_; - Expr* step_; -}; - -class JumpStmt : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static JumpStmt* New(LabelStmt* label); - virtual ~JumpStmt() {} - virtual void Accept(Visitor* v); - void SetLabel(LabelStmt* label) { label_ = label; } - -protected: - JumpStmt(LabelStmt* label): label_(label) {} - -private: - LabelStmt* label_; -}; - - -class ReturnStmt: public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static ReturnStmt* New(Expr* expr); - virtual ~ReturnStmt() {} - virtual void Accept(Visitor* v); - -protected: - ReturnStmt(::Expr* expr): expr_(expr) {} - -private: - ::Expr* expr_; -}; - - -using StmtList = std::list; - -class CompoundStmt : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr); - virtual ~CompoundStmt() {} - virtual void Accept(Visitor* v); - StmtList& Stmts() { return stmts_; } - ::Scope* Scope() { return scope_; } - -protected: - CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr) - : stmts_(stmts), scope_(scope) {} - -private: - StmtList stmts_; - ::Scope* scope_; -}; - - -struct Initializer { - Initializer(Type* type, - int offset, - Expr* expr, - unsigned char bitFieldBegin=0, - unsigned char bitFieldWidth=0) - : type_(type), - offset_(offset), - bitFieldBegin_(bitFieldBegin), - bitFieldWidth_(bitFieldWidth), - expr_(expr) {} - - bool operator<(const Initializer& rhs) const; - - // It could be the object it self or, it will be the member - // that was initialized - Type* type_; - int offset_; - unsigned char bitFieldBegin_; - unsigned char bitFieldWidth_; - - Expr* expr_; -}; - - -using InitList = std::set; - -class Declaration: public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static Declaration* New(Object* obj); - virtual ~Declaration() {} - virtual void Accept(Visitor* v); - InitList& Inits() { return inits_; } - Object* Obj() { return obj_; } - void AddInit(Initializer init); - -protected: - Declaration(Object* obj): obj_(obj) {} - - Object* obj_; - InitList inits_; -}; - - -/* - * Expr - * BinaryOp - * UnaryOp - * ConditionalOp - * FuncCall - * Constant - * Identifier - * Object - * TempVar - */ - -class Expr : public Stmt { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - friend class LValAssigner; - -public: - virtual ~Expr() {} - ::Type* Type() { return type_.GetPtr(); } - virtual bool IsLVal() = 0; - virtual void TypeChecking() = 0; - void EnsureCompatible(const QualType lhs, const QualType rhs) const; - void EnsureCompatibleOrVoidPointer(const QualType lhs, - const QualType rhs) const; - const Token* Tok() const { return tok_; } - void SetTok(const Token* tok) { tok_ = tok; } - - static Expr* MayCast(Expr* expr); - static Expr* MayCast(Expr* expr, QualType desType); - static ::Type* TryExtractScalarType(Expr* loc, Expr *operand); - static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty); - - virtual bool IsNullPointerConstant() const { return false; } - bool IsConstQualified() const { return type_.IsConstQualified(); } - bool IsRestrictQualified() const { return type_.IsRestrictQualified(); } - bool IsVolatileQualified() const { return type_.IsVolatileQualified(); } - -protected: - // You can construct a expression without specifying a type, - // then the type should be evaluated in TypeChecking() - Expr(const Token* tok, QualType type): tok_(tok), type_(type) {} - - const Token* tok_; - QualType type_; -}; - - -/* - * '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^' - * '=',(复合赋值运算符被拆分为两个运算) - * '==', '!=', '<=', '>=', - * '&&', '||' - * '['(下标运算符), '.'(成员运算符) - * ','(逗号运算符), - */ -class BinaryOp : public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - friend class LValAssigner; - friend class Declaration; - -public: - static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs); - static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs); - virtual ~BinaryOp() {} - virtual void Accept(Visitor* v); - - // Member ref operator is a lvalue - virtual bool IsLVal() { - switch (op_) { - case '.': return !Type()->ToArray() && lhs_->IsLVal(); - case ']': return !Type()->ToArray(); - case Token::MASKED_DEREF: return true; - default: return false; - } - } - ArithmType* Convert(); - static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type); - - virtual void TypeChecking(); - void SubScriptingOpTypeChecking(); - void MemberRefOpTypeChecking(); - void MultiOpTypeChecking(); - void AdditiveOpTypeChecking(); - void ShiftOpTypeChecking(); - void RangeOpTypeChecking(); - void MatmulOpTypeChecking(); - void MaskedDerefOpTypeChecking(); - void RelationalOpTypeChecking(); - void EqualityOpTypeChecking(); - void BitwiseOpTypeChecking(); - void LogicalOpTypeChecking(); - void AssignOpTypeChecking(); - void CommaOpTypeChecking(); - -protected: - BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs) - : Expr(tok, nullptr), op_(op) { - lhs_ = lhs, rhs_ = rhs; - if (op != '.') { - lhs_ = MayCast(lhs); - rhs_ = MayCast(rhs); - } - } - - int op_; - Expr* lhs_; - Expr* rhs_; -}; - - -/* - * Unary Operator: - * '++' (prefix/postfix) - * '--' (prefix/postfix) - * '&' (ADDR) - * '*' (DEREF) - * '+' (PLUS) - * '-' (MINUS) - * '~' - * '!' - * CAST // like (int)3 - */ -class UnaryOp : public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - friend class LValAssigner; - -public: - static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0); - virtual ~UnaryOp() {} - virtual void Accept(Visitor* v); - virtual bool IsLVal(); - ::Type *Convert(); - static int encodeRed(int ax, int tag); - static void decodeRed(int info, int& ax, int& tag); - void TypeChecking(); - void IncDecOpTypeChecking(); - void AddrOpTypeChecking(); - void DerefOpTypeChecking(); - void ReduceOpTypeChecking(); - void UnaryArithmOpTypeChecking(); - void BitcastOpTypeChecking(); - void CastOpTypeChecking(); - void IntrinsicOpTypeChecking(); - -protected: - UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0) - : Expr(operand->Tok(), type), op_(op), info_(info) { - operand_ = operand; - if (op_ != Token::CAST && op_ != Token::ADDR) { - operand_ = MayCast(operand); - } - } - - int op_; - int info_; - Expr* operand_; -}; - -class TransOp: public Expr { - friend class Generator; - -public: - using PermInt = std::vector; - -public: - static TransOp* New(const PermInt& perm, Expr* operand); - const PermInt& getPerm() const { return perm_; } - void Accept(Visitor* v); - bool IsLVal() { return false; } - void TypeChecking(); - -protected: - TransOp(const PermInt& perm, Expr* operand) - : Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {} - -private: - Expr* operand_; - PermInt perm_; -}; - - -// cond ? true : false -class ConditionalOp : public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static ConditionalOp* New(const Token* tok, - Expr* cond, Expr* exprTrue, Expr* exprFalse); - virtual ~ConditionalOp() {} - virtual void Accept(Visitor* v); - virtual bool IsLVal() { return false; } - ArithmType* Convert(); - virtual void TypeChecking(); - -protected: - ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse) - : Expr(cond->Tok(), nullptr), cond_(MayCast(cond)), - exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {} - -private: - Expr* cond_; - Expr* exprTrue_; - Expr* exprFalse_; -}; - - -class FuncCall : public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - using ArgList = std::vector; - -public: - static FuncCall* New(Expr* designator, const ArgList& args); - ~FuncCall() {} - virtual void Accept(Visitor* v); - - // A function call is ofcourse not lvalue - virtual bool IsLVal() { return false; } - ArgList* Args() { return &args_; } - Expr* Designator() { return designator_; } - const std::string& Name() const { return tok_->str_; } - ::FuncType* FuncType() { return designator_->Type()->ToFunc(); } - virtual void TypeChecking(); - -protected: - FuncCall(Expr* designator, const ArgList& args) - : Expr(designator->Tok(), nullptr), - designator_(designator), args_(args) {} - - Expr* designator_; - ArgList args_; -}; - - -class Constant: public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static Constant* New(const Token* tok, int tag, long val); - static Constant* New(const Token* tok, int tag, double val); - static Constant* New(const Token* tok, int tag, const std::string* val); - ~Constant() {} - virtual void Accept(Visitor* v); - virtual bool IsLVal() { return false; } - virtual void TypeChecking() {} - - long IVal() const { return ival_; } - double FVal() const { return fval_; } - const std::string* SVal() const { return sval_; } - std::string SValRepr() const; - std::string Repr() const { return std::string(".LC") + std::to_string(id_); } - -protected: - Constant(const Token* tok, QualType type, long val) - : Expr(tok, type), ival_(val) {} - Constant(const Token* tok, QualType type, double val) - : Expr(tok, type), fval_(val) {} - Constant(const Token* tok, QualType type, const std::string* val) - : Expr(tok, type), sval_(val) {} - - union { - long ival_; - double fval_; - struct { - long id_; - const std::string* sval_; - }; - }; -}; - - -class TempVar : public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static TempVar* New(QualType type); - virtual ~TempVar() {} - virtual void Accept(Visitor* v); - virtual bool IsLVal() { return true; } - virtual void TypeChecking() {} - -protected: - TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {} - -private: - static int GenTag() { - static int tag = 0; - return ++tag; - } - - int tag_; -}; - - -enum Linkage { - L_NONE, - L_EXTERNAL, - L_INTERNAL, -}; - - -class Identifier: public Expr { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - friend class LValAssigner; - -public: - static Identifier* New(const Token* tok, QualType type, Linkage linkage, const AttrList& attrList={}); - virtual ~Identifier() {} - virtual void Accept(Visitor* v); - virtual bool IsLVal() { return false; } - virtual Object* ToObject() { return nullptr; } - virtual Enumerator* ToEnumerator() { return nullptr; } - - // An identifer can be: - // object, sturct/union/enum tag, typedef name, function, label. - Identifier* ToTypeName() { - // A typename has no linkage - // And a function has external or internal linkage - if (ToObject() || ToEnumerator() || linkage_ != L_NONE) - return nullptr; - return this; - } - virtual const std::string Name() const { return tok_->str_; } - enum Linkage Linkage() const { return linkage_; } - void SetLinkage(enum Linkage linkage) { linkage_ = linkage; } - virtual void TypeChecking() {} - -protected: - Identifier(const Token* tok, QualType type, enum Linkage linkage, const AttrList& attrList={}) - : Expr(tok, type), linkage_(linkage), attrList_(attrList) {} - - // An identifier has property linkage - enum Linkage linkage_; - AttrList attrList_; -}; - - -class Enumerator: public Identifier { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static Enumerator* New(const Token* tok, int val); - virtual ~Enumerator() {} - virtual void Accept(Visitor* v); - virtual Enumerator* ToEnumerator() { return this; } - int Val() const { return cons_->IVal(); } - -protected: - Enumerator(const Token* tok, int val) - : Identifier(tok, ArithmType::New(T_INT), L_NONE), - cons_(Constant::New(tok, T_INT, (long)val)) {} - - Constant* cons_; -}; - - -class Object : public Identifier { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - friend class LValAssigner; - -public: - static Object* New(const Token* tok, - QualType type, - int storage=0, - enum Linkage linkage=L_NONE, - unsigned char bitFieldBegin=0, - unsigned char bitFieldWidth=0, - const AttrList& attrList={}); - static Object* NewAnony(const Token* tok, - QualType type, - int storage=0, - enum Linkage linkage=L_NONE, - unsigned char bitFieldBegin=0, - unsigned char bitFieldWidth=0, - const AttrList& attrList={}); - ~Object() {} - virtual void Accept(Visitor* v); - virtual Object* ToObject() { return this; } - virtual bool IsLVal() { - // TODO(wgtdkp): not all object is lval? - return true; - } - bool IsStatic() const { - return (Storage() & S_STATIC) || (Linkage() != L_NONE); - } - int Storage() const { return storage_; } - void SetStorage(int storage) { storage_ = storage; } - int Align() const { return align_; } - void SetAlign(int align) { - assert(align > 0); - // Allowing reduce alignment to implement __attribute__((packed)) - //if (align < align_) - // Error(this, "alignment specifier cannot reduce alignment"); - align_ = align; - } - int Offset() const { return offset_; } - void SetOffset(int offset) { offset_ = offset; } - Declaration* Decl() { return decl_; } - void SetDecl(Declaration* decl) { decl_ = decl; } - const AttrList& GetAttrList() const { return attrList_; } - unsigned char BitFieldBegin() const { return bitFieldBegin_; } - unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; } - unsigned char BitFieldWidth() const { return bitFieldWidth_; } - static unsigned long BitFieldMask(Object* bitField) { - return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_); - } - static unsigned long BitFieldMask(unsigned char begin, unsigned char width) { - auto end = begin + width; - return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin; - } - - bool HasInit() const { return decl_ && decl_->Inits().size(); } - bool Anonymous() const { return anonymous_; } - virtual const std::string Name() const { return Identifier::Name(); } - std::string Repr() const { - assert(IsStatic() || anonymous_); - if (anonymous_) - return "anonymous." + std::to_string(id_); - if (linkage_ == L_NONE) - return Name() + "." + std::to_string(id_); - return Name(); - } - -protected: - Object(const Token* tok, - QualType type, - int storage=0, - enum Linkage linkage=L_NONE, - unsigned char bitFieldBegin=0, - unsigned char bitFieldWidth=0, - const AttrList& attrList={}) - : Identifier(tok, type, linkage), - storage_(storage), - offset_(0), - align_(type->Align()), - decl_(nullptr), - bitFieldBegin_(bitFieldBegin), - bitFieldWidth_(bitFieldWidth), - anonymous_(false), - attrList_(attrList){} - -private: - int storage_; - int offset_; - int align_; - - Declaration* decl_; - - unsigned char bitFieldBegin_; - // 0 means it's not a bitfield - unsigned char bitFieldWidth_; - - bool anonymous_; - long id_ {0}; - - AttrList attrList_; -}; - - -/* - * Declaration - */ - -class FuncDef : public ExtDecl { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - using ParamList = std::vector; - -public: - static FuncDef* New(Identifier* ident, LabelStmt* retLabel); - virtual ~FuncDef() {} - virtual void Accept(Visitor* v); - ::FuncType* FuncType() { return ident_->Type()->ToFunc(); } - CompoundStmt* Body() { return body_; } - void SetBody(CompoundStmt* body) { body_ = body; } - std::string Name() const { return ident_->Name(); } - enum Linkage Linkage() { return ident_->Linkage(); } - -protected: - FuncDef(Identifier* ident, LabelStmt* retLabel) - : ident_(ident), retLabel_(retLabel) {} - -private: - Identifier* ident_; - LabelStmt* retLabel_; - CompoundStmt* body_; -}; - - -using ExtDeclList = std::list; - -class TranslationUnit : public ASTNode { - template friend class Evaluator; - friend class AddrEvaluator; - friend class Generator; - -public: - static TranslationUnit* New() { return new TranslationUnit();} - virtual ~TranslationUnit() {} - virtual void Accept(Visitor* v); - void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); } - ExtDeclList& ExtDecls() { return extDecls_; } - const ExtDeclList& ExtDecls() const { return extDecls_; } - -private: - TranslationUnit() {} - - ExtDeclList extDecls_; -}; - -#endif diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h deleted file mode 100644 index 531da19c5..000000000 --- a/include/triton/lang/code_gen.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#ifndef _WGTCC_CODE_GEN_H_ -#define _WGTCC_CODE_GEN_H_ - -#include "ast.h" -#include "visitor.h" -#include - -namespace triton{ -namespace ir{ - -class value; -class module; -class type; -class context; -class builder; -class attribute; - -} -} - -using namespace triton; - -class Parser; -struct Addr; -template<> class Evaluator; -struct StaticInitializer; -class LValAssigner; - -using TypeList = std::vector; -using LocationList = std::vector; -using StaticInitList = std::vector; - -// Error -inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); } -inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); } - -class Generator: public Visitor { - friend class Evaluator; - friend class LValAssigner; - -protected: - struct scope { - std::map types; - std::map values; - }; - - void set_ret(ir::value* value); - ir::value *GenUnaryMinus(ir::value* arg); - ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc); - -public: - Generator(Parser* parser) : parser_(parser) {} - - void Visit(ASTNode* node) { node->Accept(this); } - void VisitExpr(Expr* expr) { expr->Accept(this); } - void VisitStmt(Stmt* stmt) { stmt->Accept(this); } - - // Expression - void VisitBinaryOp(BinaryOp* binaryOp); - void VisitUnaryOp(UnaryOp* unaryOp); - void VisitTransOp(TransOp* transOp); - void VisitConditionalOp(ConditionalOp* condOp); - void VisitFuncCall(FuncCall* funcCall); - void VisitObject(Object* obj); - void VisitEnumerator(Enumerator* enumer); - void VisitIdentifier(Identifier* ident); - void VisitConstant(Constant* cons); - void VisitTempVar(TempVar* tempVar); - - // Statement - void VisitDeclaration(Declaration* init); - void VisitEmptyStmt(EmptyStmt* emptyStmt); - void VisitIfStmt(IfStmt* ifStmt); - void VisitForStmt(ForStmt* ifStmt); - void VisitJumpStmt(JumpStmt* jumpStmt); - void VisitReturnStmt(ReturnStmt* returnStmt); - void VisitLabelStmt(LabelStmt* labelStmt); - void VisitCompoundStmt(CompoundStmt* compoundStmt); - - void VisitFuncDef(FuncDef* funcDef); - void VisitTranslationUnit(TranslationUnit* unit); - - void Gen(ir::module *mod); - -protected: - // Triton-IR attributes - ir::attribute GenIRAttr(ASTNode::Attr attr); - - // Triton-IR metadata - void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs); - - // Triton-IR values - ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs); - ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty); - ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty); - ir::value* GenSemCastOp(ir::value* op, ir::type* type); - ir::value* GenBitCastOp(ir::value* src, ir::type* dst_ty); - - // Triton-IR types - static ir::type* GenIRType(::Type* type, ir::context &ctx); - static ir::type* GenIRArithmType(ArithmType* type, ir::context& ctx); - static ir::type* GenIRArrayType(ArrayType* type, ir::context& ctx); - static ir::type* GenIRTileType(TileType* type, ir::context& ctx); - static ir::type* GenIRFuncType(FuncType* type, ir::context& ctx); - static ir::type* GenIRPointerType(PointerType* type, ir::context& ctx); - static ir::type* GenIRStructType(StructType* type, ir::context& ctx); - void AllocObjects(Scope* scope, const FuncDef::ParamList& params=FuncDef::ParamList()); - - // SSA - void pushScope(); - void popScope(); - -private: - Parser* parser_; - ir::value* ret_; - ir::builder* bld_; - ir::context* ctx_; - ir::module* mod_; - -private: -// std::stack scopes_; - LValAssigner* assign_; -}; - - -class LValAssigner: public Visitor { -public: - LValAssigner(Generator* gen): gen_(gen) {} - - // Expression - void VisitBinaryOp(BinaryOp* binaryOp); - void VisitUnaryOp(UnaryOp* unaryOp); - void VisitObject(Object* obj); - void VisitIdentifier(Identifier* ident); - - void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); } - void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); } - void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); } - void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); } - void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); } - void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); } - void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); } - void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); } - void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); } - void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); } - void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); } - void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); } - void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); } - void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); } - void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); } - void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); } - - ir::value* GenExpr(Expr* expr, ir::value* rhs) { - rhs_ = rhs; - expr->Accept(this); - return ret_; - } - -private: - ir::value* ret_; - ir::value* rhs_; - Generator* gen_; -}; - -#endif diff --git a/include/triton/lang/cpp.h b/include/triton/lang/cpp.h deleted file mode 100644 index cfd839dd7..000000000 --- a/include/triton/lang/cpp.h +++ /dev/null @@ -1,164 +0,0 @@ -#pragma once - -#ifndef _WGTCC_CPP_H_ -#define _WGTCC_CPP_H_ - -#include "scanner.h" - -#include -#include -#include -#include -#include -#include - -class Macro; -struct CondDirective; - -using MacroMap = std::map; -using ParamList = std::list; -using ParamMap = std::map; -using PPCondStack = std::stack; -using PathList = std::list; - - -class Macro { -public: - Macro(const TokenSequence& repSeq, bool preDef=false) - : funcLike_(false), variadic_(false), - preDef_(preDef), repSeq_(repSeq) {} - - Macro(bool variadic, ParamList& params, - TokenSequence& repSeq, bool preDef=false) - : funcLike_(true), variadic_(variadic), preDef_(preDef), - params_(params), repSeq_(repSeq) {} - - ~Macro() {} - bool FuncLike() { return funcLike_; } - bool ObjLike() { return !FuncLike(); } - bool Variadic() { return variadic_; } - bool PreDef() { return preDef_; } - ParamList& Params() { return params_; } - TokenSequence RepSeq(const std::string* filename, unsigned line); - -private: - bool funcLike_; - bool variadic_; - bool preDef_; - ParamList params_; - TokenSequence repSeq_; -}; - - -struct CondDirective { - int tag_; - bool enabled_; - bool cond_; -}; - - -class Preprocessor { -public: - Preprocessor(const std::string* str, bool isSrc = true) - : curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) { - if(isSrc) - fSrc_ = str; - else - fName_ = str; - // Add predefined - Init(); - } - - - ~Preprocessor() {} - void Finalize(TokenSequence os); - void Process(TokenSequence& os); - void Expand(TokenSequence& os, TokenSequence is, bool inCond=false); - void Subst(TokenSequence& os, TokenSequence is, - bool leadingWS, const HideSet& hs, ParamMap& params); - void Glue(TokenSequence& os, TokenSequence is); - void Glue(TokenSequence& os, const Token* tok); - const Token* Stringize(TokenSequence is); - void Stringize(std::string& str, TokenSequence is); - const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap); - int GetDirective(TokenSequence& is); - const Token* EvalDefOp(TokenSequence& is); - void ReplaceIdent(TokenSequence& is); - void ParseDirective(TokenSequence& os, TokenSequence& is, int directive); - void ParseIf(TokenSequence ls); - void ParseIfdef(TokenSequence ls); - void ParseIfndef(TokenSequence ls); - void ParseElif(TokenSequence ls); - void ParseElse(TokenSequence ls); - void ParseEndif(TokenSequence ls); - void ParseInclude(TokenSequence& is, TokenSequence ls); - void ParseDef(TokenSequence ls); - void ParseUndef(TokenSequence ls); - void ParseLine(TokenSequence ls); - void ParseError(TokenSequence ls); - void ParsePragma(TokenSequence ls); - void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename); - void IncludeFile(TokenSequence& is, const std::string* filename); - bool ParseIdentList(ParamList& params, TokenSequence& is); - - - Macro* FindMacro(const std::string& name) { - auto res = macroMap_.find(name); - if (res == macroMap_.end()) - return nullptr; - return &res->second; - } - - void AddMacro(const std::string& name, - std::string* text, bool preDef=false); - - void AddMacro(const std::string& name, const Macro& macro) { - auto res = macroMap_.find(name); - if (res != macroMap_.end()) { - // TODO(wgtdkp): give warning - macroMap_.erase(res); - } - macroMap_.insert(std::make_pair(name, macro)); - } - - void RemoveMacro(const std::string& name) { - auto res = macroMap_.find(name); - if (res == macroMap_.end()) - return; - if(res->second.PreDef()) // Cannot undef predefined macro - return; - macroMap_.erase(res); - } - - std::string* SearchFile(const std::string& name, - const bool libHeader, - bool next, - const std::string& curPath); - - void AddSearchPath(std::string path); - void HandleTheFileMacro(TokenSequence& os, const Token* macro); - void HandleTheLineMacro(TokenSequence& os, const Token* macro); - void UpdateFirstTokenLine(TokenSequence ts); - - bool NeedExpand() const { - if (ppCondStack_.empty()) - return true; - auto top = ppCondStack_.top(); - return top.enabled_ && top.cond_; - } - -private: - void Init(); - - PPCondStack ppCondStack_; - unsigned curLine_; - unsigned lineLine_; - bool curCond_; - - MacroMap macroMap_; - PathList searchPaths_; - const std::string* fName_; - const std::string* fSrc_; -}; - -#endif diff --git a/include/triton/lang/encoding.h b/include/triton/lang/encoding.h deleted file mode 100644 index 297b2b732..000000000 --- a/include/triton/lang/encoding.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#ifndef _WGTCC_ENCODING_H_ -#define _WGTCC_ENCODING_H_ - -#include - - -enum class Encoding { - NONE, - CHAR16, - CHAR32, - UTF8, - WCHAR -}; - - -void ConvertToUTF16(std::string& str); -void ConvertToUTF32(std::string& str); -void AppendUCN(std::string& str, int c); - -#endif diff --git a/include/triton/lang/error.h b/include/triton/lang/error.h deleted file mode 100644 index 386ca3a3e..000000000 --- a/include/triton/lang/error.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#ifndef _WGTCC_ERROR_H_ -#define _WGTCC_ERROR_H_ - - -struct SourceLocation; -class Token; -class Expr; - - -[[noreturn]] void Error(const char* format, ...); -[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...); -[[noreturn]] void Error(const Token* tok, const char* format, ...); -[[noreturn]] void Error(const Expr* expr, const char* format, ...); - -#endif diff --git a/include/triton/lang/evaluator.h b/include/triton/lang/evaluator.h deleted file mode 100644 index ac8404550..000000000 --- a/include/triton/lang/evaluator.h +++ /dev/null @@ -1,130 +0,0 @@ -#pragma once - -#ifndef _WGTCC_EVALUATOR_H_ -#define _WGTCC_EVALUATOR_H_ - -#include "ast.h" -#include "error.h" -#include "visitor.h" - - -class Expr; - -template -class Evaluator: public Visitor { -public: - Evaluator() {} - - virtual ~Evaluator() {} - - virtual void VisitBinaryOp(BinaryOp* binary); - virtual void VisitUnaryOp(UnaryOp* unary); - virtual void VisitConditionalOp(ConditionalOp* cond); - - virtual void VisitFuncCall(FuncCall* funcCall) { - Error(funcCall, "expect constant expression"); - } - virtual void VisitEnumerator(Enumerator* enumer) { - val_ = static_cast(enumer->Val()); - } - virtual void VisitIdentifier(Identifier* ident) { - Error(ident, "expect constant expression"); - } - virtual void VisitTransOp(TransOp* trans) { - Error(trans, "expect constant expression"); - } - virtual void VisitObject(Object* obj) { - Error(obj, "expect constant expression"); - } - virtual void VisitConstant(Constant* cons) { - if (cons->Type()->IsFloat()) { - val_ = static_cast(cons->FVal()); - } else if (cons->Type()->IsInteger()) { - val_ = static_cast(cons->IVal()); - } else { - assert(false); - } - } - virtual void VisitTempVar(TempVar* tempVar) { assert(false); } - - // We may should assert here - virtual void VisitDeclaration(Declaration* init) {} - virtual void VisitIfStmt(IfStmt* ifStmt) {} - virtual void VisitForStmt(ForStmt* forStmt) {} - virtual void VisitJumpStmt(JumpStmt* jumpStmt) {} - virtual void VisitReturnStmt(ReturnStmt* returnStmt) {} - virtual void VisitLabelStmt(LabelStmt* labelStmt) {} - virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {} - virtual void VisitCompoundStmt(CompoundStmt* compStmt) {} - virtual void VisitFuncDef(FuncDef* funcDef) {} - virtual void VisitTranslationUnit(TranslationUnit* unit) {} - - T Eval(Expr* expr) { - expr->Accept(this); - return val_; - } - -private: - T val_; -}; - - -struct Addr { - std::string label_; - int offset_; -}; - -template<> -class Evaluator: public Visitor { -public: - Evaluator() {} - virtual ~Evaluator() {} - virtual void VisitBinaryOp(BinaryOp* binary); - virtual void VisitUnaryOp(UnaryOp* unary); - virtual void VisitConditionalOp(ConditionalOp* cond); - - virtual void VisitFuncCall(FuncCall* funcCall) { - Error(funcCall, "expect constant expression"); - } - virtual void VisitTransOp(TransOp* trans) { - Error(trans, "expect constant expression"); - } - virtual void VisitEnumerator(Enumerator* enumer) { - addr_.offset_ = enumer->Val(); - } - virtual void VisitIdentifier(Identifier* ident) { - addr_.label_ = ident->Name(); - addr_.offset_ = 0; - } - virtual void VisitObject(Object* obj) { - if (!obj->IsStatic()) { - Error(obj, "expect static object"); - } - addr_.label_ = obj->Repr(); - addr_.offset_ = 0; - } - virtual void VisitConstant(Constant* cons); - virtual void VisitTempVar(TempVar* tempVar) { assert(false); } - - // We may should assert here - virtual void VisitDeclaration(Declaration* init) {} - virtual void VisitIfStmt(IfStmt* ifStmt) {} - virtual void VisitForStmt(ForStmt* forStmt) {} - virtual void VisitJumpStmt(JumpStmt* jumpStmt) {} - virtual void VisitReturnStmt(ReturnStmt* returnStmt) {} - virtual void VisitLabelStmt(LabelStmt* labelStmt) {} - virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {} - virtual void VisitCompoundStmt(CompoundStmt* compStmt) {} - virtual void VisitFuncDef(FuncDef* funcDef) {} - virtual void VisitTranslationUnit(TranslationUnit* unit) {} - - Addr Eval(Expr* expr) { - expr->Accept(this); - return addr_; - } - -private: - Addr addr_; -}; - -#endif diff --git a/include/triton/lang/mem_pool.h b/include/triton/lang/mem_pool.h deleted file mode 100644 index 9b6ab53c1..000000000 --- a/include/triton/lang/mem_pool.h +++ /dev/null @@ -1,103 +0,0 @@ -#pragma once - -#ifndef _WGTCC_MEM_POOL_H_ -#define _WGTCC_MEM_POOL_H_ - -#include -#include - - -class MemPool { -public: - MemPool(): allocated_(0) {} - virtual ~MemPool() {} - MemPool(const MemPool& other) = delete; - MemPool& operator=(const MemPool& other) = delete; - virtual void* Alloc() = 0; - virtual void Free(void* addr) = 0; - virtual void Clear() = 0; - -protected: - size_t allocated_; -}; - - -template -class MemPoolImp: public MemPool { -public: - MemPoolImp() : root_(nullptr) {} - virtual ~MemPoolImp() {} - MemPoolImp(const MemPool& other) = delete; - MemPoolImp& operator=(MemPool& other) = delete; - virtual void* Alloc(); - virtual void Free(void* addr); - virtual void Clear(); - -private: - enum { - COUNT = (4 * 1024) / sizeof(T) - }; - - union Chunk { - Chunk* next_; - char mem_[sizeof(T)]; - }; - - struct Block { - Block() { - for (size_t i = 0; i < COUNT - 1; ++i) - chunks_[i].next_ = &chunks_[i+1]; - chunks_[COUNT-1].next_ = nullptr; - } - Chunk chunks_[COUNT]; - }; - - std::vector blocks_; - Chunk* root_; -}; - - -template -void* MemPoolImp::Alloc() { - if (nullptr == root_) { // 空间不够,需要分配空间 - auto block = new Block(); - root_ = block->chunks_; - // 如果blocks实现为std::list, 那么push_back实际的overhead更大 - // 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead), - // 仍然倾向于使用std::vector, - // 当然std::vector的指数级capacity增长会造成内存浪费。 - blocks_.push_back(block); - } - - auto ret = root_; - root_ = root_->next_; - - ++allocated_; - return ret; -} - - -template -void MemPoolImp::Free(void* addr) { - if (nullptr == addr) - return; - - auto chunk = static_cast(addr); - chunk->next_ = root_; - root_ = chunk; - - --allocated_; -} - - -template -void MemPoolImp::Clear() { - for (auto block: blocks_) - delete block; - - blocks_.resize(0); - root_ = nullptr; - allocated_ = 0; -} - -#endif diff --git a/include/triton/lang/parser.h b/include/triton/lang/parser.h deleted file mode 100644 index b9542b40e..000000000 --- a/include/triton/lang/parser.h +++ /dev/null @@ -1,260 +0,0 @@ -#pragma once - -#ifndef _PARSER_H_ -#define _PARSER_H_ - -#include "ast.h" -#include "encoding.h" -#include "error.h" -#include "mem_pool.h" -#include "scope.h" -#include "token.h" - -#include -#include -#include - - -class Preprocessor; - -struct DeclInfo { - DeclInfo(const Token* _tok, - QualType _type, - ASTNode::AttrList _attrs = {}) - : tok(_tok), type(_type), attrs(_attrs) {} - - const Token* tok; - QualType type; - ASTNode::AttrList attrs; -}; - - -class Parser { - using LiteralList = std::vector; - using StaticObjectList = std::vector; - using CaseLabelList = std::vector>; - using LabelJumpList = std::list>; - using LabelMap = std::map; - friend class Generator; - -public: - explicit Parser(TokenSequence& ts) - : unit_(TranslationUnit::New()), - ts_(ts), - externalSymbols_(new Scope(nullptr, S_BLOCK)), - errTok_(nullptr), - curScope_(new Scope(nullptr, S_FILE)), - curFunc_(nullptr), - breakDest_(nullptr), - continueDest_(nullptr), - caseLabels_(nullptr), - defaultLabel_(nullptr) { - ts_.SetParser(this); - } - - ~Parser() {} - - Constant* ParseConstant(const Token* tok); - Constant* ParseFloat(const Token* tok); - Constant* ParseInteger(const Token* tok); - Constant* ParseCharacter(const Token* tok); - Encoding ParseLiteral(std::string& str, const Token* tok); - Constant* ConcatLiterals(const Token* tok); - Expr* ParseGeneric(); - - void Parse(); - void ParseTranslationUnit(); - FuncDef* ParseFuncDef(Identifier* ident); - - - // Expressions - Expr* ParseExpr(); - Expr* ParsePrimaryExpr(); - QualType TryCompoundLiteral(); - Object* ParseCompoundLiteral(QualType type); - Expr* ParsePostfixExpr(); - Expr* ParsePostfixExprTail(Expr* primExpr); - Expr* ParseSubScripting(Expr* pointer); - BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs); - UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand); - FuncCall* ParseFuncCall(Expr* caller); - - Expr* ParseUnaryExpr(); - Constant* ParseSizeof(); - Constant* ParseAlignof(); - UnaryOp* ParsePrefixIncDec(const Token* tok); - UnaryOp* ParseUnaryIntrinsicOp(int op); - UnaryOp* ParseUnaryOp(const Token* tok, int op); - Expr* ParseDerefOp(const Token* tok); - - QualType ParseTypeName(); - Expr* ParseCastExpr(); - Expr* ParseRangeExpr(); - Expr* ParseMatmulExpr(); - Expr* ParseMultiplicativeExpr(); - Expr* ParseAdditiveExpr(); - Expr* ParseShiftExpr(); - Expr* ParseRelationalExpr(); - Expr* ParseEqualityExpr(); - Expr* ParseBitiwiseAndExpr(); - Expr* ParseBitwiseXorExpr(); - Expr* ParseBitwiseOrExpr(); - Expr* ParseLogicalAndExpr(); - Expr* ParseLogicalOrExpr(); - Expr* ParseConditionalExpr(); - Expr* ParseCommaExpr(); - Expr* ParseAssignExpr(); - - // Declarations - CompoundStmt* ParseDecl(); - void ParseStaticAssert(); - QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec); - QualType ParseSpecQual(); - int ParseAlignas(); - Type* ParseStructUnionSpec(bool isStruct); - StructType* ParseStructUnionDecl(StructType* type); - void ParseBitField(StructType* structType, const Token* tok, QualType type); - Type* ParseEnumSpec(); - Type* ParseEnumerator(ArithmType* type); - int ParseQual(); - QualType ParsePointer(QualType typePointedTo); - DeclInfo ParseDeclarator(QualType type); - QualType ParseArrayFuncDeclarator(const Token* ident, QualType base); - int ParseArrayLength(); - TileType::ShapeInt ParseTileShape(); - bool ParseParamList(FuncType::ParamList& params); - Object* ParseParamDecl(); - - QualType ParseAbstractDeclarator(QualType type); - Identifier* ParseDirectDeclarator(QualType type, - int storageSpec, - int funcSpec, - int align); - // Initializer - void ParseInitializer(Declaration* decl, - QualType type, - int offset, - bool designated=false, - bool forceBrace=false, - unsigned char bitFieldBegin=0, - unsigned char bitFieldWidth=0); - void ParseArrayInitializer(Declaration* decl, - ArrayType* type, - int offset, - bool designated); - StructType::Iterator ParseStructDesignator(StructType* type, - const std::string& name); - void ParseStructInitializer(Declaration* decl, - StructType* type, - int offset, - bool designated); - bool ParseLiteralInitializer(Declaration* init, - ArrayType* type, - int offset); - Declaration* ParseInitDeclarator(Identifier* ident); - Declaration* ParseInitDeclaratorSub(Object* obj); - - // Statements - Stmt* ParseStmt(); - CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr); - IfStmt* ParseIfStmt(); - CompoundStmt* ParseSwitchStmt(); - CompoundStmt* ParseWhileStmt(); - CompoundStmt* ParseDoStmt(); - ForStmt *ParseForStmt(); - JumpStmt* ParseGotoStmt(); - JumpStmt* ParseContinueStmt(); - JumpStmt* ParseBreakStmt(); - ReturnStmt* ParseReturnStmt(); - CompoundStmt* ParseLabelStmt(const Token* label); - CompoundStmt* ParseCaseStmt(); - CompoundStmt* ParseDefaultStmt(); - Identifier* ProcessDeclarator(const Token* tok, - QualType type, const ASTNode::AttrList &attrs, - int storageSpec, - int funcSpec, - int align); - // GNU extensions - ASTNode::AttrList TryAttributeSpecList(); - void ParseAttributeSpec(ASTNode::AttrList &attrList); - ASTNode::Attr ParseAttribute(); - bool IsTypeName(const Token* tok) const{ - if (tok->IsTypeSpecQual()) - return true; - - if (tok->IsIdentifier()) { - auto ident = curScope_->Find(tok); - if (ident && ident->ToTypeName()) - return true; - } - return false; - } - bool IsType(const Token* tok) const{ - if (tok->IsDecl()) - return true; - - if (tok->IsIdentifier()) { - auto ident = curScope_->Find(tok); - return (ident && ident->ToTypeName()); - } - - return false; - } - void EnsureInteger(Expr* expr) { - if (!expr->Type()->IsInteger()) { - Error(expr, "expect integer expression"); - } - } - - void EnterBlock(FuncType* funcType=nullptr); - void ExitBlock() { curScope_ = curScope_->Parent(); } - void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); } - void ExitProto() { curScope_ = curScope_->Parent(); } - FuncDef* EnterFunc(Identifier* ident); - void ExitFunc(); - - LabelStmt* FindLabel(const std::string& label) { - auto ret = curLabels_.find(label); - if (curLabels_.end() == ret) - return nullptr; - return ret->second; - } - void AddLabel(const std::string& label, LabelStmt* labelStmt) { - assert(nullptr == FindLabel(label)); - curLabels_[label] = labelStmt; - } - TranslationUnit* Unit() { return unit_; } - FuncDef* CurFunc() { return curFunc_; } - const TokenSequence& ts() const { return ts_; } - -protected: - static bool IsBuiltin(FuncType* type); - static bool IsBuiltin(const std::string& name); - static Identifier* GetBuiltin(const Token* tok); - static void DefineBuiltins(); - - static FuncType* vaStartType_; - static FuncType* vaArgType_; - - // The root of the AST - TranslationUnit* unit_; - - TokenSequence& ts_; - - // It is not the real scope, - // It contains all external symbols(resolved and not resolved) - Scope* externalSymbols_; - - const Token* errTok_; - Scope* curScope_; - FuncDef* curFunc_; - LabelMap curLabels_; - LabelJumpList unresolvedJumps_; - - LabelStmt* breakDest_; - LabelStmt* continueDest_; - CaseLabelList* caseLabels_; - LabelStmt* defaultLabel_; -}; - -#endif diff --git a/include/triton/lang/scanner.h b/include/triton/lang/scanner.h deleted file mode 100644 index 57cdff9a0..000000000 --- a/include/triton/lang/scanner.h +++ /dev/null @@ -1,86 +0,0 @@ -#pragma once - -#ifndef _WGTCC_SCANNER_H_ -#define _WGTCC_SCANNER_H_ - -#include "error.h" -#include "encoding.h" -#include "token.h" - -#include -#include - - -class Scanner { -public: - explicit Scanner(const Token* tok) - : Scanner(&tok->str_, tok->loc_) {} - Scanner(const std::string* text, const SourceLocation& loc) - : Scanner(text, loc.filename_, loc.line_, loc.column_) {} - explicit Scanner(const std::string* text, - const std::string* filename=nullptr, - unsigned line=1, unsigned column=1) - : text_(text), tok_(Token::END) { - // TODO(wgtdkp): initialization - p_ = &(*text_)[0]; - loc_ = {filename, p_, line, 1}; - } - - virtual ~Scanner() {} - Scanner(const Scanner& other) = delete; - Scanner& operator=(const Scanner& other) = delete; - - // Scan plain text and generate tokens in ts. - // The param 'ts' need not be empty, if so, the tokens - // are inserted at the *header* of 'ts'. - // The param 'ws' tells if there is leading white space - // before this token, it is only SkipComment() that will - // set this param. - Token* Scan(bool ws=false); - void Tokenize(TokenSequence& ts); - static std::string ScanHeadName(const Token* lhs, const Token* rhs); - Encoding ScanCharacter(int& val); - Encoding ScanLiteral(std::string& val); - std::string ScanIdentifier(); - -private: - Token* SkipIdentifier(); - Token* SkipNumber(); - Token* SkipLiteral(); - Token* SkipCharacter(); - Token* MakeToken(int tag); - Token* MakeNewLine(); - Encoding ScanEncoding(int c); - int ScanEscaped(); - int ScanHexEscaped(); - int ScanOctEscaped(int c); - int ScanUCN(int len); - void SkipWhiteSpace(); - void SkipComment(); - bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); } - bool IsOctal(int c) { return '0' <= c && c <= '7'; } - int XDigit(int c); - bool Empty() const { return *p_ == 0; } - int Peek(); - bool Test(int c) { return Peek() == c; }; - int Next(); - void PutBack(); - bool Try(int c) { - if (Peek() == c) { - Next(); - return true; - } - return false; - }; - void Mark() { tok_.loc_ = loc_; }; - - const std::string* text_; - SourceLocation loc_; - Token tok_; - const char* p_; -}; - - -std::string* ReadFile(const std::string& filename); - -#endif diff --git a/include/triton/lang/scope.h b/include/triton/lang/scope.h deleted file mode 100644 index b958d3ecd..000000000 --- a/include/triton/lang/scope.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once - -#ifndef _WGTCC_SCOPE_H_ -#define _WGTCC_SCOPE_H_ - -#include -#include -#include -#include - - -class Identifier; -class Token; - - -enum ScopeType { - S_FILE, - S_PROTO, - S_BLOCK, - S_FUNC, -}; - - -class Scope { - friend class StructType; - using TagList = std::vector; - using IdentMap = std::map; - -public: - explicit Scope(Scope* parent, enum ScopeType type) - : parent_(parent), type_(type) {} - ~Scope() {} - Scope* Parent() { return parent_; } - void SetParent(Scope* parent) { parent_ = parent; } - enum ScopeType Type() const { return type_; } - - Identifier* Find(const Token* tok); - Identifier* FindInCurScope(const Token* tok); - Identifier* FindTag(const Token* tok); - Identifier* FindTagInCurScope(const Token* tok); - TagList AllTagsInCurScope() const; - - void Insert(Identifier* ident); - void Insert(const std::string& name, Identifier* ident); - void InsertTag(Identifier* ident); - void Print(); - bool operator==(const Scope& other) const { return type_ == other.type_; } - IdentMap::iterator begin() { return identMap_.begin(); } - IdentMap::iterator end() { return identMap_.end(); } - size_t size() const { return identMap_.size(); } - -private: - Identifier* Find(const std::string& name); - Identifier* FindInCurScope(const std::string& name); - Identifier* FindTag(const std::string& name); - Identifier* FindTagInCurScope(const std::string& name); - std::string TagName(const std::string& name) { - return name + "@:tag"; - } - static bool IsTagName(const std::string& name) { - return name.size() > 5 && name[name.size() - 5] == '@'; - } - const Scope& operator=(const Scope& other); - Scope(const Scope& scope); - - Scope* parent_; - enum ScopeType type_; - - IdentMap identMap_; -}; - -#endif diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h deleted file mode 100644 index 178f8c42e..000000000 --- a/include/triton/lang/token.h +++ /dev/null @@ -1,434 +0,0 @@ -#pragma once - -#ifndef _WGTCC_TOKEN_H_ -#define _WGTCC_TOKEN_H_ - -#include "error.h" - -#include -#include -#include -#include -#include -#include -#include - - -class Generator; -class Parser; -class Scanner; -class Token; -class TokenSequence; - -using HideSet = std::set; -using TokenList = std::list; - - -struct SourceLocation { - const std::string* filename_; - const char* lineBegin_; - unsigned line_; - unsigned column_; - - const char* Begin() const { - return lineBegin_ + column_ - 1; - } -}; - - -class Token { - friend class Scanner; -public: - enum { - // Punctuators - LPAR = '(', - RPAR = ')', - LSQB = '[', - RSQB = ']', - COLON = ':', - COMMA = ',', - SEMI = ';', - ADD = '+', - SUB = '-', - MUL = '*', - DIV = '/', - OR = '|', - AND = '&', - XOR = '^', - LESS = '<', - GREATER = '>', - EQUAL = '=', - DOT = '.', - MOD = '%', - LBRACE = '{', - RBRACE = '}', - TILDE = '~', - NOT = '!', - COND = '?', - SHARP = '#', - MATMUL = '@', - NEW_LINE = '\n', - - DSHARP = 128, // '##' - PTR, - INC, - DEC, - LEFT, - RIGHT, - LE, - GE, - EQ, - NE, - LOGICAL_AND, - LOGICAL_OR, - - MUL_ASSIGN, - DIV_ASSIGN, - MOD_ASSIGN, - ADD_ASSIGN, - SUB_ASSIGN, - LEFT_ASSIGN, - RIGHT_ASSIGN, - AND_ASSIGN, - XOR_ASSIGN, - OR_ASSIGN, - - ELLIPSIS, - MASKED_DEREF, - // Punctuators end - - // KEYWORD BEGIN - // TYPE QUALIFIER BEGIN - CONST, - RESTRICT, - VOLATILE, - ATOMIC, - // TYPE QUALIFIER END - - // TYPE SPECIFIER BEGIN - VOID, - CHAR, - SHORT, - INT, - LONG, - HALF, - FLOAT, - DOUBLE, - SIGNED, - UNSIGNED, - BOOL, // _Bool - COMPLEX, // _Complex - STRUCT, - UNION, - ENUM, - // TYPE SPECIFIER END - - ATTRIBUTE, // GNU extension __attribute__ - // FUNCTION SPECIFIER BEGIN - INLINE, - NORETURN, // _Noreturn - // FUNCTION SPECIFIER END - - // TILE ARITHMETICS BEGIN - NEWAXIS, - MAX, - MIN, - // TILE ARITHMETICS END - - ALIGNAS, // _Alignas - // For syntactic convenience - STATIC_ASSERT, // _Static_assert - // STORAGE CLASS SPECIFIER BEGIN - TYPEDEF, - EXTERN, - STATIC, - THREAD, // _Thread_local - AUTO, - GLOBAL, - CMEM, // constant memory - - // STORAGE CLASS SPECIFIER END - BREAK, - CASE, - CONTINUE, - DEFAULT, - DO, - ELSE, - FOR, - GOTO, - IF, - RETURN, - SIZEOF, - SWITCH, - WHILE, - ALIGNOF, // _Alignof - GENERIC, // _Generic - IMAGINARY, // _Imaginary - // function keywords - BITCAST, - EXP, - LOG, - SQRTF, - // KEYWORD END - - IDENTIFIER, - CONSTANT, - I_CONSTANT, - C_CONSTANT, - F_CONSTANT, - LITERAL, - - // For the parser, a identifier is a typedef name or user defined type - POSTFIX_INC, - POSTFIX_DEC, - PREFIX_INC, - PREFIX_DEC, - ADDR, // '&' - DEREF, // '*' - PLUS, - MINUS, - CAST, - REDUCE, - - // For preprocessor - PP_IF, - PP_IFDEF, - PP_IFNDEF, - PP_ELIF, - PP_ELSE, - PP_ENDIF, - PP_INCLUDE, - PP_DEFINE, - PP_UNDEF, - PP_LINE, - PP_ERROR, - PP_PRAGMA, - PP_NONE, - PP_EMPTY, - - - IGNORE, - INVALID, - END, - NOTOK = -1, - }; - - static Token* New(int tag); - static Token* New(const Token& other); - static Token* New(int tag, - const SourceLocation& loc, - const std::string& str, - bool ws=false); - Token& operator=(const Token& other) { - tag_ = other.tag_; - ws_ = other.ws_; - loc_ = other.loc_; - str_ = other.str_; - hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr; - return *this; - } - virtual ~Token() {} - - // Token::NOTOK represents not a kw. - static int KeyWordTag(const std::string& key) { - auto kwIter = kwTypeMap_.find(key); - if (kwTypeMap_.end() == kwIter) - return Token::NOTOK; // Not a key word type - return kwIter->second; - } - static bool IsKeyWord(const std::string& name); - static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; } - bool IsKeyWord() const { return IsKeyWord(tag_); } - bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; } - bool IsLiteral() const { return tag_ == LITERAL; } - bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; } - bool IsIdentifier() const { return IDENTIFIER == tag_; } - bool IsEOF() const { return tag_ == Token::END; } - bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; } - bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; } - static const char* Lexeme(int tag) { - auto iter = tagLexemeMap_.find(tag); - if (iter == tagLexemeMap_.end()) - return nullptr; - - return iter->second; - } - - int tag_; - - // 'ws_' standards for weither there is preceding white space - // This is to simplify the '#' operator(stringize) in macro expansion - bool ws_ { false }; - SourceLocation loc_; - - std::string str_; - HideSet* hs_ { nullptr }; - -private: - explicit Token(int tag): tag_(tag) {} - Token(int tag, const SourceLocation& loc, - const std::string& str, bool ws=false) - : tag_(tag), ws_(ws), loc_(loc), str_(str) {} - - Token(const Token& other) { - *this = other; - } - - static const std::unordered_map kwTypeMap_; - static const std::unordered_map tagLexemeMap_; -}; - - -class TokenSequence { - friend class Preprocessor; - -public: - TokenSequence(): tokList_(new TokenList()), - begin_(tokList_->begin()), end_(tokList_->end()) {} - explicit TokenSequence(Token* tok) { - TokenSequence(); - InsertBack(tok); - } - explicit TokenSequence(TokenList* tokList) - : tokList_(tokList), - begin_(tokList->begin()), - end_(tokList->end()) {} - TokenSequence(TokenList* tokList, - TokenList::iterator begin, - TokenList::iterator end) - : tokList_(tokList), begin_(begin), end_(end) {} - ~TokenSequence() {} - TokenSequence(const TokenSequence& other) { *this = other; } - const TokenSequence& operator=(const TokenSequence& other) { - tokList_ = other.tokList_; - begin_ = other.begin_; - end_ = other.end_; - return *this; - } - void Copy(const TokenSequence& other) { - tokList_ = new TokenList(other.begin_, other.end_); - begin_ = tokList_->begin(); - end_ = tokList_->end(); - for (auto iter = begin_; iter != end_; ++iter) - *iter = Token::New(**iter); - } - void UpdateHeadLocation(const SourceLocation& loc) { - assert(!Empty()); - auto tok = const_cast(Peek()); - tok->loc_ = loc; - } - void FinalizeSubst(bool leadingWS, const HideSet& hs) { - auto ts = *this; - while (!ts.Empty()) { - auto tok = const_cast(ts.Next()); - if (!tok->hs_) - tok->hs_ = new HideSet(hs); - else - tok->hs_->insert(hs.begin(), hs.end()); - } - // Even if the token sequence is empty - const_cast(Peek())->ws_ = leadingWS; - } - - const Token* Expect(int expect); - bool Try(int tag) { - if (Peek()->tag_ == tag) { - Next(); - return true; - } - return false; - } - bool Test(int tag) { return Peek()->tag_ == tag; } - const Token* Next() { - auto ret = Peek(); - if (!ret->IsEOF()) { - ++begin_; - Peek(); // May skip newline token, but why ? - } else { - ++exceed_end; - } - return ret; - } - void PutBack() { - assert(begin_ != tokList_->begin()); - if (exceed_end > 0) { - --exceed_end; - } else { - --begin_; - if ((*begin_)->tag_ == Token::NEW_LINE) - PutBack(); - } - } - const Token* Peek() const; - const Token* Peek2() { - if (Empty()) - return Peek(); // Return the Token::END - Next(); - auto ret = Peek(); - PutBack(); - return ret; - } - const Token* Back() const { - auto back = end_; - return *--back; - } - void PopBack() { - assert(!Empty()); - assert(end_ == tokList_->end()); - auto size_eq1 = tokList_->back() == *begin_; - tokList_->pop_back(); - end_ = tokList_->end(); - if (size_eq1) - begin_ = end_; - } - TokenList::iterator Mark() { return begin_; } - void ResetTo(TokenList::iterator mark) { begin_ = mark; } - bool Empty() const { return Peek()->tag_ == Token::END; } - void InsertBack(TokenSequence& ts) { - auto pos = tokList_->insert(end_, ts.begin_, ts.end_); - if (begin_ == end_) { - begin_ = pos; - } - } - void InsertBack(const Token* tok) { - auto pos = tokList_->insert(end_, tok); - if (begin_ == end_) { - begin_ = pos; - } - } - - // If there is preceding newline - void InsertFront(TokenSequence& ts) { - auto pos = GetInsertFrontPos(); - begin_ = tokList_->insert(pos, ts.begin_, ts.end_); - } - void InsertFront(const Token* tok) { - auto pos = GetInsertFrontPos(); - begin_ = tokList_->insert(pos, tok); - } - bool IsBeginOfLine() const; - TokenSequence GetLine(); - void SetParser(Parser* parser) { parser_ = parser; } - void Print(FILE* fp=stdout) const; - void Print(std::string *str) const; - -private: - // Find a insert position with no preceding newline - TokenList::iterator GetInsertFrontPos() { - auto pos = begin_; - if (pos == tokList_->begin()) - return pos; - --pos; - while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE) - --pos; - return ++pos; - } - - TokenList* tokList_; - mutable TokenList::iterator begin_; - TokenList::iterator end_; - Parser* parser_ {nullptr}; - int exceed_end {0}; -}; - -#endif diff --git a/include/triton/lang/type.h b/include/triton/lang/type.h deleted file mode 100644 index 8b63b401c..000000000 --- a/include/triton/lang/type.h +++ /dev/null @@ -1,453 +0,0 @@ -#pragma once - -#ifndef _WGTCC_TYPE_H_ -#define _WGTCC_TYPE_H_ - -#include "mem_pool.h" -#include "scope.h" - -#include -#include -#include -#include - - -class Scope; -class Token; -class Expr; - -class Type; -class QualType; -class VoidType; -class Identifier; -class Object; -class Constant; - -class ArithmType; -class DerivedType; -class ArrayType; -class TileType; -class FuncType; -class PointerType; -class StructType; -class EnumType; - - -enum { - // Storage class specifiers - S_TYPEDEF = 0x01, - S_EXTERN = 0x02, - S_STATIC = 0x04, - S_THREAD = 0x08, - S_CONSTANT = 0x10, - S_GLOBAL = 0x20, - - // Type specifier - T_SIGNED = 0x40, - T_UNSIGNED = 0x80, - T_CHAR = 0x100, - T_SHORT = 0x200, - T_INT = 0x400, - T_LONG = 0x800, - T_VOID = 0x1000, - T_HALF = 0x2000, - T_FLOAT = 0x4000, - T_DOUBLE = 0x8000, - T_BOOL = 0x10000, - T_COMPLEX = 0x20000, - // T_ATOMIC = 0x40000, - T_STRUCT_UNION = 0x80000, - T_ENUM = 0x100000, - T_TYPEDEF_NAME = 0x200000, - - T_LLONG = 0x4000000, - - // Function specifier - F_INLINE = 0x8000000, - F_NORETURN = 0x10000000, -}; - - -struct Qualifier { - enum { - CONST = 0x01, - RESTRICT = 0x02, - VOLATILE = 0x04, - CMEM = 0x08, - MASK = CONST | RESTRICT | VOLATILE | CMEM - }; -}; - - -class QualType { -public: - QualType(Type* ptr, int quals=0x00) - : ptr_(reinterpret_cast(ptr)) { - assert((quals & ~Qualifier::MASK) == 0); - ptr_ |= quals; - } - - operator bool() const { return !IsNull(); } - bool IsNull() const { return GetPtr() == nullptr; } - const Type* GetPtr() const { - return reinterpret_cast(ptr_ & ~Qualifier::MASK); - } - Type* GetPtr() { - return reinterpret_cast(ptr_ & ~Qualifier::MASK); - } - Type& operator*() { return *GetPtr(); } - const Type& operator*() const { return *GetPtr(); } - Type* operator->() { return GetPtr(); } - const Type* operator->() const { return GetPtr(); } - - // Indicate whether the specified types are identical(exclude qualifiers). - friend bool operator==(QualType lhs, QualType rhs) { - return lhs.operator->() == rhs.operator->(); - } - friend bool operator!=(QualType lhs, QualType rhs) { - return !(lhs == rhs); - } - - int Qual() const { return ptr_ & 0x07; } - bool IsConstQualified() const { return ptr_ & Qualifier::CONST; } - bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; } - bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; } - bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; } - -private: - intptr_t ptr_; -}; - - -class Type { -public: - static const int intWidth_ = 4; - static const int machineWidth_ = 8; - - bool operator!=(const Type& other) const = delete; - bool operator==(const Type& other) const = delete; - - virtual bool Compatible(const Type& other) const { - return complete_ == other.complete_; - } - - virtual ~Type() {} - - // For Debugging - virtual std::string Str() const = 0; - virtual int Width() const = 0; - virtual int Align() const { return Width(); } - static int MakeAlign(int offset, int align) { - if ((offset % align) == 0) - return offset; - if (offset >= 0) - return offset + align - (offset % align); - else - return offset - align - (offset % align); - } - - static QualType MayCast(QualType type, bool inProtoScope=false); - bool Complete() const { return complete_; } - void SetComplete(bool complete) const { complete_ = complete; } - - bool IsReal() const { return IsInteger() || IsFloat(); }; - virtual bool IsScalar() const { return false; } - virtual bool IsFloat() const { return false; } - virtual bool IsInteger() const { return false; } - virtual bool IsBool() const { return false; } - virtual bool IsVoidPointer() const { return false; } - virtual bool IsUnsigned() const { return false; } - virtual bool IsTile() const { return ToTile() != nullptr; } - - const Type* ScalarType() const; - Type* ScalarType(); - - virtual VoidType* ToVoid() { return nullptr; } - virtual const VoidType* ToVoid() const { return nullptr; } - virtual ArithmType* ToArithm() { return nullptr; } - virtual const ArithmType* ToArithm() const { return nullptr; } - virtual ArrayType* ToArray() { return nullptr; } - virtual const ArrayType* ToArray() const { return nullptr; } - virtual TileType* ToTile() { return nullptr; } - virtual const TileType* ToTile() const { return nullptr; } - virtual FuncType* ToFunc() { return nullptr; } - virtual const FuncType* ToFunc() const { return nullptr; } - virtual PointerType* ToPointer() { return nullptr; } - virtual const PointerType* ToPointer() const { return nullptr; } - virtual DerivedType* ToDerived() { return nullptr; } - virtual const DerivedType* ToDerived() const { return nullptr; } - virtual StructType* ToStruct() { return nullptr; } - virtual const StructType* ToStruct() const { return nullptr; } - -protected: - Type(MemPool* pool, bool complete) - : complete_(complete), pool_(pool) {} - - mutable bool complete_; - MemPool* pool_; -}; - - -class VoidType : public Type { -public: - static VoidType* New(); - virtual ~VoidType() {} - virtual VoidType* ToVoid() { return this; } - virtual const VoidType* ToVoid() const { return this; } - virtual bool Compatible(const Type& other) const { return other.ToVoid(); } - virtual int Width() const { - // Non-standard GNU extension - return 1; - } - virtual std::string Str() const { return "void:1"; } - -protected: - explicit VoidType(MemPool* pool): Type(pool, false) {} -}; - - -class ArithmType : public Type { -public: - static ArithmType* New(int typeSpec); - - virtual ~ArithmType() {} - virtual ArithmType* ToArithm() { return this; } - virtual const ArithmType* ToArithm() const { return this; } - virtual bool Compatible(const Type& other) const { - // C11 6.2.7 [1]: Two types have compatible type if their types are the same - // But I would to loose this constraints: integer and pointer are compatible - // if (IsInteger() && other.ToPointer()) - // return other.Compatible(*this); - return this == &other; - } - - virtual int Width() const; - virtual std::string Str() const; - virtual bool IsScalar() const { return true; } - virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); } - virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; } - virtual bool IsFloat() const { - return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE); - } - virtual bool IsBool() const { return tag_ & T_BOOL; } - bool IsComplex() const { return tag_ & T_COMPLEX; } - int Tag() const { return tag_; } - int Rank() const; - static ArithmType* IntegerPromote(ArithmType* type) { - assert(type->IsInteger()); - if (type->Rank() < ArithmType::New(T_INT)->Rank()) - return ArithmType::New(T_INT); - return type; - } - static ArithmType* MaxType(ArithmType* lhsType, - ArithmType* rhsType); - -protected: - explicit ArithmType(MemPool* pool, int spec) - : Type(pool, true), tag_(Spec2Tag(spec)) {} - -private: - static int Spec2Tag(int spec); - - int tag_; -}; - - -class DerivedType : public Type { -public: - QualType Derived() const { return derived_; } - void SetDerived(QualType derived) { derived_ = derived; } - virtual DerivedType* ToDerived() { return this; } - virtual const DerivedType* ToDerived() const { return this; } - -protected: - DerivedType(MemPool* pool, QualType derived) - : Type(pool, true), derived_(derived) {} - - QualType derived_; -}; - - -class PointerType : public DerivedType { -public: - static PointerType* New(QualType derived); - virtual ~PointerType() {} - virtual PointerType* ToPointer() { return this; } - virtual const PointerType* ToPointer() const { return this; } - virtual bool Compatible(const Type& other) const; - virtual int Width() const { return 8; } - virtual bool IsScalar() const { return true; } - virtual bool IsVoidPointer() const { return derived_->ToVoid(); } - virtual std::string Str() const { - return derived_->Str() + "*:" + std::to_string(Width()); - } - -protected: - PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {} -}; - - -class ArrayType : public DerivedType { -public: - static ArrayType* New(int len, QualType eleType); - static ArrayType* New(Expr* expr, QualType eleType); - virtual ~ArrayType() { /*delete derived_;*/ } - - virtual ArrayType* ToArray() { return this; } - virtual const ArrayType* ToArray() const { return this; } - virtual bool Compatible(const Type& other) const; - virtual int Width() const { - return Complete() ? (derived_->Width() * len_): 0; - } - virtual int Align() const { return derived_->Align(); } - virtual std::string Str() const { - return derived_->Str() + "[]:" + std::to_string(Width()); - } - - int GetElementOffset(int idx) const { return derived_->Width() * idx; } - int Len() const { return len_; } - void SetLen(int len) { len_ = len; } - bool Variadic() const { return lenExpr_ != nullptr; } - -protected: - ArrayType(MemPool* pool, Expr* lenExpr, QualType derived) - : DerivedType(pool, derived), - lenExpr_(lenExpr), len_(0) { - SetComplete(false); - } - - ArrayType(MemPool* pool, int len, QualType derived) - : DerivedType(pool, derived), - lenExpr_(nullptr), len_(len) { - SetComplete(len_ >= 0); - } - const Expr* lenExpr_; - int len_; -}; - -class TileType : public DerivedType { -public: - using ShapeExpr = std::vector; - using ShapeInt = std::vector; - -public: - static TileType* New(const ShapeInt& shape, QualType eleType); - virtual ~TileType() { } - - virtual TileType* ToTile() { return this; } - virtual const TileType* ToTile() const { return this; } - virtual bool Compatible(const Type& other) const; - virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; } - virtual int Align() const { return derived_->Align(); } - virtual std::string Str() const { - return derived_->Str() + "[{}]:" + std::to_string(Width()); - } - - ShapeInt Shape() { return shape_; } - - int NumEle() const { - int ret = 1; - for(int s: shape_) - ret *= s; - return ret; - } - - bool CheckPow2NumEl() const { - int n = NumEle(); - return n && !(n & (n - 1)); - } - -protected: - TileType(MemPool* pool, const ShapeInt& shape, QualType derived); - -protected: - ShapeExpr shapeExpr_; - ShapeInt shape_; -}; - -class FuncType : public DerivedType { -public: - using ParamList = std::vector; - -public: - static FuncType* New(QualType derived, - int funcSpec, - bool variadic, - const ParamList& params); - virtual ~FuncType() {} - virtual FuncType* ToFunc() { return this; } - virtual const FuncType* ToFunc() const { return this; } - virtual bool Compatible(const Type& other) const; - virtual int Width() const { return 1; } - virtual std::string Str() const; - const ParamList& Params() const { return params_; } - void SetParams(const ParamList& params) { params_ = params; } - bool Variadic() const { return variadic_; } - bool IsInline() const { return inlineNoReturn_ & F_INLINE; } - bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; } - -protected: - FuncType(MemPool* pool, QualType derived, int inlineReturn, - bool variadic, const ParamList& params) - : DerivedType(pool, derived), inlineNoReturn_(inlineReturn), - variadic_(variadic), params_(params) { - SetComplete(false); - } - -private: - int inlineNoReturn_; - bool variadic_; - ParamList params_; -}; - - -class StructType : public Type { -public: - using MemberList = std::list; - using Iterator = std::list::iterator; - -public: - static StructType* New(bool isStruct, - bool hasTag, - Scope* parent); - virtual ~StructType() {} - virtual StructType* ToStruct() { return this; } - virtual const StructType* ToStruct() const { return this; } - virtual bool Compatible(const Type& other) const; - virtual int Width() const { return width_; } - virtual int Align() const { return align_; } - virtual std::string Str() const; - - // struct/union - void AddMember(Object* member); - void AddBitField(Object* member, int offset); - bool IsStruct() const { return isStruct_; } - Object* GetMember(const std::string& member); - Scope* MemberMap() { return memberMap_; } - MemberList& Members() { return members_; } - int Offset() const { return offset_; } - bool HasTag() const { return hasTag_; } - void MergeAnony(Object* anony); - void Finalize(); - -protected: - // Default is incomplete - StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent); - - StructType(const StructType& other); - -private: - void CalcWidth(); - - bool isStruct_; - bool hasTag_; - Scope* memberMap_; - - MemberList members_; - int offset_; - int width_; - int align_; - int bitFieldAlign_; -}; - -#endif diff --git a/include/triton/lang/visitor.h b/include/triton/lang/visitor.h deleted file mode 100644 index 239071edf..000000000 --- a/include/triton/lang/visitor.h +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#ifndef _WGTCC_VISITOR_H_ -#define _WGTCC_VISITOR_H_ - - -class BinaryOp; -class UnaryOp; -class TransOp; -class ConditionalOp; -class FuncCall; -class Identifier; -class Object; -class Enumerator; -class Constant; -class TempVar; - -class Declaration; -class IfStmt; -class ForStmt; -class JumpStmt; -class ReturnStmt; -class LabelStmt; -class EmptyStmt; -class CompoundStmt; -class FuncDef; -class TranslationUnit; - - -class Visitor { -public: - virtual ~Visitor() {} - virtual void VisitBinaryOp(BinaryOp* binary) = 0; - virtual void VisitUnaryOp(UnaryOp* unary) = 0; - virtual void VisitTransOp(TransOp* trans) = 0; - virtual void VisitConditionalOp(ConditionalOp* cond) = 0; - virtual void VisitFuncCall(FuncCall* funcCall) = 0; - virtual void VisitEnumerator(Enumerator* enumer) = 0; - virtual void VisitIdentifier(Identifier* ident) = 0; - virtual void VisitObject(Object* obj) = 0; - virtual void VisitConstant(Constant* cons) = 0; - virtual void VisitTempVar(TempVar* tempVar) = 0; - - virtual void VisitDeclaration(Declaration* init) = 0; - virtual void VisitIfStmt(IfStmt* ifStmt) = 0; - virtual void VisitForStmt(ForStmt* ifStmt) = 0; - virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0; - virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0; - virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0; - virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0; - virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0; - virtual void VisitFuncDef(FuncDef* funcDef) = 0; - virtual void VisitTranslationUnit(TranslationUnit* unit) = 0; -}; - -#endif diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h deleted file mode 100644 index 7ba2f63d3..000000000 --- a/include/triton/runtime/arg.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#ifndef _TRITON_RUNTIME_ARG_H_ -#define _TRITON_RUNTIME_ARG_H_ - -#include -#include -#include - -namespace triton{ -namespace ir{ - class type; -} - -namespace driver{ - class buffer; -} - -namespace runtime { - - - - -} -} - -#endif diff --git a/include/triton/runtime/error.h b/include/triton/runtime/error.h deleted file mode 100644 index d03c96e35..000000000 --- a/include/triton/runtime/error.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#ifndef _TRITON_RUNTIME_ERROR_H_ -#define _TRITON_RUNTIME_ERROR_H_ - -#include -#include - -namespace triton { -namespace runtime{ -namespace exception { - -class base: public std::exception {}; -#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } }; - -TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory") -TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers") - -class no_valid_configuration: public exception::base { -public: - no_valid_configuration(const std::string& err): err_(err) { } - const char * what() const throw(){ return err_.c_str(); } -private: - std::string err_; -}; - - -#undef TRITON_CREATE_RUNTIME_EXCEPTION - -} -} -} - -#endif diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h deleted file mode 100644 index c33f9d0e1..000000000 --- a/include/triton/runtime/function.h +++ /dev/null @@ -1,159 +0,0 @@ -#pragma once - -#ifndef _TRITON_RUNTIME_FUNCTION_H_ -#define _TRITON_RUNTIME_FUNCTION_H_ - -#include -#include -#include -#include -#include -#include -#include -// codegen -#include "triton/ir/function.h" -#include "triton/ir/context.h" -#include "triton/runtime/arg.h" -#include "triton/runtime/error.h" - -// driver forward declaration -namespace triton { -namespace driver{ - class module; - class stream; - class kernel; - class context; - class device; -} -} -// ir forward declaration -namespace triton{ -namespace ir { -class module; -class function; -class context; -} -} - -namespace triton{ -namespace runtime{ - - -/* ------------------------- */ -/* Compilation options */ -/* ------------------------- */ - -struct options_t { - template - T D(const std::string& name) const { - return std::stoi(defines.at(name)); - } - std::unordered_map defines; - int num_warps; -}; - -/* ------------------------- */ -/* Runtime arguments */ -/* ------------------------- */ - -enum arg_type { - INT1_T, - INT8_T, - INT16_T, - INT32_T, - INT64_T, - HALF_T, - FLOAT_T, - DOUBLE_T, - BUFFER_T -}; - -inline size_t size_of(arg_type ty){ - switch(ty){ - case INT1_T : return 1; - case INT8_T : return 1; - case INT16_T : return 2; - case INT32_T : return 4; - case INT64_T : return 8; - case HALF_T : return 2; - case FLOAT_T : return 4; - case DOUBLE_T: return 8; - case BUFFER_T: return 8; - default: throw std::runtime_error("unknown type"); - } -} - -template -void add_arg(std::stringstream& ss, T arg) { - ss.write((char*)&arg, sizeof(T)); -} - - -/* ------------------------- */ -/* ------------------------- */ - -class kernel{ -public: - typedef std::vector grid_t; - -public: - static std::shared_ptr src_to_ir(const std::string& src, const options_t& opt); - static std::tuple, - std::shared_ptr, - size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt); - -public: - kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map &attrs = {}); - void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const; - std::string get_asm(const std::string &mode); - -public: - const options_t opt; - -private: - driver::device* dev_; - // handles - std::shared_ptr ir_; - std::shared_ptr mod_; - std::shared_ptr ker_; - // shared mem - size_t shared_mem_; -}; - -struct config { - std::map defines; - int num_warps; -}; - -class function { -public: - typedef std::function grid_fn_ty; - typedef std::pair> kernel_pair_t; - typedef std::map, kernel*> cache_t; - typedef std::vector autotune_confs_t; - -public: - function(const std::string& src, const options_t& opt, driver::device *device, - const std::vector& tune_confs = {}, const std::vector &tune_key = {}); - kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); - void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); - const std::vector get_signature() { return sig_; } - -private: - std::map, std::vector>> kernels_; - std::map, kernel*> cache_; - std::vector sig_; - std::vector align_idxs_; - std::vector int_idxs_; - std::vector key_idxs_; - std::vector arg_size_; - std::vector arg_off_; - std::vector opts_; - std::string src_; - driver::device* device_; -}; - -} -} - -#endif diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index ed1d5c881..a31b1f24f 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -55,8 +55,8 @@ inline T add_to_cache(ir::value *i, T value, std::map &map) { std::vector align::get_shapes(ir::value *v) { ir::type *ty = v->get_type(); - if(ty->is_tile_ty()) - return ty->get_tile_shapes(); + if(ty->is_block_ty()) + return ty->get_block_shapes(); else return {1}; } @@ -95,7 +95,7 @@ std::vector align::populate_is_constant_reshape(ir::reshape_ins auto x_shapes = get_shapes(x); std::vector result; ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); + auto op_shapes = op->get_type()->get_block_shapes(); auto op_cst = populate_is_constant(op); unsigned current = 0; bool is_skewed = false; @@ -119,7 +119,7 @@ std::vector align::populate_is_constant_broadcast(ir::broadcast auto x_shapes = get_shapes(x); std::vector result; ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); + auto op_shapes = op->get_type()->get_block_shapes(); auto op_cst = populate_is_constant(op); for(size_t d = 0; d < x_shapes.size(); d++) if(op_shapes[d] == 1) @@ -229,7 +229,7 @@ std::vector align::populate_max_contiguous_reshape(ir::reshape_inst* x auto shapes = get_shapes(x); std::vector result; ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); + auto op_shapes = op->get_type()->get_block_shapes(); auto op_mc = populate_max_contiguous(op); unsigned current = 0; bool is_skewed = false; @@ -251,7 +251,7 @@ std::vector align::populate_max_contiguous_broadcast(ir::broadcast_ins auto shapes = get_shapes(x); std::vector result; ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); + auto op_shapes = op->get_type()->get_block_shapes(); auto op_mc = populate_max_contiguous(op); for(size_t d = 0; d < shapes.size(); d++) if(op_shapes[d] == 1) @@ -317,9 +317,9 @@ std::vector align::populate_max_contiguous_gep(ir::getelementptr_inst* } std::vector align::populate_max_contiguous_default(ir::value* v) { - if(!v->get_type()->is_tile_ty()) + if(!v->get_type()->is_block_ty()) return add_to_cache(v, {1}, max_contiguous_); - auto shapes = v->get_type()->get_tile_shapes(); + auto shapes = v->get_type()->get_block_shapes(); if(dynamic_cast(v)) return add_to_cache(v, {shapes[0]}, max_contiguous_); if(dynamic_cast(v)) @@ -450,8 +450,8 @@ std::vector align::populate_starting_multiple_cast(ir::cast_inst* x){ std::vector align::populate_starting_multiple_default(ir::value* v) { ir::type* ty = v->get_type(); - if(ty->is_tile_ty()) { - return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_); + if(ty->is_block_ty()) { + return add_to_cache(v, ty->get_block_shapes(), starting_multiple_); } if(auto *x = dynamic_cast(v)){ std::set attributes = x->get_parent()->get_attributes(x); @@ -462,7 +462,7 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { if(attr.get_kind() == ir::aligned){ ir::type* ty = x->get_type()->get_pointer_element_ty(); int nbits = ty->get_primitive_size_in_bits(); - int nbytes = nbits / 8; + int nbytes = std::max(nbits / 8, 1); return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_); } } diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 1ec198787..d68be1d82 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -15,7 +15,7 @@ void axes::update_graph_reduce(ir::instruction *i) { auto* red = static_cast(i); unsigned axis = red->get_axis(); ir::value *arg = red->get_operand(0); - auto in_shapes = arg->get_type()->get_tile_shapes(); + auto in_shapes = arg->get_type()->get_block_shapes(); unsigned current = 0; for(unsigned d = 0; d < in_shapes.size(); d++){ if(d == axis) @@ -29,8 +29,8 @@ void axes::update_graph_reshape(ir::instruction *i) { // operands ir::value *op = reshape->get_operand(0); // shapes - auto op_shapes = op->get_type()->get_tile_shapes(); - auto res_shapes = reshape->get_type()->get_tile_shapes(); + auto op_shapes = op->get_type()->get_block_shapes(); + auto res_shapes = reshape->get_type()->get_block_shapes(); // construct edges unsigned current = 0; bool is_skewed = false; @@ -58,10 +58,10 @@ void axes::update_graph_trans(ir::instruction *i) { void axes::update_graph_broadcast(ir::instruction *i) { auto *broadcast = static_cast(i); - auto shapes = broadcast->get_type()->get_tile_shapes(); + auto shapes = broadcast->get_type()->get_block_shapes(); ir::value *op = broadcast->get_operand(0); ir::type *op_ty = op->get_type(); - const auto& op_shapes = op_ty->get_tile_shapes(); + const auto& op_shapes = op_ty->get_block_shapes(); // add edge between non-broadcast axes for(unsigned d = 0; d < shapes.size(); d ++) if(op_shapes[d] == shapes[d]) @@ -70,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) { void axes::update_graph_dot(ir::instruction *i) { auto *dot = static_cast(i); - auto shapes = dot->get_type()->get_tile_shapes(); + auto shapes = dot->get_type()->get_block_shapes(); ir::value *A = dot->get_operand(0); ir::value *B = dot->get_operand(1); ir::value *D = dot->get_operand(2); @@ -83,7 +83,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) { if(i->get_num_operands() == 0) return; ir::value *op = i->get_operand(0); - if(!op->get_type()->is_tile_ty()) + if(!op->get_type()->is_block_ty()) return; auto rank = op->get_type()->get_tile_rank(); for(unsigned d = 0; d < rank; d++) @@ -96,7 +96,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) { } void axes::update_graph_no_edge(ir::instruction *i) { - if(!i->get_type()->is_tile_ty()) + if(!i->get_type()->is_block_ty()) return; auto rank = i->get_type()->get_tile_rank(); for(unsigned d = 0; d < rank; d++) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index cd27da12f..0d53f4c73 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -325,9 +325,9 @@ layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, void layouts::connect(ir::value *x, ir::value *y) { if(x == y) return; - if(!x->get_type()->is_tile_ty()) + if(!x->get_type()->is_block_ty()) return; - if(!y->get_type()->is_tile_ty()) + if(!y->get_type()->is_block_ty()) return; std::vector x_axes = axes_->get(x); std::vector y_axes = axes_->get(y); @@ -364,7 +364,7 @@ void layouts::create(size_t id, const std::vector& values) { std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast(v); }); ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp); const auto& axes = axes_->get(largest); - const auto& shapes = largest->get_type()->get_tile_shapes(); + const auto& shapes = largest->get_type()->get_block_shapes(); auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) { return dynamic_cast(v) || dynamic_cast(v); @@ -411,7 +411,7 @@ void layouts::run(ir::module &mod) { ir::value *arg = red->get_operand(0); unsigned axis = red->get_axis(); // shape - auto shapes = arg->get_type()->get_tile_shapes(); + auto shapes = arg->get_type()->get_block_shapes(); scanline_layout *layout = get(arg)->to_scanline(); shapes[axis] = layout->mts(axis); // create layout @@ -425,8 +425,8 @@ void layouts::run(ir::module &mod) { if(!in_layout || !out_layout) return; id++; - ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes(); - ir::type::tile_shapes_t shape(in_shape.size()); + ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes(); + ir::type::block_shapes_t shape(in_shape.size()); size_t ld = out_layout->get_order(0); shape[ld] = in_shape[ld]; for(size_t k = 0; k < in_shape.size(); k++) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc new file mode 100644 index 000000000..884a89859 --- /dev/null +++ b/lib/codegen/pass.cc @@ -0,0 +1,103 @@ +#include "triton/codegen/pass.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/swizzle.h" +#include "triton/codegen/selection/generator.h" +#include "triton/codegen/transform/coalesce.h" +#include "triton/codegen/transform/cts.h" +#include "triton/codegen/transform/dce.h" +#include "triton/codegen/transform/disassociate.h" +#include "triton/codegen/transform/membar.h" +#include "triton/codegen/transform/peephole.h" +#include "triton/codegen/transform/pipeline.h" +#include "triton/codegen/transform/reassociate.h" +#include "triton/driver/device.h" +#include "triton/driver/kernel.h" +#include "triton/driver/module.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" +#include "triton/ir/print.h" +#include "llvm/IR/Module.h" + +namespace triton { +namespace codegen { + +// TODO: +// There should be a proper pass manager there! +void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, + driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { + // generate llvm code + llvm::LLVMContext ctx; + std::string name = ir.get_function_list()[0]->get_name(); + std::unique_ptr llvm(new llvm::Module(name, ctx)); + // optimizations + std::unique_ptr target = dev->make_target(); + bool cts_use_async = target->as_nvidia()->sm() >= 80; + // create passes + codegen::analysis::align align; + codegen::analysis::axes axes; + codegen::transform::cts cts(cts_use_async); + codegen::transform::pipeline pipeline(cts_use_async); + codegen::transform::disassociate disassociate; + codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get()); + codegen::analysis::liveness liveness(&layouts); + codegen::analysis::swizzle swizzle(&layouts, target.get()); + codegen::analysis::allocation allocation(&liveness); + codegen::transform::membar barriers(&liveness, &layouts, &allocation); + codegen::transform::dce dce; + codegen::transform::peephole peephole(target.get(), &layouts); + codegen::transform::reassociate reassociate; + codegen::transform::coalesce coalesce(&align, &layouts); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); + // run passes + dce.run(ir); + //ir::print(ir, std::cout); + peephole.run(ir); + dce.run(ir); + pipeline.run(ir); + dce.run(ir); + //ir::print(ir, std::cout); + disassociate.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); + if (target->is_gpu()) + cts.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + coalesce.run(ir); + dce.run(ir); + align.run(ir); + dce.run(ir); + if (target->is_gpu()) { + reassociate.run(ir); + cts.run(ir); + } + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + swizzle.run(ir); + liveness.run(ir); + allocation.run(ir); + barriers.run(ir); + // ir::print(ir, std::cout); + isel.visit(ir, *llvm); + mod = driver::module::create(dev, std::move(llvm)); + ker = driver::kernel::create(&*mod, name.c_str()); + shared_mem = allocation.allocated_size(); +} + +} // namespace codegen +} // namespace triton \ No newline at end of file diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index c25495dd3..f143f8354 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -150,7 +150,7 @@ generator::generator(analysis::axes *a_axes, void generator::visit_value(ir::value* v) { if(!seen_.insert(v).second) return; - if(v->get_type()->is_tile_ty()){ + if(v->get_type()->is_block_ty()){ if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ auto double_buffer = layout->get_double_buffer(); // offset @@ -384,7 +384,7 @@ void generator::visit_load_inst(ir::load_inst* x){ // compute vector width size_t vec = 1; - if(op->get_type()->is_tile_ty()){ + if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); @@ -407,7 +407,10 @@ void generator::visit_load_inst(ir::load_inst* x){ PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2); Instruction *then_term; Instruction *else_term; + builder_->SetInsertPoint(_ret->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term); + dummy->removeFromParent(); builder_->SetInsertPoint(then_term); Value* then_ret = load(ptr); builder_->SetInsertPoint(else_term); @@ -441,7 +444,7 @@ void generator::visit_store_inst(ir::store_inst * x){ ir::value *val_op = x->get_value_operand(); // vector size size_t vec = 1; - if(val_op->get_type()->is_tile_ty()){ + if(val_op->get_type()->is_block_ty()){ auto ord = ords_.at(x->get_pointer_operand()); size_t aln = alignment_->get(ptr_op, ord[0]); size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; @@ -461,7 +464,10 @@ void generator::visit_store_inst(ir::store_inst * x){ if(mx){ Value *msk = vals_[mx->get_mask_operand()][idx]; Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); + builder_->SetInsertPoint(no_op->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false); + dummy->removeFromParent(); builder_->SetInsertPoint(term); store(val, ptr); builder_->SetInsertPoint(no_op); @@ -501,13 +507,15 @@ void generator::visit_splat_inst(ir::splat_inst* x) { */ void generator::visit_broadcast_inst(ir::broadcast_inst* x) { ir::value* op = x->get_operand(0); - const auto& shape = op->get_type()->get_tile_shapes(); + const auto& shape = op->get_type()->get_block_shapes(); for(auto out_idx: idxs_.at(x)){ indices_t in_idx = out_idx; for(size_t k = 0; k < in_idx.size(); k++) in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k]; vals_[x][out_idx] = vals_[op][in_idx]; } +// for(size_t i = 0; i < idxs_.at(x).size(); i++) +// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; } /** @@ -527,9 +535,9 @@ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) { } /** - * \brief Code Generation for `get_num_program` + * \brief Code Generation for `get_num_programs` */ -void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) { +void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) { Module *module = builder_->GetInsertBlock()->getModule(); Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis()); vals_[np][{}] = ret; @@ -621,7 +629,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { //TODO: clean-up void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { - if(add->get_type()->is_tile_ty()){ + if(add->get_type()->is_block_ty()){ ir::value* ptr = add->get_operand(0); ir::value* val = add->get_operand(1); ir::value* msk = add->get_operand(2); @@ -706,9 +714,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { //TODO: clean-up void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { // shapes - auto shape_c = C->get_type()->get_tile_shapes(); - auto shape_a = A->get_type()->get_tile_shapes(); - auto shape_b = B->get_type()->get_tile_shapes(); + auto shape_c = C->get_type()->get_block_shapes(); + auto shape_a = A->get_type()->get_block_shapes(); + auto shape_b = B->get_type()->get_block_shapes(); // order auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); @@ -877,7 +885,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va */ //TODO: clean-up void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { - const auto& shapes = dot->get_type()->get_tile_shapes(); + const auto& shapes = dot->get_type()->get_block_shapes(); std::map, std::vector> fcs; @@ -887,8 +895,8 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir fcs[key].push_back(vals_[D][idx]); }; - auto shape_a = A->get_type()->get_tile_shapes(); - auto shape_b = B->get_type()->get_tile_shapes(); + auto shape_a = A->get_type()->get_block_shapes(); + auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); analysis::mma_layout* layout = layouts_->get(dot)->to_mma(); @@ -1059,9 +1067,9 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default) */ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) { - auto shape_c = C->get_type()->get_tile_shapes(); - auto shape_a = A->get_type()->get_tile_shapes(); - auto shape_b = B->get_type()->get_tile_shapes(); + auto shape_c = C->get_type()->get_block_shapes(); + auto shape_a = A->get_type()->get_block_shapes(); + auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline(); @@ -1161,7 +1169,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { ir::value *D = dot->get_operand(2); Type *c_ty = cvt(D->get_type()->get_scalar_ty()); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector{c_ty}); - auto A_shapes = A->get_type()->get_tile_shapes(); + auto A_shapes = A->get_type()->get_block_shapes(); size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; bool is_outer = NK == 1; @@ -1236,7 +1244,10 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::functionSetInsertPoint(barrier->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + dummy->removeFromParent(); builder_->SetInsertPoint(term); Value* ret = load(gep(base, thread)); for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ @@ -1359,10 +1370,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { * \brief Code Generation for `select` */ void generator::visit_select_inst(ir::select_inst* x) { - for(indices_t idx: idxs_.at(x)) + for(indices_t idx: idxs_.at(x)){ vals_[x][idx] = select(vals_[x->get_operand(0)][idx], vals_[x->get_operand(1)][idx], vals_[x->get_operand(2)][idx]); + } } /** @@ -1370,7 +1382,7 @@ void generator::visit_select_inst(ir::select_inst* x) { */ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { ir::value *op = rc->get_operand(0); - ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes(); + ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes(); // pointer to temporary shared memory Type *ty = cvt(rc->get_type()->get_scalar_ty()); // layout @@ -1435,7 +1447,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); int n_shared_0 = std::max(in_vec / out_vec, 1); - auto shapes = x->get_type()->get_tile_shapes(); + auto shapes = x->get_type()->get_block_shapes(); BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); std::map, Value*> tmp; @@ -1520,7 +1532,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - auto shapes = cts->get_type()->get_tile_shapes(); + auto shapes = cts->get_type()->get_block_shapes(); // store to shared Value *current = nullptr; @@ -1901,13 +1913,13 @@ void generator::visit_argument(ir::argument* arg) { void generator::init_idx(ir::value *v) { idxs_[v].clear(); - if(!v->get_type()->is_tile_ty()){ + if(!v->get_type()->is_block_ty()){ idxs_[v].push_back({}); return; } if(layouts_->get(v)->to_shared()) return; - const auto &shapes = v->get_type()->get_tile_shapes(); + const auto &shapes = v->get_type()->get_block_shapes(); size_t rank = shapes.size(); std::vector axes(rank); std::vector ord(rank); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index bea371c44..517fd96d9 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -37,7 +37,7 @@ int membar::group_of(ir::value* v, std::vector &async_write) { membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { val_set_t ret; for(ir::value* a: as){ - if(!a->get_type()->is_tile_ty()) + if(!a->get_type()->is_block_ty()) continue; analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); if(!a_layout) @@ -45,7 +45,7 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b int a_start = alloc_->offset(a_layout); int a_end = a_start + a_layout->get_size(); for(ir::value* b: bs){ - if(!b->get_type()->is_tile_ty()) + if(!b->get_type()->is_block_ty()) continue; analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); if(!b_layout) @@ -80,7 +80,7 @@ void membar::transfer(ir::basic_block *block, // Get shared memory reads std::set read; std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), - [&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();}); + [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();}); // RAW (async) val_set_t tmp; std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 392f6ea94..f5eeeb5d0 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -58,7 +58,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { } bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ - // dot(a, b, 0) + c -> dot(a, b, c) + // dot(a, b, c) + d -> dot(a, b, c + d) + // d + dot(a, b, c) -> dot(a, b, c + d) auto add = dynamic_cast(value); if(add && add->get_op() == ir::binary_op_t::FAdd) { ir::value *lhs = add->get_operand(0); @@ -131,10 +132,10 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ if(!x) return false; ir::value *arg = x->get_operand(0); - auto shapes = arg->get_type()->get_tile_shapes(); + auto shapes = arg->get_type()->get_block_shapes(); if(shapes[x->get_axis()] == 1){ builder.set_insert_point(x); - ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes()); + ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes()); x->replace_all_uses_with(new_red); return true; } diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 32af28463..00520e9d6 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector(v); + if(!i) + return v; + if(ir::phi_node* phi = dynamic_cast(v)) + return phi->get_incoming_value(phi_idx); + + std::vector new_ops; + for(ir::value* op: i->ops()){ + new_ops.push_back(rematerialize(builder, op, phi_idx)); + } + ir::instruction* ret = i->clone(); + for(size_t k = 0; k < new_ops.size(); k++) + ret->set_operand(k, new_ops[k]); + builder.insert(ret); + return ret; +} + void pipeline::run(ir::module &mod) { // *Very* conservative heuristics for pre-fetching. // A load instruction can be pipelined if: @@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) { // pre-fetch first iteration builder.set_insert_point(header->get_inst_list().back()); ir::value* first_ptr = ptr->get_value_for_block(header); - ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes()); + ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes()); ir::value* false_value; if(auto* masked_load = dynamic_cast(load)){ - first_mask = builder.create_and(first_mask, masked_load->get_mask_operand()); - false_value = masked_load->get_false_value_operand(); + ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0); + ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0); + first_mask = builder.create_and(first_mask, remat_mask); + false_value = remat_false_value; } else - false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes()); + false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); - ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes()); - if(auto* masked_load = dynamic_cast(load)) - next_mask = builder.create_and(next_mask, masked_load->get_mask_operand()); + ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes()); + if(auto* masked_load = dynamic_cast(load)){ + ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1); + ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1); + next_mask = builder.create_and(next_mask, remat_mask); + false_value = remat_false_value; + } ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); // phi node builder.set_insert_point(block->get_first_non_phi()); diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index 01293e1a5..0dad7a19a 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -40,7 +40,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, // handle retiling if(ir::instruction* op = dynamic_cast(old_value)){ - auto shapes = op->get_type()->get_tile_shapes(); + auto shapes = op->get_type()->get_block_shapes(); ir::value *old_arg = op->get_operand(0); ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst); // retile(x + y) = retile(x) + retile(y) @@ -54,19 +54,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, builder.set_insert_point(op); new_lhs = builder.create_reshape(old_lhs, shapes); new_rhs = builder.create_reshape(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs, op->get_name()); + new_value = builder.create_add(new_lhs, new_rhs); } if(dynamic_cast(op)){ builder.set_insert_point(op); new_lhs = builder.create_broadcast(old_lhs, shapes); new_rhs = builder.create_broadcast(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs, op->get_name()); + new_value = builder.create_add(new_lhs, new_rhs); } if(dynamic_cast(op)){ builder.set_insert_point(op); new_lhs = builder.create_splat(old_lhs, shapes); new_rhs = builder.create_splat(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs, op->get_name()); + new_value = builder.create_add(new_lhs, new_rhs); } } } @@ -84,10 +84,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, ir::value *rlhs = bin_lhs->get_operand(1); // (cst + x) + y -> cst + (x + y) if(is_cst(llhs)) - new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name); + new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs)); // (x + cst) + y -> cst + (x + y) if(is_cst(rlhs)) - new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name); + new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs)); } // x + (y + z) if(ir::instruction* bin_rhs = is_bin_add(rhs)){ @@ -95,10 +95,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, ir::value *rrhs = bin_rhs->get_operand(1); // x + (cst + y) -> cst + (x + y) if(is_cst(lrhs)) - new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst); + new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst); // x + (y + cst) -> cst + (x + y) if(is_cst(rrhs)) - new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst); + new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst); } } // extract constant and non-constant @@ -166,7 +166,7 @@ void reassociate::run(ir::module &mod) { ir::value* dyn = infos.at(op).dyn_ptr; ir::value* cst = *sta->idx_begin(); if(dynamic_cast(rt)) { - auto shapes = rt->get_type()->get_tile_shapes(); + auto shapes = rt->get_type()->get_block_shapes(); ir::value* ndyn = builder.create_broadcast(dyn, shapes); ir::value* broadcast = builder.create_broadcast(cst, shapes); ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast}); @@ -202,7 +202,7 @@ void reassociate::run(ir::module &mod) { ir::value *cst = *sta->idx_begin(); ir::value *off = *pz->idx_begin(); ir::value *pz_dyn = builder.create_gep(dyn, {off}); - ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name()); + ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}); pz->replace_all_uses_with(pz_sta); infos[pz_sta].dyn_ptr = pz_dyn; infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta; @@ -235,7 +235,8 @@ void reassociate::run(ir::module &mod) { phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a)); builder.set_insert_point(phi->get_parent()->get_first_non_phi()); // re-add the offset - ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta"); + ir::value *phi_sta = builder.create_gep(phi_dyn, {off}); + phi_sta->set_name( phi->get_name() + "_sta"); phi->replace_all_uses_with(phi_sta); // remove offset from pz if(auto *x = dynamic_cast(pz)){ @@ -245,8 +246,8 @@ void reassociate::run(ir::module &mod) { builder.set_insert_point(*it); } ir::value *_0 = builder.get_int32(0); - if(off->get_type()->is_tile_ty()) - _0 = builder.create_splat(_0, off->get_type()->get_tile_shapes()); + if(off->get_type()->is_block_ty()) + _0 = builder.create_splat(_0, off->get_type()->get_block_shapes()); ir::value *neg_off = builder.create_sub(_0, off); ir::value *pz_dyn = builder.create_gep(pz, {neg_off}); phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z)); diff --git a/lib/codegen/transform/reorder.cc b/lib/codegen/transform/reorder.cc index 2949e427d..47dc47b6c 100644 --- a/lib/codegen/transform/reorder.cc +++ b/lib/codegen/transform/reorder.cc @@ -11,38 +11,38 @@ namespace codegen{ namespace transform{ void reorder::run(ir::module& mod){ - ir::builder &builder = mod.get_builder(); - std::vector> to_replace; +// ir::builder &builder = mod.get_builder(); +// std::vector> to_replace; - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - if(auto* ld = dynamic_cast(i)){ - ir::value* _ptr = ld->get_pointer_operand(); - ir::value* _msk = ld->get_mask_operand(); - ir::value* _val = ld->get_false_value_operand(); - auto ptr = std::find(block->begin(), block->end(), _ptr); - auto msk = std::find(block->begin(), block->end(), _msk); - auto val = std::find(block->begin(), block->end(), _val); - if(ptr == block->end() || msk == block->end() || val == block->end()) - continue; - auto it = std::find(block->begin(), block->end(), i); - int dist_ptr = std::distance(ptr, it); - int dist_msk = std::distance(msk, it); - int dist_val = std::distance(val, it); - if(dist_ptr < dist_msk && dist_ptr < dist_val) - builder.set_insert_point(++ptr); - if(dist_msk < dist_ptr && dist_msk < dist_val) - builder.set_insert_point(++msk); - if(dist_val < dist_ptr && dist_val < dist_msk) - builder.set_insert_point(++val); - ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val); - to_replace.push_back(std::make_pair(ld, new_ld)); - } - } +// for(ir::function *fn: mod.get_function_list()) +// for(ir::basic_block *block: fn->blocks()) +// for(ir::instruction* i: block->get_inst_list()){ +// if(auto* ld = dynamic_cast(i)){ +// ir::value* _ptr = ld->get_pointer_operand(); +// ir::value* _msk = ld->get_mask_operand(); +// ir::value* _val = ld->get_false_value_operand(); +// auto ptr = std::find(block->begin(), block->end(), _ptr); +// auto msk = std::find(block->begin(), block->end(), _msk); +// auto val = std::find(block->begin(), block->end(), _val); +// if(ptr == block->end() || msk == block->end() || val == block->end()) +// continue; +// auto it = std::find(block->begin(), block->end(), i); +// int dist_ptr = std::distance(ptr, it); +// int dist_msk = std::distance(msk, it); +// int dist_val = std::distance(val, it); +// if(dist_ptr < dist_msk && dist_ptr < dist_val) +// builder.set_insert_point(++ptr); +// if(dist_msk < dist_ptr && dist_msk < dist_val) +// builder.set_insert_point(++msk); +// if(dist_val < dist_ptr && dist_val < dist_msk) +// builder.set_insert_point(++val); +// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val); +// to_replace.push_back(std::make_pair(ld, new_ld)); +// } +// } - for(auto& x: to_replace) - x.first->replace_all_uses_with(x.second); +// for(auto& x: to_replace) +// x.first->replace_all_uses_with(x.second); } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 67c08edc1..26fc60692 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -212,7 +212,7 @@ static std::map vptx = { {11020, 72} }; -std::string cu_module::compile_llvm_module(std::unique_ptr module, driver::device* device) { +std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; int max_nvvm_ptx = 64; @@ -316,6 +316,7 @@ void cu_module::init_from_ptx(const std::string& ptx) { // log = match.suffix(); // } // std::cout << log << std::endl; +// std::cout << ptx_ << std::endl; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, @@ -351,8 +352,9 @@ cu_module::cu_module(driver::device* device, std::unique_ptr ll_mo oss << *ll_module; oss.flush(); std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH"); - if(cache_path.empty()) - ptx_ = compile_llvm_module(std::move(ll_module), device); + if(cache_path.empty()){ + ptx_ = compile_llvm_module(ll_module.get(), device); + } else{ tools::mkdir(cache_path); // update cache path to PTX file @@ -370,7 +372,7 @@ cu_module::cu_module(driver::device* device, std::unique_ptr ll_mo ptx_ = _ptx.str(); // compile and write-back if read empty if(ptx_.empty()){ - ptx_ = compile_llvm_module(std::move(ll_module), device); + ptx_ = compile_llvm_module(ll_module.get(), device); std::ofstream ofs(cache_path); ofs << ptx_; } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 92db216ff..051e26636 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -54,6 +54,18 @@ value *builder::get_int32(int32_t val) value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} +value *builder::get_float16(float val) +{ return constant_fp::get(type::get_half_ty(ctx_), val); } + +value *builder::get_float32(float val) +{ return constant_fp::get(type::get_float_ty(ctx_), val); } + +value *builder::get_range(int32_t _lo, int32_t _hi) { + constant_int* lo = static_cast(get_int32(_lo)); + constant_int* hi = static_cast(get_int32(_hi)); + return insert(make_range::create(lo, hi)); +} + type *builder::get_void_ty() { return type::get_void_ty(ctx_); } @@ -105,8 +117,8 @@ value *builder::create_ret_void() { // cast instructions //===----------------------------------------------------------------------===// #define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *src, type *dst_ty, std::string const &name){\ - return create_cast(OPCODE, src, dst_ty, name);\ + value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\ + return create_cast(OPCODE, src, dst_ty);\ } DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) @@ -117,20 +129,20 @@ DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI) DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt) DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc) -value* builder::create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name){ - return insert(cast_inst::create(op, v, dst_ty), name); +value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){ + return insert(cast_inst::create(op, v, dst_ty)); } -value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name){ - return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed), name); +value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){ + return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed)); } //===----------------------------------------------------------------------===// // phi instructions //===----------------------------------------------------------------------===// -phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string &name){ - return insert(phi_node::create(ty, num_reserved), name); +phi_node* builder::create_phi(type *ty, unsigned num_reserved){ + return insert(phi_node::create(ty, num_reserved)); } //===----------------------------------------------------------------------===// @@ -138,8 +150,8 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string //===----------------------------------------------------------------------===// #define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\ - return insert(binary_operator::create(OPCODE, lhs, rhs), name);\ + value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ + return insert(binary_operator::create(OPCODE, lhs, rhs));\ } // Binary @@ -156,22 +168,22 @@ DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub) value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs, - value *rhs, const std::string &name, + value *rhs, bool has_nuw, bool has_nsw) { - binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name); + binary_operator* result = insert(binary_operator::create(op, lhs, rhs)); if (has_nuw) result->set_has_no_unsigned_wrap(); if (has_nsw) result->set_has_no_signed_wrap(); return result; } #define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\ - value* builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw){\ - return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, has_nuw, has_nsw);\ + value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\ + return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\ }\ #define DEFINE_BINARY_INT(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\ - return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\ + value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ + return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\ } @@ -196,21 +208,21 @@ DEFINE_BINARY_INT(xor, binary_op_t::Xor) // getelementptr instructions //===----------------------------------------------------------------------===// -value* builder::create_gep(value *ptr, const std::vector& idx_list, const std::string &name){ - return insert(getelementptr_inst::create(ptr, idx_list), name); +value* builder::create_gep(value *ptr, const std::vector& idx_list){ + return insert(getelementptr_inst::create(ptr, idx_list)); } //===----------------------------------------------------------------------===// // icmp instructions //===----------------------------------------------------------------------===// -value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){ - return insert(icmp_inst::create(pred, lhs, rhs), name); +value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){ + return insert(icmp_inst::create(pred, lhs, rhs)); } #define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\ - value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs, const std::string &name){\ - return create_icmp(OPCODE, lhs, rhs, name);\ + value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\ + return create_icmp(OPCODE, lhs, rhs);\ } // Signed @@ -232,13 +244,13 @@ DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE) // fcmp instructions //===----------------------------------------------------------------------===// -value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){ - return insert(fcmp_inst::create(pred, lhs, rhs), name); +value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){ + return insert(fcmp_inst::create(pred, lhs, rhs)); } #define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\ - value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs, const std::string &name){\ - return create_fcmp(OPCODE, lhs, rhs, name);\ + value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\ + return create_fcmp(OPCODE, lhs, rhs);\ } // Ordered @@ -255,102 +267,92 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr, const std::string &name){ - return insert(unmasked_load_inst::create(ptr, name)); -// type *ty = ptr->get_type()->get_pointer_element_ty(); -// value *mask = constant_int::get(get_int1_ty(), 1); -// value *undef = undef_value::get(ty); -// if(ptr->get_type()->is_tile_ty()){ -// auto shapes = ptr->get_type()->get_tile_shapes(); -// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name)); -// } -// return insert(masked_load_inst::create(ptr, mask, undef, name)); - +value *builder::create_load(value *ptr){ + return insert(unmasked_load_inst::create(ptr)); } -value *builder::create_store(value *ptr, value *val, const std::string &name){ - return insert(unmasked_store_inst::create(ptr, val, name)); +value *builder::create_store(value *ptr, value *val){ + return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){ - return insert(masked_load_inst::create(ptr, mask, false_value, name)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value){ + return insert(masked_load_inst::create(ptr, mask, false_value)); } - -value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){ - return insert(masked_store_inst::create(ptr, val, mask, name)); +value *builder::create_masked_store(value *ptr, value *val, value *mask){ + return insert(masked_store_inst::create(ptr, val, mask)); } //===----------------------------------------------------------------------===// -// tile instructions +// block instructions //===----------------------------------------------------------------------===// -value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { - return insert(reshape_inst::create(arg, shapes, name)); +value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) { + return insert(reshape_inst::create(arg, shapes)); } -value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { - return insert(splat_inst::create(arg, shapes, name)); +value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) { + return insert(splat_inst::create(arg, shapes)); } -value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { - return insert(broadcast_inst::create(arg, shapes, name)); +value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) { + return insert(broadcast_inst::create(arg, shapes)); } -value *builder::create_downcast(value *arg, const std::string &name) { - return insert(downcast_inst::create(arg, name)); +value *builder::create_downcast(value *arg) { + return insert(downcast_inst::create(arg)); } //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// -value *builder::create_get_program_id(unsigned axis, const std::string &name) { - return insert(get_program_id_inst::create(ctx_, axis, name)); +value *builder::create_get_program_id(unsigned axis) { + return insert(get_program_id_inst::create(ctx_, axis)); } -value *builder::create_get_num_program(unsigned axis, const std::string &name) { - return insert(get_num_program_inst::create(ctx_, axis, name)); +value *builder::create_get_num_programs(unsigned axis) { + return insert(get_num_programs_inst::create(ctx_, axis)); } -value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){ - return insert(atomic_cas_inst::create(ptr, cmp, val, name)); +value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ + return insert(atomic_cas_inst::create(ptr, cmp, val)); } -value *builder::create_atomic_exch(value *ptr, value *val, const std::string &name){ - return insert(atomic_exch_inst::create(ptr, val, name)); +value *builder::create_atomic_exch(value *ptr, value *val){ + return insert(atomic_exch_inst::create(ptr, val)); } -value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){ - return insert(atomic_add_inst::create(ptr, val, msk, name)); +value *builder::create_atomic_add(value *ptr, value *val, value *msk){ + return insert(atomic_add_inst::create(ptr, val, msk)); } -value *builder::create_exp(value *arg, const std::string &name){ - return insert(exp_inst::create(arg, name)); +value *builder::create_exp(value *arg){ + return insert(exp_inst::create(arg)); } -value *builder::create_log(value *arg, const std::string &name){ - return insert(log_inst::create(arg, name)); +value *builder::create_log(value *arg){ + return insert(log_inst::create(arg)); } -value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { - return insert(dot_inst::create_nn(A, B, C, name)); +value *builder::create_dot(value *A, value *B, value *C) { + return insert(dot_inst::create_nn(A, B, C)); } -value *builder::create_trans(value *A, const std::vector& perm, const std::string &name) { - return insert(trans_inst::create(A, perm, name)); +value *builder::create_trans(value *A, const std::vector& perm) { + return insert(trans_inst::create(A, perm)); } -value *builder::create_sqrt(value *A, const std::string &name) { - return insert(sqrt_inst::create(A, name)); +value *builder::create_sqrt(value *A) { + return insert(sqrt_inst::create(A)); } -value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) { - return insert(reduce_inst::create(A, op, axis, name)); +value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) { + return insert(reduce_inst::create(A, op, axis)); } -value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){ - return insert(select_inst::create(pred, if_value, else_value, name)); +value *builder::create_select(value *pred, value *if_value, value *else_value){ + return insert(select_inst::create(pred, if_value, else_value)); } //===----------------------------------------------------------------------===// @@ -358,26 +360,28 @@ value *builder::create_select(value *pred, value *if_value, value *else_value, c //===----------------------------------------------------------------------===// -value *builder::create_copy_to_shared(value *arg, const std::string &name) { - return insert(copy_to_shared_inst::create(arg, name)); +value *builder::create_copy_to_shared(value *arg) { + return insert(copy_to_shared_inst::create(arg)); } -value *builder::create_copy_from_shared(value *arg, const std::string &name) { - return insert(copy_from_shared_inst::create(arg, name)); +value *builder::create_copy_from_shared(value *arg) { + return insert(copy_from_shared_inst::create(arg)); } -value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) { - return insert(masked_load_async_inst::create(ptr, mask, false_value, name)); +value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) { + return insert(masked_load_async_inst::create(ptr, mask, false_value)); } value *builder::create_barrier(const std::string &name) { - return insert(barrier_inst::create(ctx_, name)); + return insert(barrier_inst::create(ctx_)); } value *builder::create_async_wait(int N) { return insert(async_wait_inst::create(ctx_, N)); } + + } } diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index 3a7aabf9e..ece402979 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -30,9 +30,9 @@ constant *constant::get_null_value(type *ty) { constant *constant::get_all_ones_value(type *ty) { if(ty->is_integer_ty()) - return constant_int::get(ty, 0xFFFFFFFF); + return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF); if(ty->is_floating_point_ty()) - return constant_fp::get(ty, 0xFFFFFFFF); + return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF); throw std::runtime_error("Cannot create all ones value for that type!"); } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc new file mode 100644 index 000000000..1a3b8460c --- /dev/null +++ b/lib/ir/dispatch.cc @@ -0,0 +1,613 @@ +#include "triton/ir/dispatch.h" +#include + +namespace triton{ +namespace ir{ + + +ir::value* throw_unreachable(std::string key) { + throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " + "This is likely a bug on our side."); + return 0; +} + +//===----------------------------------------------------------------------===// +// Programming Model +//===----------------------------------------------------------------------===// + +ir::value *dispatch::program_id(int axis, ir::builder *builder) { + return builder->create_get_program_id(axis); +} + +ir::value *dispatch::num_programs(int axis, ir::builder *builder) { + return builder->create_get_num_programs(axis); +} + +//===----------------------------------------------------------------------===// +// Implicit Casting Utilities +//===----------------------------------------------------------------------===// + +ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ + int a_rank = a_ty->get_integer_bitwidth(); + int b_rank = b_ty->get_integer_bitwidth(); + return a_rank > b_rank ? a_ty : b_ty; +} + +ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ + context &ctx = a_ty->get_context(); + // 1) if one operand is double, the other is implicitly + // converted to double + if(a_ty->is_double_ty() || b_ty->is_double_ty()) + return type::get_double_ty(ctx); + // 2) if one operand is float, the other is implicitly + // converted to float + if(a_ty->is_float_ty() || b_ty->is_float_ty()) + return type::get_float_ty(ctx); + // 3 ) if one operand is half, the other is implicitly + // converted to half + if(a_ty->is_half_ty() || b_ty->is_half_ty()) + return type::get_half_ty(ctx); + if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) + throw_unreachable("augment_types"); + // 4 ) both operands are integer and undergo + // integer promotion + return integer_promote(a_ty, b_ty); +} + +//===----------------------------------------------------------------------===// +// Binary Operators +//===----------------------------------------------------------------------===// + +void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { + throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); +} + +void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ + + if(type_a->is_pointer_ty()){ + if(!allow_ptr_a) + throw_incompatible_types(type_a, type_b); + // T* + U* with T != U + if(type_b->is_pointer_ty() && (type_a != type_b)) + throw_incompatible_types(type_a, type_b); + // T* + float + if(type_b->is_floating_point_ty()) + throw_incompatible_types(type_a, type_b); + } +} + +void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, + bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, + bool arithmetic_check = true){ + // implicit broadcasting + std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); + // implicit typecasting + ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); + ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); + check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); + check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); + if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){ + ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty); + lhs = dispatch::cast(lhs, ret_sca_ty, builder); + rhs = dispatch::cast(rhs, ret_sca_ty, builder); + } +} + +ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, true); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // offset + ptr + // ptr + offset + if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) + std::swap(input, other); + if (input_scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {other}); + // float + float + else if (input_scalar_ty->is_floating_point_ty()) + return builder->create_fadd(input, other); + // int + int + else if (input_scalar_ty->is_integer_ty()) + return builder->create_add(input, other); + return throw_unreachable("add"); +} + +ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, false); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // ptr - offset + if (scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {dispatch::minus(other, builder)}); + // float + float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fsub(input, other); + // int + int + else if (scalar_ty->is_integer_ty()) + return builder->create_sub(input, other); + return throw_unreachable("sub"); +} + +ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float * float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fmul(input, other); + // int * int + else if (scalar_ty->is_integer_ty()) + return builder->create_mul(input, other); + return throw_unreachable("mul"); +} + +ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, false); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // float / int + if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) + other = cast(other, input_scalar_ty, builder); + // int / float + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) + input = cast(input, other_scalar_ty, builder); + // int / int (cast to float32) + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + input = cast(input, builder->get_float_ty(), builder); + other = cast(other, builder->get_float_ty(), builder); + } + // float / float (cast to highest exponent type) + else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ + if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) + other = cast(other, input_scalar_ty, builder); + else + input = cast(input, other_scalar_ty, builder); + } + // unreachable + else + return throw_unreachable("div"); + return builder->create_fdiv(input, other); +} + +ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ + binary_op_type_checking(input, other, builder, false, false, false); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); + input = dispatch::cast(input, ret_ty, builder); + other = dispatch::cast(other, ret_ty, builder); + return builder->create_sdiv(input, other); + } + return throw_unreachable("floordiv"); +} + +ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float % int + if (scalar_ty->is_floating_point_ty()) + return builder->create_frem(input, other); + // int % int + else if (scalar_ty->is_integer_ty()) + return builder->create_srem(input, other); + return throw_unreachable("mod"); +} + + +void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){ + binary_op_type_checking(input, other, builder, false, false, false); + ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); + ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); + if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) + throw_incompatible_types(input_sca_ty, other_sca_ty); + // for some reason pytorch assigns the result of binary op to have the type of the lhs... + if(force_lhs_type){ + if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth()) + other = dispatch::cast(other, input_sca_ty, builder); + } + else{ + if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) + input = dispatch::cast(input, other_sca_ty, builder); + else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) + other = dispatch::cast(other, input_sca_ty, builder); + } + +} + +ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder, true); + return builder->create_and(input, other); +} + +ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder, true); + return builder->create_or(input, other); +} + + +ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder, true); + return builder->create_xor(input, other); +} + + +ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder, false); + return builder->create_lshr(input, other); +} + + +ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder, false); + return builder->create_shl(input, other); +} + +//===----------------------------------------------------------------------===// +// Unary Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::plus(ir::value *input, ir::builder *) { + return input; +} + +ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty()) + throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); + ir::value *_0 = ir::constant::get_null_value(input_sca_ty); + return dispatch::sub(_0, input, builder); +} + +ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) + throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); + ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); + return dispatch::xor_(input, _1, builder); +} + + +//===----------------------------------------------------------------------===// +// Comparison Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float > float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGT(input, other); + // int > int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSGT(input, other); + return throw_unreachable("greater_than"); +} + +ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float >= float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGE(input, other); + // int >= int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSGE(input, other); + return throw_unreachable("greater_equal"); +} + +ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLT(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSLT(input, other); + return throw_unreachable("less_than"); +} + +ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLE(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSLE(input, other); + return throw_unreachable("less_equal"); +} + +ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOEQ(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpEQ(input, other); + return throw_unreachable("equal"); +} + +ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpONE(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpNE(input, other); + return throw_unreachable("equal"); +} + +//===----------------------------------------------------------------------===// +// Block Creation +//===----------------------------------------------------------------------===// + +ir::value* dispatch::arange(int start, int end, ir::builder *builder) { + return builder->get_range(start, end); +} + +ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { + ir::value *_0 = ir::constant::get_null_value(dtype); + return builder->create_splat(_0, shape); +} + +//===----------------------------------------------------------------------===// +// Shape Manipulation +//===----------------------------------------------------------------------===// + + +ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { + unsigned numel = 1; + for(unsigned s: dst_shape) numel *= s; + if(input->get_type()->get_tile_num_elements() != numel) + throw semantic_error("cannot reshape block of different shape"); + return builder->create_reshape(input, dst_shape); +} + +ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { + if (!input->get_type()->is_block_ty()) + return builder->create_splat(input, shape); + auto src_shape = input->get_type()->get_block_shapes(); + if (src_shape.size() != shape.size()) + throw std::runtime_error("Cannot broadcast"); + return builder->create_broadcast(input, shape); +} + +std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + // make_shape_compatible(block, scalar) + if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) + rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); + // make_shape_compatible(scalar, block) + else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) + lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); + // make_shape_compatible(block, block) + else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { + auto lhs_shape = lhs_ty->get_block_shapes(); + auto rhs_shape = rhs_ty->get_block_shapes(); + if (lhs_shape.size() != rhs_shape.size()) + throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); + ir::type::block_shapes_t ret_shape; + for (size_t i = 0; i < lhs_shape.size(); ++i) { + unsigned left = lhs_shape[i]; + unsigned right = rhs_shape[i]; + if (left == 1) + ret_shape.push_back(right); + else if (right == 1) + ret_shape.push_back(left); + else if (left == right) + ret_shape.push_back(left); + else + throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + + ": " + std::to_string(left) + " and " + std::to_string(right)); + } + if (lhs_shape != ret_shape) + lhs = builder->create_broadcast(lhs, ret_shape); + if (rhs_shape != ret_shape) + rhs = builder->create_broadcast(rhs, ret_shape); + } + return std::make_tuple(lhs, rhs); +} + +ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // FP Truncation + bool truncate_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); + if (truncate_fp) + return builder->create_fp_trunc(input, dst_ty); + // FP Extension + bool ext_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); + if (ext_fp) + return builder->create_fp_ext(input, dst_ty); + // Int cast + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && + src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) + return builder->create_int_cast(input, dst_ty, true); + // Float -> Int + if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ + if(dst_sca_ty->is_bool_ty()) + return builder->create_fp_to_ui(input, dst_ty); + else + return builder->create_fp_to_si(input, dst_ty); + } + // int -> Float + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ + if(src_sca_ty->is_bool_ty()) + return builder->create_ui_to_fp(input, dst_ty); + else + return builder->create_si_to_fp(input, dst_ty); + } + // Ptr -> Ptr + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::BitCast, input, dst_ty); + // * -> Bool + if (dst_sca_ty->is_bool_ty()) { + if (src_sca_ty->is_pointer_ty()) + input = cast(input, builder->get_int64_ty(), builder); + ir::value *other = builder->get_int64(0); + if (src_ty->is_bool_ty()) + other = builder->create_splat(other, src_ty->get_block_shapes()); + return builder->create_icmpNE(input, other); + } + return throw_unreachable("cast"); +} + +//===----------------------------------------------------------------------===// +// Memory Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask){ + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + } + if(other){ + other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); + other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + } + } + if (!mask && !other) + return builder->create_load(ptr); + if (!mask) + throw std::runtime_error("`other` cannot be provided without `mask`"); + ir::type *elt_ty = ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(); + auto shape = ptr->get_type()->get_block_shapes(); + if(!other){ + other = ir::undef_value::get(elt_ty); + if(ptr->get_type()->is_block_ty()) + other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); + } + return builder->create_masked_load(ptr, mask, other); +} + +ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()) + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + if(mask) + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + ir::type *ptr_ty = ptr->get_type(); + val = dispatch::cast(val, ptr_ty->get_scalar_ty()->get_pointer_element_ty(), builder); + if (!mask) + return builder->create_store(ptr, val); + if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) + throw semantic_error("Mask must have boolean scalar type"); + return builder->create_masked_store(ptr, val, mask); +} + +ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ + return builder->create_atomic_cas(ptr, cmp, val); +} + +ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder){ + return builder->create_atomic_exch(ptr, val); +} + +//===----------------------------------------------------------------------===// +// Linear Algebra +//===----------------------------------------------------------------------===// + +ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + ir::value *_0 = builder->get_float32(0); + unsigned M = lhs->get_type()->get_block_shapes()[0]; + unsigned N = rhs->get_type()->get_block_shapes()[1]; + _0 = builder->create_splat(_0, {M, N}); + return builder->create_dot(lhs, rhs, _0); +} + + +//===----------------------------------------------------------------------===// +// Indexing +//===----------------------------------------------------------------------===// + +ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ + condition = dispatch::cast(condition, builder->get_int1_ty(), builder); + if(condition->get_type()->is_block_ty()){ + x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); + y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); + } + if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty()) + throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty()); + return builder->create_select(condition, x, y); +} + + +//===----------------------------------------------------------------------===// +// Reductions +//===----------------------------------------------------------------------===// + +ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, + ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (scalar_ty->is_floating_point_ty()) + return builder->create_reduce(input, FLOAT_OP, axis); + else if (scalar_ty->is_integer_ty()) + return builder->create_reduce(input, INT_OP, axis); + return throw_unreachable(name); +} + +ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); +} + +ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); +} + +ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); +} + + +//===----------------------------------------------------------------------===// +// Math +//===----------------------------------------------------------------------===// + +ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { + return builder->create_exp(x); +} + +ir::value *dispatch::log(ir::value *x, ir::builder *builder) { + return builder->create_log(x); +} + +ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { + return builder->create_sqrt(x); +} + + +// + +ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("multiple_of"); + i->set_metadata(ir::metadata::multiple_of, value); + return i; +} + +ir::value *dispatch::debug_barrier(ir::builder *builder) { + return builder->create_barrier(); +} + + +} +} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 8e197648d..8da0be6ef 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -1,4 +1,5 @@ #include +#include #include "triton/ir/context.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" @@ -30,9 +31,9 @@ void instruction::erase_from_parent() { } bool instruction::has_tile_result_or_op() { - bool result = get_type()->is_tile_ty(); + bool result = get_type()->is_block_ty(); for(unsigned i = 0; i < get_num_operands(); i++) - result |= get_operand(i)->get_type()->is_tile_ty(); + result |= get_operand(i)->get_type()->is_block_ty(); return result; } @@ -209,8 +210,8 @@ cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value * type* cmp_inst::make_cmp_result_type(type *ty){ type* int1_ty = type::get_int1_ty(ty->get_context()); - if (tile_type* tile_ty = dynamic_cast(ty)) - return tile_type::get_same_shapes(int1_ty, tile_ty); + if (block_type* tile_ty = dynamic_cast(ty)) + return block_type::get_same_shapes(int1_ty, tile_ty); return int1_ty; } @@ -279,7 +280,7 @@ std::string cast_inst::repr_impl() const { } // TODO bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) { - assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty()); + assert(arg->get_type()->is_block_ty() == ty->is_block_ty()); return true; } @@ -383,11 +384,11 @@ type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vec unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space(); type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space); // Tile GEP - if(ty->is_tile_ty()) - return tile_type::get_same_shapes(ptr_ty, ty); + if(ty->is_block_ty()) + return block_type::get_same_shapes(ptr_ty, ty); for(value *idx : idx_list) - if (idx->get_type()->is_tile_ty()) - return tile_type::get_same_shapes(ptr_ty, ty); + if (idx->get_type()->is_block_ty()) + return block_type::get_same_shapes(ptr_ty, ty); // Scalar GEP return ptr_ty; } @@ -440,8 +441,8 @@ load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::str type *load_inst::get_pointee_type(type *ty) { type *scalar_ty = ty->get_scalar_ty(); type *pointee_ty = scalar_ty->get_pointer_element_ty(); - if(ty->is_tile_ty()) - return tile_type::get_same_shapes(pointee_ty, ty); + if(ty->is_block_ty()) + return block_type::get_same_shapes(pointee_ty, ty); return pointee_ty; } @@ -531,14 +532,14 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask // retile_inst classes //===----------------------------------------------------------------------===// -retile_inst::retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, +retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next) - : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { } + : unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { } // reshape -instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes, +instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes, const std::string &name, instruction *next) { return new reshape_inst(arg, INST_RESHAPE, shapes, name, next); } @@ -546,14 +547,14 @@ instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes, // splat -instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes, +instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes, const std::string &name, instruction *next) { return new splat_inst(arg, INST_SPLAT, shapes, name, next); } // broadcast -instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes, +instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes, const std::string &name, instruction *next) { return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next); } @@ -610,20 +611,20 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C, ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) { // get argument shapes - ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes(); + ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes(); // permutate argument shapes perm = init_perm(ty, perm); - ir::tile_type::tile_shapes_t res_shapes = arg_shapes; + ir::block_type::block_shapes_t res_shapes = arg_shapes; for(size_t i = 0; i < perm.size(); i++) res_shapes[i] = arg_shapes[perm[i]]; // construct type - return tile_type::get(ty->get_scalar_ty(), res_shapes); + return block_type::get(ty->get_scalar_ty(), res_shapes); } std::vector trans_inst::init_perm(ir::type* ty, const std::vector& perm) { if(!perm.empty()) return perm; - auto size = ty->get_tile_shapes().size(); + auto size = ty->get_block_shapes().size(); std::vector result; result.push_back(size - 1); for(size_t i = 0; i < size - 1; i++) @@ -682,13 +683,13 @@ std::string reduce_inst::to_str(op_t op) { } type* reduce_inst::get_res_type(value *arg, unsigned axis) { - ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); + ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes(); shapes.erase(shapes.begin() + axis); type *scalar_ty = arg->get_type()->get_scalar_ty(); if(shapes.empty()) // shapes.push_back(1); return scalar_ty; - return tile_type::get(scalar_ty, shapes); + return block_type::get(scalar_ty, shapes); } reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) @@ -733,13 +734,13 @@ instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std: } // get_num_program -get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next) +get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next) : builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){ } -instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { - return new get_num_program_inst(type::get_int32_ty(ctx), axis, name, next); +instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { + return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next); } @@ -863,7 +864,7 @@ make_range *make_range::create(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); assert(((constant_int*)first)->get_value() == 0); - type *ty = tile_type::get(first->get_type(), {(unsigned)last->get_value()}); + type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value()}); return new make_range(ty, first, last); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 334768d27..1665bffb7 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -10,8 +10,8 @@ namespace triton{ namespace ir{ /* Module */ -module::module(const std::string &name) - : name_(name), builder_(context_) { +module::module(const std::string &name, builder &builder) + : name_(name), builder_(builder) { sealed_blocks_.insert(nullptr); } @@ -19,10 +19,6 @@ ir::builder& module::get_builder() { return builder_; } -ir::context& module::get_context() { - return context_; -} - void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ values_[val_key_t{name, block}] = value; auto it = metadatas_.find(name); @@ -98,7 +94,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block ir::value *result; bool is_const = const_.find(name) != const_.end(); auto &preds = block->get_predecessors(); - ir::type *ty = get_scope().types.at(name); + ir::type *ty = get_scope().get_type(name); if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); result = (ir::value*)incomplete_phis_[block][name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 8300a32c4..9d985dc25 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -14,7 +14,7 @@ namespace ir{ // attributes type *type::get_scalar_ty() const { - if(is_tile_ty()) + if(is_block_ty()) return get_tile_element_ty(); return const_cast(this); } @@ -28,7 +28,7 @@ unsigned type::get_primitive_size_in_bits() const { case FP128TyID: return 128; case PPC_FP128TyID: return 128; case IntegerTyID: return ((integer_type*)(this))->get_bitwidth(); - case TileTyID: return ((tile_type*)(this))->get_bitwidth(); + case BlockTyID: return ((block_type*)(this))->get_bitwidth(); default: return 0; } } @@ -37,19 +37,19 @@ unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } unsigned type::get_tile_bitwidth() const -{ return ((tile_type*)(this))->get_bitwidth(); } +{ return ((block_type*)(this))->get_bitwidth(); } unsigned type::get_fp_mantissa_width() const { id_t id = get_scalar_ty()->id_; assert(is_floating_point_ty() && "Not a floating point type!"); - if (id == HalfTyID) return 11; - if (id == FloatTyID) return 24; + if (id == HalfTyID) return 10; + if (id == FloatTyID) return 23; if (id == DoubleTyID) return 53; throw std::runtime_error("unreachable"); } type* type::get_tile_element_ty() const { - assert(is_tile_ty()); + assert(is_block_ty()); return contained_tys_[0]; } @@ -62,31 +62,31 @@ type * type::get_pointer_element_ty() const { type *ptr_ty = get_scalar_ty(); assert(ptr_ty->is_pointer_ty()); type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty(); - if(is_tile_ty()) - return tile_type::get_same_shapes(scalar_ty, (type*)this); + if(is_block_ty()) + return block_type::get_same_shapes(scalar_ty, (type*)this); return scalar_ty; } -const type::tile_shapes_t &type::get_tile_shapes() const { - assert(is_tile_ty()); - return ((tile_type*)this)->get_shapes(); +type::block_shapes_t type::get_block_shapes() const { + assert(is_block_ty()); + return ((block_type*)this)->get_shapes(); } const size_t type::get_tile_rank() const { - return get_tile_shapes().size(); + return get_block_shapes().size(); } const size_t type::get_tile_ranks1() const { int ret = 0; - for(int s: get_tile_shapes()) + for(int s: get_block_shapes()) ret += s > 1; return ret; } unsigned type::get_tile_num_elements() const { - const tile_shapes_t& shapes = get_tile_shapes(); + const block_shapes_t& shapes = get_block_shapes(); unsigned result = 1; for(auto shape: shapes) result *= shape; @@ -112,7 +112,7 @@ bool type::is_sized() const { return true; } // tile types are sizes - if(is_tile_ty()) + if(is_block_ty()) return get_scalar_ty()->is_sized(); return false; } @@ -160,12 +160,12 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ //===----------------------------------------------------------------------===// type* composite_type::get_type_at_index(value *) const{ - assert(is_tile_ty()); + assert(is_block_ty()); return get_scalar_ty(); } bool composite_type::index_valid(value *idx) const{ - assert(is_tile_ty()); + assert(is_block_ty()); return idx->get_type()->is_int_or_tileint_ty(); } @@ -173,41 +173,41 @@ bool composite_type::index_valid(value *idx) const{ // tile_type class //===----------------------------------------------------------------------===// -tile_type::tile_type(type *ty, const tile_shapes_t &shapes) - : composite_type(ty->get_context(), TileTyID), shapes_(shapes) { +block_type::block_type(type *ty, const block_shapes_t &shapes) + : composite_type(ty->get_context(), BlockTyID), shapes_(shapes) { contained_tys_.push_back(ty); } -bool tile_type::is_valid_elt_ty(type *ty) { +bool block_type::is_valid_elt_ty(type *ty) { return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty(); } -unsigned tile_type::get_num_elements() const { +unsigned block_type::get_num_elements() const { unsigned res = 1; for(auto shape: shapes_) res *= shape; return res; } -unsigned tile_type::get_bitwidth() const { +unsigned block_type::get_bitwidth() const { return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits(); } -tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) { +block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) { assert(elt_ty && "Can't get a tile of type!"); assert(shapes.size() && "Can't create a tile with empty shapes!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); - tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)]; + block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; if(!entry) - entry = new tile_type(elt_ty, shapes); + entry = new block_type(elt_ty, shapes); return entry; } -tile_type* tile_type::get_same_shapes(type *ty, type *ref){ - assert(ref->is_tile_ty()); - return get(ty, ref->get_tile_shapes()); +block_type* block_type::get_same_shapes(type *ty, type *ref){ + assert(ref->is_block_ty()); + return get(ty, ref->get_block_shapes()); } //===----------------------------------------------------------------------===// diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc deleted file mode 100644 index 62c6a0b6c..000000000 --- a/lib/lang/ast.cc +++ /dev/null @@ -1,1118 +0,0 @@ -#include "triton/lang/ast.h" -#include "triton/lang/error.h" -#include "triton/lang/evaluator.h" -#include "triton/lang/mem_pool.h" -#include "triton/lang/parser.h" -#include "triton/lang/token.h" - - -static MemPoolImp binaryOpPool; -static MemPoolImp transOpPool; -static MemPoolImp conditionalOpPool; -static MemPoolImp funcCallPool; -static MemPoolImp initializationPool; -static MemPoolImp objectPool; -static MemPoolImp identifierPool; -static MemPoolImp enumeratorPool; -static MemPoolImp constantPool; -static MemPoolImp tempVarPool; -static MemPoolImp unaryOpPool; -static MemPoolImp emptyStmtPool; -static MemPoolImp ifStmtPool; -static MemPoolImp forStmtPool; -static MemPoolImp jumpStmtPool; -static MemPoolImp returnStmtPool; -static MemPoolImp labelStmtPool; -static MemPoolImp compoundStmtPool; -static MemPoolImp funcDefPool; - - -/* - * Accept - */ - -void Declaration::Accept(Visitor* v) { - v->VisitDeclaration(this); -} - - -void EmptyStmt::Accept(Visitor* v) { - // Nothing to do -} - - -void LabelStmt::Accept(Visitor* v) { - v->VisitLabelStmt(this); -} - - -void IfStmt::Accept(Visitor* v) { - v->VisitIfStmt(this); -} - -void ForStmt::Accept(Visitor* v) { - v->VisitForStmt(this); -} - - -void JumpStmt::Accept(Visitor* v) { - v->VisitJumpStmt(this); -} - - -void ReturnStmt::Accept(Visitor* v) { - v->VisitReturnStmt(this); -} - - -void CompoundStmt::Accept(Visitor* v) { - v->VisitCompoundStmt(this); -} - - -void BinaryOp::Accept(Visitor* v) { - v->VisitBinaryOp(this); -} - - -void UnaryOp::Accept(Visitor* v) { - v->VisitUnaryOp(this); -} - -void TransOp::Accept(Visitor* v) { - v->VisitTransOp(this); -} - -void ConditionalOp::Accept(Visitor* v) { - v->VisitConditionalOp(this); -} - - -void FuncCall::Accept(Visitor* v) { - v->VisitFuncCall(this); -} - - -void Identifier::Accept(Visitor* v) { - v->VisitIdentifier(this); -} - - -void Object::Accept(Visitor* v) { - v->VisitObject(this); -} - - -void Constant::Accept(Visitor* v) { - v->VisitConstant(this); -} - - -void Enumerator::Accept(Visitor* v) -{ - v->VisitEnumerator(this); -} - - -void TempVar::Accept(Visitor* v) { - v->VisitTempVar(this); -} - - -void FuncDef::Accept(Visitor* v) { - v->VisitFuncDef(this); -} - - -void TranslationUnit::Accept(Visitor* v) { - v->VisitTranslationUnit(this); -} - - -// Casting array to pointer, function to pointer to function -Expr* Expr::MayCast(Expr* expr) { - auto type = Type::MayCast(expr->Type()); - // If the types are equal, no need cast - if (type != expr->Type()) { // Pointer comparison is enough - return UnaryOp::New(Token::CAST, expr, type); - } - return expr; -} - - -Expr* Expr::MayCast(Expr* expr, QualType desType) { - expr = MayCast(expr); - auto srcType = expr->Type(); - if (desType->ToPointer() && srcType->ToPointer()) - if (desType->IsVoidPointer() || srcType->IsVoidPointer()) - return expr; - if (!desType->Compatible(*expr->Type())) - expr = UnaryOp::New(Token::CAST, expr, desType); - return expr; -} - -// Extract the operand's scalar type if possible -// and emit an error otherwise -::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) { - auto scalType = operand->Type()->ScalarType(); - if(!scalType) - Error(loc, "expect tile or scalar operand"); - return scalType; -} - -// If operand is a tile, return a tile of the same shape and -// provided element type -// If operand is a scalar, return provided element type -// directly -::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) { - assert(ty->IsScalar()); - ::Type *retTy = ty; - if(TileType *T = operand->Type()->ToTile()) - retTy = TileType::New(T->Shape(), retTy); - return retTy; -} - -BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) { - return New(tok, tok->tag_, lhs, rhs); -} - - -BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) { - switch (op) { - case ',': case '.': case '=': - case '*': case '/': case '%': - case '+': case '-': case '&': - case '^': case '|': case '<': - case '>': - case Token::LEFT: - case Token::RIGHT: - case Token::LE: - case Token::GE: - case Token::EQ: - case Token::NE: - case Token::LOGICAL_AND: - case Token::LOGICAL_OR: - case Token::ELLIPSIS: - case Token::MATMUL: - case Token::MASKED_DEREF: - break; - default: - assert(0); - } - - auto ret = new (binaryOpPool.Alloc()) BinaryOp(tok, op, lhs, rhs); - ret->pool_ = &binaryOpPool; - - ret->TypeChecking(); - return ret; -} - - -ArithmType* BinaryOp::Convert() { - // Both lhs and rhs are ensured to be have arithmetic scalar type - auto lhsType = lhs_->Type()->ScalarType()->ToArithm(); - auto rhsType = rhs_->Type()->ScalarType()->ToArithm(); - assert(lhsType && rhsType); - auto maxType = ArithmType::MaxType(lhsType, rhsType); - if (lhsType != maxType) { // Pointer comparation is enough! - lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType)); - } - if (rhsType != maxType) { - rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType)); - } - return maxType; -} - -void BinaryOp::Broadcast(Expr* loc, Expr *&lhs, Expr *&rhs, QualType& type) { - auto lhsType = lhs->Type()->ToTile(); - auto rhsType = rhs->Type()->ToTile(); - auto eleType = type->ScalarType(); - assert(eleType); - if(!lhsType && !rhsType) - return ; - else if(lhsType && !rhsType){ - type = TileType::New(lhsType->Shape(), eleType); - ::Type* rtype = TileType::New(lhsType->Shape(), rhs->Type()->ScalarType()); - rhs = UnaryOp::New(Token::CAST, rhs, rtype); - } - else if(!lhsType && rhsType){ - type = TileType::New(rhsType->Shape(), eleType); - ::Type* ltype = TileType::New(rhsType->Shape(), lhs->Type()->ScalarType()); - lhs = UnaryOp::New(Token::CAST, lhs, ltype); - - } - else { - auto lhsShape = lhsType->Shape(); - auto rhsShape = rhsType->Shape(); - auto lhsRank = lhsShape.size(); - auto rhsRank = rhsShape.size(); - auto retRank = std::max(lhsRank, rhsRank); - // pad to the left until shapes have the same rank - while(lhsShape.size() < retRank) - lhsShape.insert(lhsShape.begin(), 1); - while(rhsShape.size() < retRank) - rhsShape.insert(rhsShape.begin(), 1); - // broadcast if possible - TileType::ShapeInt retShape(retRank); - for(size_t i = 0; i < retRank; i++) { - if(lhsShape[i] == 1) - retShape[i] = rhsShape[i]; - else if(rhsShape[i] == 1) - retShape[i] = lhsShape[i]; - else if(lhsShape[i] == rhsShape[i]) - retShape[i] = lhsShape[i]; - else - Error(loc, "cannot broadcast dimension %d " - "for operands of shape %d and %d", - i, lhsShape[i], rhsShape[i]); - } - ::Type* ltype = TileType::New(retShape, lhsType->ScalarType()); - ::Type* rtype = TileType::New(retShape, rhsType->ScalarType()); - type = TileType::New(retShape, eleType); - if(retShape != lhsShape) - lhs = UnaryOp::New(Token::CAST, lhs, ltype); - if(retShape != rhsShape) - rhs = UnaryOp::New(Token::CAST, rhs, rtype); - } -} - -/* - * Type checking - */ - -void Expr::EnsureCompatibleOrVoidPointer(const QualType lhs, - const QualType rhs) const { - if (lhs->ToPointer() && rhs->ToPointer() && - (lhs->IsVoidPointer() || rhs->IsVoidPointer())) { - return; - } - EnsureCompatible(lhs, rhs); -} - - -void Expr::EnsureCompatible(const QualType lhs, const QualType rhs) const { - if (!lhs->Compatible(*rhs)) - Error(this, "incompatible types"); -} - - -void BinaryOp::TypeChecking() { - switch (op_) { - case '.': - return MemberRefOpTypeChecking(); - - case '*': - case '/': - case '%': - return MultiOpTypeChecking(); - - case '+': - case '-': - return AdditiveOpTypeChecking(); - - case Token::LEFT: - case Token::RIGHT: - return ShiftOpTypeChecking(); - - case '<': - case '>': - case Token::LE: - case Token::GE: - return RelationalOpTypeChecking(); - - case Token::EQ: - case Token::NE: - return EqualityOpTypeChecking(); - - case '&': - case '^': - case '|': - return BitwiseOpTypeChecking(); - - case Token::LOGICAL_AND: - case Token::LOGICAL_OR: - return LogicalOpTypeChecking(); - - case '=': - return AssignOpTypeChecking(); - - case ',': - return CommaOpTypeChecking(); - - case Token::ELLIPSIS: - return RangeOpTypeChecking(); - - case Token::MATMUL: - return MatmulOpTypeChecking(); - - case Token::MASKED_DEREF: - return MaskedDerefOpTypeChecking(); - - default: - assert(0); - } -} - - -void BinaryOp::CommaOpTypeChecking() { - type_ = rhs_->Type(); -} - - -void BinaryOp::SubScriptingOpTypeChecking() { - assert(false); -} - - -void BinaryOp::MemberRefOpTypeChecking() { - type_ = rhs_->Type(); -} - - -void BinaryOp::MultiOpTypeChecking() { - ::Type* lhsScalType = lhs_->Type()->ScalarType(); - ::Type* rhsScalType = rhs_->Type()->ScalarType(); - if(!lhsScalType || !rhsScalType) { - Error(this, "operands should have type or scalar type"); - } - if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { - Error(this, "operands should have arithmetic type"); - } - if ('%' == op_ && - !(lhsScalType->IsInteger() && rhsScalType->IsInteger())) { - Error(this, "operands of '%%' should be integers"); - } - type_ = Convert(); - Broadcast(this, lhs_, rhs_, type_); -} - - -/* - * Additive operator is only allowed between: - * 1. arithmetic types (bool, interger, floating) - * 2. pointer can be used: - * 1. lhs of MINUS operator, and rhs must be integer or pointer; - * 2. lhs/rhs of ADD operator, and the other operand must be integer; - * 3. tiles can be used: - * 1. the scalar type of lhs/rhs satisfy the above requirements - * 2. lhs/rhs that have identical shape - * 3. lhs/rhs that can be broadcast as per numpy-like semantics - */ -void BinaryOp::AdditiveOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - auto lhsPtrType = lhsScalType->ToPointer(); - auto rhsPtrType = rhsScalType->ToPointer(); - if (lhsPtrType) { - if (op_ == '-') { - if (rhsPtrType) { - if (!lhsPtrType->Compatible(*rhsPtrType)) - Error(this, "invalid operands to binary -"); - type_ = ArithmType::New(T_LONG); // ptrdiff_t - } else if (!rhsScalType->IsInteger()) { - Error(this, "invalid operands to binary -"); - } else { - type_ = lhsPtrType; - } - } else if (!rhsScalType->IsInteger()) { - Error(this, "invalid operands to binary +"); - } else { - type_ = lhsPtrType; - } - } else if (rhsPtrType) { - if (op_ == '+' && !lhsScalType->IsInteger()) { - Error(this, "invalid operands to binary '+'"); - } else if (op_ == '-' && !lhsPtrType) { - Error(this, "invalid operands to binary '-'"); - } - type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType; - std::swap(lhs_, rhs_); // To simplify code gen - } else { - if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { - Error(this, "invalid operands to binary %s", tok_->str_.c_str()); - } - type_ = Convert(); - } - Broadcast(this, lhs_, rhs_, type_); -} - -void BinaryOp::RangeOpTypeChecking() { - auto lhsType = lhs_->Type()->ToArithm(); - auto rhsType = rhs_->Type()->ToArithm(); - if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger()) - Error(this, "expect integers for range operator"); - lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType)); - rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType)); - long begin = Evaluator().Eval(lhs_); - long end = Evaluator().Eval(rhs_); - int len = static_cast(end - begin); - if(len < 0) - Error(this, "range cannot be negative"); - TileType* ret = TileType::New(TileType::ShapeInt{len}, lhs_->Type()); - if(!ret->CheckPow2NumEl()) - Error(this, "range must have power of 2 number of elements"); - type_ = ret; -} - -void BinaryOp::MaskedDerefOpTypeChecking() { -// auto lhsTileType = lhs_->Type()->ToTile(); -// auto rhsTileType = rhs_->Type()->ToTile(); - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - auto lhsType = lhsScalType->ToArithm(); - auto rhsType = rhsScalType->ToPointer(); - if (!rhsType) - Error(this, "pointer expected for deref pointer in operator '*?'"); - if (!lhsType || (lhsType && !lhsType->IsBool())) - Error(this, "bool expected for deref mask in operator '*?'"); - type_ = ScalarOrLikeTile(rhs_, rhsType->Derived().GetPtr()); - Broadcast(this, lhs_, rhs_, type_); -} - -void BinaryOp::MatmulOpTypeChecking() { - auto lhsType = lhs_->Type()->ToTile(); - auto rhsType = rhs_->Type()->ToTile(); - if(!lhsType || !rhsType) - Error(this, "expect tile operands for matrix multiplication"); - auto lhsShape = lhsType->Shape(); - auto rhsShape = rhsType->Shape(); - size_t lhsRank = lhsShape.size(); - size_t rhsRank = rhsShape.size(); - if(lhsRank != rhsRank) - Error(this, "matrix multiplication operands have incompatible rank" - "%d and %d", lhsRank, rhsRank); - for(int d = 2; d < lhsRank; d++) - if(lhsShape[d] != rhsShape[d]) - Error(this, "matrix multiplication operands have incompatible batch dimension" - "%d and %d for axis %d", lhsShape[d], rhsShape[d], d); - if(lhsShape[1] != rhsShape[0]) - Error(this, "matrix multiplication operands have incompatible inner dimension" - " %d and %d", lhsShape[1], rhsShape[0]); - // ret shape - TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]}; - for(int d = 2; d < lhsRank; d++) - retShape.push_back(lhsShape[d]); - QualType retType = lhsType->Derived(); - if(retType != rhsType->Derived()) - Error(this, "matrix multiplication operands have incompatible data types"); - ArithmType* ScalType = lhsType->ScalarType()->ToArithm(); - if(ScalType->Tag() & T_HALF) - ScalType = ArithmType::New(T_FLOAT); - type_ = TileType::New(retShape, ScalType); -} - -void BinaryOp::ShiftOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - auto lhsType = lhsScalType->ToArithm(); - auto rhsType = rhsScalType->ToArithm(); - if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger()) - Error(this, "expect integers for shift operator"); - lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType))); - rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType))); - type_ = lhs_->Type(); - Broadcast(this, lhs_, rhs_, type_); -} - - -void BinaryOp::RelationalOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) { - EnsureCompatible(lhsScalType, rhsScalType); - } else { - if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) { - Error(this, "expect real type of operands"); - } - Convert(); - } - type_ = ArithmType::New(T_BOOL); - Broadcast(this, lhs_, rhs_, type_); -} - - -void BinaryOp::EqualityOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) { - EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType); - } else { - if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) - Error(this, "invalid operands to binary %s", tok_->str_.c_str()); - Convert(); - } - type_ = ArithmType::New(T_BOOL); - Broadcast(this, lhs_, rhs_, type_); -} - - -void BinaryOp::BitwiseOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger()) - Error(this, "operands of '&' should be integer"); - type_ = Convert(); - Broadcast(this, lhs_, rhs_, type_); -} - - -void BinaryOp::LogicalOpTypeChecking() { - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar()) - Error(this, "the operand should be arithmetic type or pointer"); - type_ = ArithmType::New(T_BOOL); - Broadcast(this, lhs_, rhs_, type_); -} - - -void BinaryOp::AssignOpTypeChecking() { - if (lhs_->IsConstQualified()) { - Error(lhs_, "left operand of '=' is const qualified"); - } else if (!lhs_->IsLVal()) { - Error(lhs_, "lvalue expression expected"); - } - - ::Type* lhsScalType = TryExtractScalarType(this, lhs_); - ::Type* rhsScalType = TryExtractScalarType(this, rhs_); - if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { - EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType); - } - - // The other constraints are lefted to cast operator - rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType)); - type_ = lhs_->Type(); - rhs_ = UnaryOp::New(Token::CAST, rhs_, type_); -} - -/* - * Unary Operators - */ - -UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) { - auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info); - ret->pool_ = &unaryOpPool; - - ret->TypeChecking(); - return ret; -} - - -int UnaryOp::encodeRed(int ax, int tag) { - int result = 0; - result |= ax; - result |= tag << 16; - return result; -} - -void UnaryOp::decodeRed(int info, int& ax, int& tag) { - ax = info & 0x0000FFFF; - tag = (info & 0xFFFF0000) >> 16; -} - -bool UnaryOp::IsLVal() { - // Only deref('*') could be lvalue; - return op_ == Token::DEREF; -} - - -::Type* UnaryOp::Convert() { - auto scalType = operand_->Type()->ScalarType(); - assert(scalType); - auto arithmType = scalType->ToArithm(); - assert(arithmType); - if (arithmType->IsInteger()) - arithmType = ArithmType::IntegerPromote(arithmType); - ::Type* retType = ScalarOrLikeTile(operand_, arithmType); - operand_ = Expr::MayCast(operand_, retType); - return retType; -} - - -void UnaryOp::TypeChecking() { - switch (op_) { - case Token::POSTFIX_INC: - case Token::POSTFIX_DEC: - case Token::PREFIX_INC: - case Token::PREFIX_DEC: - return IncDecOpTypeChecking(); - - case Token::ADDR: - return AddrOpTypeChecking(); - - case Token::DEREF: - return DerefOpTypeChecking(); - - case Token::PLUS: - case Token::MINUS: - case '~': - case '!': - return UnaryArithmOpTypeChecking(); - - case Token::BITCAST: - return BitcastOpTypeChecking(); - - case Token::CAST: - return CastOpTypeChecking(); - - case Token::REDUCE: - return ReduceOpTypeChecking(); - - case Token::EXP: - case Token::LOG: - case Token::SQRTF: - return IntrinsicOpTypeChecking(); - - default: - assert(false); - } -} - -void UnaryOp::IncDecOpTypeChecking() { - if (operand_->IsConstQualified()) { - Error(this, "increment/decrement of const qualified expression"); - } else if (!operand_->IsLVal()) { - Error(this, "lvalue expression expected"); - } - auto scalType = TryExtractScalarType(this, operand_); - if (!scalType->IsReal() && !scalType->ToPointer()) { - Error(this, "expect operand of real type or pointer"); - } - type_ = operand_->Type(); -} - - -void UnaryOp::AddrOpTypeChecking() { - auto funcType = operand_->Type()->ToFunc(); - if (funcType == nullptr && !operand_->IsLVal()) - Error(this, "expression must be an lvalue or function designator"); - if(operand_->Type()->IsTile()) - Error(this, "cannot take the address of a tile"); - type_ = PointerType::New(operand_->Type()); -} - - -void UnaryOp::DerefOpTypeChecking() { - auto scalType = TryExtractScalarType(this, operand_); - auto pointerType = scalType->ToPointer(); - if (!pointerType) - Error(this, "pointer expected for deref operator '*'"); - type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr()); -} - -void UnaryOp::ReduceOpTypeChecking() { - int ax, tag; - decodeRed(info_, ax, tag); - auto tileType = operand_->Type()->ToTile(); - if(!tileType) - Error(this, "array expected for reduction operation"); - auto shape = tileType->Shape(); - shape.erase(shape.begin() + ax); - if(shape.empty()) - type_ = tileType->Derived(); - else - type_ = TileType::New(shape, tileType->Derived()); -} - -void UnaryOp::UnaryArithmOpTypeChecking() { - auto scalType = TryExtractScalarType(this, operand_); - if (Token::PLUS == op_ || Token::MINUS == op_) { - if (!scalType->ToArithm()) - Error(this, "Arithmetic type expected"); - Convert(); - type_ = operand_->Type(); - } else if ('~' == op_) { - if (!scalType->IsInteger()) - Error(this, "integer expected for operator '~'"); - Convert(); - type_ = operand_->Type(); - } else if (!scalType->IsScalar()) { - Error(this, "arithmetic type or pointer expected for operator '!'"); - } else { - type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT)); - } -} - -void UnaryOp::BitcastOpTypeChecking() { - auto operandType = Type::MayCast(operand_->Type()); - if(type_->Width() != operandType->Width()) - Error(this, "cannot bitcast to type of different width"); -} - -void UnaryOp::CastOpTypeChecking() { - auto operandType = Type::MayCast(operand_->Type()); - // The type_ has been initiated to dest type - if (type_->ToVoid()) { - // The expression becomes a void expression - } else if(type_->IsTile() || operandType->IsTile()) { - /* Broadcasting rules: - * 1. Tiles with 1 element can be converted to scalar - * 2. Scalar can be converted to tiles of any shapes - * 3. Tiles can be converted to another tile only if the - * mismatching dimensions are unitary - */ - if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1) - Error(this, "tile with more than one element cannot be casted to scalar"); - if(type_->IsTile() && operandType->IsTile()){ - if(!type_->ToTile()->CheckPow2NumEl()) - Error(this, "tile must have power of 2 number of elements"); - auto operandShape = operandType->ToTile()->Shape(); - auto shape = type_->ToTile()->Shape(); - // this is a shape downcast - if(operandShape.size() > shape.size()){ - size_t operandNumel = 1; - size_t numel = 1; - for(auto x: operandShape) - operandNumel *= x; - for(auto x: shape) - numel *= x; - if(operandNumel != numel) - Error(this, "cast cannot change number of elements"); - return; - } - // this is a shape upcast - while(operandShape.size() < shape.size()) - operandShape.insert(operandShape.begin(), 1); - for(size_t i = 0; i < shape.size(); i++) { - if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i]) - Error(this, "cannot broadcast dimension %d " - "for operands of shape %d and %d", - i, shape[i], operandShape[i]); - } - } - } else if (!type_->IsScalar() || !operandType->IsScalar()) { - if (!type_->Compatible(*operandType)) - Error(this, "the cast type should be arithemetic type or pointer"); - } else if (type_->IsFloat() && operandType->ToPointer()) { - Error(this, "cannot cast a pointer to floating"); - } else if (type_->ToPointer() && operandType->IsFloat()) { - Error(this, "cannot cast a floating to pointer"); - } -} - -void UnaryOp::IntrinsicOpTypeChecking() { - type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_FLOAT)); -} - -/* - * Transposition Operator - */ -void TransOp::TypeChecking() { - auto tileType = operand_->Type()->ToTile(); - if(!tileType) - Error(this, "tile expected for transposition operator '^'"); - auto opShape = tileType->Shape(); - if(perm_.size() != opShape.size()) - Error(this, "invalid permutations"); - // permutate input shape - TileType::ShapeInt resShape(opShape.size()); - for(int d = 0; d < opShape.size(); d++) - resShape[d] = opShape[perm_[d]]; - type_ = TileType::New(resShape, tileType->Derived()); -} - -TransOp* TransOp::New(const PermInt& perm, Expr* operand) { - auto ret = new (transOpPool.Alloc()) TransOp(perm, operand); - ret->pool_ = &transOpPool; - ret->TypeChecking(); - return ret; -} - -/* - * Conditional Operator - */ - -ConditionalOp* ConditionalOp::New(const Token* tok, - Expr* cond, - Expr* exprTrue, - Expr* exprFalse) { - auto ret = new (conditionalOpPool.Alloc()) - ConditionalOp(cond, exprTrue, exprFalse); - ret->pool_ = &conditionalOpPool; - - ret->TypeChecking(); - return ret; -} - - -ArithmType* ConditionalOp::Convert() { - auto lhsType = exprTrue_->Type()->ScalarType()->ToArithm(); - auto rhsType = exprFalse_->Type()->ScalarType()->ToArithm(); - assert(lhsType && rhsType); - auto type = ArithmType::MaxType(lhsType, rhsType); - if (lhsType != type) { // Pointer comparation is enough! - exprTrue_ = UnaryOp::New(Token::CAST, exprTrue_, type); - } - if (rhsType != type) { - exprFalse_ = UnaryOp::New(Token::CAST, exprFalse_, type); - } - - return type; -} - - -void ConditionalOp::TypeChecking() { - auto condScalarType = TryExtractScalarType(this, cond_); - if (!condScalarType) { - Error(cond_->Tok(), "condition must be tile or scalar"); - } - auto lhsType = TryExtractScalarType(this, exprTrue_); - auto rhsType = TryExtractScalarType(this, exprFalse_); - if (lhsType->ToArithm() && rhsType->ToArithm()) { - type_ = Convert(); - } else { - EnsureCompatibleOrVoidPointer(lhsType, rhsType); - type_ = lhsType; - } - BinaryOp::Broadcast(this, exprFalse_, exprTrue_, type_); -} - - -/* - * Function Call - */ - -FuncCall* FuncCall::New(Expr* designator, const ArgList& args) { - auto ret = new (funcCallPool.Alloc()) FuncCall(designator, args); - ret->pool_ = &funcCallPool; - - ret->TypeChecking(); - return ret; -} - - -void FuncCall::TypeChecking() { - auto pointerType = designator_->Type()->ToPointer(); - if (pointerType) { - if (!pointerType->Derived()->ToFunc()) - Error(designator_, "called object is not a function or function pointer"); - // Convert function pointer to function type - designator_ = UnaryOp::New(Token::DEREF, designator_); - } - auto funcType = designator_->Type()->ToFunc(); - if (!funcType) { - Error(designator_, "called object is not a function or function pointer"); - } else if (!funcType->Derived()->ToVoid() && - !funcType->Derived()->Complete()) { - Error(designator_, "invalid use of incomplete return type"); - } - - auto arg = args_.begin(); - for (auto param: funcType->Params()) { - if (arg == args_.end()) - Error(this, "too few arguments for function call"); - *arg = Expr::MayCast(*arg, param->Type()); - ++arg; - } - if (arg != args_.end() && !funcType->Variadic()) - Error(this, "too many arguments for function call"); - - // C11 6.5.2.2 [6]: promote float to double if it has no prototype - while (arg != args_.end()) { - if ((*arg)->Type()->IsFloat() && (*arg)->Type()->Width() == 4) { - auto type = ArithmType::New(T_DOUBLE); - *arg = UnaryOp::New(Token::CAST, *arg, type); - } - ++arg; - } - - type_ = funcType->Derived(); -} - - -/* - * Identifier - */ - -Identifier* Identifier::New(const Token* tok, - QualType type, - enum Linkage linkage, - const AttrList &attrList) { - auto ret = new (identifierPool.Alloc()) Identifier(tok, type, linkage, attrList); - ret->pool_ = &identifierPool; - return ret; -} - - -Enumerator* Enumerator::New(const Token* tok, int val) { - auto ret = new (enumeratorPool.Alloc()) Enumerator(tok, val); - ret->pool_ = &enumeratorPool; - return ret; -} - - -Declaration* Declaration::New(Object* obj) { - auto ret = new (initializationPool.Alloc()) Declaration(obj); - ret->pool_ = &initializationPool; - return ret; -} - -void Declaration::AddInit(Initializer init) { - init.expr_ = Expr::MayCast(init.expr_, init.type_); - auto res = inits_.insert(init); - if (!res.second) { - inits_.erase(res.first); - inits_.insert(init); - } -} - - -/* - * Object - */ - -Object* Object::New(const Token* tok, - QualType type, - int storage, - enum Linkage linkage, - unsigned char bitFieldBegin, - unsigned char bitFieldWidth, - const AttrList& attrList) { - auto ret = new (objectPool.Alloc()) - Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth, attrList); - ret->pool_ = &objectPool; - - static long id = 0; - if (ret->IsStatic() || ret->Anonymous()) - ret->id_ = ++id; - return ret; -} - - -Object* Object::NewAnony(const Token* tok, - QualType type, - int storage, - enum Linkage linkage, - unsigned char bitFieldBegin, - unsigned char bitFieldWidth, - const AttrList& attrList) { - auto ret = new (objectPool.Alloc()) - Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth, attrList); - ret->pool_ = &objectPool; - ret->anonymous_ = true; - - static long id = 0; - if (ret->IsStatic() || ret->anonymous_) - ret->id_ = ++id; - return ret; -} - - -/* - * Constant - */ - -Constant* Constant::New(const Token* tok, int tag, long val) { - auto type = ArithmType::New(tag); - auto ret = new (constantPool.Alloc()) Constant(tok, type, val); - ret->pool_ = &constantPool; - return ret; -} - - -Constant* Constant::New(const Token* tok, int tag, double val) { - auto type = ArithmType::New(tag); - auto ret = new (constantPool.Alloc()) Constant(tok, type, val); - ret->pool_ = &constantPool; - return ret; -} - - -Constant* Constant::New(const Token* tok, int tag, const std::string* val) { - auto derived = ArithmType::New(tag); - auto type = ArrayType::New(val->size() / derived->Width(), derived); - - auto ret = new (constantPool.Alloc()) Constant(tok, type, val); - ret->pool_ = &constantPool; - - static long id = 0; - ret->id_ = ++id; - return ret; -} - - -std::string Constant::SValRepr() const { - std::vector buf(4 * sval_->size() + 1); - for (size_t i = 0; i < sval_->size(); ++i) { - int c = (*sval_)[i]; - sprintf(&buf[i * 4], "\\x%1x%1x", (c >> 4) & 0xf, c & 0xf); - } - return std::string(buf.begin(), buf.end() - 1); -} - - -/* - * TempVar - */ - -TempVar* TempVar::New(QualType type) { - auto ret = new (tempVarPool.Alloc()) TempVar(type); - ret->pool_ = &tempVarPool; - return ret; -} - - -/* - * Statement - */ - -EmptyStmt* EmptyStmt::New() { - auto ret = new (emptyStmtPool.Alloc()) EmptyStmt(); - ret->pool_ = &emptyStmtPool; - return ret; -} - - -// The else stmt could be null -IfStmt* IfStmt::New(Expr* cond, Stmt* then, Stmt* els) { - auto ret = new (ifStmtPool.Alloc()) IfStmt(cond, then, els); - ret->pool_ = &ifStmtPool; - return ret; -} - - -CompoundStmt* CompoundStmt::New(std::list& stmts, ::Scope* scope) { - auto ret = new (compoundStmtPool.Alloc()) CompoundStmt(stmts, scope); - ret->pool_ = &compoundStmtPool; - return ret; -} - -ForStmt* ForStmt::New(Stmt* body, Stmt* init, Expr* cond, Expr* step) { - auto ret = new (forStmtPool.Alloc()) ForStmt(body, init, cond, step); - ret->pool_ = &forStmtPool; - return ret; -} - -JumpStmt* JumpStmt::New(LabelStmt* label) { - auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label); - ret->pool_ = &jumpStmtPool; - return ret; -} - - -ReturnStmt* ReturnStmt::New(Expr* expr) { - auto ret = new (returnStmtPool.Alloc()) ReturnStmt(expr); - ret->pool_ = &returnStmtPool; - return ret; -} - - -LabelStmt* LabelStmt::New() { - auto ret = new (labelStmtPool.Alloc()) LabelStmt(); - ret->pool_ = &labelStmtPool; - return ret; -} - - -FuncDef* FuncDef::New(Identifier* ident, LabelStmt* retLabel) { - auto ret = new (funcDefPool.Alloc()) FuncDef(ident, retLabel); - ret->pool_ = &funcDefPool; - return ret; -} - - -bool Initializer::operator<(const Initializer& rhs) const { - if (offset_ < rhs.offset_) - return true; - return (offset_ == rhs.offset_ && bitFieldBegin_ < rhs.bitFieldBegin_); -} diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc deleted file mode 100644 index 77568014f..000000000 --- a/lib/lang/code_gen.cc +++ /dev/null @@ -1,847 +0,0 @@ -#include "triton/lang/code_gen.h" -#include "triton/lang/evaluator.h" -#include "triton/lang/parser.h" -#include "triton/lang/token.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" - -// Helpers -void Generator::set_ret(ir::value* value) { - ret_ = value; -} - -inline bool is_terminator(ir::value* x) { - return x && dynamic_cast(x); -} - - -// Expression - -void Generator::VisitBinaryOp(BinaryOp* binary) { - - Visit(binary->rhs_); - ir::value* rhs = ret_; - - if(binary->op_ == '=') - return set_ret(assign_->GenExpr(binary->lhs_, rhs)); - - Visit(binary->lhs_); - ir::value* lhs = ret_; - - // op info - auto type = binary->lhs_->Type()->ScalarType(); - auto flt = type->IsFloat(); - auto sign = !type->IsUnsigned(); - - // return - switch(binary->op_){ - case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs)); - case Token::LOGICAL_OR: return set_ret(bld_->create_or(lhs, rhs)); - case '|': return set_ret(bld_->create_or(lhs, rhs)); - case '&': return set_ret(bld_->create_and(lhs, rhs)); - case '^': return set_ret(bld_->create_xor(lhs, rhs)); - case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs)); - case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs)); - case '.': return error_not_implemented(". binary operator not implemented"); - case ',': return error_not_implemented(", binary operator not implemented"); - case '@' : { - ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); - ir::type* ret_scal_ty = ret_ty->get_scalar_ty(); - ir::value* _0; - if(ret_scal_ty->is_float_ty()) - _0 = ir::constant_fp::get(ret_scal_ty, 0); - else - _0 = ir::constant_int::get(ret_scal_ty, 0); - _0 = bld_->create_splat(_0, ret_ty->get_tile_shapes()); - return set_ret(bld_->create_dot(lhs, rhs, _0)); - } - case Token::MASKED_DEREF: { - // TODO: FIXME - ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); - ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty()); - auto it = bld_->get_insert_block(); - if(ret_ty->is_tile_ty()) - false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes()); - bld_->set_insert_point(it); - return set_ret(bld_->create_masked_load(rhs, lhs, false_value)); - } - case Token::ELLIPSIS: { - auto clhs = dynamic_cast(lhs); - auto crhs = dynamic_cast(rhs); - if(!clhs || !crhs) - error_not_implemented("ellipsis between variables not implemented"); - return set_ret(bld_->insert(ir::make_range::create(clhs, crhs))); - } - case '+': - if(binary->lhs_->Type()->ScalarType()->ToPointer()){ - return set_ret(bld_->create_gep(lhs, {rhs})); - } - else if(flt) - return set_ret(bld_->create_fadd(lhs, rhs)); - else - return set_ret(bld_->create_add(lhs, rhs)); - case '-': - if(binary->lhs_->Type()->ToPointer()) - return set_ret(bld_->create_gep(lhs, {GenUnaryMinus(rhs)})); - else if(flt) - return set_ret(bld_->create_fsub(lhs, rhs)); - else - return set_ret(bld_->create_sub(lhs, rhs)); - case '*': - if(flt) - return set_ret(bld_->create_fmul(lhs, rhs)); - else - return set_ret(bld_->create_mul(lhs, rhs)); - case '/': - if(flt) - return set_ret(bld_->create_fdiv(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_sdiv(lhs, rhs)); - else if(!sign) - return set_ret(bld_->create_udiv(lhs, rhs)); - else - return should_not_happen("/ should not encounter type not in {float, int}"); - case '%': - if(flt) - return set_ret(bld_->create_frem(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_srem(lhs, rhs)); - else - return set_ret(bld_->create_urem(lhs, rhs)); - case '<': - if(flt) - return set_ret(bld_->create_fcmpOLT(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_icmpSLT(lhs, rhs)); - else if(!sign) - return set_ret(bld_->create_icmpULT(lhs, rhs)); - else - return should_not_happen("< should not encounter type not in {float, int}"); - case '>': - if(flt) - return set_ret(bld_->create_fcmpOGT(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_icmpSGT(lhs, rhs)); - else if(!sign) - return set_ret(bld_->create_icmpUGT(lhs, rhs)); - else - return should_not_happen("> should not encounter type not in {float, int}"); - case Token::LE: - if(flt) - return set_ret(bld_->create_fcmpOLE(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_icmpSLE(lhs, rhs)); - else if(!sign) - return set_ret(bld_->create_icmpULE(lhs, rhs)); - else - return should_not_happen("<= should not encounter type not in {float, int}"); - case Token::GE: - if(flt) - return set_ret(bld_->create_fcmpOGE(lhs, rhs)); - else if(sign) - return set_ret(bld_->create_icmpSGE(lhs, rhs)); - else if(!sign) - return set_ret(bld_->create_icmpUGE(lhs, rhs)); - else - return should_not_happen(">= should not encounter type not in {float, int}"); - case Token::EQ: - if(flt) - return set_ret(bld_->create_fcmpOEQ(lhs, rhs)); - else - return set_ret(bld_->create_icmpEQ(lhs, rhs)); - case Token::NE: - if(flt) - return set_ret(bld_->create_fcmpONE(lhs, rhs)); - else - return set_ret(bld_->create_icmpNE(lhs, rhs)); - default: - return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented"); - } - should_not_happen(""); -} - -ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { - using ir::reduce_inst; - switch(tag){ - case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD; - case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB; - case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX; - case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN; - default: break; - } - error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented"); - return reduce_inst::op_t(); -} - -ir::value* Generator::GenUnaryMinus(ir::value* arg) { - ir::type *ty = arg->get_type(); - ir::type *sca_ty = ty->get_scalar_ty(); - ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty); - if(ty->is_tile_ty()) - _0 = bld_->create_splat(_0, ty->get_tile_shapes()); - if(sca_ty->is_floating_point_ty()) - return bld_->create_fsub(_0, arg); - else - return bld_->create_sub(_0, arg); -} - -ir::value* Generator::GenUnaryInc(UnaryOp* expr, bool is_postfix, - bool is_inc) { - Visit(expr->operand_); - ir::value* arg = ret_; - - ir::value *_1 = nullptr; - ir::value *instr = nullptr; - - if (arg->get_type()->is_floating_point_ty()) { - _1 = ir::constant_fp::get(arg->get_type(), 1.0); - if (is_inc) - instr = bld_->create_fadd(arg, _1); - else - instr = bld_->create_fsub(arg, _1); - } else if (arg->get_type()->is_integer_ty()) { - _1 = ir::constant_int::get(arg->get_type(), 1); - if (is_inc) - instr = bld_->create_add(arg, _1); - else - instr = bld_->create_sub(arg, _1); - } else if (arg->get_type()->is_pointer_ty()) { - ir::type *ty = ir::type::get_int64_ty(*ctx_); - _1 = ir::constant_int::get(ty, 1); - if (is_inc) - instr = bld_->create_gep(arg, {_1}); - else { - ir::value *neg_1 = ir::constant_int::get(ty, -1); - instr = bld_->create_gep(arg, {neg_1}); - } - } else - error_not_implemented("data type not supported for unary inc"); - - mod_->set_value(arg->get_name(), instr); - - if (is_postfix) - return arg; - else - return instr; -} - -void Generator::VisitUnaryOp(UnaryOp* unary) { - // recursion - Visit(unary->operand_); - ir::value* arg = ret_; - ir::type *arg_ty = arg->get_type(); - ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); - // return - switch (unary->op_) { - case Token::PREFIX_INC: return set_ret(GenUnaryInc(unary, false, true)); - case Token::PREFIX_DEC: return set_ret(GenUnaryInc(unary, false, false)); - case Token::POSTFIX_INC: return set_ret(GenUnaryInc(unary, true, true)); - case Token::POSTFIX_DEC: return set_ret(GenUnaryInc(unary, true, false)); - case Token::ADDR: return error_not_implemented("unary & not implemented"); - case Token::DEREF: return set_ret(bld_->create_load(arg)); - case Token::PLUS: return error_not_implemented("unary + not implemented"); - case Token::MINUS: return set_ret(GenUnaryMinus(arg)); - case '~': return error_not_implemented("unary ~ not implemented"); - case '!': return error_not_implemented("unary ! not implemented"); - case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); - case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); - case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast - case Token::LOG: return set_ret(bld_->create_log(arg)); - case Token::SQRTF: return set_ret(bld_->create_sqrt(arg)); - case Token::REDUCE: { - int ax, tag; - UnaryOp::decodeRed(unary->info_, ax, tag); - bool is_float = arg_scal_ty->is_floating_point_ty(); - ir::reduce_inst::op_t op = reduce_op(tag, is_float); - return set_ret(bld_->create_reduce(arg, op, ax)); - } - default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented"); - } - return should_not_happen(""); -} - -void Generator::VisitTransOp(TransOp *trans) { - Visit(trans->operand_); - ir::value* arg = ret_; - return set_ret(bld_->create_trans(arg, trans->getPerm())); -} - -void Generator::VisitConditionalOp(ConditionalOp* condOp) { - auto &instructions = bld_->get_insert_block()->get_inst_list(); - VisitExpr(condOp->cond_); - ir::value* true_cond = ret_; - ir::instruction* start = instructions.back(); - VisitExpr(condOp->exprTrue_); - ir::value* true_val = ret_; - VisitExpr(condOp->exprFalse_); - ir::value* false_val = ret_; - auto begin = std::find(instructions.begin(), instructions.end(), start); - bool is_in_true_cond = true; - for(auto it = begin; it != instructions.end(); it++){ - ir::instruction* instr = *it; - // we mask load with `cond` when used to compute true_value - // we mask load with `!cond` when used to compute false_value - if(auto ld = dynamic_cast(instr)){ - bld_->set_insert_point(ld); - ir::type* ty = ld->get_type(); - ir::value* cond = is_in_true_cond ? true_cond : true_cond; - ir::value* ptr = ld->get_pointer_operand(); - ir::value* else_val = ir::undef_value::get(ty); - ir::value* masked_ld = bld_->create_masked_load(ptr, cond, else_val); - ld->replace_all_uses_with(masked_ld); - ld->erase_from_parent(); - if(true_val == ld) - true_val = masked_ld; - if(false_val == ld) - false_val = masked_ld; - it = std::find(instructions.begin(), instructions.end(), masked_ld); - } - if(instr == true_val) - is_in_true_cond = false; - } - bld_->set_insert_point(bld_->get_insert_block()); - return set_ret(bld_->create_select(true_cond, true_val, false_val)); - -// VisitExpr(condOp->cond_); -// ir::value* cond = ret_; -// VisitExpr(condOp->exprTrue_); -// ir::value* true_val = ret_; -// VisitExpr(condOp->exprFalse_); -// ir::value* false_val = ret_; -// if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { -// if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) -// false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); -// ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val); -// ld->replace_all_uses_with(new_ld); -// ld->erase_from_parent(); -// return set_ret(new_ld); -// } -// return set_ret(bld_->create_select(cond, true_val, false_val)); -} - -void Generator::VisitFuncCall(FuncCall* funcCall) { - std::string name = funcCall->Name(); - if(name == "get_program_id"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ret = ret_; - if(auto axis = dynamic_cast(ret)) - return set_ret(bld_->create_get_program_id(axis->get_value())); - else - return should_not_happen("get_program_id argument should be constant"); - } - if(name == "get_num_programs"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ret = ret_; - if(auto axis = dynamic_cast(ret)) - return set_ret(bld_->create_get_num_program(axis->get_value())); - else - return should_not_happen("get_num_programs argument should be constant"); - } - if(name == "atomic_cas"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ptr = ret_; - VisitExpr(funcCall->Args()->at(1)); - ir::value* cmp = ret_; - VisitExpr(funcCall->Args()->at(2)); - ir::value* val = ret_; - return set_ret(bld_->create_atomic_cas(ptr, cmp, val)); - } - if(name == "atomic_xchg"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ptr = ret_; - VisitExpr(funcCall->Args()->at(1)); - ir::value* val = ret_; - return set_ret(bld_->create_atomic_exch(ptr, val)); - } - if(name.substr(0, 10) == "atomic_add"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ptr = ret_; - VisitExpr(funcCall->Args()->at(1)); - ir::value* val = ret_; - VisitExpr(funcCall->Args()->at(2)); - ir::value* msk = ret_; - return set_ret(bld_->create_atomic_add(ptr, val, msk)); - } - if(name == "calloc"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ret = ret_; - ir::constant_int *size = dynamic_cast(ret); - assert(size); - ir::alloc_const* alloc = new ir::alloc_const(bld_->get_int8_ty(), size); - mod_->add_alloc(alloc); - return set_ret(alloc); - } - //TODO: integrate this into conditionalop - if(name == "select"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* cond = ret_; - VisitExpr(funcCall->Args()->at(1)); - ir::value* true_val = ret_; - VisitExpr(funcCall->Args()->at(2)); - ir::value* false_val = ret_; - return set_ret(bld_->create_select(cond, true_val, false_val)); - } - if(name == "__debug_barrier"){ - bld_->create_barrier(); - return; - } - return error_not_implemented("function calls not implemented"); -} - -void Generator::VisitObject(Object* obj) { - return set_ret(mod_->get_value(obj->Name())); -} - -void Generator::VisitEnumerator(Enumerator* enumer) { - return error_not_implemented("enumeration not implemented"); -} - -void Generator::VisitIdentifier(Identifier* ident) { - return set_ret(mod_->get_value(ident->Name())); -} - -void Generator::VisitConstant(Constant* cons) { - Type* ctype = cons->Type(); - ir::type *type = GenIRType(cons->Type(), *ctx_); - if(ctype->IsInteger()) - return set_ret(ir::constant_int::get(type, cons->IVal())); - if(ctype->IsFloat() && ctype->IsReal()) - return set_ret(ir::constant_fp::get(type, cons->FVal())); - return error_not_implemented("constant of type not in {int, float} not implemented"); -} - -void Generator::VisitTempVar(TempVar* tempVar) { - return error_not_implemented("temporary variable not implemented"); -} - -// Statement -// TODO: int x = x; crashes -void Generator::VisitDeclaration(Declaration* decl) { - auto obj = decl->obj_; - // initialize to undef - - ir::type* ty = GenIRType(obj->Type(), *ctx_); - ir::value* val = ir::undef_value::get(ty); -//obj->GetAttrList() - // compute initializers - std::vector inits; - for (const Initializer& init: decl->Inits()) { - VisitExpr(init.expr_); - ir::value *val = ret_; - for(const auto& attr: obj->GetAttrList()) - SetIRMetadata(attr, val); - inits.push_back(val); - } - // initialize declaration - ir::type::id_t id = ty->get_type_id(); - if(id == ir::type::StructTyID) - error_not_implemented("struct not implemented"); - if(inits.size() > 1) - error_not_implemented("initializer list > 1 element not implemented"); - if(inits.size() > 0) - val = inits[0]; - assert(val->get_type() == ty); - // update scope symbols table - const std::string &name = obj->Name(); - if(!name.empty()){ - mod_->set_value(name, val); - mod_->get_scope().types[name] = ty; - } -} - -void Generator::VisitEmptyStmt(EmptyStmt*) { - return; -} - -void Generator::VisitIfStmt(IfStmt* ifStmt) { - ir::function *fn = bld_->get_insert_block()->get_parent(); - Stmt *then_ = ifStmt->then_; - Stmt *else_ = ifStmt->else_; - VisitExpr(ifStmt->cond_); - ir::value* cond = ret_; - ir::basic_block *then_bb = ir::basic_block::create(*ctx_, "then", fn); - ir::basic_block *else_bb = else_? ir::basic_block::create(*ctx_, "else", fn) : nullptr; - ir::basic_block *endif_bb = ir::basic_block::create(*ctx_, "endif", fn); - // seal blocks - mod_->seal_block(then_bb); - if(else_bb) - mod_->seal_block(else_bb); - // branches - if(else_) - bld_->create_cond_br(cond, then_bb, else_bb); - else - bld_->create_cond_br(cond, then_bb, endif_bb); - // then - bld_->set_insert_point(then_bb); - VisitStmt(then_); - if(!is_terminator(ret_)) - bld_->create_br(endif_bb); - // else - if(else_){ - bld_->set_insert_point(else_bb); - VisitStmt(else_); - if(!is_terminator(ret_)) - bld_->create_br(endif_bb); - } - // endif - mod_->seal_block(endif_bb); - bld_->set_insert_point(endif_bb); -} - -void Generator::VisitForStmt(ForStmt *forStmt) { - Stmt *init_ = forStmt->init_; - Expr *cond_ = forStmt->cond_; - Expr *step_ = forStmt->step_; - Stmt *body_ = forStmt->body_; - ir::basic_block *current_bb = bld_->get_insert_block(); - ir::function *fn = current_bb->get_parent(); - ir::basic_block *loop_bb = ir::basic_block::create(*ctx_, "loop", fn); - ir::basic_block *next_bb = ir::basic_block::create(*ctx_, "postloop", fn); - mod_->set_continue_fn([&](){ - if(step_) - VisitExpr(step_); - VisitExpr(cond_); - ir::value *cond = ret_; - return bld_->create_cond_br(cond, loop_bb, next_bb); - }); - if(init_) - VisitStmt(init_); - VisitExpr(cond_); - ir::value *cond = ret_; - bld_->create_cond_br(cond, loop_bb, next_bb); -// bld_->create_br(loop_bb); - bld_->set_insert_point(loop_bb); - if(body_) - VisitStmt(body_); - if(!is_terminator(ret_)) - mod_->get_continue_fn()(); - ir::basic_block *stop_bb = bld_->get_insert_block(); - mod_->seal_block(stop_bb); - mod_->seal_block(loop_bb); - mod_->seal_block(bld_->get_insert_block()); - mod_->seal_block(next_bb); - bld_->set_insert_point(next_bb); -} - -void Generator::VisitJumpStmt(JumpStmt* jumpStmt) { - return error_not_implemented("jump not implemented"); -} - -void Generator::VisitReturnStmt(ReturnStmt* returnStmt) { - ir::value *ret; - if(returnStmt->expr_) - return error_not_implemented("non-void return not implemented"); - else - ret = bld_->create_ret_void(); - return set_ret(ret); -} - -void Generator::VisitLabelStmt(LabelStmt* labelStmt) { - return error_not_implemented("label not implemented"); -} - -void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) { - if (compoundStmt->scope_) - pushScope(); - for (auto stmt: compoundStmt->stmts_) - Visit(stmt); - if(compoundStmt->scope_) - popScope(); -} - -void Generator::VisitFuncDef(FuncDef* funcDef) { - Stmt *body = funcDef->body_; - const std::string& name = funcDef->Name(); - FuncType* type = funcDef->FuncType(); - auto prototype = dynamic_cast(GenIRType(type, *ctx_)); - if(!prototype) - should_not_happen("could not parse function prototype"); - ir::function *fn = mod_->get_or_insert_function(name, prototype); - std::vector args = fn->args(); - size_t i = 0; - for(Object* obj: type->Params()){ - std::string name = obj->Name(); - args[i]->set_name(name); - if(obj->Type()->ToPointer()) - fn->add_attr(i + 1, ir::attribute(ir::aligned, 16)); - for(ASTNode::Attr attr: obj->GetAttrList()){ - fn->add_attr(i + 1, GenIRAttr(attr)); - } - if(obj->IsRestrictQualified()) - fn->add_attr(i, ir::attribute(ir::noalias)); - mod_->set_value(name, nullptr, args[i]); - mod_->get_scope().types[name] = args[i]->get_type(); - i++; - } - ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn); - mod_->seal_block(entry); - mod_->get_builder().set_insert_point(entry); - VisitStmt(body); - if(!dynamic_cast(ret_)) - mod_->get_builder().create_ret_void(); -} - -void Generator::VisitTranslationUnit(TranslationUnit* unit) { - pushScope(); - for (auto extDecl: unit->ExtDecls()) - Visit(extDecl); - popScope(); -} - -void Generator::Gen(ir::module *mod) { - mod_ = mod; - ctx_ = &mod_->get_context(); - bld_ = &mod_->get_builder(); - assign_ = new LValAssigner(this); - VisitTranslationUnit(parser_->Unit()); - delete assign_; - assign_ = nullptr; -} - - - -ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) { - if(src->get_type() == dst_ty) - return src; - if(dst_ty->is_tile_ty()) { - ir::type *src_ty = src->get_type(); - auto dst_shapes = dst_ty->get_tile_shapes(); - if(!src_ty->is_tile_ty()) - return bld_->create_splat(src, dst_shapes); - auto src_shapes = src_ty->get_tile_shapes(); - if(src_shapes.size() != dst_shapes.size()){ - unsigned src_numel = 1; - for(unsigned s: src_shapes) - src_numel *= s; - unsigned dst_numel = 1; - for(unsigned s: dst_shapes) - dst_numel *= s; - if(src_numel == dst_numel) - return bld_->create_reshape(src, dst_shapes); - else { - auto padded_shapes = src_shapes; - while(padded_shapes.size() != dst_shapes.size()) - padded_shapes.insert(padded_shapes.begin(), 1); - // check that broadcast is legal - for(size_t d = 0; d < padded_shapes.size(); d++){ - if(dst_shapes[d] != padded_shapes[d] && - padded_shapes[d] != 1) - should_not_happen("broadcast should not happen between these shapes"); - } - // pad and broadcast - ir::value *padded = bld_->create_reshape(src, padded_shapes); - return bld_->create_broadcast(padded, dst_shapes); - } - } - else{ - return bld_->create_broadcast(src, dst_shapes); - } - } - else if(src->get_type()->is_tile_ty() && src->get_type()->get_tile_num_elements() == 1){ - return bld_->create_downcast(src); - } - return src; -} - -ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) { - ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); - ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); - if(src->get_type()->is_tile_ty()) - dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type()); - bool src_signed = false; - bool dst_signed = false; - if(src_scalar_ty == dst_scalar_ty) - return src; - else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty()) - return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())), - bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes())); - else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) - return bld_->create_si_to_fp(src, dst_ty); - else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) - return bld_->create_ui_to_fp(src, dst_ty); - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed) - return bld_->create_fp_to_si(src, dst_ty); - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed) - return bld_->create_fp_to_ui(src, dst_ty); - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width()) - return bld_->create_fp_ext(src, dst_ty); - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width()) - return bld_->create_fp_trunc(src, dst_ty); - else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && - src_scalar_ty->get_integer_bitwidth()) - return bld_->create_int_cast(src, dst_ty, dst_signed); - else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty()) - return bld_->create_cast(ir::BitCast, src, dst_ty); - else{ - error_not_implemented("cast type not implemented"); - return nullptr; - } -} - -ir::value* Generator::GenSemCastOp(ir::value* src, ir::type* dst_ty) { - return GenNumcastOp(GenBroadcastOp(src, dst_ty), dst_ty); -} - -ir::value* Generator::GenBitCastOp(ir::value* src, ir::type* dst_ty) { - return bld_->create_cast(ir::BitCast, GenBroadcastOp(src, dst_ty), dst_ty); -} - - -// Triton-IR Attr -ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { - if(attr.kind == ASTNode::Attr::MULTIPLEOF) { - VisitExpr(attr.vals[0]); - auto cst = dynamic_cast(ret_); - if(!cst) should_not_happen("multipleof only works on constants"); - return ir::attribute(ir::multiple_of, cst->get_value()); - } - if(attr.kind == ASTNode::Attr::ALIGNED) { - VisitExpr(attr.vals[0]); - auto cst = dynamic_cast(ret_); - return ir::attribute(ir::aligned, cst->get_value()); - } - if(attr.kind == ASTNode::Attr::NOALIAS) - return ir::attribute(ir::noalias); - if(attr.kind == ASTNode::Attr::READONLY) - return ir::attribute(ir::readonly); - if(attr.kind == ASTNode::Attr::WRITEONLY) - return ir::attribute(ir::writeonly); - if(attr.kind == ASTNode::Attr::RETUNE) - return ir::attribute(ir::retune); - error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented"); - return ir::attribute(ir::not_implemented); -} - -void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) { - auto *i = dynamic_cast(v); - if(!i) - return; - if(attr.kind == ASTNode::Attr::MULTIPLEOF) - i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value()); -} - -// Triton-IR Types -ir::type* Generator::GenIRType(::Type* type, ir::context& ctx) { - if(auto T = type->ToVoid()) - return ir::type::get_void_ty(ctx); - if(auto T = type->ToArithm()) - return GenIRArithmType(T, ctx); - if(auto T = type->ToArray()) - return GenIRArrayType(T, ctx); - if(auto T = type->ToTile()) - return GenIRTileType(T, ctx); - if(auto T = type->ToFunc()) - return GenIRFuncType(T, ctx); - if(auto T = type->ToPointer()) - return GenIRPointerType(T, ctx); - if(auto T = type->ToStruct()) - return GenIRStructType(T, ctx); - assert(false); - return nullptr; -} - -ir::type* Generator::GenIRArithmType(ArithmType* type, ir::context& ctx) { - int tag = type->Tag(); - if(tag & T_BOOL) - return ir::type::get_int1_ty(ctx); - if(tag & T_CHAR) - return ir::type::get_int8_ty(ctx); - if(tag & T_SHORT) - return ir::type::get_int16_ty(ctx); - if(tag & T_INT) - return ir::type::get_int32_ty(ctx); - if(tag & T_LONG) - return ir::type::get_int64_ty(ctx); - if(tag & T_HALF) - return ir::type::get_half_ty(ctx); - if(tag & T_FLOAT) - return ir::type::get_float_ty(ctx); - if(tag & T_DOUBLE) - return ir::type::get_double_ty(ctx); - assert(false); - return nullptr; -} - -ir::type* Generator::GenIRArrayType(ArrayType* type, ir::context& ctx) { - assert(false); - return nullptr; -} - -ir::type* Generator::GenIRTileType(TileType* type, ir::context& ctx) { - ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx); - auto _shape = type->Shape(); - ir::tile_type::tile_shapes_t shape; - for(int s: _shape) - shape.push_back(static_cast(s)); - return ir::tile_type::get(ele_ty, shape); -} - -ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) { - ir::type* ret_ty = GenIRType(type->Derived().GetPtr(), ctx); - std::vector param_tys; - for(Object* obj: type->Params()) - param_tys.push_back(GenIRType(obj->Type(), ctx)); - return ir::function_type::get(ret_ty, param_tys); -} - -ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) { - ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx); - unsigned addr_space = 1; - if(type->Derived().IsConstantQualified()) - addr_space = 4; - return ir::pointer_type::get(ele_ty, addr_space); -} - -ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) { - error_not_implemented("struct not implemented"); - return nullptr; -} - -void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) { - return error_not_implemented("alloc not implemented"); -} - -// SSA -void Generator::pushScope() { - mod_->add_new_scope(); -} - -void Generator::popScope() { - mod_->pop_scope(); -} - -// LValue Generator -void LValAssigner::VisitBinaryOp(BinaryOp* binary) { - if(binary->op_ != Token::MASKED_DEREF) - error_not_implemented("lvalue for binary non masked-deref not implemented"); - gen_->VisitExpr(binary->lhs_); - ir::value* mask = gen_->ret_; - gen_->VisitExpr(binary->rhs_); - ir::value* addr = gen_->ret_; - ret_ = gen_->bld_->create_masked_store(addr, rhs_, mask); -} - -void LValAssigner::VisitUnaryOp(UnaryOp* unary) { - if(unary->op_ != Token::DEREF) - error_not_implemented("lvalue for unary non deref not implemented"); - gen_->VisitExpr(unary->operand_); - ir::value* addr = gen_->ret_; - ret_ = gen_->bld_->create_store(addr, rhs_); -} - -void LValAssigner::VisitObject(Object* obj) { - std::string name = obj->Name(); - gen_->mod_->set_value(name, rhs_); - ret_ = rhs_; -} - -void LValAssigner::VisitIdentifier(Identifier* ident) { - std::string name = ident->Name(); - gen_->mod_->set_value(name, rhs_); - ret_ = rhs_; -} - - - diff --git a/lib/lang/cpp.cc b/lib/lang/cpp.cc deleted file mode 100644 index 2cdfb453a..000000000 --- a/lib/lang/cpp.cc +++ /dev/null @@ -1,884 +0,0 @@ -#include "triton/lang/cpp.h" - -#include "triton/lang/evaluator.h" -#include "triton/lang/parser.h" - -#include -#include -#include -#include - - - -using DirectiveMap = std::unordered_map; - -static const DirectiveMap directiveMap { - {"if", Token::PP_IF}, - {"ifdef", Token::PP_IFDEF}, - {"ifndef", Token::PP_IFNDEF}, - {"elif", Token::PP_ELIF}, - {"else", Token::PP_ELSE}, - {"endif", Token::PP_ENDIF}, - {"include", Token::PP_INCLUDE}, - // Non-standard GNU extension - {"include_next", Token::PP_INCLUDE}, - {"define", Token::PP_DEFINE}, - {"undef", Token::PP_UNDEF}, - {"line", Token::PP_LINE}, - {"error", Token::PP_ERROR}, - {"pragma", Token::PP_PRAGMA} -}; - - -/* - * params: - * is: input token sequence - * os: output token sequence - */ -void Preprocessor::Expand(TokenSequence& os, TokenSequence is, bool inCond) { - Macro* macro = nullptr; - int direcitve; - while (!is.Empty()) { - UpdateFirstTokenLine(is); - auto tok = is.Peek(); - const auto& name = tok->str_; - - if ((direcitve = GetDirective(is)) != Token::INVALID) { - ParseDirective(os, is, direcitve); - } else if (!inCond && !NeedExpand()) { - // Discards the token - is.Next(); - } else if (inCond && name == "defined") { - is.Next(); - os.InsertBack(EvalDefOp(is)); - } else if (tok->hs_ && tok->hs_->find(name) != tok->hs_->end()) { - os.InsertBack(is.Next()); - } else if ((macro = FindMacro(name))) { - is.Next(); - - if (name == "__FILE__") { - HandleTheFileMacro(os, tok); - } else if (name == "__LINE__") { - HandleTheLineMacro(os, tok); - } else if (macro->ObjLike()) { - // Make a copy, as subst will change repSeq - auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_); - - TokenList tokList; - TokenSequence repSeqSubsted(&tokList); - ParamMap paramMap; - // TODO(wgtdkp): hideset is not right - // Make a copy of hideset - // HS U {name} - auto hs = tok->hs_ ? *tok->hs_: HideSet(); - hs.insert(name); - Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap); - is.InsertFront(repSeqSubsted); - } else if (is.Try('(')) { - ParamMap paramMap; - auto rpar = ParseActualParam(is, macro, paramMap); - auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_); - TokenList tokList; - TokenSequence repSeqSubsted(&tokList); - - // (HS ^ HS') U {name} - // Use HS' U {name} directly - auto hs = rpar->hs_ ? *rpar->hs_: HideSet(); - hs.insert(name); - Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap); - is.InsertFront(repSeqSubsted); - } else { - os.InsertBack(tok); - } - } else { - os.InsertBack(is.Next()); - } - } -} - - -static bool FindActualParam(TokenSequence& ap, - ParamMap& params, - const std::string& fp) { - auto res = params.find(fp); - if (res == params.end()) { - return false; - } - ap.Copy(res->second); - return true; -} - - -void Preprocessor::Subst(TokenSequence& os, - TokenSequence is, - bool leadingWS, - const HideSet& hs, - ParamMap& params) { - TokenSequence ap; - - while (!is.Empty()) { - if (is.Test('#') && FindActualParam(ap, params, is.Peek2()->str_)) { - is.Next(); is.Next(); - auto tok = Stringize(ap); - os.InsertBack(tok); - } else if (is.Test(Token::DSHARP) && - FindActualParam(ap, params, is.Peek2()->str_)) { - is.Next(); is.Next(); - if (!ap.Empty()) - Glue(os, ap); - } else if (is.Test(Token::DSHARP)) { - is.Next(); - auto tok = is.Next(); - Glue(os, tok); - } else if (is.Peek2()->tag_ == Token::DSHARP && - FindActualParam(ap, params, is.Peek()->str_)) { - is.Next(); - - if (ap.Empty()) { - is.Next(); - if (FindActualParam(ap, params, is.Peek()->str_)) { - is.Next(); - os.InsertBack(ap); - } - } else { - os.InsertBack(ap); - } - } else if (FindActualParam(ap, params, is.Peek()->str_)) { - auto tok = is.Next(); - const_cast(ap.Peek())->ws_ = tok->ws_; - Expand(os, ap); - } else { - os.InsertBack(is.Peek()); - is.Next(); - } - } - - os.FinalizeSubst(leadingWS, hs); -} - - -void Preprocessor::Glue(TokenSequence& os, const Token* tok) { - TokenList tokList {tok}; - TokenSequence is(&tokList); - Glue(os, is); -} - - -void Preprocessor::Glue(TokenSequence& os, TokenSequence is) { - auto lhs = os.Back(); - auto rhs = is.Peek(); - - auto str = new std::string(lhs->str_ + rhs->str_); - TokenSequence ts; - Scanner scanner(str, lhs->loc_); - scanner.Tokenize(ts); - - is.Next(); - - if (ts.Empty()) { - // TODO(wgtdkp): - // No new Token generated - // How to handle it??? - } else { - os.PopBack(); - auto newTok = const_cast(ts.Next()); - newTok->ws_ = lhs->ws_; - newTok->hs_ = lhs->hs_; - os.InsertBack(newTok); - } - - if (!ts.Empty()) { - Error(lhs, "macro expansion failed: cannot concatenate"); - } - - os.InsertBack(is); -} - - -/* - * This is For the '#' operator in func-like macro - */ -const Token* Preprocessor::Stringize(TokenSequence is) { - std::string str = "\""; - while (!is.Empty()) { - auto tok = is.Next(); - // Have preceding white space - // and is not the first token of the sequence - str.append(tok->ws_ && str.size() > 1, ' '); - if (tok->tag_ == Token::LITERAL || tok->tag_ == Token::C_CONSTANT) { - for (auto c: tok->str_) { - if (c == '"' || c == '\\') - str.push_back('\\'); - str.push_back(c); - } - } else { - str += tok->str_; - } - } - str.push_back('\"'); - - auto ret = Token::New(*is.Peek()); - ret->tag_ = Token::LITERAL; - ret->str_ = str; - return ret; -} - - -void Preprocessor::Finalize(TokenSequence os) { - while (!os.Empty()) { - auto tok = os.Next(); - if (tok->tag_ == Token::INVALID) { - Error(tok, "stray token in program"); - } else if (tok->tag_ == Token::IDENTIFIER) { - auto tag = Token::KeyWordTag(tok->str_); - if (Token::IsKeyWord(tag)) { - const_cast(tok)->tag_ = tag; - } else { - const_cast(tok)->str_ = Scanner(tok).ScanIdentifier(); - } - } - if (fName_ && !tok->loc_.filename_) { - assert(false); - } - } -} - - -// TODO(wgtdkp): add predefined macros -void Preprocessor::Process(TokenSequence& os) { - TokenSequence is; - // Add source file - if(fName_) - IncludeFile(is, fName_); - else - IncludeSrc(is, fSrc_, nullptr); - // Expand - Expand(os, is); - Finalize(os); -} - - -const Token* Preprocessor::ParseActualParam(TokenSequence& is, - Macro* macro, - ParamMap& paramMap) { - const Token* ret; - if (macro->Params().size() == 0 && !macro->Variadic()) { - ret = is.Next(); - if (ret->tag_ != ')') - Error(ret, "too many arguments"); - return ret; - } - - auto fp = macro->Params().begin(); - TokenSequence ap; - - int cnt = 1; - while (cnt > 0) { - if (is.Empty()) - Error(is.Peek(), "premature end of input"); - else if (is.Test('(')) - ++cnt; - else if (is.Test(')')) - --cnt; - - if ((is.Test(',') && cnt == 1) || cnt == 0) { - - if (fp == macro->Params().end()) { - if (!macro->Variadic()) - Error(is.Peek(), "too many arguments"); - if (cnt == 0) - paramMap.insert(std::make_pair("__VA_ARGS__", ap)); - else - ap.InsertBack(is.Peek()); - } else { - paramMap.insert(std::make_pair(*fp, ap)); - ap = TokenSequence(); - ++fp; - } - } else { - ap.InsertBack(is.Peek()); - } - ret = is.Next(); - } - - if (fp != macro->Params().end()) - Error(is.Peek(), "too few params"); - return ret; -} - - -const Token* Preprocessor::EvalDefOp(TokenSequence& is) { - auto hasPar = is.Try('('); - auto macro = is.Expect(Token::IDENTIFIER); - auto cons = Token::New(*macro); - if (hasPar) is.Expect(')'); - cons->tag_ = Token::I_CONSTANT; - cons->str_ = FindMacro(macro->str_) ? "1": "0"; - return cons; -} - - -void Preprocessor::ReplaceIdent(TokenSequence& is) { - TokenSequence os; - while (!is.Empty()) { - auto tok = is.Next(); - if (tok->tag_ == Token::IDENTIFIER) { - auto cons = Token::New(*tok); - cons->tag_ = Token::I_CONSTANT; - cons->str_ = "0"; - os.InsertBack(cons); - } else { - os.InsertBack(tok); - } - } - is = os; -} - - -int Preprocessor::GetDirective(TokenSequence& is) { - if (!is.Test('#') || !is.IsBeginOfLine()) - return Token::INVALID; - - is.Next(); - if (is.IsBeginOfLine()) - return Token::PP_EMPTY; - - auto tag = is.Peek()->tag_; - if (tag == Token::IDENTIFIER || Token::IsKeyWord(tag)) { - auto str = is.Peek()->str_; - auto res = directiveMap.find(str); - if (res == directiveMap.end()) - return Token::PP_NONE; - return res->second; - } - return Token::PP_NONE; -} - - -void Preprocessor::ParseDirective(TokenSequence& os, - TokenSequence& is, - int directive) { - if (directive == Token::PP_EMPTY) - return; - auto ls = is.GetLine(); - switch(directive) { - case Token::PP_IF: - ParseIf(ls); break; - case Token::PP_IFDEF: - ParseIfdef(ls); break; - case Token::PP_IFNDEF: - ParseIfndef(ls); break; - case Token::PP_ELIF: - ParseElif(ls); break; - case Token::PP_ELSE: - ParseElse(ls); break; - case Token::PP_ENDIF: - ParseEndif(ls); break; - case Token::PP_INCLUDE: - if (NeedExpand()) - ParseInclude(is, ls); - break; - case Token::PP_DEFINE: - if (NeedExpand()) - ParseDef(ls); - break; - case Token::PP_UNDEF: - if (NeedExpand()) - ParseUndef(ls); - break; - case Token::PP_LINE: - if (NeedExpand()) - ParseLine(ls); - break; - case Token::PP_ERROR: - if (NeedExpand()) - ParseError(ls); - break; - case Token::PP_PRAGMA: - if (NeedExpand()) - ParsePragma(ls); - break; - case Token::PP_NONE: - break; - default: - assert(false); - } -} - - -void Preprocessor::ParsePragma(TokenSequence ls) { - // TODO(wgtdkp): - ls.Next(); -} - - -void Preprocessor::ParseError(TokenSequence ls) { - ls.Next(); - const auto& literal = Stringize(ls); - std::string msg; - Scanner(literal).ScanLiteral(msg); - Error(ls.Peek(), "%s", msg.c_str()); -} - - -void Preprocessor::ParseLine(TokenSequence ls) { - auto directive = ls.Next(); // Skip directive 'line' - TokenSequence ts; - Expand(ts, ls); - auto tok = ts.Expect(Token::I_CONSTANT); - - int line = 0; - size_t end = 0; - try { - line = stoi(tok->str_, &end, 10); - } catch (const std::out_of_range& oor) { - Error(tok, "line number out of range"); - } - if (line == 0 || end != tok->str_.size()) { - Error(tok, "illegal line number"); - } - - curLine_ = line; - lineLine_ = directive->loc_.line_; - if (ts.Empty()) - return; - tok = ts.Expect(Token::LITERAL); - - // Enusure "s-char-sequence" - if (tok->str_.front() != '"' || tok->str_.back() != '"') { - Error(tok, "expect s-char-sequence"); - } -} - - -void Preprocessor::ParseIf(TokenSequence ls) { - if (!NeedExpand()) { - ppCondStack_.push({Token::PP_IF, false, false}); - return; - } - - auto tok = ls.Next(); // Skip the directive - - if (ls.Empty()) { - Error(tok, "expect expression in 'if' directive"); - } - - TokenSequence ts; - Expand(ts, ls, true); - ReplaceIdent(ts); - - Parser parser(ts); - auto expr = parser.ParseExpr(); - if (!parser.ts().Empty()) { - Error(parser.ts().Peek(), "unexpected extra expression"); - } - bool cond; - if (expr->Type()->IsFloat()) { - cond = static_cast(Evaluator().Eval(expr)); - } else { - cond = static_cast(Evaluator().Eval(expr)); - } - ppCondStack_.push({Token::PP_IF, NeedExpand(), cond}); -} - - -void Preprocessor::ParseIfdef(TokenSequence ls) { - if (!NeedExpand()) { - ppCondStack_.push({Token::PP_IFDEF, false, false}); - return; - } - - ls.Next(); - auto ident = ls.Expect(Token::IDENTIFIER); - if (!ls.Empty()) { - Error(ls.Peek(), "expect new line"); - } - - auto cond = FindMacro(ident->str_) != nullptr; - ppCondStack_.push({Token::PP_IFDEF, NeedExpand(), cond}); -} - - -void Preprocessor::ParseIfndef(TokenSequence ls) { - ParseIfdef(ls); - auto top = ppCondStack_.top(); - ppCondStack_.pop(); - top.tag_ = Token::PP_IFNDEF; - top.cond_ = !top.cond_; - - ppCondStack_.push(top); -} - - -void Preprocessor::ParseElif(TokenSequence ls) { - auto directive = ls.Next(); // Skip the directive - - if (ppCondStack_.empty()) - Error(directive, "unexpected 'elif' directive"); - auto top = ppCondStack_.top(); - if (top.tag_ == Token::PP_ELSE) - Error(directive, "unexpected 'elif' directive"); - - while (!ppCondStack_.empty()) { - top = ppCondStack_.top(); - if (top.tag_ == Token::PP_IF || - top.tag_ == Token::PP_IFDEF || - top.tag_ == Token::PP_IFNDEF || - top.cond_) { - break; - } - ppCondStack_.pop(); - } - if (ppCondStack_.empty()) - Error(directive, "unexpected 'elif' directive"); - auto enabled = top.enabled_; - if (!enabled) { - ppCondStack_.push({Token::PP_ELIF, false, false}); - return; - } - - if (ls.Empty()) { - Error(ls.Peek(), "expect expression in 'elif' directive"); - } - - TokenSequence ts; - Expand(ts, ls, true); - ReplaceIdent(ts); - - Parser parser(ts); - auto expr = parser.ParseExpr(); - if (!parser.ts().Empty()) { - Error(parser.ts().Peek(), "unexpected extra expression"); - } - bool cond; - if (expr->Type()->IsFloat()) { - std::cout << Evaluator().Eval(expr) << std::endl; - cond = static_cast(Evaluator().Eval(expr)); - } else { - cond = static_cast(Evaluator().Eval(expr)); - } - cond = cond && !top.cond_; - ppCondStack_.push({Token::PP_ELIF, true, cond}); -} - - -void Preprocessor::ParseElse(TokenSequence ls) { - auto directive = ls.Next(); - if (!ls.Empty()) - Error(ls.Peek(), "expect new line"); - - if (ppCondStack_.empty()) - Error(directive, "unexpected 'else' directive"); - auto top = ppCondStack_.top(); - if (top.tag_ == Token::PP_ELSE) - Error(directive, "unexpected 'else' directive"); - - while (!ppCondStack_.empty()) { - top = ppCondStack_.top(); - if (top.tag_ == Token::PP_IF || - top.tag_ == Token::PP_IFDEF || - top.tag_ == Token::PP_IFNDEF || - top.cond_) { - break; - } - ppCondStack_.pop(); - } - if (ppCondStack_.empty()) - Error(directive, "unexpected 'else' directive"); - - auto cond = !top.cond_; - auto enabled = top.enabled_; - ppCondStack_.push({Token::PP_ELSE, enabled, cond}); -} - - -void Preprocessor::ParseEndif(TokenSequence ls) { - auto directive = ls.Next(); - if (!ls.Empty()) - Error(ls.Peek(), "expect new line"); - - while ( !ppCondStack_.empty()) { - auto top = ppCondStack_.top(); - ppCondStack_.pop(); - - if (top.tag_ == Token::PP_IF - || top.tag_ == Token::PP_IFDEF - || top.tag_ == Token::PP_IFNDEF) { - return; - } - } - - if (ppCondStack_.empty()) - Error(directive, "unexpected 'endif' directive"); -} - - -// Have Read the '#' -void Preprocessor::ParseInclude(TokenSequence& is, TokenSequence ls) { - bool next = ls.Next()->str_ == "include_next"; // Skip 'include' - if (!ls.Test(Token::LITERAL) && !ls.Test('<')) { - TokenSequence ts; - Expand(ts, ls, true); - ls = ts; - } - - auto tok = ls.Next(); - if (tok->tag_ == Token::LITERAL) { - if (!ls.Empty()) { - Error(ls.Peek(), "expect new line"); - } - std::string filename; - Scanner(tok).ScanLiteral(filename); - auto fullPath = SearchFile(filename, false, next, *tok->loc_.filename_); - if (fullPath == nullptr) - Error(tok, "%s: No such file or directory", filename.c_str()); - - IncludeFile(is, fullPath); - } else if (tok->tag_ == '<') { - auto lhs = tok; - auto rhs = tok; - int cnt = 1; - while (!(rhs = ls.Next())->IsEOF()) { - if (rhs->tag_ == '<') - ++cnt; - else if (rhs->tag_ == '>') - --cnt; - if (cnt == 0) - break; - } - if (cnt != 0) - Error(rhs, "expect '>'"); - if (!ls.Empty()) - Error(ls.Peek(), "expect new line"); - - const auto& filename = Scanner::ScanHeadName(lhs, rhs); - auto fullPath = SearchFile(filename, true, next, *tok->loc_.filename_); - if (fullPath == nullptr) { - Error(tok, "%s: No such file or directory", filename.c_str()); - } - IncludeFile(is, fullPath); - } else { - Error(tok, "expect filename(string or in '<>')"); - } -} - - -void Preprocessor::ParseUndef(TokenSequence ls) { - ls.Next(); // Skip directive - - auto ident = ls.Expect(Token::IDENTIFIER); - if (!ls.Empty()) - Error(ls.Peek(), "expect new line"); - - RemoveMacro(ident->str_); -} - - -void Preprocessor::ParseDef(TokenSequence ls) { - ls.Next(); - auto ident = ls.Expect(Token::IDENTIFIER); - if (ident->str_ == "defined") { - Error(ident, "'defined' cannot be used as a macro name"); - } - auto tok = ls.Peek(); - if (tok->tag_ == '(' && !tok->ws_) { - // There is no white space between ident and '(' - // Hence, we are defining function-like macro - - // Parse Identifier list - ls.Next(); // Skip '(' - ParamList params; - auto variadic = ParseIdentList(params, ls); - const auto& macro = Macro(variadic, params, ls); - AddMacro(ident->str_, macro); - } else { - AddMacro(ident->str_, Macro(ls)); - } -} - - -bool Preprocessor::ParseIdentList(ParamList& params, TokenSequence& is) { - const Token* tok = is.Peek(); - while (!is.Empty()) { - tok = is.Next(); - if (tok->tag_ == ')') { - return false; - } else if (tok->tag_ == Token::ELLIPSIS) { - is.Expect(')'); - return true; - } else if (tok->tag_ != Token::IDENTIFIER) { - Error(tok, "expect identifier"); - } - - for (const auto& param: params) { - if (param == tok->str_) - Error(tok, "duplicated param"); - } - params.push_back(tok->str_); - - if (!is.Try(',')) { - is.Expect(')'); - return false; - } - } - - Error(tok, "unexpected end of line"); -} - -void Preprocessor::IncludeSrc(TokenSequence& is, - const std::string* text, - const std::string* filename) { - TokenSequence ts {is.tokList_, is.begin_, is.begin_}; - Scanner scanner(text, filename); - scanner.Tokenize(ts); - - // We done including header file - is.begin_ = ts.begin_; -} - -void Preprocessor::IncludeFile(TokenSequence& is, - const std::string* filename) { - IncludeSrc(is, ReadFile(*filename), filename); -} - - -static std::string GetDir(const std::string& path) { - auto pos = path.rfind('/'); - if (pos == std::string::npos) - return "./"; - return path.substr(0, pos + 1); -} - - -std::string* Preprocessor::SearchFile(const std::string& name, - const bool libHeader, - bool next, - const std::string& curPath) { - if (libHeader && !next) { - searchPaths_.push_back(GetDir(curPath)); - } else { - searchPaths_.push_front(GetDir(curPath)); - } - - auto iter = searchPaths_.begin(); - for (; iter != searchPaths_.end(); ++iter) { - auto dd = open(iter->c_str(), O_RDONLY); - if (dd == -1) // TODO(wgtdkp): or ensure it before preprocessing - continue; - auto fd = openat(dd, name.c_str(), O_RDONLY); - close(dd); - if (fd != -1) { - // Intentional, so that recursive include - // will result in running out of file descriptor - //close(fd); - auto path = *iter + name; - if (next) { - if (path != curPath) - continue; - else - next = false; - } else { - if (path == curPath) - continue; - if (libHeader && !next) - searchPaths_.pop_back(); - else - searchPaths_.pop_front(); - return new std::string(path); - } - } else if (errno == EMFILE) { - Error("may recursive include"); - } - } - return nullptr; -} - - -void Preprocessor::AddMacro(const std::string& name, - std::string* text, - bool preDef) { - TokenSequence ts; - Scanner scanner(text); - scanner.Tokenize(ts); - Macro macro(ts, preDef); - - AddMacro(name, macro); -} - - -static std::string* Date() { - time_t t = time(NULL); - struct tm* tm = localtime(&t); - char buf[14]; - strftime(buf, sizeof buf, "\"%a %M %Y\"", tm); - return new std::string(buf); -} - - -void Preprocessor::Init() { - // Preinclude search paths - AddSearchPath("/usr/local/include/"); - AddSearchPath("/usr/include/x86_64-linux-gnu/"); - AddSearchPath("/usr/include/linux/"); - AddSearchPath("/usr/include/"); - AddSearchPath("/usr/local/include/"); - - // The __FILE__ and __LINE__ macro is empty - // They are handled seperately - AddMacro("__FILE__", Macro(TokenSequence(), true)); - AddMacro("__LINE__", Macro(TokenSequence(), true)); - - AddMacro("__DATE__", Date(), true); - AddMacro("__STDC__", new std::string("1"), true); - AddMacro("__STDC__HOSTED__", new std::string("0"), true); - AddMacro("__STDC_VERSION__", new std::string("201103L"), true); -} - - -void Preprocessor::HandleTheFileMacro(TokenSequence& os, const Token* macro) { - auto file = Token::New(*macro); - file->tag_ = Token::LITERAL; - file->str_ = "\"" + *macro->loc_.filename_ + "\""; - os.InsertBack(file); -} - - -void Preprocessor::HandleTheLineMacro(TokenSequence& os, const Token* macro) { - auto line = Token::New(*macro); - line->tag_ = Token::I_CONSTANT; - line->str_ = std::to_string(macro->loc_.line_); - os.InsertBack(line); -} - - -void Preprocessor::UpdateFirstTokenLine(TokenSequence ts) { - auto loc = ts.Peek()->loc_; - loc.line_ = curLine_ + loc.line_ - lineLine_ - 1; - ts.UpdateHeadLocation(loc); -} - - -TokenSequence Macro::RepSeq(const std::string* filename, unsigned line) { - // Update line - TokenList tl; - TokenSequence ret(&tl); - ret.Copy(repSeq_); - auto ts = ret; - while (!ts.Empty()) { - auto loc = ts.Peek()->loc_; - loc.filename_ = filename; - loc.line_ = line; - ts.UpdateHeadLocation(loc); - ts.Next(); - } - return ret; -} - - -void Preprocessor::AddSearchPath(std::string path) { - if (path.back() != '/') - path += "/"; - if (path[0] != '/') - path = "./" + path; - searchPaths_.push_front(path); -} diff --git a/lib/lang/encoding.cc b/lib/lang/encoding.cc deleted file mode 100644 index 931e4fc30..000000000 --- a/lib/lang/encoding.cc +++ /dev/null @@ -1,42 +0,0 @@ -#include "triton/lang/encoding.h" - -#include -#include -#include -#include - - -static void Append16LE(std::string& str, char16_t c) { - str.push_back(c & UCHAR_MAX); - str.push_back((c >> 8) & UCHAR_MAX); -} - - -static void Append32LE(std::string& str, char32_t c) { - Append16LE(str, c & USHRT_MAX); - Append16LE(str, (c >> 16) & USHRT_MAX); -} - - -void ConvertToUTF16(std::string& str) { - std::wstring_convert, char16_t> utf8_ucs2_cvt; - auto str16 = utf8_ucs2_cvt.from_bytes(str); - str.resize(0); - for (auto c16: str16) - Append16LE(str, c16); -} - - -void ConvertToUTF32(std::string& str) { - std::wstring_convert, char32_t> utf8_ucs4_cvt; - auto str32 = utf8_ucs4_cvt.from_bytes(str); - str.resize(0); - for (auto c32: str32) - Append32LE(str, c32); -} - - -void AppendUCN(std::string& str, int c) { - std::wstring_convert, char32_t> utf8_ucs4_cvt; - str += utf8_ucs4_cvt.to_bytes(static_cast(c)); -} diff --git a/lib/lang/error.cc b/lib/lang/error.cc deleted file mode 100644 index 3c2d8b339..000000000 --- a/lib/lang/error.cc +++ /dev/null @@ -1,91 +0,0 @@ -#include "triton/lang/error.h" - -#include "triton/lang/ast.h" -#include "triton/lang/token.h" - -#include -#include -#include -#include - - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" - - -void Error(const char* format, ...) { - fprintf(stderr, - ANSI_COLOR_RED "error: " ANSI_COLOR_RESET); - - va_list args; - va_start(args, format); - vfprintf(stderr, format, args); - va_end(args); - - fprintf(stderr, "\n"); - - exit(-1); -} - - -[[noreturn]] -static void VError(const SourceLocation& loc, - const char* format, - va_list args) { - const char* filename = nullptr; - if(loc.filename_) - filename = loc.filename_->c_str(); - fprintf(stderr, - "%s:%d:%d: " ANSI_COLOR_RED "error: " ANSI_COLOR_RESET, - filename, - loc.line_, - loc.column_); - vfprintf(stderr, format, args); - fprintf(stderr, "\n "); - - bool sawNoSpace = false; - int nspaces = 0; - for (auto p = loc.lineBegin_; *p != '\n' && *p != 0; p++) { - if (!sawNoSpace && (*p == ' ' || *p == '\t')) { - ++nspaces; - } else { - sawNoSpace = true; - fputc(*p, stderr); - } - } - - fprintf(stderr, "\n "); - for (unsigned i = 1; i + nspaces < loc.column_; ++i) - fputc(' ', stderr); - fprintf(stderr, "^\n"); - exit(-1); -} - - -void Error(const SourceLocation& loc, const char* format, ...) { - va_list args; - va_start(args, format); - VError(loc, format, args); - va_end(args); -} - - -void Error(const Token* tok, const char* format, ...) { - va_list args; - va_start(args, format); - VError(tok->loc_, format, args); - va_end(args); -} - - -void Error(const Expr* expr, const char* format, ...) { - va_list args; - va_start(args, format); - VError(expr->Tok()->loc_, format, args); - va_end(args); -} diff --git a/lib/lang/evaluator.cc b/lib/lang/evaluator.cc deleted file mode 100644 index 0123f4239..000000000 --- a/lib/lang/evaluator.cc +++ /dev/null @@ -1,206 +0,0 @@ -#include "triton/lang/evaluator.h" -#include "triton/lang/ast.h" -#include "triton/lang/token.h" - - -template -void Evaluator::VisitBinaryOp(BinaryOp* binary) { -#define L Evaluator().Eval(binary->lhs_) -#define R Evaluator().Eval(binary->rhs_) -#define LL Evaluator().Eval(binary->lhs_) -#define LR Evaluator().Eval(binary->rhs_) - - if (binary->Type()->ToPointer()) { - auto val = Evaluator().Eval(binary); - if (val.label_.size()) { - Error(binary, "expect constant integer expression"); - } - val_ = static_cast(val.offset_); - return; - } - - switch (binary->op_) { - case '+': val_ = L + R; break; - case '-': val_ = L - R; break; - case '*': val_ = L * R; break; - case '/': { - auto l = L, r = R; - if (r == 0) - Error(binary, "division by zero"); - val_ = l / r; - } break; - case '%': { - auto l = LL, r = LR; - if (r == 0) - Error(binary, "division by zero"); - val_ = l % r; - } break; - // Bitwise operators that do not accept float - case '|': val_ = LL | LR; break; - case '&': val_ = LL & LR; break; - case '^': val_ = LL ^ LR; break; - case Token::LEFT: val_ = LL << LR; break; - case Token::RIGHT: val_ = LL >> LR; break; - - case '<': val_ = L < R; break; - case '>': val_ = L > R; break; - case Token::LOGICAL_AND: val_ = L && R; break; - case Token::LOGICAL_OR: val_ = L || R; break; - case Token::EQ: val_ = L == R; break; - case Token::NE: val_ = L != R; break; - case Token::LE: val_ = L <= R; break; - case Token::GE: val_ = L >= R; break; - case '=': case ',': val_ = R; break; - case '.': { - auto addr = Evaluator().Eval(binary); - if (addr.label_.size()) - Error(binary, "expect constant expression"); - val_ = addr.offset_; - } - default: assert(false); - } - -#undef L -#undef R -#undef LL -#undef LR -} - - -template -void Evaluator::VisitUnaryOp(UnaryOp* unary) { -#define VAL Evaluator().Eval(unary->operand_) -#define LVAL Evaluator().Eval(unary->operand_) - - switch (unary->op_) { - case Token::PLUS: val_ = VAL; break; - case Token::MINUS: val_ = -VAL; break; - case '~': val_ = ~LVAL; break; - case '!': val_ = !VAL; break; - case Token::CAST: - if (unary->Type()->IsInteger()) - val_ = static_cast(VAL); - else - val_ = VAL; - break; - case Token::ADDR: { - auto addr = Evaluator().Eval(unary->operand_); - if (addr.label_.size()) - Error(unary, "expect constant expression"); - val_ = addr.offset_; - } break; - default: Error(unary, "expect constant expression"); - } - -#undef LVAL -#undef VAL -} - - -template -void Evaluator::VisitConditionalOp(ConditionalOp* condOp) { - bool cond; - auto condType = condOp->cond_->Type(); - if (condType->IsInteger()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val != 0; - } else if (condType->IsFloat()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val != 0.0; - } else if (condType->ToPointer()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val.label_.size() || val.offset_; - } else { - assert(false); - } - - if (cond) { - val_ = Evaluator().Eval(condOp->exprTrue_); - } else { - val_ = Evaluator().Eval(condOp->exprFalse_); - } -} - - -void Evaluator::VisitBinaryOp(BinaryOp* binary) { -#define LR Evaluator().Eval(binary->rhs_) -#define R Evaluator().Eval(binary->rhs_) - - auto l = Evaluator().Eval(binary->lhs_); - - int width = 1; - auto pointerType = binary->Type()->ToPointer(); - if (pointerType) - width = pointerType->Derived()->Width(); - - switch (binary->op_) { - case '+': - assert(pointerType); - addr_.label_ = l.label_; - addr_.offset_ = l.offset_ + LR * width; - break; - case '-': - assert(pointerType); - addr_.label_ = l.label_; - addr_.offset_ = l.offset_ + LR * width; - break; - case '.': { - addr_.label_ = l.label_; - auto type = binary->lhs_->Type()->ToStruct(); - auto offset = type->GetMember(binary->rhs_->tok_->str_)->Offset(); - addr_.offset_ = l.offset_ + offset; - break; - } - default: assert(false); - } -#undef LR -#undef R -} - - -void Evaluator::VisitUnaryOp(UnaryOp* unary) { - auto addr = Evaluator().Eval(unary->operand_); - - switch (unary->op_) { - case Token::CAST: - case Token::ADDR: - case Token::DEREF: - addr_ = addr; break; - default: assert(false); - } -} - - -void Evaluator::VisitConditionalOp(ConditionalOp* condOp) { - bool cond; - auto condType = condOp->cond_->Type(); - if (condType->IsInteger()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val != 0; - } else if (condType->IsFloat()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val != 0.0; - } else if (condType->ToPointer()) { - auto val = Evaluator().Eval(condOp->cond_); - cond = val.label_.size() || val.offset_; - } else { - assert(false); - } - - if (cond) { - addr_ = Evaluator().Eval(condOp->exprTrue_); - } else { - addr_ = Evaluator().Eval(condOp->exprFalse_); - } -} - - -void Evaluator::VisitConstant(Constant* cons) { - if (cons->Type()->IsInteger()) { - addr_ = {"", static_cast(cons->IVal())}; - } else if (cons->Type()->ToArray()) { - assert(false); - } else { - assert(false); - } -} diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc deleted file mode 100644 index 8f6ad617f..000000000 --- a/lib/lang/parser.cc +++ /dev/null @@ -1,2799 +0,0 @@ -#include "triton/lang/parser.h" - -#include "triton/lang/cpp.h" -#include "triton/lang/encoding.h" -#include "triton/lang/error.h" -#include "triton/lang/evaluator.h" -#include "triton/lang/scope.h" -#include "triton/lang/type.h" - -#include -#include -#include -#include - - -FuncType* Parser::vaStartType_ {nullptr}; -FuncType* Parser::vaArgType_ {nullptr}; - - -FuncDef* Parser::EnterFunc(Identifier* ident) { - curFunc_ = FuncDef::New(ident, LabelStmt::New()); - return curFunc_; -} - - -void Parser::ExitFunc() { - // Resolve 那些待定的jump; - // 如果有jump无法resolve,也就是有未定义的label,报错; - for (auto iter = unresolvedJumps_.begin(); - iter != unresolvedJumps_.end(); ++iter) { - auto label = iter->first; - auto labelStmt = FindLabel(label->str_); - if (labelStmt == nullptr) { - Error(label, "label '%s' used but not defined", - label->str_.c_str()); - } - - iter->second->SetLabel(labelStmt); - } - - unresolvedJumps_.clear(); //清空未定的 jump 动作 - curLabels_.clear(); //清空 label map - - curFunc_ = nullptr; -} - - -void Parser::EnterBlock(FuncType* funcType) { - curScope_ = new Scope(curScope_, S_BLOCK); - if (funcType) { - // Merge elements in param scope into current block scope - for (auto param: funcType->Params()) - curScope_->Insert(param); - } -} - - -void Parser::Parse() { - DefineBuiltins(); - ParseTranslationUnit(); -} - - -void Parser::ParseTranslationUnit() { - while (!ts_.Peek()->IsEOF()) { - if (ts_.Try(Token::STATIC_ASSERT)) { - ParseStaticAssert(); - continue; - } else if (ts_.Try(';')) { - continue; - } - - int storageSpec, funcSpec, align; - auto declType = ParseDeclSpec(&storageSpec, &funcSpec, &align); - auto declInfo = ParseDeclarator(declType); - - auto tok = declInfo.tok; - auto type = declInfo.type; - auto attrs = declInfo.attrs; - - if (tok == nullptr) { - ts_.Expect(';'); - continue; - } - - auto ident = ProcessDeclarator(tok, type, attrs, storageSpec, funcSpec, align); - type = ident->Type(); - - if (tok && type->ToFunc() && ts_.Try('{')) { // Function definition - unit_->Add(ParseFuncDef(ident)); - } else { // Declaration - auto decl = ParseInitDeclarator(ident); - if (decl) unit_->Add(decl); - - while (ts_.Try(',')) { - auto ident = ParseDirectDeclarator(declType, storageSpec, - funcSpec, align); - decl = ParseInitDeclarator(ident); - if (decl) unit_->Add(decl); - } - // GNU extension: function/type/variable attributes - TryAttributeSpecList(); - ts_.Expect(';'); - } - } -} - - -FuncDef* Parser::ParseFuncDef(Identifier* ident) { - auto funcDef = EnterFunc(ident); - if (funcDef->FuncType()->Complete()) { - Error(ident, "redefinition of '%s'", funcDef->Name().c_str()); - } - // TODO(wgtdkp): param checking - auto funcType = ident->Type()->ToFunc(); - funcType->SetComplete(true); - for (auto param: funcType->Params()) { - if (param->Anonymous()) - Error(param, "param name omitted"); - } - funcDef->SetBody(ParseCompoundStmt(funcType)); - ExitFunc(); - - return funcDef; -} - - -Expr* Parser::ParseExpr() { - return ParseCommaExpr(); -} - - -Expr* Parser::ParseCommaExpr() { - auto lhs = ParseAssignExpr(); - auto tok = ts_.Peek(); - while (ts_.Try(',')) { - auto rhs = ParseAssignExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - return lhs; -} - - -Expr* Parser::ParsePrimaryExpr() { - if (ts_.Empty()) { - Error(ts_.Peek(), "premature end of input"); - } - - auto tok = ts_.Next(); - if (tok->tag_ == '(') { - auto expr = ParseExpr(); - ts_.Expect(')'); - return expr; - } - - if (tok->IsIdentifier()) { - auto ident = curScope_->Find(tok); - if (ident) return ident; - if (IsBuiltin(tok->str_)) return GetBuiltin(tok); - Error(tok, "undefined symbol '%s'", tok->str_.c_str()); - } else if (tok->IsConstant()) { - return ParseConstant(tok); - } else if (tok->IsLiteral()) { - return ConcatLiterals(tok); - } else if (tok->tag_ == Token::GENERIC) { - return ParseGeneric(); - } - - Error(tok, "'%s' unexpected", tok->str_.c_str()); - return nullptr; // Make compiler happy -} - - -static void ConvertLiteral(std::string& val, Encoding enc) { - switch (enc) { - case Encoding::NONE: - case Encoding::UTF8: break; - case Encoding::CHAR16: ConvertToUTF16(val); break; - case Encoding::CHAR32: - case Encoding::WCHAR: ConvertToUTF32(val); break; - } -} - - -Constant* Parser::ConcatLiterals(const Token* tok) { - auto val = new std::string; - auto enc = Scanner(tok).ScanLiteral(*val); - ConvertLiteral(*val, enc); - while (ts_.Test(Token::LITERAL)) { - auto nextTok = ts_.Next(); - std::string nextVal; - auto nextEnc = Scanner(nextTok).ScanLiteral(nextVal); - ConvertLiteral(nextVal, nextEnc); - if (enc == Encoding::NONE) { - ConvertLiteral(*val, nextEnc); - enc = nextEnc; - } - if (nextEnc != Encoding::NONE && nextEnc != enc) - Error(nextTok, "cannot concat lietrals with different encodings"); - *val += nextVal; - } - - int tag = T_CHAR; - switch (enc) { - case Encoding::NONE: - case Encoding::UTF8: - tag = T_CHAR; val->append(1, '\0'); break; - case Encoding::CHAR16: - tag = T_UNSIGNED | T_SHORT; val->append(2, '\0'); break; - case Encoding::CHAR32: - case Encoding::WCHAR: - tag = T_UNSIGNED | T_INT; val->append(4, '\0'); break; - } - - return Constant::New(tok, tag, val); -} - - -Encoding Parser::ParseLiteral(std::string& str, const Token* tok) { - return Scanner(tok).ScanLiteral(str); -} - - -Constant* Parser::ParseConstant(const Token* tok) { - assert(tok->IsConstant()); - if (tok->tag_ == Token::I_CONSTANT) { - return ParseInteger(tok); - } else if (tok->tag_ == Token::C_CONSTANT) { - return ParseCharacter(tok); - } else { - return ParseFloat(tok); - } -} - - -Constant* Parser::ParseFloat(const Token* tok) { - const auto& str = tok->str_; - size_t end = 0; - double val = 0.0; - try { - val = stod(str, &end); - } catch (const std::out_of_range& oor) { - Error(tok, "float out of range"); - } - - int tag = T_DOUBLE; - if (str[end] == 'f' || str[end] == 'F') { - tag = T_FLOAT; - ++end; - } else if (str[end] == 'l' || str[end] == 'L') { - tag = T_LONG | T_DOUBLE; - ++end; - } - if (str[end] != 0) - Error(tok, "invalid suffix"); - return Constant::New(tok, tag, val); -} - - -Constant* Parser::ParseCharacter(const Token* tok) { - int val; - auto enc = Scanner(tok).ScanCharacter(val); - - int tag; - switch (enc) { - case Encoding::NONE: - val = (char)val; - tag = T_INT; break; - case Encoding::CHAR16: - val = (char16_t)val; - tag = T_UNSIGNED | T_SHORT; break; - case Encoding::WCHAR: - case Encoding::CHAR32: tag = T_UNSIGNED | T_INT; break; - default: assert(false); - } - return Constant::New(tok, tag, static_cast(val)); -} - - -Constant* Parser::ParseInteger(const Token* tok) { - const auto& str = tok->str_; - size_t end = 0; - long val = 0; - try { - val = stoull(str, &end, 0); - } catch (const std::out_of_range& oor) { - Error(tok, "integer out of range"); - } - - int tag = 0; - for (; str[end]; ++end) { - if (str[end] == 'u' || str[end] == 'U') { - if (tag & T_UNSIGNED) - Error(tok, "invalid suffix"); - tag |= T_UNSIGNED; - } else { - if ((tag & T_LONG) || (tag & T_LLONG)) - Error(tok, "invalid suffix"); - if (str[end + 1] == 'l' || str[end + 1] =='L') { - tag |= T_LLONG; - ++end; - } else { - tag |= T_LONG; - } - } - } - - bool decimal = ('1' <= str[0] && str[0] <= '9'); - if (decimal) { - switch (tag) { - case 0: - tag |= !(val & ~(long)INT_MAX) ? T_INT: T_LONG; break; - case T_UNSIGNED: - tag |= !(val & ~(long)UINT_MAX) ? T_INT: T_LONG; break; - case T_LONG: break; - case T_UNSIGNED | T_LONG: break; - } - } else { - switch (tag) { - case 0: - tag |= !(val & ~(long)INT_MAX) ? T_INT - : !(val & ~(long)UINT_MAX) ? T_UNSIGNED - : !(val & ~(long)LONG_MAX) ? T_LONG - : T_UNSIGNED | T_LONG; break; - case T_UNSIGNED: - tag |= !(val & ~(long)UINT_MAX) ? T_INT: T_LONG; break; - case T_LONG: - tag |= !(val & ~(long)LONG_MAX) ? 0: T_UNSIGNED; break; - case T_UNSIGNED | T_LONG: - break; - } - } - - return Constant::New(tok, tag, val); -} - - -Expr* Parser::ParseGeneric() { - ts_.Expect('('); - auto controlExpr = ParseAssignExpr(); - ts_.Expect(','); - Expr* selectedExpr = nullptr; - bool isDefault = false; - while (true) { - if (ts_.Try(Token::DEFAULT)) { - ts_.Expect(':'); - auto defaultExpr = ParseAssignExpr(); - if (!selectedExpr) { - selectedExpr = defaultExpr; - isDefault = true; - } - } else { - auto tok = ts_.Peek(); - auto type = ParseTypeName(); - ts_.Expect(':'); - auto expr = ParseAssignExpr(); - if (type->Compatible(*controlExpr->Type())) { - if (selectedExpr && !isDefault) { - Error(tok, "more than one generic association" - " are compatible with control expression"); - } - selectedExpr = expr; - isDefault = false; - } - } - if (!ts_.Try(',')) { - ts_.Expect(')'); - break; - } - } - - if (!selectedExpr) - Error(ts_.Peek(), "no compatible generic association"); - return selectedExpr; -} - - -QualType Parser::TryCompoundLiteral() { - auto mark = ts_.Mark(); - if (ts_.Try('(') && IsTypeName(ts_.Peek())) { - auto type = ParseTypeName(); - if (ts_.Try(')') && ts_.Test('{')) - return type; - } - ts_.ResetTo(mark); - return nullptr; -} - - -Expr* Parser::ParsePostfixExpr() { - if (ts_.Peek()->IsEOF()) { - Error(ts_.Peek(), "premature end of input"); - } - - auto type = TryCompoundLiteral(); - if (type) { - auto anony = ParseCompoundLiteral(type); - return ParsePostfixExprTail(anony); - } - - Expr* primExpr; - //FIXME: merge into generic array functions - if(ts_.Try(Token::EXP)) - primExpr = ParseUnaryIntrinsicOp(Token::EXP); - else if(ts_.Try(Token::SQRTF)) - primExpr = ParseUnaryIntrinsicOp(Token::SQRTF); - else if(ts_.Try(Token::LOG)) - primExpr = ParseUnaryIntrinsicOp(Token::LOG); - else - primExpr = ParsePrimaryExpr(); - return ParsePostfixExprTail(primExpr); -} - - -Object* Parser::ParseCompoundLiteral(QualType type) { - auto linkage = curScope_->Type() == S_FILE ? L_INTERNAL: L_NONE; - auto anony = Object::NewAnony(ts_.Peek(), type, 0, linkage); - auto decl = ParseInitDeclaratorSub(anony); - - // Just for generator to find the compound literal - if (curScope_->Type() == S_FILE) { - unit_->Add(decl); - } else { - curScope_->Insert(anony->Repr(), anony); - } - return anony; -} - - -// Return the constructed postfix expression -Expr* Parser::ParsePostfixExprTail(Expr* lhs) { - while (true) { - auto tok = ts_.Next(); - - switch (tok->tag_) { - case '[': lhs = ParseSubScripting(lhs); break; - case '(': lhs = ParseFuncCall(lhs); break; - case Token::PTR: lhs = UnaryOp::New(Token::DEREF, lhs); - // Fall through - case '.': lhs = ParseMemberRef(tok, '.', lhs); break; - case Token::INC: - case Token::DEC: lhs = ParsePostfixIncDec(tok, lhs); break; - default: ts_.PutBack(); return lhs; - } - } -} - - -Expr* Parser::ParseSubScripting(Expr* lhs) { - auto lhsTile = lhs->Type()->ToTile(); - if(lhsTile == nullptr) - Error(lhs, "tile expected"); - TileType::ShapeInt lhsShape = lhsTile->Shape(); - QualType lhsQual = lhsTile->Derived(); - // create ret shape - TileType::ShapeInt shape; - TileType::ShapeInt axVec; - size_t i = 0; - const Token* tok; - std::vector> redInfo; - do { - tok = ts_.Next(); - switch(tok->tag_) { - case ':': - shape.push_back(lhsShape[i++]); - break; - case Token::NEWAXIS: - shape.push_back(1); - break; - case Token::ADD: - case Token::SUB: - case Token::MAX: - case Token::MIN:{ - int info = UnaryOp::encodeRed(i, tok->tag_); - redInfo.push_back({i, info}); - shape.push_back(lhsShape[i++]); - break; - } - case '^':{ - Expr* expr = ParseConditionalExpr(); - EnsureInteger(expr); - int ax = Evaluator().Eval(expr); - axVec.push_back(ax); - if(ax < 0 || ax >= lhsShape.size()) - Error(tok, "unknown axis %d in transposition", ax); - shape.push_back(lhsShape[ax]); - i++; - break; - } - - default: - Error(tok, "Unexpected subscript symbol encountered at dimension %d", i); - break; - } - }while(ts_.Try(',')); - ts_.Expect(']'); - - // transposition mode - std::set axSet(axVec.begin(), axVec.end()); - if(!axSet.empty()){ - if(axSet.size()!=lhsShape.size()) - Error(tok, "transposition must address all axes of input array"); - return TransOp::New(axVec, lhs); - } - - // broadcasting mode - if(lhsShape.size() > i) - Error(tok, "broadcasting not using all operand axes"); - - // create ret tile - Expr* res = lhs; - for(auto r: redInfo){ - shape.erase(shape.begin() + r.first); - Type *retType; - if(shape.empty()) - retType = lhsQual.GetPtr(); - else - retType = TileType::New(shape, lhsQual); - res = UnaryOp::New(Token::REDUCE, res, retType, r.second); - } - if(!shape.empty()){ - TileType *retType = TileType::New(shape, lhsQual); - res = UnaryOp::New(Token::CAST, res, retType); - } - return res; -} - - -BinaryOp* Parser::ParseMemberRef(const Token* tok, int op, Expr* lhs) { - auto memberName = ts_.Peek()->str_; - ts_.Expect(Token::IDENTIFIER); - - auto structUnionType = lhs->Type()->ToStruct(); - if (structUnionType == nullptr) { - Error(tok, "an struct/union expected"); - } - - auto rhs = structUnionType->GetMember(memberName); - if (rhs == nullptr) { - Error(tok, "'%s' is not a member of '%s'", - memberName.c_str(), "[obj]"); - } - - return BinaryOp::New(tok, op, lhs, rhs); -} - - -UnaryOp* Parser::ParsePostfixIncDec(const Token* tok, Expr* operand) { - auto op = tok->tag_ == Token::INC ? - Token::POSTFIX_INC: Token::POSTFIX_DEC; - return UnaryOp::New(op, operand); -} - - -FuncCall* Parser::ParseFuncCall(Expr* designator) { - FuncCall::ArgList args; - while (!ts_.Try(')')) { - args.push_back(Expr::MayCast(ParseAssignExpr())); - if (!ts_.Test(')')) - ts_.Expect(','); - } - return FuncCall::New(designator, args); -} - - -Expr* Parser::ParseUnaryExpr() { - auto tok = ts_.Next(); - switch (tok->tag_) { - case Token::ALIGNOF: return ParseAlignof(); - case Token::SIZEOF: return ParseSizeof(); - case Token::INC: return ParsePrefixIncDec(tok); - case Token::DEC: return ParsePrefixIncDec(tok); - case '&': return ParseUnaryOp(tok, Token::ADDR); - case '*': return ParseDerefOp(tok); - case '+': return ParseUnaryOp(tok, Token::PLUS); - case '-': return ParseUnaryOp(tok, Token::MINUS); - case '~': return ParseUnaryOp(tok, '~'); - case '!': return ParseUnaryOp(tok, '!'); - case '^': { - auto operand = ParseCastExpr(); - TileType::ShapeInt shape = operand->Type()->ToTile()->Shape(); - TransOp::PermInt perm(shape.size()); - for(int d = 0; d < shape.size(); d++) - perm[d] = d; - std::rotate(perm.begin(), perm.begin() + 1, perm.end()); - return TransOp::New(perm, operand); - } - default: - ts_.PutBack(); - return ParsePostfixExpr(); - } -} - - -Constant* Parser::ParseSizeof() { - QualType type(nullptr); - auto tok = ts_.Next(); - if (tok->tag_ == '(' && IsTypeName(ts_.Peek())) { - type = ParseTypeName(); - ts_.Expect(')'); - } else { - ts_.PutBack(); - auto expr = ParseUnaryExpr(); - type = expr->Type(); - } - - if (type->ToFunc() || type->ToVoid()) { - } else if (!type->Complete()) { - Error(tok, "sizeof(incomplete type)"); - } - long val = type->Width(); - return Constant::New(tok, T_UNSIGNED | T_LONG, val); -} - - -Constant* Parser::ParseAlignof() { - ts_.Expect('('); - auto tok = ts_.Peek(); - auto type = ParseTypeName(); - ts_.Expect(')'); - - long val = type->Align(); - return Constant::New(tok, T_UNSIGNED| T_LONG, val); -} - - -UnaryOp* Parser::ParsePrefixIncDec(const Token* tok) { - assert(tok->tag_ == Token::INC || tok->tag_ == Token::DEC); - - auto op = tok->tag_ == Token::INC ? - Token::PREFIX_INC: Token::PREFIX_DEC; - auto operand = ParseUnaryExpr(); - return UnaryOp::New(op, operand); -} - -UnaryOp* Parser::ParseUnaryIntrinsicOp(int op) { - ts_.Expect('('); - auto operand = ParseExpr(); - ts_.Expect(')'); - auto ret = UnaryOp::New(op, operand); - return ret; -} - -UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) { - auto operand = ParseCastExpr(); - return UnaryOp::New(op, operand); -} - -Expr* Parser::ParseDerefOp(const Token* tok) { - Expr* pred = nullptr; - if(ts_.Try('?')){ - ts_.Expect('('); - pred = ParseExpr(); - ts_.Expect(')'); - } - Expr* addr = ParseCastExpr(); - if(pred) - return BinaryOp::New(tok, Token::MASKED_DEREF, pred, addr); - else - return UnaryOp::New(Token::DEREF, addr); -} - -QualType Parser::ParseTypeName() { - auto type = ParseSpecQual(); - if (ts_.Test('*') || ts_.Test('(') || ts_.Test('[')) // abstract-declarator FIRST set - return ParseAbstractDeclarator(type); - return type; -} - - -Expr* Parser::ParseCastExpr() { - auto tok = ts_.Next(); - // bitcast - if (tok->tag_ == Token::BITCAST) { - ts_.Expect('<'); - auto type = ParseTypeName(); - ts_.Expect('>'); - ts_.Expect('('); - auto operand = ParseExpr(); - ts_.Expect(')'); - return UnaryOp::New(Token::BITCAST, operand, type); - } - // semantic cast - if (tok->tag_ == '(' && IsTypeName(ts_.Peek())) { - auto type = ParseTypeName(); - ts_.Expect(')'); - if (ts_.Test('{')) { - auto anony = ParseCompoundLiteral(type); - return ParsePostfixExprTail(anony); - } - auto operand = ParseCastExpr(); - return UnaryOp::New(Token::CAST, operand, type); - } - - ts_.PutBack(); - return ParseUnaryExpr(); -} - -Expr* Parser::ParseRangeExpr() { - auto lhs = ParseCastExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == Token::ELLIPSIS) { - auto rhs = ParseCastExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - tok = ts_.Next(); - } - ts_.PutBack(); - return lhs; -} - -Expr* Parser::ParseMatmulExpr() { - auto lhs = ParseRangeExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == Token::MATMUL) { - auto rhs = ParseRangeExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - tok = ts_.Next(); - } - ts_.PutBack(); - return lhs; -} - -Expr* Parser::ParseMultiplicativeExpr() { - auto lhs = ParseMatmulExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') { - auto rhs = ParseMatmulExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - tok = ts_.Next(); - } - ts_.PutBack(); - return lhs; -} - - -Expr* Parser::ParseAdditiveExpr() { - auto lhs = ParseMultiplicativeExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == '+' || tok->tag_ == '-') { - auto rhs = ParseMultiplicativeExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Next(); - } - - ts_.PutBack(); - return lhs; -} - - -Expr* Parser::ParseShiftExpr() { - auto lhs = ParseAdditiveExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == Token::LEFT || tok->tag_ == Token::RIGHT) { - auto rhs = ParseAdditiveExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Next(); - } - - ts_.PutBack(); - return lhs; -} - - -Expr* Parser::ParseRelationalExpr() { - auto lhs = ParseShiftExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == Token::LE || tok->tag_ == Token::GE - || tok->tag_ == '<' || tok->tag_ == '>') { - auto rhs = ParseShiftExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Next(); - } - - ts_.PutBack(); - return lhs; -} - - -Expr* Parser::ParseEqualityExpr() { - auto lhs = ParseRelationalExpr(); - auto tok = ts_.Next(); - while (tok->tag_ == Token::EQ || tok->tag_ == Token::NE) { - auto rhs = ParseRelationalExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Next(); - } - - ts_.PutBack(); - return lhs; -} - - -Expr* Parser::ParseBitiwiseAndExpr() { - auto lhs = ParseEqualityExpr(); - auto tok = ts_.Peek(); - while (ts_.Try('&')) { - auto rhs = ParseEqualityExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - - return lhs; -} - - -Expr* Parser::ParseBitwiseXorExpr() { - auto lhs = ParseBitiwiseAndExpr(); - auto tok = ts_.Peek(); - while (ts_.Try('^')) { - auto rhs = ParseBitiwiseAndExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - - return lhs; -} - - -Expr* Parser::ParseBitwiseOrExpr() { - auto lhs = ParseBitwiseXorExpr(); - auto tok = ts_.Peek(); - while (ts_.Try('|')) { - auto rhs = ParseBitwiseXorExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - - return lhs; -} - - -Expr* Parser::ParseLogicalAndExpr() { - auto lhs = ParseBitwiseOrExpr(); - auto tok = ts_.Peek(); - while (ts_.Try(Token::LOGICAL_AND)) { - auto rhs = ParseBitwiseOrExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - - return lhs; -} - - -Expr* Parser::ParseLogicalOrExpr() { - auto lhs = ParseLogicalAndExpr(); - auto tok = ts_.Peek(); - while (ts_.Try(Token::LOGICAL_OR)) { - auto rhs = ParseLogicalAndExpr(); - lhs = BinaryOp::New(tok, lhs, rhs); - - tok = ts_.Peek(); - } - - return lhs; -} - - -Expr* Parser::ParseConditionalExpr() { - auto cond = ParseLogicalOrExpr(); - auto tok = ts_.Peek(); - if (ts_.Try('?')) { - // Non-standard GNU extension - // a ?: b equals a ? a: c - auto exprTrue = ts_.Test(':') ? cond: ParseExpr(); - ts_.Expect(':'); - auto exprFalse = ParseConditionalExpr(); - - return ConditionalOp::New(tok, cond, exprTrue, exprFalse); - } - return cond; -} - - -Expr* Parser::ParseAssignExpr() { - // Yes, I know the lhs should be unary expression, - // let it handled by type checking - Expr* lhs = ParseConditionalExpr(); - Expr* rhs; - - auto tok = ts_.Next(); - switch (tok->tag_) { - case Token::MUL_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '*', lhs, rhs); - break; - - case Token::DIV_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '/', lhs, rhs); - break; - - case Token::MOD_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '%', lhs, rhs); - break; - - case Token::ADD_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '+', lhs, rhs); - break; - - case Token::SUB_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '-', lhs, rhs); - break; - - case Token::LEFT_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, Token::LEFT, lhs, rhs); - break; - - case Token::RIGHT_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, Token::RIGHT, lhs, rhs); - break; - - case Token::AND_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '&', lhs, rhs); - break; - - case Token::XOR_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '^', lhs, rhs); - break; - - case Token::OR_ASSIGN: - rhs = ParseAssignExpr(); - rhs = BinaryOp::New(tok, '|', lhs, rhs); - break; - - case '=': - rhs = ParseAssignExpr(); - break; - - default: - ts_.PutBack(); - return lhs; // Could be constant - } - - return BinaryOp::New(tok, '=', lhs, rhs); -} - - -void Parser::ParseStaticAssert() { - ts_.Expect('('); - auto condExpr = ParseAssignExpr(); - ts_.Expect(','); - auto msg = ConcatLiterals(ts_.Expect(Token::LITERAL)); - ts_.Expect(')'); - ts_.Expect(';'); - if (!Evaluator().Eval(condExpr)) { - Error(ts_.Peek(), "static assertion failed: %s\n", - msg->SVal()->c_str()); - } -} - - -// Return: list of declarations -CompoundStmt* Parser::ParseDecl() { - StmtList stmts; - if (ts_.Try(Token::STATIC_ASSERT)) { - ParseStaticAssert(); - } else { - int storageSpec, funcSpec, align; - auto type = ParseDeclSpec(&storageSpec, &funcSpec, &align); - if (!ts_.Test(';')) { - do { - auto ident = ParseDirectDeclarator(type, storageSpec, funcSpec, align); - auto init = ParseInitDeclarator(ident); - if (init) stmts.push_back(init); - } while (ts_.Try(',')); - } - ts_.Expect(';'); - } - - return CompoundStmt::New(stmts); -} - - -// For state machine -enum { - // Compatibility for these key words - COMP_SIGNED = T_SHORT | T_INT | T_LONG | T_LLONG, - COMP_UNSIGNED = T_SHORT | T_INT | T_LONG | T_LLONG, - COMP_CHAR = T_SIGNED | T_UNSIGNED, - COMP_SHORT = T_SIGNED | T_UNSIGNED | T_INT, - COMP_INT = T_SIGNED | T_UNSIGNED | T_LONG | T_SHORT | T_LLONG, - COMP_LONG = T_SIGNED | T_UNSIGNED | T_LONG | T_INT, - COMP_DOUBLE = T_LONG | T_COMPLEX, - COMP_COMPLEX = T_FLOAT | T_DOUBLE | T_LONG, - - COMP_THREAD = S_EXTERN | S_STATIC, -}; - - -static inline void TypeLL(int& typeSpec) { - if (typeSpec & T_LONG) { - typeSpec &= ~T_LONG; - typeSpec |= T_LLONG; - } else { - typeSpec |= T_LONG; - } -} - - -QualType Parser::ParseSpecQual() { - return ParseDeclSpec(nullptr, nullptr, nullptr); -} - - -static void EnsureAndSetStorageSpec(const Token* tok, int* storage, int spec) { - if (!storage) - Error(tok, "unexpected storage specifier"); - if (*storage != 0) - Error(tok, "duplicated storage specifier"); - *storage |= spec; -} - - -/* - * param: storage: null, only type specifier and qualifier accepted; - */ -QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec) { -#define ERR_FUNC_SPEC ("unexpected function specifier") -#define ERR_STOR_SPEC ("unexpected storage specifier") -#define ERR_DECL_SPEC ("two or more data types in declaration specifiers") - - QualType type(nullptr); - int qualSpec = 0; - int typeSpec = 0; - - if (storageSpec) *storageSpec = 0; - if (funcSpec) *funcSpec = 0; - if (alignSpec) *alignSpec = 0; - - const Token* tok; - for (; ;) { - tok = ts_.Next(); - switch (tok->tag_) { - // Function specifier - case Token::INLINE: - if (!funcSpec) - Error(tok, ERR_FUNC_SPEC); - *funcSpec |= F_INLINE; - break; - - case Token::NORETURN: - if (!funcSpec) - Error(tok, ERR_FUNC_SPEC); - *funcSpec |= F_NORETURN; - break; - - // Alignment specifier - case Token::ALIGNAS: { - if (!alignSpec) - Error(tok, "unexpected alignment specifier"); - auto align = ParseAlignas(); - if (align) - *alignSpec = align; - break; - } - // Storage specifier - // TODO(wgtdkp): typedef needs more constraints - case Token::TYPEDEF: - EnsureAndSetStorageSpec(tok, storageSpec, S_TYPEDEF); - break; - - case Token::EXTERN: - EnsureAndSetStorageSpec(tok, storageSpec, S_EXTERN); - break; - - case Token::GLOBAL: - EnsureAndSetStorageSpec(tok, storageSpec, S_GLOBAL); - break; - - case Token::STATIC: - if (!storageSpec) - Error(tok, ERR_FUNC_SPEC); - if (*storageSpec & ~S_THREAD) - Error(tok, "duplicated storage specifier"); - *storageSpec |= S_STATIC; - break; - - case Token::THREAD: - if (!storageSpec) - Error(tok, ERR_FUNC_SPEC); - if (*storageSpec & ~COMP_THREAD) - Error(tok, "duplicated storage specifier"); - *storageSpec |= S_THREAD; - break; - - - // Type qualifier - case Token::CONST: qualSpec |= Qualifier::CONST; break; - case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break; - case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break; - case Token::CMEM: qualSpec |= Qualifier::CMEM; break; - - // Type specifier - case Token::SIGNED: - if (typeSpec & ~COMP_SIGNED) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_SIGNED; - break; - - case Token::UNSIGNED: - if (typeSpec & ~COMP_UNSIGNED) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_UNSIGNED; - break; - - case Token::VOID: - if (typeSpec & ~0) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_VOID; - break; - - case Token::CHAR: - if (typeSpec & ~COMP_CHAR) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_CHAR; - break; - - case Token::SHORT: - if (typeSpec & ~COMP_SHORT) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_SHORT; - break; - - case Token::INT: - if (typeSpec & ~COMP_INT) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_INT; - break; - - case Token::LONG: - if (typeSpec & ~COMP_LONG) - Error(tok, ERR_DECL_SPEC); - TypeLL(typeSpec); - break; - - case Token::HALF: - if(typeSpec & ~T_COMPLEX) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_HALF; - break; - - case Token::FLOAT: - if (typeSpec & ~T_COMPLEX) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_FLOAT; - break; - - case Token::DOUBLE: - if (typeSpec & ~COMP_DOUBLE) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_DOUBLE; - break; - - case Token::BOOL: - if (typeSpec != 0) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_BOOL; - break; - - case Token::COMPLEX: - if (typeSpec & ~COMP_COMPLEX) - Error(tok, ERR_DECL_SPEC); - typeSpec |= T_COMPLEX; - break; - - case Token::STRUCT: - case Token::UNION: - if (typeSpec & ~0) - Error(tok, ERR_DECL_SPEC); - type = ParseStructUnionSpec(Token::STRUCT == tok->tag_); - typeSpec |= T_STRUCT_UNION; - break; - - case Token::ENUM: - if (typeSpec != 0) - Error(tok, ERR_DECL_SPEC); - type = ParseEnumSpec(); - typeSpec |= T_ENUM; - break; - - case Token::ATOMIC: - Error(tok, "atomic not supported"); - break; - - default: - if (typeSpec == 0 && IsTypeName(tok)) { - auto ident = curScope_->Find(tok); - type = ident->Type(); - // We may change the length of a array type by initializer, - // thus, make a copy of this type. - auto arrType = type->ToArray(); - if (arrType && !type->Complete()) - type = ArrayType::New(arrType->Len(), arrType->Derived()); - typeSpec |= T_TYPEDEF_NAME; - } else { - goto end_of_loop; - } - } - } - -end_of_loop: - ts_.PutBack(); - switch (typeSpec) { - case 0: - Error(tok, "expect type specifier"); - break; - - case T_VOID: - type = VoidType::New(); - break; - - case T_STRUCT_UNION: - case T_ENUM: - case T_TYPEDEF_NAME: - break; - - default: - type = ArithmType::New(typeSpec); - break; - } - // GNU extension: type attributes - //if (storageSpec && (*storageSpec & S_TYPEDEF)) - // TryAttributeSpecList(); - - return QualType(type.GetPtr(), qualSpec | type.Qual()); - -#undef ERR_FUNC_SPEC -#undef ERR_STOR_SPEC -#undef ERR_DECL_SPEC -} - - -int Parser::ParseAlignas() { - int align; - ts_.Expect('('); - auto tok = ts_.Peek(); - if (IsTypeName(ts_.Peek())) { - auto type = ParseTypeName(); - ts_.Expect(')'); - align = type->Align(); - } else { - auto expr = ParseExpr(); - align = Evaluator().Eval(expr); - ts_.Expect(')'); - } - if (align < 0 || ((align - 1) & align)) - Error(tok, "requested alignment is not a positive power of 2"); - return align; -} - - -Type* Parser::ParseEnumSpec() { - // GNU extension: type attributes - TryAttributeSpecList(); - - std::string tagName; - auto tok = ts_.Peek(); - if (ts_.Try(Token::IDENTIFIER)) { - tagName = tok->str_; - if (ts_.Try('{')) { - // 定义enum类型 - auto tagIdent = curScope_->FindTagInCurScope(tok); - if (!tagIdent) { - auto type = ArithmType::New(T_INT); - auto ident = Identifier::New(tok, type, L_NONE); - curScope_->InsertTag(ident); - return ParseEnumerator(type); // 处理反大括号: '}' - } - - if (!tagIdent->Type()->IsInteger()) // struct/union tag - Error(tok, "redefinition of enumeration tag '%s'", tagName.c_str()); - return ParseEnumerator(tagIdent->Type()->ToArithm()); - } else { - auto tagIdent = curScope_->FindTag(tok); - if (tagIdent) { - return tagIdent->Type(); - } - auto type = ArithmType::New(T_INT); - auto ident = Identifier::New(tok, type, L_NONE); - curScope_->InsertTag(ident); - return type; - } - } - - ts_.Expect('{'); - auto type = ArithmType::New(T_INT); - return ParseEnumerator(type); // 处理反大括号: '}' -} - - -Type* Parser::ParseEnumerator(ArithmType* type) { - assert(type && type->IsInteger()); - int val = 0; - do { - auto tok = ts_.Expect(Token::IDENTIFIER); - // GNU extension: enumerator attributes - TryAttributeSpecList(); - - const auto& enumName = tok->str_; - auto ident = curScope_->FindInCurScope(tok); - if (ident) { - Error(tok, "redefinition of enumerator '%s'", enumName.c_str()); - } - if (ts_.Try('=')) { - auto expr = ParseAssignExpr(); - val = Evaluator().Eval(expr); - } - auto enumer = Enumerator::New(tok, val); - ++val; - curScope_->Insert(enumer); - ts_.Try(','); - } while (!ts_.Try('}')); - - type->SetComplete(true); - return type; -} - - -/* - * 四种 name space: - * 1.label, 如 goto end; 它有函数作用域 - * 2.struct/union/enum 的 tag - * 3.struct/union 的成员 - * 4.其它的普通的变量 - */ -Type* Parser::ParseStructUnionSpec(bool isStruct) { - // GNU extension: type attributes - TryAttributeSpecList(); - - std::string tagName; - auto tok = ts_.Peek(); - if (ts_.Try(Token::IDENTIFIER)) { - tagName = tok->str_; - if (ts_.Try('{')) { - // 看见大括号,表明现在将定义该struct/union类型 - // 我们不用关心上层scope是否定义了此tag,如果定义了,那么就直接覆盖定义 - auto tagIdent = curScope_->FindTagInCurScope(tok); - if (!tagIdent) { - // 现在是在当前scope第一次看到name,所以现在是第一次定义,连前向声明都没有; - auto type = StructType::New(isStruct, tagName.size(), curScope_); - auto ident = Identifier::New(tok, type, L_NONE); - curScope_->InsertTag(ident); - return ParseStructUnionDecl(type); // 处理反大括号: '}' - } - - - // 在当前scope找到了类型,但可能只是声明;注意声明与定义只能出现在同一个scope; - // 1.如果声明在定义的外层scope,那么即使在内层scope定义了完整的类型,此声明仍然是无效的; - // 因为如论如何,编译器都不会在内部scope里面去找定义,所以声明的类型仍然是不完整的; - // 2.如果声明在定义的内层scope,(也就是先定义,再在内部scope声明),这时,不完整的声明会覆盖掉完整的定义; - // 因为编译器总是向上查找符号,不管找到的是完整的还是不完整的,都要; - if (!tagIdent->Type()->Complete()) { - // 找到了此tag的前向声明,并更新其符号表,最后设置为complete type - return ParseStructUnionDecl(tagIdent->Type()->ToStruct()); - } else { - // 在当前作用域找到了完整的定义,并且现在正在定义同名的类型,所以报错; - Error(tok, "redefinition of struct tag '%s'", tagName.c_str()); - } - } else { - // 没有大括号,表明不是定义一个struct/union;那么现在只可能是在: - // 1.声明; - // 2.声明的同时,定义指针(指针允许指向不完整类型) (struct Foo* p; 是合法的) 或者其他合法的类型; - // 如果现在索引符号表,那么: - // 1.可能找到name的完整定义,也可能只找得到不完整的声明;不管name指示的是不是完整类型,我们都只能选择name指示的类型; - // 2.如果我们在符号表里面压根找不到name,那么现在是name的第一次声明,创建不完整的类型并插入符号表; - auto tagIdent = curScope_->FindTag(tok); - - // 如果tag已经定义或声明,那么直接返回此定义或者声明 - if (tagIdent) { - return tagIdent->Type(); - } - // 如果tag尚没有定义或者声明,那么创建此tag的声明(因为没有见到‘{’,所以不会是定义) - auto type = StructType::New(isStruct, true, curScope_); - - // 因为有tag,所以不是匿名的struct/union, 向当前的scope插入此tag - auto ident = Identifier::New(tok, type, L_NONE); - curScope_->InsertTag(ident); - return type; - } - } - // 没见到identifier,那就必须有struct/union的定义,这叫做匿名struct/union; - ts_.Expect('{'); - - // 现在,如果是有tag,那它没有前向声明;如果是没有tag,那更加没有前向声明; - // 所以现在是第一次开始定义一个完整的struct/union类型 - auto type = StructType::New(isStruct, tagName.size(), curScope_); - return ParseStructUnionDecl(type); // 处理反大括号: '}' -} - - -StructType* Parser::ParseStructUnionDecl(StructType* type) { -#define ADD_MEMBER() { \ - auto member = Object::New(tok, memberType); \ - if (align > 0) \ - member->SetAlign(align); \ - type->AddMember(member); \ -} - - // 既然是定义,那输入肯定是不完整类型,不然就是重定义了 - assert(type && !type->Complete()); - - auto scopeBackup = curScope_; - curScope_ = type->MemberMap(); // Internal symbol lookup rely on curScope_ - while (!ts_.Try('}')) { - if (ts_.Empty()) { - Error(ts_.Peek(), "premature end of input"); - } - - if(ts_.Try(Token::STATIC_ASSERT)) { - ParseStaticAssert(); - continue; - } - - // 解析type specifier/qualifier, 不接受storage等 - int align; - auto baseType = ParseDeclSpec(nullptr, nullptr, &align); - do { - auto declInfo = ParseDeclarator(baseType); - auto tok = declInfo.tok; - auto memberType = declInfo.type; - - if (ts_.Try(':')) { - ParseBitField(type, tok, memberType); - continue; - } - - if (tok == nullptr) { - auto suType = memberType->ToStruct(); - if (suType && !suType->HasTag()) { - auto anony = Object::NewAnony(ts_.Peek(), suType); - type->MergeAnony(anony); - continue; - } else { - Error(ts_.Peek(), "declaration does not declare anything"); - } - } - - const auto& name = tok->str_; - if (type->GetMember(name)) { - Error(tok, "duplicate member '%s'", name.c_str()); - } else if (!memberType->Complete()) { - // C11 6.7.2.1 [3]: - if (type->IsStruct() && - // Struct has more than one named member - type->MemberMap()->size() > 0 && - memberType->ToArray()) { - ts_.Expect(';'); ts_.Expect('}'); - ADD_MEMBER(); - goto finalize; - } else { - Error(tok, "field '%s' has incomplete type", name.c_str()); - } - } else if (memberType->ToFunc()) { - Error(tok, "field '%s' declared as a function", name.c_str()); - } - - ADD_MEMBER(); - } while (ts_.Try(',')); - ts_.Expect(';'); - } -finalize: - // GNU extension: type attributes - TryAttributeSpecList(); - - // struct/union定义结束,设置其为完整类型 - type->Finalize(); - type->SetComplete(true); - // TODO(wgtdkp): we need to export tags defined inside struct - const auto& tags = curScope_->AllTagsInCurScope(); - for (auto tag: tags) { - if (scopeBackup->FindTag(tag->Tok())) - Error(tag, "redefinition of tag '%s'\n", tag->Name().c_str()); - scopeBackup->InsertTag(tag); - } - curScope_ = scopeBackup; - - return type; -} - - -void Parser::ParseBitField(StructType* structType, - const Token* tok, - QualType type) { - if (!type->IsInteger()) { - Error(tok ? tok: ts_.Peek(), "expect integer type for bitfield"); - } - - auto expr = ParseAssignExpr(); - auto width = Evaluator().Eval(expr); - if (width < 0) { - Error(expr, "expect non negative value"); - } else if (width == 0 && tok) { - Error(tok, "no declarator expected for a bitfield with width 0"); - } else if (width > type->Width() * 8) { - Error(expr, "width exceeds its type"); - } - - auto offset = structType->Offset() - type->Width(); - // C11 6.7.5 [2]: alignment attribute shall not be specified in declaration of a bit field - // so here is ok to use type->Align() - offset = Type::MakeAlign(std::max(offset, 0), type->Align()); - - int bitFieldOffset; - unsigned char begin; - - if (!structType->IsStruct()) { - begin = 0; - bitFieldOffset = 0; - } else if (structType->Members().size() == 0) { - begin = 0; - bitFieldOffset = 0; - } else { - auto last = structType->Members().back(); - auto totalBits = last->Offset() * 8; - if (last->BitFieldWidth()) { - totalBits += last->BitFieldEnd(); - } else { // Is not bit field - totalBits += last->Type()->Width() * 8; - } - - if (width == 0) - width = type->Width() * 8 - totalBits; // So posterior bitfield would be packed - if (width == 0) // A bitfield with zero width is never added to member list - return; // Because we use bitfield width to tell if a member is bitfield or not. - if (width + totalBits <= type->Width() * 8) { - begin = totalBits % 8; - bitFieldOffset = totalBits / 8; - } else { - begin = 0; - bitFieldOffset = Type::MakeAlign(structType->Offset(), type->Width()); - } - } - - Object* bitField; - if (tok) { - bitField = Object::New(tok, type, 0, L_NONE, begin, width); - } else { - bitField = Object::NewAnony(ts_.Peek(), type, 0, L_NONE, begin, width); - } - structType->AddBitField(bitField, bitFieldOffset); -} - - -int Parser::ParseQual() { - int qualSpec = 0; - for (; ;) { - auto tok = ts_.Next(); - switch (tok->tag_) { - case Token::CONST: qualSpec |= Qualifier::CONST; break; - case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break; - case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break; - case Token::CMEM: qualSpec |= Qualifier::CMEM; break; - case Token::ATOMIC: Error(tok, "do not support 'atomic'"); break; - default: ts_.PutBack(); return qualSpec; - } - } -} - - -QualType Parser::ParsePointer(QualType typePointedTo) { - while (ts_.Try('*')) { - auto t = PointerType::New(typePointedTo); - typePointedTo = QualType(t, ParseQual()); - } - return typePointedTo; -} - - -static QualType ModifyBase(QualType type, QualType base, QualType newBase) { - if (type == base) - return newBase; - - auto ty = type->ToDerived(); - ty->SetDerived(ModifyBase(ty->Derived(), base, newBase)); - - return ty; -} - - -/* - * Return: pair of token(must be identifier) and it's type - * if token is nullptr, then we are parsing abstract declarator - * else, parsing direct declarator. - */ -DeclInfo Parser::ParseDeclarator(QualType base) { - // May be pointer - auto pointerType = ParsePointer(base); - - if (ts_.Try('(')) { - // 现在的 pointerType 并不是正确的 base type - auto declInfo = ParseDeclarator(pointerType); - auto tok = declInfo.tok; - auto type = declInfo.type; - - ts_.Expect(')'); - - auto newBase = ParseArrayFuncDeclarator(tok, pointerType); - - // 修正 base type - auto retType = ModifyBase(type, pointerType, newBase); - return DeclInfo(declInfo.tok, retType); - } else if (ts_.Peek()->IsIdentifier()) { - auto tok = ts_.Next(); - // GNU extension: variable attributes - ASTNode::AttrList attrList = TryAttributeSpecList(); - auto retType = ParseArrayFuncDeclarator(tok, pointerType); - return DeclInfo(tok, retType, attrList); - } else { - errTok_ = ts_.Peek(); - auto retType = ParseArrayFuncDeclarator(nullptr, pointerType); - return DeclInfo(nullptr, retType); - } -} - - -Identifier* Parser::ProcessDeclarator(const Token* tok, - QualType type, - const ASTNode::AttrList& attrs, - int storageSpec, - int funcSpec, - int align) { - assert(tok); - - // 检查在同一 scope 是否已经定义此变量 - // 如果 storage 是 typedef,那么应该往符号表里面插入 type - // 定义 void 类型变量是非法的,只能是指向void类型的指针 - // 如果 funcSpec != 0, 那么现在必须是在定义函数,否则出错 - const auto& name = tok->str_; - Identifier* ident; - - if (storageSpec & S_TYPEDEF) { - // C11 6.7.5 [2]: alignment specifier - if (align > 0) - Error(tok, "alignment specified for typedef"); - - ident = curScope_->FindInCurScope(tok); - if (ident) { // There is prio declaration in the same scope - // The same declaration, simply return the prio declaration - if (!type->Compatible(*ident->Type())) - Error(tok, "conflicting types for '%s'", name.c_str()); - - // TODO(wgtdkp): add previous declaration information - return ident; - } - - if(!attrs.empty()) { - Error(tok, "typedef attributes not allowed"); - } - - ident = Identifier::New(tok, type, L_NONE); - curScope_->Insert(ident); - return ident; - } - - if (type->ToVoid()) { - Error(tok, "variable or field '%s' declared void", - name.c_str()); - } - - if (type->ToFunc() && curScope_->Type() != S_FILE - && (storageSpec & S_STATIC)) { - Error(tok, "invalid storage class for function '%s'", name.c_str()); - } - - Linkage linkage; - // Identifiers in function prototype have no linkage - if (curScope_->Type() == S_PROTO) { - linkage = L_NONE; - } else if (curScope_->Type() == S_FILE) { - linkage = L_EXTERNAL; // Default linkage for file scope identifiers - if (storageSpec & S_STATIC) - linkage = L_INTERNAL; - } else if (!(storageSpec & S_EXTERN)) { - linkage = L_NONE; // Default linkage for block scope identifiers - if (type->ToFunc()) - linkage = L_EXTERNAL; - } else { - linkage = L_EXTERNAL; - } - - ident = curScope_->FindInCurScope(tok); - if (ident) { // There is prio declaration in the same scope - if (!type->Compatible(*ident->Type())) { - Error(tok, "conflicting types for '%s'", name.c_str()); - } - - // The same scope prio declaration has no linkage, - // there is a redeclaration error - if (linkage == L_NONE) { - Error(tok, "redeclaration of '%s' with no linkage", - name.c_str()); - } else if (linkage == L_EXTERNAL) { - if (ident->Linkage() == L_NONE) { - Error(tok, "conflicting linkage for '%s'", name.c_str()); - } - } else { - if (ident->Linkage() != L_INTERNAL) { - Error(tok, "conflicting linkage for '%s'", name.c_str()); - } - } - // The same declaration, simply return the prio declaration - if (!ident->Type()->Complete()) - ident->Type()->SetComplete(type->Complete()); - // Prio declaration of a function may omit the param name - if (type->ToFunc()) - ident->Type()->ToFunc()->SetParams(type->ToFunc()->Params()); - else if (ident->ToObject() && !(storageSpec & S_EXTERN)) - ident->ToObject()->SetStorage(ident->ToObject()->Storage() & ~S_EXTERN); - return ident; - } else if (linkage == L_EXTERNAL) { - ident = curScope_->Find(tok); - if (ident) { - if (!type->Compatible(*ident->Type())) { - Error(tok, "conflicting types for '%s'", name.c_str()); - } - if (ident->Linkage() != L_NONE) { - linkage = ident->Linkage(); - } - // Don't return, override it - } else { - ident = externalSymbols_->FindInCurScope(tok); - if (ident) { - if (!type->Compatible(*ident->Type())) { - Error(tok, "conflicting types for '%s'", name.c_str()); - } - // TODO(wgtdkp): ??????? - // Don't return - // To stop later declaration with the same name in the same scope overriding this declaration - - // Useless here, just keep it - if (!ident->Type()->Complete()) - ident->Type()->SetComplete(type->Complete()); - //return ident; - } - } - } - - Identifier* ret; - // TODO(wgtdkp): Treat function as object ? - if (type->ToFunc()) { - // C11 6.7.5 [2]: alignment specifier - if (align > 0) - Error(tok, "alignment specified for function"); - ret = Identifier::New(tok, type, linkage, attrs); - } else { - auto obj = Object::New(tok, type, storageSpec, linkage, 0, 0, attrs); - if (align > 0) - obj->SetAlign(align); - ret = obj; - } - curScope_->Insert(ret); - if (linkage == L_EXTERNAL && ident == nullptr) { - externalSymbols_->Insert(ret); - } - - return ret; -} - - -QualType Parser::ParseArrayFuncDeclarator(const Token* ident, QualType base) { - if (ts_.Try('[')) { - if(!base->IsScalar()) { - Error(ts_.Peek(), "tiles must have scalar elements"); - } - auto shape = ParseTileShape(); - ts_.Expect(']'); - base = ParseArrayFuncDeclarator(ident, base); - if (!base->Complete()) { - Error(ident, "'%s' has incomplete element type", ident->str_.c_str()); - } - // return a pointer for tiles in constant memory: - TileType* ret = TileType::New(shape, base); - if(!ret->CheckPow2NumEl()) - Error(ts_.Peek(), "tile must have power of 2 number of elements"); - return ret; - } else if (ts_.Try('(')) { // Function declaration - if (base->ToFunc()) { - Error(ts_.Peek(), - "the return value of function cannot be function"); - } else if (nullptr != base->ToArray()) { - Error(ts_.Peek(), - "the return value of function cannot be array"); - } - - FuncType::ParamList params; - EnterProto(); - auto variadic = ParseParamList(params); - ExitProto(); - - ts_.Expect(')'); - base = ParseArrayFuncDeclarator(ident, base); - - return FuncType::New(base, 0, variadic, params); - } - - - return base; -} - - -/* - * Return: -1, length not specified - */ -int Parser::ParseArrayLength() { - auto hasStatic = ts_.Try(Token::STATIC); - auto qual = ParseQual(); - if (0 != qual) - hasStatic = ts_.Try(Token::STATIC); - - // 不支持变长数组 - if (!hasStatic && ts_.Test(']')) - return -1; - - auto expr = ParseAssignExpr(); - EnsureInteger(expr); - auto ret = Evaluator().Eval(expr); - if (ret < 0) { - Error(expr, "size of array is negative"); - } - return ret; -} - -TileType::ShapeInt Parser::ParseTileShape() { - TileType::ShapeInt ret; - size_t i = 0; - do { - Expr* expr = ParseConditionalExpr(); - EnsureInteger(expr); - int dim = Evaluator().Eval(expr); - if (dim < 0) - Error(expr, "shape %d of tile is negative", i); - ret.push_back(dim); - i++; - }while(ts_.Try(',')); - return ret; -} - -/* - * Return: true, variadic; - */ -bool Parser::ParseParamList(FuncType::ParamList& params) { - if (ts_.Test(')')) - return false; - auto param = ParseParamDecl(); - if (param->Type()->ToVoid()) - return false; - params.push_back(param); - - while (ts_.Try(',')) { - if (ts_.Try(Token::ELLIPSIS)) - return true; - param = ParseParamDecl(); - if (param->Type()->ToVoid()) - Error(param, "'void' must be the only parameter"); - params.push_back(param); - } - return false; -} - - -Object* Parser::ParseParamDecl() { - int storageSpec, funcSpec; - // C11 6.7.5 [2]: alignment specifier cannot be specified in params - auto type = ParseDeclSpec(&storageSpec, &funcSpec, nullptr); - auto tokTypePair = ParseDeclarator(type); - auto tok = tokTypePair.tok; - QualType fullType(tokTypePair.type.GetPtr(), type.Qual()); - type = Type::MayCast(fullType, true); - auto attrs = tokTypePair.attrs; - if (!tok) { // Abstract declarator - return Object::NewAnony(ts_.Peek(), type, 0, Linkage::L_NONE); - } - - // Align set to non positive, stands for not specified - auto ident = ProcessDeclarator(tok, type, attrs, storageSpec, funcSpec, -1); - if (!ident->ToObject()) - Error(ident, "expect object in param list"); - - return ident->ToObject(); -} - - -QualType Parser::ParseAbstractDeclarator(QualType type) { - auto declInfo = ParseDeclarator(type); - auto tok = declInfo.tok; - type = declInfo.type; - if (tok) { // Not a abstract declarator! - Error(tok, "unexpected identifier '%s'", tok->str_.c_str()); - } - return type; -} - - -Identifier* Parser::ParseDirectDeclarator(QualType type, - int storageSpec, - int funcSpec, - int align) { - auto declInfo = ParseDeclarator(type); - auto tok = declInfo.tok; - type = declInfo.type; - auto attrs = declInfo.attrs; - if (tok == nullptr) { - Error(errTok_, "expect identifier or '('"); - } - - return ProcessDeclarator(tok, type, attrs, storageSpec, funcSpec, align); -} - - -Declaration* Parser::ParseInitDeclarator(Identifier* ident) { - auto obj = ident->ToObject(); - if (!obj) { // Do not record function Declaration - return nullptr; - } - - const auto& name = obj->Name(); - if (ts_.Try('=')) { - return ParseInitDeclaratorSub(obj); - } - - if (!obj->Type()->Complete()) { - if (obj->Linkage() == L_NONE) { - Error(obj, "storage size of '%s' isn’t known", name.c_str()); - } - // FIXME(wgtdkp): - // Discards the incomplete object declarations - // It causes linking failure of forward-declared objects with imcomplete type - return nullptr; - } - - if (!obj->Decl()) { - auto decl = Declaration::New(obj); - obj->SetDecl(decl); - return decl; - } - - return nullptr; -} - - -Declaration* Parser::ParseInitDeclaratorSub(Object* obj) { - const auto& name = obj->Name(); - if ((curScope_->Type() != S_FILE) && obj->Linkage() != L_NONE) { - Error(obj, "'%s' has both 'extern' and initializer", name.c_str()); - } - - if (!obj->Type()->Complete() && !obj->Type()->ToArray()) { - Error(obj, "variable '%s' has initializer but incomplete type", - name.c_str()); - } - - if (obj->HasInit()) { - Error(obj, "redefinition of variable '%s'", name.c_str()); - } - - // There could be more than one declaration for - // an object in the same scope. - // But it must has external or internal linkage. - // So, for external/internal objects, - // the initialization will always go to - // the first declaration. As the initialization - // is evaluated at compile time, - // the order doesn't matter. - // For objects with no linkage, there is - // always only one declaration. - // Once again, we need not to worry about - // the order of the initialization. - if (obj->Decl()) { - ParseInitializer(obj->Decl(), obj->Type(), 0, false, true); - return nullptr; - } else { - auto decl = Declaration::New(obj); - ParseInitializer(decl, obj->Type(), 0, false, true); - obj->SetDecl(decl); - return decl; - } -} - - -void Parser::ParseInitializer(Declaration* decl, - QualType type, - int offset, - bool designated, - bool forceBrace, - unsigned char bitFieldBegin, - unsigned char bitFieldWidth) { - if (designated && !ts_.Test('.') && !ts_.Test('[')) { - ts_.Expect('='); - } - -// std::cout << "parsing initialized " << decl->Obj()->Name() << std::endl; - Expr* expr; - auto arrType = type->ToArray(); - auto structType = type->ToStruct(); - // A compound literal in initializer is reduced to a initializer directly - // It means that the compound literal will never be created - //auto literalType = TryCompoundLiteral(); - //if (literalType && !literalType->Compatible(*type)) - // Error("incompatible type of initializer"); - if (arrType) { - if (forceBrace && !ts_.Test('{') && !ts_.Test(Token::LITERAL)) { - ts_.Expect('{'); - } else if (!ParseLiteralInitializer(decl, arrType, offset)) { - ParseArrayInitializer(decl, arrType, offset, designated); - arrType->SetComplete(true); - } - return; - } else if (structType) { - if (!ts_.Test('.') && !ts_.Test('{')) { - auto mark = ts_.Mark(); - expr = ParseAssignExpr(); - if (structType->Compatible(*expr->Type())) { - decl->AddInit({structType, offset, expr}); - return; - } - ts_.ResetTo(mark); - if (forceBrace) - ts_.Expect('{'); - } - return ParseStructInitializer(decl, structType, offset, designated); - } - - // Scalar type - auto hasBrace = ts_.Try('{'); - expr = ParseAssignExpr(); - if (hasBrace) { - ts_.Try(','); - ts_.Expect('}'); - } - decl->AddInit({type.GetPtr(), offset, expr, bitFieldBegin, bitFieldWidth}); -} - - -bool Parser::ParseLiteralInitializer(Declaration* decl, - ArrayType* type, - int offset) { - if (!type->Derived()->IsInteger()) - return false; - - auto hasBrace = ts_.Try('{'); - if (!ts_.Test(Token::LITERAL)) { - if (hasBrace) ts_.PutBack(); - return false; - } - auto literal = ConcatLiterals(ts_.Next()); - auto tok = literal->Tok(); - - if (hasBrace) { - ts_.Try(','); - ts_.Expect('}'); - } - - if (!type->Complete()) { - type->SetLen(literal->Type()->ToArray()->Len()); - type->SetComplete(true); - } - - auto width = std::min(type->Width(), literal->Type()->Width()); - auto str = literal->SVal()->c_str(); - - for (; width >= 8; width -= 8) { - auto p = reinterpret_cast(str); - auto type = ArithmType::New(T_LONG); - auto val = Constant::New(tok, T_LONG, static_cast(*p)); - decl->AddInit({type, offset, val}); - offset += 8; - str += 8; - } - - for (; width >= 4; width -= 4) { - auto p = reinterpret_cast(str); - auto type = ArithmType::New(T_INT); - auto val = Constant::New(tok, T_INT, static_cast(*p)); - decl->AddInit({type, offset, val}); - offset += 4; - str += 4; - } - - for (; width >= 2; width -= 2) { - auto p = reinterpret_cast(str); - auto type = ArithmType::New(T_SHORT); - auto val = Constant::New(tok, T_SHORT, static_cast(*p)); - decl->AddInit({type, offset, val}); - offset += 2; - str += 2; - } - - for (; width >= 1; --width) { - auto p = str; - auto type = ArithmType::New(T_CHAR); - auto val = Constant::New(tok, T_CHAR, static_cast(*p)); - decl->AddInit({type, offset, val}); - offset++; - str++; - } - - return true; -} - - -void Parser::ParseArrayInitializer(Declaration* decl, - ArrayType* type, - int offset, - bool designated) { - assert(type); - - if (!type->Complete()) - type->SetLen(0); - - int idx = 0; - auto width = type->Derived()->Width(); - auto hasBrace = ts_.Try('{'); - while (true) { - if (ts_.Test('}')) { - if (hasBrace) - ts_.Next(); - return; - } - - if (!designated && !hasBrace && (ts_.Test('.') || ts_.Test('['))) { - ts_.PutBack(); // Put the read comma(',') back - return; - } else if ((designated = ts_.Try('['))) { - auto expr = ParseAssignExpr(); - EnsureInteger(expr); - idx = Evaluator().Eval(expr); - ts_.Expect(']'); - - if (idx < 0 || (type->Complete() && idx >= type->Len())) { - Error(ts_.Peek(), "excess elements in array initializer"); - } - } - - ParseInitializer(decl, type->Derived(), offset + idx * width, designated); - designated = false; - ++idx; - - if (type->Complete() && idx >= type->Len()) { - break; - } else if (!type->Complete()) { - type->SetLen(std::max(idx, type->Len())); - } - - // Needless comma at the end is legal - if (!ts_.Try(',')) { - if (hasBrace) - ts_.Expect('}'); - return; - } - } - - if (hasBrace) { - ts_.Try(','); - if (!ts_.Try('}')) { - Error(ts_.Peek(), "excess elements in array initializer"); - } - } -} - - -StructType::Iterator Parser::ParseStructDesignator(StructType* type, - const std::string& name) { - auto iter = type->Members().begin(); - for (; iter != type->Members().end(); ++iter) { - if ((*iter)->Anonymous()) { - auto anonyType = (*iter)->Type()->ToStruct(); - assert(anonyType); - if (anonyType->GetMember(name)) { - return iter; // ParseStructDesignator(anonyType); - } - } else if ((*iter)->Name() == name) { - return iter; - } - } - assert(false); - return iter; -} - - -void Parser::ParseStructInitializer(Declaration* decl, - StructType* type, - int offset, - bool designated) { - assert(type); - - auto hasBrace = ts_.Try('{'); - auto member = type->Members().begin(); - while (true) { - if (ts_.Test('}')) { - if (hasBrace) - ts_.Next(); - return; - } - - if (!designated && !hasBrace && (ts_.Test('.') || ts_.Test('['))) { - ts_.PutBack(); // Put the read comma(',') back - return; - } - - if ((designated = ts_.Try('.'))) { - auto tok = ts_.Expect(Token::IDENTIFIER); - const auto& name = tok->str_; - if (!type->GetMember(name)) { - Error(tok, "member '%s' not found", name.c_str()); - } - member = ParseStructDesignator(type, name); - } - if (member == type->Members().end()) - break; - - if ((*member)->Anonymous()) { - if (designated) { // Put back '.' and member name. - ts_.PutBack(); - ts_.PutBack(); - } - // Because offsets of member of anonymous struct/union are based - // directly on external struct/union - ParseInitializer(decl, (*member)->Type(), offset, designated, false, - (*member)->BitFieldBegin(), (*member)->BitFieldWidth()); - } else { - ParseInitializer(decl, (*member)->Type(), - offset + (*member)->Offset(), designated, false, - (*member)->BitFieldBegin(), (*member)->BitFieldWidth()); - } - designated = false; - ++member; - - // Union, just init the first member - if (!type->IsStruct()) - break; - - if (!hasBrace && member == type->Members().end()) - break; - - // Needless comma at the end is allowed - if (!ts_.Try(',')) { - if (hasBrace) - ts_.Expect('}'); - return; - } - } - - if (hasBrace) { - ts_.Try(','); - if (!ts_.Try('}')) { - Error(ts_.Peek(), "excess members in struct initializer"); - } - } -} - - -/* - * Statements - */ - -Stmt* Parser::ParseStmt() { - auto tok = ts_.Next(); - if (tok->IsEOF()) - Error(tok, "premature end of input"); - - switch (tok->tag_) { - // GNU extension: statement attributes - case Token::ATTRIBUTE: - TryAttributeSpecList(); - case ';': - return EmptyStmt::New(); - case '{': - return ParseCompoundStmt(); - case Token::IF: - return ParseIfStmt(); - case Token::SWITCH: - return ParseSwitchStmt(); - case Token::WHILE: - return ParseWhileStmt(); - case Token::DO: - return ParseDoStmt(); - case Token::FOR: - return ParseForStmt(); - case Token::GOTO: - return ParseGotoStmt(); - case Token::CONTINUE: - return ParseContinueStmt(); - case Token::BREAK: - return ParseBreakStmt(); - case Token::RETURN: - return ParseReturnStmt(); - case Token::CASE: - return ParseCaseStmt(); - case Token::DEFAULT: - return ParseDefaultStmt(); - } - - if (tok->IsIdentifier() && ts_.Try(':')) { - // GNU extension: label attributes - TryAttributeSpecList(); - return ParseLabelStmt(tok); - } - - ts_.PutBack(); - auto expr = ParseExpr(); - ts_.Expect(';'); - - return expr; -} - - -CompoundStmt* Parser::ParseCompoundStmt(FuncType* funcType) { - EnterBlock(funcType); - - std::list stmts; - - while (!ts_.Try('}')) { - if (ts_.Peek()->IsEOF()) { - Error(ts_.Peek(), "premature end of input"); - } - - if (IsType(ts_.Peek())) { - stmts.push_back(ParseDecl()); - } else { - stmts.push_back(ParseStmt()); - } - } - - auto scope = curScope_; - ExitBlock(); - - return CompoundStmt::New(stmts, scope); -} - - -IfStmt* Parser::ParseIfStmt() { - ts_.Expect('('); - auto tok = ts_.Peek(); - auto cond = ParseExpr(); - if (!cond->Type()->IsScalar()) { - Error(tok, "expect scalar"); - } - ts_.Expect(')'); - - auto then = ParseStmt(); - Stmt* els = nullptr; - if (ts_.Try(Token::ELSE)) - els = ParseStmt(); - - return IfStmt::New(cond, then, els); -} - - -/* - * for 循环结构: - * for (declaration; expression1; expression2) statement - * 展开后的结构: - * declaration - * cond: if (expression1) then empty - * else goto end - * statement - * step: expression2 - * goto cond - * next: - */ - -#define ENTER_LOOP_BODY(breakDest, continueDest) \ -{ \ - LabelStmt* breakDestBackup = breakDest_; \ - LabelStmt* continueDestBackup = continueDest_; \ - breakDest_ = breakDest; \ - continueDest_ = continueDest; - -#define EXIT_LOOP_BODY() \ - breakDest_ = breakDestBackup; \ - continueDest_ = continueDestBackup; \ -} - -ForStmt* Parser::ParseForStmt() { - EnterBlock(); - ts_.Expect('('); - // init - Stmt* init = nullptr; - if (IsType(ts_.Peek())) { - init = ParseDecl(); - } else if (!ts_.Try(';')) { - init = ParseExpr(); - ts_.Expect(';'); - } - // cond - Expr* cond = nullptr; - if (!ts_.Try(';')) { - cond = ParseExpr(); - ts_.Expect(';'); - } - // step - Expr* step = nullptr; - if (!ts_.Try(')')) { - step = ParseExpr(); - ts_.Expect(')'); - } - // body - Stmt* body = ParseStmt(); - ExitBlock(); - return ForStmt::New(body, init, cond, step); -} - - -/* - * while 循环结构: - * while (expression) statement - * 展开后的结构: - * cond: if (expression) then empty - * else goto end - * statement - * goto cond - * end: - */ -CompoundStmt* Parser::ParseWhileStmt() { - std::list stmts; - ts_.Expect('('); - auto tok = ts_.Peek(); - auto condExpr = ParseExpr(); - ts_.Expect(')'); - - if (!condExpr->Type()->IsScalar()) { - Error(tok, "scalar expression expected"); - } - - auto condLabel = LabelStmt::New(); - auto endLabel = LabelStmt::New(); - auto gotoEndStmt = JumpStmt::New(endLabel); - auto ifStmt = IfStmt::New(condExpr, EmptyStmt::New(), gotoEndStmt); - stmts.push_back(condLabel); - stmts.push_back(ifStmt); - - Stmt* bodyStmt; - ENTER_LOOP_BODY(endLabel, condLabel) - bodyStmt = ParseStmt(); - EXIT_LOOP_BODY() - - stmts.push_back(bodyStmt); - stmts.push_back(JumpStmt::New(condLabel)); - stmts.push_back(endLabel); - - return CompoundStmt::New(stmts); -} - - -/* - * do-while 循环结构: - * do statement while (expression) - * 展开后的结构: - * begin: statement - * cond: if (expression) then goto begin - * else goto end - * end: - */ -CompoundStmt* Parser::ParseDoStmt() { - auto beginLabel = LabelStmt::New(); - auto condLabel = LabelStmt::New(); - auto endLabel = LabelStmt::New(); - - Stmt* bodyStmt; - ENTER_LOOP_BODY(endLabel, beginLabel) - bodyStmt = ParseStmt(); - EXIT_LOOP_BODY() - - ts_.Expect(Token::WHILE); - ts_.Expect('('); - auto condExpr = ParseExpr(); - ts_.Expect(')'); - ts_.Expect(';'); - - auto gotoBeginStmt = JumpStmt::New(beginLabel); - auto gotoEndStmt = JumpStmt::New(endLabel); - auto ifStmt = IfStmt::New(condExpr, gotoBeginStmt, gotoEndStmt); - - std::list stmts; - stmts.push_back(beginLabel); - stmts.push_back(bodyStmt); - stmts.push_back(condLabel); - stmts.push_back(ifStmt); - stmts.push_back(endLabel); - - return CompoundStmt::New(stmts); -} - - -#undef ENTER_LOOP_BODY -#undef EXIT_LOOP_BODY - - -#define ENTER_SWITCH_BODY(breakDest, caseLabels) \ -{ \ - CaseLabelList* caseLabelsBackup = caseLabels_; \ - LabelStmt* defaultLabelBackup = defaultLabel_; \ - LabelStmt* breakDestBackup = breakDest_; \ - breakDest_ = breakDest; \ - caseLabels_ = &caseLabels; \ - defaultLabel_ = nullptr; - -#define EXIT_SWITCH_BODY() \ - caseLabels_ = caseLabelsBackup; \ - breakDest_ = breakDestBackup; \ - defaultLabel_ = defaultLabelBackup; \ -} - - -/* - * switch - * jump stmt (skip case labels) - * case labels - * jump stmts - * default jump stmt - */ -CompoundStmt* Parser::ParseSwitchStmt() { - std::list stmts; - ts_.Expect('('); - auto tok = ts_.Peek(); - auto expr = ParseExpr(); - ts_.Expect(')'); - - if (!expr->Type()->IsInteger()) { - Error(tok, "switch quantity not an integer"); - } - - auto testLabel = LabelStmt::New(); - auto endLabel = LabelStmt::New(); - auto t = TempVar::New(expr->Type()); - auto assign = BinaryOp::New(tok, '=', t, expr); - stmts.push_back(assign); - stmts.push_back(JumpStmt::New(testLabel)); - - CaseLabelList caseLabels; - ENTER_SWITCH_BODY(endLabel, caseLabels); - - auto bodyStmt = ParseStmt(); // Fill caseLabels and defaultLabel - stmts.push_back(bodyStmt); - stmts.push_back(JumpStmt::New(endLabel)); - stmts.push_back(testLabel); - - for (auto iter = caseLabels.begin(); - iter != caseLabels.end(); ++iter) { - auto cond = BinaryOp::New(tok, Token::EQ, t, iter->first); - auto then = JumpStmt::New(iter->second); - auto ifStmt = IfStmt::New(cond, then, nullptr); - stmts.push_back(ifStmt); - } - if (defaultLabel_) - stmts.push_back(JumpStmt::New(defaultLabel_)); - EXIT_SWITCH_BODY(); - - stmts.push_back(endLabel); - - return CompoundStmt::New(stmts); -} - - -#undef ENTER_SWITCH_BODY -#undef EXIT_SWITCH_BODY - - -CompoundStmt* Parser::ParseCaseStmt() { - auto tok = ts_.Peek(); - - // Case ranges: Non-standard GNU extension - long begin, end; - begin = Evaluator().Eval(ParseAssignExpr()); - if (ts_.Try(Token::ELLIPSIS)) - end = Evaluator().Eval(ParseAssignExpr()); - else - end = begin; - ts_.Expect(':'); - - auto labelStmt = LabelStmt::New(); - for (auto val = begin; val <= end; ++val) { - if (val > INT_MAX) - Error(tok, "case range exceed range of int"); - auto cons = Constant::New(tok, T_INT, val); - caseLabels_->push_back(std::make_pair(cons, labelStmt)); - } - - std::list stmts; - stmts.push_back(labelStmt); - stmts.push_back(ParseStmt()); - - return CompoundStmt::New(stmts); -} - - -CompoundStmt* Parser::ParseDefaultStmt() { - auto tok = ts_.Peek(); - ts_.Expect(':'); - if (defaultLabel_) { // There is a 'default' stmt - Error(tok, "multiple default labels in one switch"); - } - auto labelStmt = LabelStmt::New(); - defaultLabel_ = labelStmt; - - std::list stmts; - stmts.push_back(labelStmt); - stmts.push_back(ParseStmt()); - - return CompoundStmt::New(stmts); -} - - -JumpStmt* Parser::ParseContinueStmt() { - auto tok = ts_.Peek(); - ts_.Expect(';'); - if (continueDest_ == nullptr) { - Error(tok, "'continue' is allowed only in loop"); - } - - return JumpStmt::New(continueDest_); -} - - -JumpStmt* Parser::ParseBreakStmt() { - auto tok = ts_.Peek(); - ts_.Expect(';'); - if (breakDest_ == nullptr) { - Error(tok, "'break' is allowed only in switch/loop"); - } - - return JumpStmt::New(breakDest_); -} - - -ReturnStmt* Parser::ParseReturnStmt() { - Expr* expr; - - if (ts_.Try(';')) { - expr = nullptr; - } else { - expr = ParseExpr(); - ts_.Expect(';'); - - auto retType = curFunc_->FuncType()->Derived(); - expr = Expr::MayCast(expr, retType); - } - - return ReturnStmt::New(expr); -} - - -JumpStmt* Parser::ParseGotoStmt() { - auto label = ts_.Peek(); - ts_.Expect(Token::IDENTIFIER); - ts_.Expect(';'); - - auto labelStmt = FindLabel(label->str_); - if (labelStmt) { - return JumpStmt::New(labelStmt); - } - - auto unresolvedJump = JumpStmt::New(nullptr); - unresolvedJumps_.push_back(std::make_pair(label, unresolvedJump)); - - return unresolvedJump; -} - - -CompoundStmt* Parser::ParseLabelStmt(const Token* label) { - const auto& labelStr = label->str_; - auto stmt = ParseStmt(); - if (nullptr != FindLabel(labelStr)) { - Error(label, "redefinition of label '%s'", labelStr.c_str()); - } - - auto labelStmt = LabelStmt::New(); - AddLabel(labelStr, labelStmt); - std::list stmts; - stmts.push_back(labelStmt); - stmts.push_back(stmt); - - return CompoundStmt::New(stmts); -} - - -bool Parser::IsBuiltin(const std::string& name) { - return name == "__builtin_va_arg" || - name == "__builtin_va_start"; -} - - -bool Parser::IsBuiltin(FuncType* type) { - assert(vaStartType_ && vaArgType_); - return type == vaStartType_ || type == vaArgType_; -} - - -// Builtin functions will be inlined -void Parser::DefineBuiltins() { - // FIXME: potential bug: using same object for params!!! - auto voidPtr = PointerType::New(VoidType::New()); - auto param = Object::New(nullptr, voidPtr); - FuncType::ParamList pl; - pl.push_back(param); - pl.push_back(param); - vaStartType_ = FuncType::New(VoidType::New(), F_INLINE, false, pl); - vaArgType_ = FuncType::New(voidPtr, F_INLINE, false, pl); -} - - -Identifier* Parser::GetBuiltin(const Token* tok) { - assert(vaStartType_ && vaArgType_); - static Identifier* vaStart = nullptr; - static Identifier* vaArg = nullptr; - const auto& name = tok->str_; - if (name == "__builtin_va_start") { - if (!vaStart) - vaStart = Identifier::New(tok, vaStartType_, Linkage::L_EXTERNAL); - return vaStart; - } else if (name == "__builtin_va_arg") { - if (!vaArg) - vaArg = Identifier::New(tok, vaArgType_, Linkage::L_EXTERNAL); - return vaArg; - } - assert(false); - return nullptr; -} - - -/* - * GNU extensions - */ - -// Attribute -ASTNode::AttrList Parser::TryAttributeSpecList() { - ASTNode::AttrList attrList; - while (ts_.Try(Token::ATTRIBUTE)) - ParseAttributeSpec(attrList); - return attrList; -} - - -void Parser::ParseAttributeSpec(ASTNode::AttrList& attrList) { - ts_.Expect('('); - ts_.Expect('('); - - while (!ts_.Try(')')) { - attrList.push_back(ParseAttribute()); - if (!ts_.Try(',')) { - ts_.Expect(')'); - break; - } - } - ts_.Expect(')'); -} - - -ASTNode::Attr Parser::ParseAttribute() { - ASTNode::Attr ret; - if (!ts_.Test(Token::IDENTIFIER)) - return ret; - auto tok = ts_.Next(); - std::string name = tok->str_; - // set kind - if(name == "aligned") - ret.kind = ASTNode::Attr::ALIGNED; - else if(name == "readonly") - ret.kind = ASTNode::Attr::READONLY; - else if(name == "writeonly") - ret.kind = ASTNode::Attr::WRITEONLY; - else if(name == "multipleof") - ret.kind = ASTNode::Attr::MULTIPLEOF; - else if(name == "noalias") - ret.kind = ASTNode::Attr::NOALIAS; - else if(name == "retune") - ret.kind = ASTNode::Attr::RETUNE; - else - Error(tok, "unknown attribute kind"); - // set exprs - if (ts_.Try('(')) { - if (ts_.Try(')')) - return ret; - ret.vals.push_back(ParseExpr()); - if (ts_.Test(',')) { - while (ts_.Try(',')) {} - } - ts_.Try(')'); - } - return ret; -} diff --git a/lib/lang/scanner.cc b/lib/lang/scanner.cc deleted file mode 100644 index 9c394ecfd..000000000 --- a/lib/lang/scanner.cc +++ /dev/null @@ -1,452 +0,0 @@ -#include "triton/lang/scanner.h" - -#include -#include - - -void Scanner::Tokenize(TokenSequence& ts) { - while (true) { - auto tok = Scan(); - if (tok->tag_ == Token::END) { - if (ts.Empty() || (ts.Back()->tag_ != Token::NEW_LINE)) { - auto t = Token::New(*tok); - t->tag_ = Token::NEW_LINE; - t->str_ = "\n"; - ts.InsertBack(t); - } - break; - } else { - if (!ts.Empty() && ts.Back()->tag_ == Token::NEW_LINE) - tok->ws_ = true; - ts.InsertBack(tok); - } - } -} - - -std::string Scanner::ScanHeadName(const Token* lhs, const Token* rhs) { - std::string str; - const char* begin = lhs->loc_.Begin() + 1; - const char* end = rhs->loc_.Begin(); - for (; begin != end; ++begin) { - if (*begin == '\n' && str.back() == '\\') - str.pop_back(); - else - str.push_back(*begin); - } - return str; -} - - -Token* Scanner::Scan(bool ws) { - tok_.ws_ = ws; - SkipWhiteSpace(); - - Mark(); - - if (Test('\n')) { - auto ret = MakeNewLine(); - Next(); - return ret; - } - auto c = Next(); - switch (c) { - case '#': return MakeToken(Try('#') ? Token::DSHARP: c); - case ':': return MakeToken(Try('>') ? ']': c); - case '(': case ')': case '[': case ']': - case '?': case ',': case '{': case '}': - case '~': case ';': case '@': - return MakeToken(c); - case '-': - if (Try('>')) return MakeToken(Token::PTR); - if (Try('-')) return MakeToken(Token::DEC); - if (Try('=')) return MakeToken(Token::SUB_ASSIGN); - return MakeToken(c); - case '+': - if (Try('+')) return MakeToken(Token::INC); - if (Try('=')) return MakeToken(Token::ADD_ASSIGN); - return MakeToken(c); - case '<': - if (Try('<')) return MakeToken(Try('=') ? Token::LEFT_ASSIGN: Token::LEFT); - if (Try('=')) return MakeToken(Token::LE); - if (Try(':')) return MakeToken('['); - if (Try('%')) return MakeToken('{'); - return MakeToken(c); - case '%': - if (Try('=')) return MakeToken(Token::MOD_ASSIGN); - if (Try('>')) return MakeToken('}'); - if (Try(':')) { - if (Try('%')) { - if (Try(':')) return MakeToken(Token::DSHARP); - PutBack(); - } - return MakeToken('#'); - } - return MakeToken(c); - case '>': - if (Try('>')) return MakeToken(Try('=') ? Token::RIGHT_ASSIGN: Token::RIGHT); - if (Try('=')) return MakeToken(Token::GE); - return MakeToken(c); - case '=': return MakeToken(Try('=') ? Token::EQ: c); - case '!': return MakeToken(Try('=') ? Token::NE: c); - case '&': - if (Try('&')) return MakeToken(Token::LOGICAL_AND); - if (Try('=')) return MakeToken(Token::AND_ASSIGN); - return MakeToken(c); - case '|': - if (Try('|')) return MakeToken(Token::LOGICAL_OR); - if (Try('=')) return MakeToken(Token::OR_ASSIGN); - return MakeToken(c); - case '*': return MakeToken(Try('=') ? Token::MUL_ASSIGN: c); - case '/': - if (Test('/') || Test('*')) { - SkipComment(); - return Scan(true); - } - return MakeToken(Try('=') ? Token::DIV_ASSIGN: c); - case '^': return MakeToken(Try('=') ? Token::XOR_ASSIGN: c); - case '.': - if (isdigit(Peek())) return SkipNumber(); - if (Try('.')) { - if (Try('.')) return MakeToken(Token::ELLIPSIS); - PutBack(); - return MakeToken('.'); - } - return MakeToken(c); - case '0' ... '9': return SkipNumber(); - case 'u': case 'U': case 'L': { - /*auto enc = */ScanEncoding(c); - if (Try('\'')) return SkipCharacter(); - if (Try('\"')) return SkipLiteral(); - return SkipIdentifier(); - } - case '\'': return SkipCharacter(); - case '\"': return SkipLiteral(); - case 'a' ... 't': case 'v' ... 'z': case 'A' ... 'K': - case 'M' ... 'T': case 'V' ... 'Z': case '_': case '$': - case 0x80 ... 0xfd: - return SkipIdentifier(); - case '\\': - // Universal character name is allowed in identifier - if (Test('u') || Test('U')) - return SkipIdentifier(); - return MakeToken(Token::INVALID); - case '\0': return MakeToken(Token::END); - default: return MakeToken(Token::INVALID); - } -} - - -void Scanner::SkipWhiteSpace() { - while (isspace(Peek()) && Peek() != '\n') { - tok_.ws_ = true; - Next(); - } -} - - -void Scanner::SkipComment() { - if (Try('/')) { - // Line comment terminated an newline or eof - while (!Empty()) { - if (Peek() == '\n') - return; - Next(); - } - return; - } else if (Try('*')) { - while (!Empty()) { - auto c = Next(); - if (c == '*' && Peek() == '/') { - Next(); - return; - } - } - Error(loc_, "unterminated block comment"); - } - assert(false); -} - - -std::string Scanner::ScanIdentifier() { - std::string val; - while (!Empty()) { - auto c = Next(); - if (IsUCN(c)) { - c = ScanEscaped(); // Call ScanUCN() - AppendUCN(val, c); - } else { - val.push_back(c); - } - } - return val; -} - - -Token* Scanner::SkipIdentifier() { - PutBack(); - auto c = Next(); - while (isalnum(c) - || (0x80 <= c && c <= 0xfd) - || c == '_' - || c == '$' - || IsUCN(c)) { - if (IsUCN(c)) - c = ScanEscaped(); // Just read it - c = Next(); - } - PutBack(); - return MakeToken(Token::IDENTIFIER); -} - - -// Scan PP-Number -Token* Scanner::SkipNumber() { - PutBack(); - bool sawHexPrefix = false; - int tag = Token::I_CONSTANT; - auto c = Next(); - while (c == '.' || isdigit(c) || isalpha(c) || c == '_' || IsUCN(c)) { - if (c == 'e' || c =='E' || c == 'p' || c == 'P') { - if (!Try('-')) Try('+'); - if (!((c == 'e' || c == 'E') && sawHexPrefix)) - tag = Token::F_CONSTANT; - } else if (IsUCN(c)) { - ScanEscaped(); - } else if (c == '.') { - tag = Token::F_CONSTANT; - } else if (c == 'x' || c == 'X') { - sawHexPrefix = true; - } - c = Next(); - } - PutBack(); - return MakeToken(tag); -} - - -Encoding Scanner::ScanLiteral(std::string& val) { - auto enc = Test('\"') ? Encoding::NONE: ScanEncoding(Next()); - Next(); - val.resize(0); - while (!Test('\"')) { - auto c = Next(); - bool isucn = IsUCN(c); - if (c == '\\') - c = ScanEscaped(); - if (isucn) - AppendUCN(val, c); - else - val.push_back(c); - } - return enc; -} - - -Token* Scanner::SkipLiteral() { - auto c = Next(); - while (c != '\"' && c != '\n' && c != '\0') { - if (c == '\\') Next(); - c = Next(); - } - if (c != '\"') - Error(loc_, "unterminated string literal"); - return MakeToken(Token::LITERAL); -} - - -Encoding Scanner::ScanCharacter(int& val) { - auto enc = Test('\'') ? Encoding::NONE: ScanEncoding(Next()); - Next(); - val = 0; - while (!Test('\'')) { - auto c = Next(); - if (c == '\\') - c = ScanEscaped(); - if (enc == Encoding::NONE) - val = (val << 8) + c; - else - val = c; - } - return enc; -} - - -Token* Scanner::SkipCharacter() { - auto c = Next(); - while (c != '\'' && c != '\n' && c != '\0') { - if (c == '\\') Next(); - c = Next(); - } - if (c != '\'') - Error(loc_, "unterminated character constant"); - return MakeToken(Token::C_CONSTANT); -} - - -int Scanner::ScanEscaped() { - auto c = Next(); - switch (c) { - case '\\': case '\'': case '\"': case '\?': - return c; - case 'a': return '\a'; - case 'b': return '\b'; - case 'f': return '\f'; - case 'n': return '\n'; - case 'r': return '\r'; - case 't': return '\t'; - case 'v': return '\v'; - // Non-standard GCC extention - case 'e': return '\033'; - case 'x': return ScanHexEscaped(); - case '0' ... '7': return ScanOctEscaped(c); - case 'u': return ScanUCN(4); - case 'U': return ScanUCN(8); - default: Error(loc_, "unrecognized escape character '%c'", c); - } - return c; // Make compiler happy -} - - -int Scanner::ScanHexEscaped() { - int val = 0, c = Peek(); - if (!isxdigit(c)) - Error(loc_, "expect xdigit, but got '%c'", c); - while (isxdigit(c)) { - val = (val << 4) + XDigit(c); - Next(); - c = Peek(); - } - return val; -} - - -int Scanner::ScanOctEscaped(int c) { - int val = XDigit(c); - c = Peek(); - if (!IsOctal(c)) - return val; - val = (val << 3) + XDigit(c); - Next(); - - c = Peek(); - if (!IsOctal(c)) - return val; - val = (val << 3) + XDigit(c); - Next(); - return val; -} - - -int Scanner::ScanUCN(int len) { - assert(len == 4 || len == 8); - int val = 0; - for (auto i = 0; i < len; ++i) { - auto c = Next(); - if (!isxdigit(c)) - Error(loc_, "expect xdigit, but got '%c'", c); - val = (val << 4) + XDigit(c); - } - return val; -} - - -int Scanner::XDigit(int c) { - switch (c) { - case '0' ... '9': return c - '0'; - case 'a' ... 'z': return c - 'a' + 10; - case 'A' ... 'Z': return c - 'A' + 10; - default: assert(false); return c; - } -} - - -Encoding Scanner::ScanEncoding(int c) { - switch (c) { - case 'u': return Try('8') ? Encoding::UTF8: Encoding::CHAR16; - case 'U': return Encoding::CHAR32; - case 'L': return Encoding::WCHAR; - default: assert(false); return Encoding::NONE; - } -} - - -std::string* ReadFile(const std::string& filename) { - FILE* f = fopen(filename.c_str(), "r"); - if (!f) Error("%s: No such file or directory", filename.c_str()); - auto text = new std::string; - int c; - while (EOF != (c = fgetc(f))) - text->push_back(c); - fclose(f); - return text; -} - - -int Scanner::Next() { - int c = Peek(); - ++p_; - if (c == '\n') { - ++loc_.line_; - loc_.column_ = 1; - loc_.lineBegin_ = p_; - } else { - ++loc_.column_; - } - return c; -} - - -int Scanner::Peek() { - int c = (uint8_t)(*p_); - if (c == '\\' && p_[1] == '\n') { - p_ += 2; - ++loc_.line_; - loc_.column_ = 1; - loc_.lineBegin_ = p_; - return Peek(); - } - return c; -} - - -// There couldn't be more than one PutBack() that -// cross two line, so just leave lineBegin, because -// we never care about the pos of newline token -void Scanner::PutBack() { - int c = *--p_; - if (c == '\n' && p_[-1] == '\\') { - --loc_.line_; - --p_; - return PutBack(); - } else if (c == '\n') { - --loc_.line_; - } else { - --loc_.column_; - } -} - - -Token* Scanner::MakeToken(int tag) { - tok_.tag_ = tag; - auto& str = tok_.str_; - str.resize(0); - const char* p = tok_.loc_.lineBegin_ + tok_.loc_.column_ - 1; - for (; p < p_; ++p) { - if (p[0] == '\n' && p[-1] == '\\') - str.pop_back(); - else - str.push_back(p[0]); - } - return Token::New(tok_); -} - - -/* - * New line is special, it is generated before reading the character '\n' - */ -Token* Scanner::MakeNewLine() { - tok_.tag_ = '\n'; - tok_.str_ = std::string(p_, p_ + 1); - return Token::New(tok_); -} diff --git a/lib/lang/scope.cc b/lib/lang/scope.cc deleted file mode 100644 index 9e487deba..000000000 --- a/lib/lang/scope.cc +++ /dev/null @@ -1,111 +0,0 @@ -#include "triton/lang/scope.h" - -#include "triton/lang/ast.h" - -#include -#include - - -Identifier* Scope::Find(const Token* tok) { - auto ret = Find(tok->str_); - if (ret) ret->SetTok(tok); - return ret; -} - - -Identifier* Scope::FindInCurScope(const Token* tok) { - auto ret = FindInCurScope(tok->str_); - if (ret) ret->SetTok(tok); - return ret; -} - - -Identifier* Scope::FindTag(const Token* tok) { - auto ret = FindTag(tok->str_); - if (ret) ret->SetTok(tok); - return ret; -} - - -Identifier* Scope::FindTagInCurScope(const Token* tok) { - auto ret = FindTagInCurScope(tok->str_); - if (ret) ret->SetTok(tok); - return ret; -} - - -void Scope::Insert(Identifier* ident) { - Insert(ident->Name(), ident); -} - - -void Scope::InsertTag(Identifier* ident) { - Insert(TagName(ident->Name()), ident); -} - - -Identifier* Scope::Find(const std::string& name) { - auto ident = identMap_.find(name); - if (ident != identMap_.end()) - return ident->second; - if (type_ == S_FILE || parent_ == nullptr) - return nullptr; - return parent_->Find(name); -} - - -Identifier* Scope::FindInCurScope(const std::string& name) { - auto ident = identMap_.find(name); - if (ident == identMap_.end()) - return nullptr; - return ident->second; -} - - -void Scope::Insert(const std::string& name, Identifier* ident) { - assert(FindInCurScope(name) == nullptr); - identMap_[name] = ident; -} - - -Identifier* Scope::FindTag(const std::string& name) { - auto tag = Find(TagName(name)); - if (tag) assert(tag->ToTypeName()); - return tag; -} - - -Identifier* Scope::FindTagInCurScope(const std::string& name) { - auto tag = FindInCurScope(TagName(name)); - assert(tag == nullptr || tag->ToTypeName()); - return tag; -} - - -Scope::TagList Scope::AllTagsInCurScope() const { - TagList tags; - for (auto& kv: identMap_) { - if (IsTagName(kv.first)) - tags.push_back(kv.second); - } - return tags; -} - - -void Scope::Print() { - std::cout << "scope: " << this << std::endl; - - auto iter = identMap_.begin(); - for (; iter != identMap_.end(); ++iter) { - auto name = iter->first; - auto ident = iter->second; - if (ident->ToTypeName()) { - std::cout << name << "\t[type:\t" - << ident->Type()->Str() << "]" << std::endl; - } else { - std::cout << name << "\t[object:\t" - << ident->Type()->Str() << "]" << std::endl; - } - } - std::cout << std::endl; -} diff --git a/lib/lang/token.cc b/lib/lang/token.cc deleted file mode 100644 index 5e9b535b6..000000000 --- a/lib/lang/token.cc +++ /dev/null @@ -1,271 +0,0 @@ -#include "triton/lang/token.h" - -#include "triton/lang/mem_pool.h" -#include "triton/lang/parser.h" - - -static MemPoolImp tokenPool; - -const std::unordered_map Token::kwTypeMap_ { - { "__constant__", Token::CMEM }, - { "__global__", Token::GLOBAL }, - { "auto", Token::AUTO }, - { "break", Token::BREAK }, - { "case", Token::CASE }, - { "char", Token::CHAR }, - { "const", Token::CONST }, - { "continue", Token::CONTINUE }, - { "default", Token::DEFAULT }, - { "do", Token::DO }, - { "double", Token::DOUBLE }, - { "else", Token::ELSE }, - { "enum", Token::ENUM }, - { "extern", Token::EXTERN }, - { "float", Token::FLOAT }, - { "for", Token::FOR }, - { "goto", Token::GOTO }, - { "half", Token::HALF }, - { "if", Token::IF }, - { "inline", Token::INLINE }, - { "int", Token::INT }, - { "long", Token::LONG }, - { "newaxis", Token::NEWAXIS }, - { "signed", Token::SIGNED }, - { "unsigned", Token::UNSIGNED }, - { "restrict", Token::RESTRICT }, - { "return", Token::RETURN }, - { "short", Token::SHORT }, - { "sizeof", Token::SIZEOF }, - { "static", Token::STATIC }, - { "struct", Token::STRUCT }, - { "switch", Token::SWITCH }, - { "typedef", Token::TYPEDEF }, - { "union", Token::UNION }, - { "void", Token::VOID }, - { "volatile", Token::VOLATILE }, - { "while", Token::WHILE }, - { "bitcast", Token::BITCAST }, - { "exp", Token::EXP }, - { "log", Token::LOG }, - { "sqrtf", Token::SQRTF }, - { "_Alignas", Token::ALIGNAS }, - { "_Alignof", Token::ALIGNOF }, - { "_Atomic", Token::ATOMIC }, - { "__attribute__", Token::ATTRIBUTE }, - { "_Bool", Token::BOOL }, - { "_Complex", Token::COMPLEX }, - { "_Generic", Token::GENERIC }, - { "_Imaginary", Token::IMAGINARY }, - { "_Noreturn", Token::NORETURN }, - { "_Static_assert", Token::STATIC_ASSERT }, - { "_Thread_local", Token::THREAD }, - { "max", Token::MAX }, - { "min", Token::MIN }, -}; - -const std::unordered_map Token::tagLexemeMap_ { - { '(', "(" }, - { ')', ")" }, - { '[', "[" }, - { ']', "]" }, - { ':', ":" }, - { ',', "," }, - { ';', ";" }, - { '+', "+" }, - { '-', "-" }, - { '*', "*" }, - { '/', "/" }, - { '|', "|" }, - { '&', "&" }, - { '<', "<" }, - { '>', ">" }, - { '=', "=" }, - { '.', "." }, - { '%', "%" }, - { '{', "{" }, - { '}', "}" }, - { '^', "^" }, - { '~', "~" }, - { '!', "!" }, - { '?', "?" }, - { '#', "#" }, - { '@', "@" }, - - { Token::DSHARP, "##" }, - { Token::PTR, "->" }, - { Token::INC, "++" }, - { Token::DEC, "--" }, - { Token::LEFT, "<<" }, - { Token::RIGHT, ">>" }, - { Token::LE, "<=" }, - { Token::GE, ">=" }, - { Token::EQ, "==" }, - { Token::NE, "!=" }, - { Token::LOGICAL_AND, "&&" }, - { Token::LOGICAL_OR, "||" }, - { Token::MUL_ASSIGN, "*=" }, - { Token::DIV_ASSIGN, "/=" }, - { Token::MOD_ASSIGN, "%=" }, - { Token::ADD_ASSIGN, "+=" }, - { Token::SUB_ASSIGN, "-=" }, - { Token::LEFT_ASSIGN, "<<=" }, - { Token::RIGHT_ASSIGN, ">>=" }, - { Token::AND_ASSIGN, "&=" }, - { Token::XOR_ASSIGN, "^=" }, - { Token::OR_ASSIGN, "|=" }, - { Token::ELLIPSIS, "..." }, - { Token::AUTO, "auto" }, - { Token::BREAK, "break" }, - { Token::CASE, "case" }, - { Token::CHAR, "char" }, - { Token::CONST, "const" }, - { Token::CONTINUE, "continue" }, - { Token::DEFAULT, "default" }, - { Token::DO, "do" }, - { Token::DOUBLE, "double" }, - { Token::ELSE, "else" }, - { Token::ENUM, "enum" }, - { Token::EXTERN, "extern" }, - { Token::FLOAT, "float" }, - { Token::FOR, "for" }, - { Token::GLOBAL, "global" }, - { Token::GOTO, "goto" }, - { Token::IF, "if" }, - { Token::INLINE, "inline" }, - { Token::INT, "int" }, - { Token::LONG, "long" }, - { Token::NEWAXIS, "newaxis" }, - { Token::SIGNED, "signed" }, - { Token::UNSIGNED, "unsigned" }, - { Token::RESTRICT, "restrict" }, - { Token::RETURN, "return" }, - { Token::SHORT, "short" }, - { Token::SIZEOF, "sizeof" }, - { Token::STATIC, "static" }, - { Token::STRUCT, "struct" }, - { Token::SWITCH, "switch" }, - { Token::TYPEDEF, "typedef" }, - { Token::UNION, "union" }, - { Token::VOID, "void" }, - { Token::VOLATILE, "volatile" }, - { Token::WHILE, "while" }, - { Token::BITCAST, "bitcast" }, - { Token::EXP, "exp" }, - { Token::LOG, "log" }, - { Token::SQRTF, "sqrtf" }, - { Token::ALIGNAS, "_Alignas" }, - { Token::ALIGNOF, "_Alignof" }, - { Token::ATOMIC, "_Atomic" }, - { Token::ATTRIBUTE, "__attribute__" }, - { Token::BOOL, "_Bool" }, - { Token::COMPLEX, "_Complex" }, - { Token::GENERIC, "_Generic" }, - { Token::IMAGINARY, "_Imaginary" }, - { Token::NORETURN, "_Noreturn" }, - { Token::STATIC_ASSERT, "_Static_assert" }, - { Token::THREAD, "_Thread_local" }, - - { Token::END, "(eof)" }, - { Token::IDENTIFIER, "(identifier)" }, - { Token::CONSTANT, "(constant)" }, - { Token::LITERAL, "(string literal)" }, -}; - - -Token* Token::New(int tag) { - return new (tokenPool.Alloc()) Token(tag); -} - - -Token* Token::New(const Token& other) { - return new (tokenPool.Alloc()) Token(other); -} - - -Token* Token::New(int tag, - const SourceLocation& loc, - const std::string& str, - bool ws) { - return new (tokenPool.Alloc()) Token(tag, loc, str, ws); -} - - -TokenSequence TokenSequence::GetLine() { - auto begin = begin_; - while (begin_ != end_ && (*begin_)->tag_ != Token::NEW_LINE) - ++begin_; - auto end = begin_; - return {tokList_, begin, end}; -} - - -/* - * If this seq starts from the begin of a line. - * Called only after we have saw '#' in the token sequence. - */ -bool TokenSequence::IsBeginOfLine() const { - if (begin_ == tokList_->begin()) - return true; - - auto pre = begin_; - --pre; - - // We do not insert a newline at the end of a source file. - // Thus if two token have different filename, the second is - // the begin of a line. - return ((*pre)->tag_ == Token::NEW_LINE || - (*pre)->loc_.filename_ != (*begin_)->loc_.filename_); -} - -const Token* TokenSequence::Peek() const { - static auto eof = Token::New(Token::END); - if (begin_ != end_ && (*begin_)->tag_ == Token::NEW_LINE) { - ++begin_; - return Peek(); - } else if (begin_ == end_) { - if (end_ != tokList_->begin()) - *eof = *Back(); - eof->tag_ = Token::END; - return eof; - } else if (parser_ && (*begin_)->tag_ == Token::IDENTIFIER && - (*begin_)->str_ == "__func__") { - auto filename = Token::New(*(*begin_)); - filename->tag_ = Token::LITERAL; - filename->str_ = "\"" + parser_->CurFunc()->Name() + "\""; - *begin_ = filename; - } - return *begin_; -} - - -const Token* TokenSequence::Expect(int expect) { - auto tok = Peek(); - if (!Try(expect)) { - Error(tok, "'%s' expected, but got '%s'", - Token::Lexeme(expect), tok->str_.c_str()); - } - return tok; -} - -void TokenSequence::Print(FILE* fp) const { - unsigned lastLine = 0; - auto ts = *this; - while (!ts.Empty()) { - auto tok = ts.Next(); - if (lastLine != tok->loc_.line_) { - fputs("\n", fp); - for (unsigned i = 0; i < tok->loc_.column_; ++i) - fputc(' ', fp); - } else if (tok->ws_) { - fputc(' ', fp); - } - fputs(tok->str_.c_str(), fp); - fflush(fp); - lastLine = tok->loc_.line_; - } - fputs("\n", fp); -} - -//void TokenSequence::Print(std::string *str) const { - -//} diff --git a/lib/lang/type.cc b/lib/lang/type.cc deleted file mode 100644 index dc0b65125..000000000 --- a/lib/lang/type.cc +++ /dev/null @@ -1,508 +0,0 @@ -#include "triton/lang/type.h" - -#include "triton/lang/ast.h" -#include "triton/lang/scope.h" -#include "triton/lang/token.h" - -#include -#include -#include - - -static MemPoolImp voidTypePool; -static MemPoolImp arrayTypePool; -static MemPoolImp tileTypePool; -static MemPoolImp funcTypePool; -static MemPoolImp pointerTypePool; -static MemPoolImp structUnionTypePool; -static MemPoolImp arithmTypePool; - - -QualType Type::MayCast(QualType type, bool inProtoScope) { - auto funcType = type->ToFunc(); - auto arrayType = type->ToArray(); - if (funcType) { - return PointerType::New(funcType); - } else if (arrayType) { - auto ret = PointerType::New(arrayType->Derived()); - // C11 6.7.6.3 [7]: qualifiers are specified in '[]' - // As we do not support qualifiers in '[]', the qualifier whould be none - return QualType(ret, inProtoScope? 0: Qualifier::CONST); - } - return type; -} - -const Type* Type::ScalarType() const { - if(IsScalar()) - return this; - if(const TileType* p = ToTile()) - return p->Derived().GetPtr(); - return nullptr; -} - -Type* Type::ScalarType() { - auto cthis = const_cast(this); - return const_cast(cthis->ScalarType()); -} - -VoidType* VoidType::New() { - static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool); - return ret; -} - - -ArithmType* ArithmType::New(int typeSpec) { -#define NEW_TYPE(tag) \ - new (arithmTypePool.Alloc()) ArithmType(&arithmTypePool, tag); - - static auto boolType = NEW_TYPE(T_BOOL); - static auto charType = NEW_TYPE(T_CHAR); - static auto ucharType = NEW_TYPE(T_UNSIGNED | T_CHAR); - static auto shortType = NEW_TYPE(T_SHORT); - static auto ushortType = NEW_TYPE(T_UNSIGNED | T_SHORT); - static auto intType = NEW_TYPE(T_INT); - static auto uintType = NEW_TYPE(T_UNSIGNED | T_INT); - static auto longType = NEW_TYPE(T_LONG); - static auto ulongType = NEW_TYPE(T_UNSIGNED | T_LONG); - static auto llongType = NEW_TYPE(T_LLONG) - static auto ullongType = NEW_TYPE(T_UNSIGNED | T_LLONG); - static auto halfType = NEW_TYPE(T_HALF); - static auto floatType = NEW_TYPE(T_FLOAT); - static auto doubleType = NEW_TYPE(T_DOUBLE); - static auto ldoubleType = NEW_TYPE(T_LONG | T_DOUBLE); - - auto tag = ArithmType::Spec2Tag(typeSpec); - switch (tag) { - case T_BOOL: return boolType; - case T_CHAR: return charType; - case T_UNSIGNED | T_CHAR: return ucharType; - case T_SHORT: return shortType; - case T_UNSIGNED | T_SHORT:return ushortType; - case T_INT: return intType; - case T_UNSIGNED: - case T_UNSIGNED | T_INT: return uintType; - case T_LONG: return longType; - case T_UNSIGNED | T_LONG: return ulongType; - case T_LLONG: return llongType; - case T_UNSIGNED | T_LLONG:return ullongType; - case T_HALF: return halfType; - case T_FLOAT: return floatType; - case T_DOUBLE: return doubleType; - case T_LONG | T_DOUBLE: return ldoubleType; - default: - assert(tag & T_COMPLEX); - Error("complex not supported yet"); - } - return nullptr; // Make compiler happy - -#undef NEW_TYPE -} - - -ArrayType* ArrayType::New(int len, QualType eleType) { - return new (arrayTypePool.Alloc()) - ArrayType(&arrayTypePool, len, eleType); -} - - -ArrayType* ArrayType::New(Expr* expr, QualType eleType) { - return new (arrayTypePool.Alloc()) - ArrayType(&arrayTypePool, expr, eleType); -} - -TileType* TileType::New(const ShapeInt &shape, QualType eleType) { - return new (tileTypePool.Alloc()) - TileType(&tileTypePool, shape, eleType); -} - -FuncType* FuncType::New(QualType derived, - int funcSpec, - bool variadic, - const ParamList& params) { - return new (funcTypePool.Alloc()) - FuncType(&funcTypePool, derived, funcSpec, variadic, params); -} - - -PointerType* PointerType::New(QualType derived) { - return new (pointerTypePool.Alloc()) - PointerType(&pointerTypePool, derived); -} - - -StructType* StructType::New(bool isStruct, - bool hasTag, - Scope* parent) { - return new (structUnionTypePool.Alloc()) - StructType(&structUnionTypePool, isStruct, hasTag, parent); -} - - -int ArithmType::Width() const { - switch (tag_) { - case T_BOOL: case T_CHAR: case T_UNSIGNED | T_CHAR: - return 1; - case T_SHORT: case T_UNSIGNED | T_SHORT: - return intWidth_ >> 1; - case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: - return intWidth_; - case T_LONG: case T_UNSIGNED | T_LONG: - return intWidth_ << 1; - case T_LLONG: case T_UNSIGNED | T_LLONG: - return intWidth_ << 1; - case T_HALF: - return intWidth_ >> 1; - case T_FLOAT: - return intWidth_; - case T_DOUBLE: - return intWidth_ << 1; - case T_LONG | T_DOUBLE: - return intWidth_ << 1; - case T_HALF | T_COMPLEX: - return intWidth_; - case T_FLOAT | T_COMPLEX: - return intWidth_ << 1; - case T_DOUBLE | T_COMPLEX: - return intWidth_ << 2; - case T_LONG | T_DOUBLE | T_COMPLEX: - return intWidth_ << 2; - default: - assert(false); - } - - return intWidth_; // Make compiler happy -} - - -int ArithmType::Rank() const { - switch (tag_) { - case T_BOOL: return 0; - case T_CHAR: case T_UNSIGNED | T_CHAR: return 1; - case T_SHORT: case T_UNSIGNED | T_SHORT: return 2; - case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3; - case T_LONG: case T_UNSIGNED | T_LONG: return 4; - case T_LLONG: case T_UNSIGNED | T_LLONG: return 5; - case T_HALF: return 6; - case T_FLOAT: return 7; - case T_DOUBLE: return 8; - case T_LONG | T_DOUBLE: return 9; - default: - assert(tag_ & T_COMPLEX); - Error("complex not supported yet"); - } - return 0; -} - - -ArithmType* ArithmType::MaxType(ArithmType* lhs, - ArithmType* rhs) { - if (lhs->IsInteger()) - lhs = ArithmType::IntegerPromote(lhs); - if (rhs->IsInteger()) - rhs = ArithmType::IntegerPromote(rhs); - auto ret = lhs->Rank() > rhs->Rank() ? lhs: rhs; - if (lhs->Width() == rhs->Width() && (lhs->IsUnsigned() || rhs->IsUnsigned())) - return ArithmType::New(T_UNSIGNED | ret->Tag()); - return ret; -} - - -/* - * Converting from type specifier to type tag - */ -int ArithmType::Spec2Tag(int spec) { - if (spec == T_SIGNED) { - return T_INT; - } - spec &= ~T_SIGNED; - if ((spec & T_SHORT) || (spec & T_LONG) - || (spec & T_LLONG)) { - spec &= ~T_INT; - } - return spec; -} - - -std::string ArithmType::Str() const { - std::string width = ":" + std::to_string(Width()); - - switch (tag_) { - case T_BOOL: - return "bool" + width; - - case T_CHAR: - return "char" + width; - - case T_UNSIGNED | T_CHAR: - return "unsigned char" + width; - - case T_SHORT: - return "short" + width; - - case T_UNSIGNED | T_SHORT: - return "unsigned short" + width; - - case T_INT: - return "int" + width; - - case T_UNSIGNED: - return "unsigned int" + width; - - case T_LONG: - return "long" + width; - - case T_UNSIGNED | T_LONG: - return "unsigned long" + width; - - case T_LLONG: - return "long long" + width; - - case T_UNSIGNED | T_LLONG: - return "unsigned long long" + width; - - case T_FLOAT: - return "float" + width; - - case T_DOUBLE: - return "double" + width; - - case T_LONG | T_DOUBLE: - return "long double" + width; - - case T_FLOAT | T_COMPLEX: - return "float complex" + width; - - case T_DOUBLE | T_COMPLEX: - return "double complex" + width; - - case T_LONG | T_DOUBLE | T_COMPLEX: - return "long double complex" + width; - - default: - assert(false); - } - - return "error"; // Make compiler happy -} - - -bool PointerType::Compatible(const Type& other) const { - // C11 6.7.6.1 [2]: pointer compatibility - auto otherPointer = other.ToPointer(); - return otherPointer && - derived_->Compatible(*otherPointer->derived_); - - // FIXME(wgtdkp): cannot loose compatible constraints - //return other.IsInteger() || - // (otherPointer && derived_->Compatible(*otherPointer->derived_)); -} - - -bool ArrayType::Compatible(const Type& other) const { - // C11 6.7.6.2 [6]: For two array type to be compatible, - // the element types must be compatible, and have same length - // if both specified. - auto otherArray = other.ToArray(); - if (!otherArray) return false; - if (!derived_->Compatible(*otherArray->derived_)) return false; - // The lengths should equal if both specified - if (complete_ && otherArray->complete_) - return len_ == otherArray->len_; - return true; -} - -TileType::TileType(MemPool* pool, const ShapeInt& shape, QualType derived) - : DerivedType(pool, derived), - shape_(shape) { - bool isComplete = true; - for(int s: shape_) - isComplete = isComplete && (s>=0); - SetComplete(isComplete); -} - -bool TileType::Compatible(const Type& other) const { - // For two tile type to be compatible, - // the element types must be compatible - // and they must have the same shapea - auto otherTile = other.ToTile(); - if(!otherTile) - return false; - if (!derived_->Compatible(*otherTile->derived_)) - return false; - // The shapes should be equal if both specified - if(complete_ && otherTile->complete_) - return shape_ == otherTile->shape_; - return true; -} - - - -bool FuncType::Compatible(const Type& other) const { - auto otherFunc = other.ToFunc(); - // The other type is not an function type - if (!otherFunc) return false; - // TODO(wgtdkp): do we need to check the type of return value when deciding - // compatibility of two function types ?? - if (!derived_->Compatible(*otherFunc->derived_)) - return false; - if (params_.size() != otherFunc->params_.size()) - return false; - - auto thisIter = params_.begin(); - auto otherIter = otherFunc->params_.begin(); - while (thisIter != params_.end()) { - if (!(*thisIter)->Type()->Compatible(*(*otherIter)->Type())) - return false; - ++thisIter; - ++otherIter; - } - - return true; -} - - -std::string FuncType::Str() const { - auto str = derived_->Str() + "("; - auto iter = params_.begin(); - for (; iter != params_.end(); ++iter) { - str += (*iter)->Type()->Str() + ", "; - } - if (variadic_) - str += "..."; - else if (params_.size()) - str.resize(str.size() - 2); - - return str + ")"; -} - - -StructType::StructType(MemPool* pool, - bool isStruct, - bool hasTag, - Scope* parent) - : Type(pool, false), - isStruct_(isStruct), - hasTag_(hasTag), - memberMap_(new Scope(parent, S_BLOCK)), - offset_(0), - width_(0), - // If a struct type has no member, it gets alignment of 1 - align_(1), - bitFieldAlign_(1) {} - - -Object* StructType::GetMember(const std::string& member) { - auto ident = memberMap_->FindInCurScope(member); - if (ident == nullptr) - return nullptr; - return ident->ToObject(); -} - - -void StructType::CalcWidth() { - width_ = 0; - auto iter = memberMap_->identMap_.begin(); - for (; iter != memberMap_->identMap_.end(); ++iter) { - width_ += iter->second->Type()->Width(); - } -} - - -bool StructType::Compatible(const Type& other) const { - return this == &other; // Pointer comparison -} - - -// TODO(wgtdkp): more detailed representation -std::string StructType::Str() const { - std::string str = isStruct_ ? "struct": "union"; - return str + ":" + std::to_string(width_); -} - - -// Remove useless unnamed bitfield members as they are just for parsing -void StructType::Finalize() { - for (auto iter = members_.begin(); iter != members_.end();) { - if ((*iter)->BitFieldWidth() && (*iter)->Anonymous()) { - members_.erase(iter++); - } else { - ++iter; - } - } -} - - -void StructType::AddMember(Object* member) { - auto offset = MakeAlign(offset_, member->Align()); - member->SetOffset(offset); - - members_.push_back(member); - memberMap_->Insert(member->Name(), member); - - align_ = std::max(align_, member->Align()); - bitFieldAlign_ = std::max(bitFieldAlign_, align_); - - if (isStruct_) { - offset_ = offset + member->Type()->Width(); - width_ = MakeAlign(offset_, align_); - } else { - assert(offset_ == 0); - width_ = std::max(width_, member->Type()->Width()); - width_ = MakeAlign(width_, align_); - } -} - - -void StructType::AddBitField(Object* bitField, int offset) { - bitField->SetOffset(offset); - members_.push_back(bitField); - if (!bitField->Anonymous()) - memberMap_->Insert(bitField->Name(), bitField); - - auto bytes = MakeAlign(bitField->BitFieldEnd(), 8) / 8; - bitFieldAlign_ = std::max(bitFieldAlign_, bitField->Align()); - // Does not aligned, default is 1 - if (isStruct_) { - offset_ = offset + bytes; - width_ = MakeAlign(offset_, std::max(bitFieldAlign_, bitField->Align())); - } else { - assert(offset_ == 0); - width_ = std::max(width_, bitField->Type()->Width()); - } -} - - -// Move members of Anonymous struct/union to external struct/union -void StructType::MergeAnony(Object* anony) { - auto anonyType = anony->Type()->ToStruct(); - auto offset = MakeAlign(offset_, anony->Align()); - - // Members in map are never anonymous - for (auto& kv: *anonyType->memberMap_) { - auto& name = kv.first; - auto member = kv.second->ToObject(); - if (member == nullptr) { - continue; - } - // Every member of anonymous struct/union - // are offseted by external struct/union - member->SetOffset(offset + member->Offset()); - - if (GetMember(name)) { - Error(member, "duplicated member '%s'", name.c_str()); - } - // Simplify anony struct's member searching - memberMap_->Insert(name, member); - } - anony->SetOffset(offset); - members_.push_back(anony); - - align_ = std::max(align_, anony->Align()); - if (isStruct_) { - offset_ = offset + anonyType->Width(); - width_ = MakeAlign(offset_, align_); - } else { - assert(offset_ == 0); - width_ = std::max(width_, anonyType->Width()); - } -} diff --git a/lib/runtime/arg.cc b/lib/runtime/arg.cc deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc deleted file mode 100644 index 1087571a8..000000000 --- a/lib/runtime/function.cc +++ /dev/null @@ -1,364 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/codegen/transform/dce.h" -#include "triton/codegen/transform/peephole.h" -#include "triton/codegen/transform/membar.h" -#include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/transform/cts.h" -#include "triton/codegen/transform/disassociate.h" -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/transform/pipeline.h" -#include "triton/runtime/function.h" -#include "triton/lang/cpp.h" -#include "triton/lang/parser.h" -#include "triton/lang/code_gen.h" -#include "triton/driver/device.h" -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/driver/module.h" -#include "triton/driver/error.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/print.h" -#include "triton/runtime/error.h" -#include "triton/tools/bench.hpp" -#include "triton/tools/sha1.hpp" -#include "triton/tools/sys/getenv.hpp" -#include "triton/tools/sys/mkdir.hpp" -#include "llvm/IR/Module.h" -#include -#include - - -namespace triton{ -namespace runtime { - -/* --------------------------------- */ -/* --------------------------------- */ -/* --------------------------------- */ - -std::shared_ptr kernel::src_to_ir(const std::string& _src, const options_t& opt) { - std::string src = -R"( -#define bool _Bool -#define true 1 -#define false 0 - -#define __readonly __attribute__((readonly)) -#define __writeonly __attribute__((writeonly)) -#define __noalias __attribute__((noalias)) -#define __aligned(A) __attribute__((aligned(A))) -#define __multipleof(A) __attribute__((multipleof(A))) -#define __retune __attribute__((retune)) - -#define F32_INFINITY bitcast(0x7F800000) -#define F16_INFINITY bitcast((int16)0x7C00) - -#define min(a,b) (((a)<(b))?(a):(b)) -#define max(a,b) (((a)>(b))?(a):(b)) - -#define PASTER(a, b, _) a ## _ ## b -#define EVALUATOR(a, b, _) PASTER(a, b, _) -#define atomic_add(TYPE, TM, TN) EVALUATOR(atomic_add, EVALUATOR(TYPE, EVALUATOR(TM, TN, x), _), _) -#define DECLARATION(TYPE, TM, TN) extern void atomic_add(TYPE, TM, TN)(TYPE*[TM, TN], TYPE[TM, TN], bool[TM, TN]) - -DECLARATION(float, 64, 64); -DECLARATION(float, 64, 128); -DECLARATION(float, 128, 64); -DECLARATION(float, 128, 128); -extern void atomic_add_half_1x1(half*, half, bool); - -DECLARATION(half , 64, 64); -DECLARATION(half , 64, 128); -DECLARATION(half , 128, 64); -DECLARATION(half , 128, 128); -extern void atomic_add_float_1x1(float*, float, bool); - -extern int atomic_cas(int*, int, int); -extern int atomic_xchg(int*, int); -extern int get_program_id(int); -extern void __debug_barrier(); -extern int get_num_programs(int); -extern int select(bool, int, int); -extern char __constant__ * calloc(int); - -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint32; -typedef unsigned long uint64; -typedef char int8; -typedef short int16; -typedef int int32; -typedef long int64; -)"; - src += _src; - // pre-process - TokenSequence tokens; - Preprocessor cpp(&src, true); - for(auto it: opt.defines) - cpp.AddMacro(it.first, &it.second); - cpp.Process(tokens); - // src -> ast - Parser parser(tokens); - parser.Parse(); - // ast -> triton-ir - auto ret = std::make_shared(""); - Generator gen(&parser); - gen.Gen(&*ret); - return ret; -} - -std::tuple, - std::shared_ptr, - size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) { - // generate llvm code - llvm::LLVMContext ctx; - std::string name = ir.get_function_list()[0]->get_name(); - std::unique_ptr llvm(new llvm::Module(name, ctx)); - // optimizations - std::unique_ptr target = dev->make_target(); - bool cts_use_async = target->as_nvidia()->sm() >= 80; - // create passes - codegen::analysis::align align; - codegen::analysis::axes axes; - codegen::transform::cts cts(cts_use_async); - codegen::transform::pipeline pipeline(cts_use_async); - codegen::transform::disassociate disassociate; - codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); - codegen::analysis::liveness liveness(&layouts); - codegen::analysis::swizzle swizzle(&layouts, target.get()); - codegen::analysis::allocation allocation(&liveness); - codegen::transform::membar barriers(&liveness, &layouts, &allocation); - codegen::transform::dce dce; - codegen::transform::peephole peephole(target.get(), &layouts); - codegen::transform::reassociate reassociate; - codegen::transform::coalesce coalesce(&align, &layouts); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); - // run passes - dce.run(ir); - peephole.run(ir); - dce.run(ir); - pipeline.run(ir); - dce.run(ir); - disassociate.run(ir); - dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); - peephole.run(ir); - dce.run(ir); -// ir::print(ir, std::cout); - if(target->is_gpu()) - cts.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); - coalesce.run(ir); - dce.run(ir); - align.run(ir); - dce.run(ir); - if(target->is_gpu()){ - reassociate.run(ir); - cts.run(ir); - } - dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); - peephole.run(ir); - dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); - swizzle.run(ir); - liveness.run(ir); - allocation.run(ir); - barriers.run(ir); - isel.visit(ir, *llvm); - std::shared_ptr mod(driver::module::create(dev, std::move(llvm))); - std::shared_ptr ker(driver::kernel::create(&*mod, name.c_str())); - size_t shared_mem = allocation.allocated_size(); - return std::make_tuple(mod, ker, shared_mem); -} - -kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map &attrs): - opt(opt), dev_(dev) { - // compile to Triton IR - ir_ = src_to_ir(src, opt); - // add attributes - for(const auto&x: attrs) - ir_->get_function_list()[0]->add_attr(x.first, x.second); - // compile to binary - std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt); -} - -void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector& _grid) const{ - // set grid - if(_grid.size() > 3) - throw std::runtime_error("grid size must be no greater than 3"); - std::array grid; - for(size_t i = 0; i < 3; i++) - grid[i] = (i < _grid.size()) ? _grid[i] : 1; - // enqueue - stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_); -} - -std::string kernel::get_asm(const std::string& mode) { - std::vector modes = {"llir", "ptx"}; - if(std::find(modes.begin(), modes.end(), mode) == modes.end()){ - std::string err = "Unrecognized mode. Supported values are: "; - for(std::string m: modes){ - if(m != modes[0]) - err += ", "; - err += m; - } - throw std::runtime_error(err); - } - if(mode == "llir") - return ((driver::cu_module*)mod_.get())->llir(); - if(mode == "ptx") - return ((driver::cu_module*)mod_.get())->ptx(); - assert(false); - return ""; -} -/* --------------------------------- */ -/* --------------------------------- */ -/* --------------------------------- */ - - - - -function::function(const std::string& src, const options_t &opt, driver::device *device, - const std::vector &tune_confs, const std::vector& tune_key) - : src_(src), device_(device) { - // kernel options - size_t num_opts = std::max(tune_confs.size(), (size_t)1); - opts_ = std::vector(num_opts, opt); - for(size_t i = 0; i < tune_confs.size(); i++){ - opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end()); - opts_[i].num_warps = tune_confs[i].num_warps; - } - std::shared_ptr ir = kernel::src_to_ir(src, opts_[0]); - std::vector args = ir->get_function_list()[0]->args(); - // signature - auto convert = [](ir::type *ty) { - if(ty->is_integer_ty(1)) return INT1_T; - if(ty->is_integer_ty(8)) return INT8_T; - if(ty->is_integer_ty(16)) return INT16_T; - if(ty->is_integer_ty(32)) return INT32_T; - if(ty->is_integer_ty(64)) return INT64_T; - if(ty->is_half_ty()) return HALF_T; - if(ty->is_float_ty()) return FLOAT_T; - if(ty->is_double_ty()) return DOUBLE_T; - if(ty->is_pointer_ty()) return BUFFER_T; - throw std::runtime_error("unknown type"); - }; - for(ir::argument* arg: args) - sig_.push_back(convert(arg->get_type())); - // find indices of autotune keys - for(const std::string& name: tune_key){ - auto pred = [&](ir::argument* arg) { return arg->get_name() == name; }; -// std::cout << "----" << std::endl; -// for(ir::argument* arg: args) -// std::cout << arg->get_name() << std::endl; - auto it = std::find_if(args.begin(), args.end(), pred); - if(it == args.end()) - throw std::runtime_error(name + " is not a valid argument name"); - key_idxs_.push_back(std::distance(args.begin(), it)); - } - // find indices of pointer - for(size_t i = 0; i < args.size(); i++) - if(args[i]->get_type()->is_pointer_ty() || - args[i]->get_type()->is_integer_ty()) - align_idxs_.push_back(i); - // argument size and offset - size_t curr = 0; - for(arg_type ty: sig_){ - arg_size_.push_back(size_of(ty)); - arg_off_.push_back(curr); - curr += arg_size_.back(); - } -} - -uint64_t pow2_divisor(uint64_t N){ - if(N % 16 == 0) return 16; - if(N % 8 == 0) return 8; - if(N % 4 == 0) return 4; - if(N % 2 == 0) return 2; - return 1; -} - -kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) { - // align key - std::vector rt_key(align_idxs_.size(), 0); - for(size_t i = 0; i < align_idxs_.size(); i++){ - int idx = align_idxs_[i]; - uint64_t tmp = 0; - std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); - rt_key[i] = pow2_divisor(tmp); - } - // auto-tuning key - std::vector at_key(key_idxs_.size(), 0); - for(size_t i = 0; i < at_key.size(); i++){ - int idx = key_idxs_[i]; - std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); - } - // cache key - std::vector cache_key; - cache_key.reserve(rt_key.size() + at_key.size()); - cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end()); - cache_key.insert(cache_key.end(), at_key.begin(), at_key.end()); - auto it = cache_.find(cache_key); - if(it != cache_.end()) - return it->second; - // compile kernels - if(kernels_.find(rt_key) == kernels_.end()){ - std::map attrs; - for(size_t i = 0; i < align_idxs_.size(); i++){ - bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T; - attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])}); - } - for(const options_t& opt: opts_) - kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs)); - } - // run auto-tuner - double best_ts = INFINITY; - auto& kernels = kernels_.at(rt_key); - kernel* ret = nullptr; - if(kernels.size() == 1) - ret = &*kernels.back(); - else{ - for(auto ¤t : kernels_.at(rt_key)){ - auto grid = grid_fn(current->opt); - while(grid.size() < 3) - grid.push_back(1); - double ts = tools::bench([&]() { (*current)(args, stream, grid); }, - stream, 5, 20); - ret = (ts < best_ts) ? &*current : ret; - best_ts = std::min(ts, best_ts); - } - stream->synchronize(); - } - it = cache_.insert({cache_key, ret}).first; - return it->second; -} - -void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) { - runtime::kernel* fn = autotune(args, grid_fn, stream); - (*fn)(args, stream, grid_fn(fn->opt)); -} - - -} -} diff --git a/python/setup.py b/python/setup.py index 192a74485..c2a35b32c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -49,11 +49,11 @@ class CMakeBuild(build_ext): self.build_extension(ext) def build_extension(self, ext): - # self.debug = True - self.debug = False + #self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories - llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm") + build_suffix = 'debug' if self.debug else 'release' + llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}") if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) if not os.path.exists(llvm_build_dir): diff --git a/python/src/functions.h b/python/src/functions.h new file mode 100644 index 000000000..4d426890e --- /dev/null +++ b/python/src/functions.h @@ -0,0 +1,676 @@ +#include "triton/ir/builder.h" +#include +#include +#include + +namespace ir = triton::ir; +namespace py = pybind11; + +static const std::string _builder_doc = R"pbdoc( + :param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function + :type builder: triton.ir.builder +)pbdoc"; + +#define VA_ARGS(...) , ##__VA_ARGS__ +#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \ + MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \ + ret::reference VA_ARGS(__VA_ARGS__), "builder"_a) + +void throw_not_implemented(std::string key) { + throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side."); +} + +void throw_not_int_or_float(std::string key) { + throw std::runtime_error("`" + key + "` only supported for integer and floating point types."); +} + +enum type_code { + _bool, + int8, + int16, + int32, + int64, + float16, + float32, + float64 +}; + +ir::type *make_ir(type_code ty, ir::builder *builder) { + switch (ty) { + case float16: + return builder->get_half_ty(); + case float32: + return builder->get_float_ty(); + default: + throw_not_implemented("make_ir"); + } +} + +type_code from_ir(ir::type *ty) { + if (ty->is_half_ty()) + return float16; + if (ty->is_float_ty()) + return float32; + throw_not_implemented("from_ir"); +} + +/*---------------------------------------------- + definition of triton.cast / triton.ir.value.to + ----------------------------------------------*/ +std::string cast_docstr = R"pbdoc( + Tries to cast a block to a new data type. + + :param input: The input block. + :type input: triton.ir.value +)pbdoc"; + +ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) { + ir::type *src_ty = input->get_type(); + ir::type *dst_ty = make_ir(_dtype, builder); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // FP Truncation + bool truncate_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); + if (truncate_fp) + return builder->create_fp_trunc(input, dst_ty); + // FP Extension + bool ext_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); + if (ext_fp) + return builder->create_fp_ext(input, dst_ty); + // Int cast + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && + src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) + return builder->create_int_cast(input, dst_ty, true); + // Float -> Int + if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()) + return builder->create_fp_to_si(input, dst_ty); + // int -> Float + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()) + return builder->create_si_to_fp(input, dst_ty); + // Ptr -> Ptr + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::BitCast, input, dst_ty); + // * -> Bool + if (dst_sca_ty->is_bool_ty()) { + if (src_sca_ty->is_pointer_ty()) + input = cast(input, int64, builder); + ir::value *other = builder->get_int64(0); + if (src_ty->is_bool_ty()) + other = builder->create_splat(other, src_ty->get_block_shapes()); + return builder->create_icmpNE(input, other); + } + throw_not_implemented("cast"); +} + +/*---------------------------------------------- + definition of triton.broadcast_check + ----------------------------------------------*/ +std::string try_broadcast_docstr = R"pbdoc( + Tries to broadcast two blocks to a common compatible shape. + + :param input: The first input block. + :type input: triton.ir.value + :param other: The second input block. + :type other: triton.ir.value +)pbdoc"; + +std::tuple try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + // make_shape_compatible(block, scalar) + if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) + rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); + // make_shape_compatible(scalar, block) + else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) + lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); + // make_shape_compatible(block, block) + else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { + auto lhs_shape = lhs_ty->get_block_shapes(); + auto rhs_shape = rhs_ty->get_block_shapes(); + if (lhs_shape.size() != rhs_shape.size()) + throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); + ir::type::block_shapes_t ret_shape; + for (size_t i = 0; i < lhs_shape.size(); ++i) { + unsigned left = lhs_shape[i]; + unsigned right = rhs_shape[i]; + if (left == 1) + ret_shape.push_back(right); + else if (right == 1) + ret_shape.push_back(left); + else if (left == right) + ret_shape.push_back(left); + else + throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + + ": " + std::to_string(left) + " and " + std::to_string(right)); + } + if (lhs_shape != ret_shape) + lhs = builder->create_broadcast(lhs, ret_shape); + if (rhs_shape != ret_shape) + rhs = builder->create_broadcast(rhs, ret_shape); + } + return std::make_tuple(lhs, rhs); +} + +/*---------------------------------------------- + definition of triton.broadcast_to + ----------------------------------------------*/ +std::string broadcast_to_docstr = R"pbdoc( + Tries to broadcast a block to a new shape. + + :param input: The input block. + :type input: triton.value + :param shape: The new shape. + :type shape: tuple of int +)pbdoc"; + +ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) { + if (!input->get_type()->is_block_ty()) + return builder->create_splat(input, shape); + auto src_shape = input->get_type()->get_block_shapes(); + if (src_shape.size() != shape.size()) + throw std::runtime_error("Cannot broadcast"); + return builder->create_broadcast(input, shape); +} + +/*---------------------------------------------- + definition of triton.load + ----------------------------------------------*/ +std::string load_docstr = R"pbdoc( + Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`. + + :param pointer: Pointer to the data to be loaded. + :type pointer: Block of triton.pointer + :param mask: if mask[idx] is false, do not load the data at `pointer[idx]`. + :type mask: Block of triton.bool, optional + :param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]` + :type other: Block of triton.value, optional + )pbdoc"; + +ir::value *load(ir::value *pointer, std::optional _mask, std::optional _other, ir::builder *builder) { + if (!_mask.has_value() && !_other.has_value()) + return builder->create_load(pointer); + if (!_mask.has_value()) + throw std::runtime_error("`other` cannot be provided without `mask`"); + ir::value *mask = _mask.value(); + ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty(); + auto shape = pointer->get_type()->get_block_shapes(); + ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty); + other = cast(other, from_ir(elt_ty), builder); + other = broadcast_to(other, shape, builder); + mask = broadcast_to(mask, shape, builder); + return builder->create_masked_load(pointer, mask, other); +} + +/*---------------------------------------------- + definition of triton.store + ----------------------------------------------*/ +std::string store_docstr = R"pbdoc( + Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`. + + :param pointer: The memory locations where the elements of `value` are stored. + :type pointer: Block of triton.pointer + :param value: The block of elements to be stored. + :type value: Block of triton.value + :param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`. + :type mask: Block of triton.bool, optional + )pbdoc"; +ir::value *store(ir::value *ptr, ir::value *val, std::optional _mask, ir::builder *builder) { + if (!_mask.has_value()) + return builder->create_store(ptr, val); + ir::value *mask = _mask.value(); + return builder->create_masked_store(ptr, val, mask); +} + +/*---------------------------------------------- + definition of triton.dot + ----------------------------------------------*/ +std::string dot_docstr = R"pbdoc( + Returns the matrix product of two blocks. + The two blocks must be two dimensionals and have compatible inner dimensions. + + :param input: The first block to be multiplied. + :type input: 2D block of scalar-type in {`float16`, `float32`} + :param other: The second block to be multiplied. + :type other: 2D block of scalar-type in {`float16`, `float32`} + )pbdoc"; +ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + ir::value *_0 = builder->get_float32(0); + unsigned M = lhs->get_type()->get_block_shapes()[0]; + unsigned N = rhs->get_type()->get_block_shapes()[1]; + _0 = builder->create_splat(_0, {M, N}); + return builder->create_dot(lhs, rhs, _0); +} + +/*---------------------------------------------- + definition of triton.where + ----------------------------------------------*/ +std::string where_docstr = R"pbdoc( + Returns a block of elements from either `x` or `y`, depending on `condition`. + Note that `x` and `y` are always evaluated regardless of the value of `condition`. + If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + )pbdoc"; +ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) { + return builder->create_select(condition, x, y); +}; + +/*---------------------------------------------- + definition of triton.arange + ----------------------------------------------*/ +std::string arange_docstr = R"pbdoc( + Returns contiguous values within the open interval [start, end). + + :param start: Start of the interval. + :type start: int + :param stop: End of the interval. + :type stop: int + )pbdoc"; +ir::value *arange(int start, int end, ir::builder *builder) { + return builder->get_range(start, end); +}; + +/*---------------------------------------------- + definition of triton.program_id + ----------------------------------------------*/ +std::string program_id_docstr = R"pbdoc( + Returns the id of the current program instance along the given `axis`. + Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + )pbdoc"; +ir::value *program_id(int axis, ir::builder *builder) { + return builder->create_get_program_id(axis); +}; + +/*---------------------------------------------- + definition of triton.num_programs + ----------------------------------------------*/ +std::string num_programs_docstr = R"pbdoc( + Returns the number of program instances launched along the given `axis`. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + )pbdoc"; +ir::value *num_programs(int axis, ir::builder *builder) { + return builder->create_get_num_programs(axis); +}; + +/*---------------------------------------------- + definition of triton.zeros + ----------------------------------------------*/ +std::string zeros_docstr = R"pbdoc( + Returns a block filled with the scalar value 0 and the given shape. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., triton.float16 + :type dtype: triton.ir.dtype + )pbdoc"; +ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) { + ir::type *dtype = make_ir(_dtype, builder); + ir::value *_0 = ir::constant::get_null_value(dtype); + return builder->create_splat(_0, shape); +}; + +/*---------------------------------------------- + definition of triton.exp + ----------------------------------------------*/ +std::string _exp_docstr = R"pbdoc( + Returns the element-wise exponential of `input`. + )pbdoc"; +ir::value *_exp(ir::value *input, ir::builder *builder) { + return builder->create_exp(input); +}; + +/*---------------------------------------------- + definition of triton.log + ----------------------------------------------*/ +std::string _log_docstr = R"pbdoc( + Returns the element-wise natural logarithm of `input`. + )pbdoc"; +ir::value *_log(ir::value *input, ir::builder *builder) { + return builder->create_log(input); +}; + +/*---------------------------------------------- + definition of triton.sqrt + ----------------------------------------------*/ +std::string sqrt_docstr = R"pbdoc( + Returns the element-wise square root of `input`. + )pbdoc"; +ir::value *sqrt(ir::value *input, ir::builder *builder) { + return builder->create_sqrt(input); +}; + +/*---------------------------------------------- + definition of triton.min + ----------------------------------------------*/ +ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, + ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (scalar_ty->is_floating_point_ty()) + return builder->create_reduce(input, FLOAT_OP, axis); + else if (scalar_ty->is_integer_ty()) + return builder->create_reduce(input, INT_OP, axis); + else + throw_not_int_or_float(name); +} + +std::string min_docstr = R"pbdoc( + Returns the minimum value of `input`. + )pbdoc"; +ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); +}; + +/*---------------------------------------------- + definition of triton.max + ----------------------------------------------*/ +std::string max_docstr = R"pbdoc( + Returns the maximum value of `input`. + )pbdoc"; +ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); +}; + +/*---------------------------------------------- + definition of triton.sum + ----------------------------------------------*/ +std::string sum_docstr = R"pbdoc( + Returns the sum of `input`. + )pbdoc"; +ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); +}; + +/*---------------------------------------------- + definition of triton.atomic_cas + ----------------------------------------------*/ +std::string atomic_cas_docstr = R"pbdoc( + Atomic compare-and-swap. + )pbdoc"; +ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) { + return builder->create_atomic_cas(ptr, cmp, val); +}; + +/*---------------------------------------------- + definition of triton.atomic_xchg + ----------------------------------------------*/ +std::string atomic_xchg_docstr = R"pbdoc( + Atomic exchange. + )pbdoc"; +ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) { + return builder->create_atomic_exch(ptr, val); +}; + +/*---------------------------------------------- + debug barrier + ----------------------------------------------*/ +std::string debug_barrier_docstr = R"pbdoc( + Temporary hacky fixup for when the compiler forgets to insert sync barriers +)pbdoc"; +ir::value *debug_barrier(ir::builder *builder) { + return builder->create_barrier(); +} + +#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \ + MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \ + ret::reference VA_ARGS(__VA_ARGS__), "builder"_a) + +template +std::function +binary_op(const FN &fn) { + auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) { + //std::tie(self, other) = try_broadcast(self, other, builder); + return fn(self, other, builder); + }; + return ret; +} + +/*---------------------------------------------- + definition of self + other + ----------------------------------------------*/ +std::string add_docstr = R"pbdoc( + Returns self + other, element-wise. +)pbdoc"; +ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // ptr + offset + if (scalar_ty->is_pointer_ty()) + return builder->create_gep(self, {other}); + // float + float + else if (scalar_ty->is_floating_point_ty()) + return builder->create_fadd(self, other); + // int + int + else if (scalar_ty->is_integer_ty()) + return builder->create_add(self, other); + throw_not_implemented("add"); +} + +/*---------------------------------------------- + definition of self - other + ----------------------------------------------*/ +std::string sub_docstr = R"pbdoc( + Returns self - other, element-wise. +)pbdoc"; +ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // ptr + offset + if (scalar_ty->is_pointer_ty()) + return builder->create_gep(self, {other}); + // float + float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fsub(self, other); + // int + int + else if (scalar_ty->is_integer_ty()) + return builder->create_sub(self, other); + throw_not_implemented("sub"); +} + +/*---------------------------------------------- + definition of self * other + ----------------------------------------------*/ +std::string mul_docstr = R"pbdoc( + Returns self * other, element-wise. +)pbdoc"; +ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float * float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fmul(self, other); + // int * int + else if (scalar_ty->is_integer_ty()) + return builder->create_mul(self, other); + throw_not_implemented("mul"); +} + +/*---------------------------------------------- + definition of self > other + ----------------------------------------------*/ +std::string greater_than_docstr = R"pbdoc( + Returns self > other, element-wise. +)pbdoc"; +ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float > float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGT(self, other); + // int > int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSGT(self, other); + throw_not_implemented("greater_than"); +} + +/*---------------------------------------------- + definition of self >= other + ----------------------------------------------*/ +std::string greater_equal_docstr = R"pbdoc( + Returns self >= other, element-wise. +)pbdoc"; +ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float >= float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGE(self, other); + // int >= int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSGE(self, other); + throw_not_implemented("greater_equal"); +} + +/*---------------------------------------------- + definition of self < other + ----------------------------------------------*/ +std::string less_than_docstr = R"pbdoc( + Returns self < other, element-wise. +)pbdoc"; +ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLT(self, other); + // int < int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSLT(self, other); + throw_not_implemented("less_than"); +} + +/*---------------------------------------------- + definition of self <= other + ----------------------------------------------*/ +std::string less_equal_docstr = R"pbdoc( + Returns self <= other, element-wise. +)pbdoc"; +ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLE(self, other); + // int < int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpSLE(self, other); + throw_not_implemented("less_equal"); +} + +/*---------------------------------------------- + definition of self == other + ----------------------------------------------*/ +std::string equal_docstr = R"pbdoc( + Returns self == other, element-wise. +)pbdoc"; +ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOEQ(self, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpEQ(self, other); + throw_not_implemented("equal"); +} + +/*---------------------------------------------- + definition of self / other + ----------------------------------------------*/ +std::string _div_docstr = R"pbdoc( + Returns self / other, element-wise. +)pbdoc"; +ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float / float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fdiv(self, other); + // int / int + else if (scalar_ty->is_integer_ty()) + return builder->create_sdiv(self, other); + throw_not_implemented("div"); +} + +/*---------------------------------------------- + definition of self % other + ----------------------------------------------*/ +std::string mod_docstr = R"pbdoc( + Returns self % other, element-wise. +)pbdoc"; +ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) { + ir::type *scalar_ty = self->get_type()->get_scalar_ty(); + // float % int + if (scalar_ty->is_floating_point_ty()) + return builder->create_frem(self, other); + // int % int + else if (scalar_ty->is_integer_ty()) + return builder->create_srem(self, other); + throw_not_implemented("mod"); +} + +/*---------------------------------------------- + definition of self & other + ----------------------------------------------*/ +std::string _and_docstr = R"pbdoc( + Returns self & other, element-wise. +)pbdoc"; +ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) { + return builder->create_and(self, other); +} + +/*---------------------------------------------- + definition of minimum(self, other) + ----------------------------------------------*/ +std::string minimum_docstr = R"pbdoc( + Returns element-wise minimum of self and other +)pbdoc"; +ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) { + return where(less_than(self, other, builder), self, other, builder); +} + +/*---------------------------------------------- + definition of self[slices] + ----------------------------------------------*/ + +enum slice_mode_t { + NEWAXIS, + ALL +}; + +std::string subscript_docstr = R"pbdoc( + returns self[slices]. + + :param slices: The slices to subscript with. + :type slices: List of `None` or `:` slices. +)pbdoc"; +ir::value *subscript(ir::value *self, std::vector slices, ir::builder *builder) { + std::vector modes; + for (py::object slice : slices) { + py::object none = py::none(); + py::object all = py::make_tuple(none, none, none); + if (slice.is(none)) + modes.push_back(NEWAXIS); + else if (all.attr("__eq__")(slice)) + modes.push_back(ALL); + else + throw std::runtime_error("slice must be None or (None, None, None)"); + } + + ir::type::block_shapes_t shape; + size_t curr = 0; + for (slice_mode_t mode : modes) { + if (mode == NEWAXIS) + shape.push_back(1); + else { + assert(mode == ALL); + shape.push_back(self->get_type()->get_block_shapes()[curr++]); + } + } + return builder->create_reshape(self, shape); +} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index 3d8a06e44..52468a0bf 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,5 +1,13 @@ -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" +#include "triton/codegen/pass.h" +#include "triton/driver/kernel.h" +#include "triton/driver/module.h" +#include "triton/driver/stream.h" +#include "triton/ir/builder.h" +#include "triton/ir/dispatch.h" +#include "triton/ir/enums.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" +#include #include #include #include @@ -8,78 +16,9 @@ #include namespace py = pybind11; - -using namespace triton; -namespace rt = triton::runtime; +namespace ir = triton::ir; namespace drv = triton::driver; -/*****************************************************************************/ -/* Python bindings for triton::tools */ -/*****************************************************************************/ - -/*! - @brief Function for extracting kernels out of a given source-string - - This can be important to enable pre-processor macros (or tunable parameters) that should only - be defined within the scope of a single kernel function -*/ -std::string extract_kernels(const std::string &str, const std::vector &names) { - if (names.empty()) - return str; - // search for all regex matches of kernel_regex in str - std::smatch matches; - std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); - std::sregex_iterator it(str.begin(), str.end(), regex); - std::sregex_iterator end; - std::vector> kernels; - for (; it != end; ++it) { - int pos = it->position(); - int len = it->length(); - std::string name = it->str(1); - kernels.push_back(std::make_tuple(name, pos, len)); - } - // check that all the kernels provided actually exist - for (const std::string &name : names) { - auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; - bool found = std::any_of(kernels.begin(), kernels.end(), pred); - if (!found) - throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); - } - // simple parsing logic to extract the declaration and body of each specified kernel - std::string ret; - for (const auto &k : kernels) { - std::string name; - int pos, len; - std::tie(name, pos, len) = k; - if (std::find(names.begin(), names.end(), name) == names.end()) - continue; - std::string def = str.substr(pos, str.size() - pos); - // skip over declaration - // by finding matching ')' for first '(' - int count = 1; - pos = def.find('('); - while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { - count += def[pos] == '('; - count -= def[pos] == ')'; - } - // skip over definition - // by finding matching '{' for first '}' - count = 1; - pos = def.find('{', pos); - while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { - count += def[pos] == '{'; - count -= def[pos] == '}'; - } - ret += def.substr(0, pos); - ret += '\n'; - } - return ret; -} - -void init_triton_tools(py::module &&m) { - m.def("extract_kernels", &extract_kernels); -} - /*****************************************************************************/ /* Python bindings for triton::driver */ /*****************************************************************************/ @@ -88,14 +27,14 @@ void init_triton_driver(py::module &&m) { // base device py::class_(m, "device"); // cuda device - py::class_(m, "cu_device") + py::class_(m, "cu_device") .def(py::init([](int dev_id, bool take_ownership) { CUdevice handle; drv::dispatch::cuDeviceGet(&handle, dev_id); return new drv::cu_device(handle, take_ownership); })); // host device - py::class_(m, "host_device") + py::class_(m, "host_device") .def(py::init<>()); // base stream @@ -108,54 +47,236 @@ void init_triton_driver(py::module &&m) { // py doesn't support opaque pointer (e.g., CUstream) so // we assume it has been converted to uint64_t .def(py::init([](uint64_t handle, bool take_ownership) { - return std::unique_ptr(new driver::cu_stream((CUstream)handle, take_ownership)); - })); + return std::unique_ptr(new drv::cu_stream((CUstream)handle, take_ownership)); + })) + .def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel, + size_t grid_0, size_t grid_1, size_t grid_2, + size_t block_0, size_t block_1, size_t block_2, + const std::string &args, + size_t shared_mem) { + return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2}, + (void *)args.data(), args.size(), shared_mem); + }); + + py::class_(m, "module"); + //py::class_(m, "cu_module"); + + py::class_(m, "kernel"); } /*****************************************************************************/ -/* Python bindings for triton::runtime */ +/* Python bindings for triton::codegen */ /*****************************************************************************/ -void init_triton_runtime(py::module &&m) { - // argument type - py::enum_(m, "arg_type") - .value("int1", rt::INT1_T) - .value("int8", rt::INT8_T) - .value("int16", rt::INT16_T) - .value("int32", rt::INT32_T) - .value("int64", rt::INT64_T) - .value("half", rt::HALF_T) - .value("float", rt::FLOAT_T) - .value("double", rt::DOUBLE_T) - .value("buffer", rt::BUFFER_T); - // compilation options - py::class_(m, "options", py::dynamic_attr()) - .def(py::init<>()) - .def_readwrite("defines", &rt::options_t::defines) - .def_readwrite("num_warps", &rt::options_t::num_warps) - .def("__getattr__", [](rt::options_t *opt, const std::string &name) { - return opt->D(name); - }); - // kernel - py::class_(m, "kernel") - .def("__call__", &rt::kernel::operator()) - .def_readonly("opt", &rt::kernel::opt) - .def("asm", &rt::kernel::get_asm); - // tune conf - py::class_(m, "config") - .def(py::init, int>(), - py::arg("defines") = std::map(), - py::arg("num_warps")); - // function - py::class_(m, "function") - .def(py::init &, const std::vector &>()) - .def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal) - .def("signature", &rt::function::get_signature); +void init_triton_codegen(py::module &&m) { + m.def( + "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) { + drv::module *mod; + drv::kernel *ker; + size_t shared_mem; + triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem); + return std::make_tuple(mod, ker, shared_mem); + }, + py::return_value_policy::take_ownership); +} + +/*****************************************************************************/ +/* User-facing language features */ +/*****************************************************************************/ + +void init_triton_frontend(py::module &&m) { + using ret = py::return_value_policy; + + // programming model + m.def("program_id", &ir::dispatch::program_id, ret::reference); + m.def("num_programs", &ir::dispatch::num_programs, ret::reference); + // binary + m.def("add", &ir::dispatch::add, ret::reference); + m.def("sub", &ir::dispatch::sub, ret::reference); + m.def("mul", &ir::dispatch::mul, ret::reference); + m.def("truediv", &ir::dispatch::truediv, ret::reference); + m.def("floordiv", &ir::dispatch::floordiv, ret::reference); + m.def("mod", &ir::dispatch::mod, ret::reference); + m.def("and_", &ir::dispatch::and_, ret::reference); + m.def("or_", &ir::dispatch::or_, ret::reference); + m.def("xor_", &ir::dispatch::xor_, ret::reference); + m.def("lshr", &ir::dispatch::lshr, ret::reference); + m.def("shl", &ir::dispatch::shl, ret::reference); + // unary + m.def("plus", &ir::dispatch::plus, ret::reference); + m.def("minus", &ir::dispatch::minus, ret::reference); + m.def("invert", &ir::dispatch::invert, ret::reference); + // comparison + m.def("greater_than", &ir::dispatch::greater_than, ret::reference); + m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); + m.def("less_than", &ir::dispatch::less_than, ret::reference); + m.def("less_equal", &ir::dispatch::less_equal, ret::reference); + m.def("equal", &ir::dispatch::equal, ret::reference); + m.def("not_equal", &ir::dispatch::not_equal, ret::reference); + // block creation + m.def("arange", &ir::dispatch::arange, ret::reference); + m.def("zeros", &ir::dispatch::zeros, ret::reference); + // type manipuatation + m.def("reshape", &ir::dispatch::reshape, ret::reference); + typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); + typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); + m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("cast", &ir::dispatch::cast, ret::reference); + // memory + m.def("load", &ir::dispatch::load, ret::reference); + m.def("store", &ir::dispatch::store, ret::reference); + m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); + m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); + // linear algebra + m.def("dot", &ir::dispatch::dot, ret::reference); + // indexing + m.def("where", &ir::dispatch::where, ret::reference); + // reduction + m.def("min", &ir::dispatch::min, ret::reference); + m.def("max", &ir::dispatch::max, ret::reference); + m.def("sum", &ir::dispatch::sum, ret::reference); + // math + m.def("exp", &ir::dispatch::exp, ret::reference); + m.def("log", &ir::dispatch::log, ret::reference); + m.def("sqrt", &ir::dispatch::sqrt, ret::reference); + // internal (debugging only) + m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); + m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); +} + +/*****************************************************************************/ +/* Python bindings for triton::ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::class_(m, "context") + .def(py::init<>()); + + auto value = py::class_(m, "value"); + value.def_property("name", &ir::value::get_name, &ir::value::set_name); + value.def_property_readonly("type", &ir::value::get_type); + + py::class_(m, "user"); + + py::class_(m, "constant"); + + py::class_(m, "undef") + .def("get", &ir::undef_value::get, ret::reference); + + py::class_(m, "constant_int") + .def_property_readonly("value", &ir::constant_int::get_value) + .def("__int__", [](ir::constant_int *self) { return self->get_value(); }); + + py::class_(m, "constant_float") + .def_property_readonly("value", &ir::constant_fp::get_value); + + py::class_(m, "type") + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) + .def("make_ptr", &ir::pointer_type::get, ret::reference) + .def("make_function", &ir::function_type::get, ret::reference) + .def("make_block", &ir::block_type::get, ret::reference) + .def("get_void", &ir::type::get_void_ty, ret::reference) + .def("get_fp16", &ir::type::get_half_ty, ret::reference) + .def("get_fp32", &ir::type::get_float_ty, ret::reference) + .def("get_fp64", &ir::type::get_double_ty, ret::reference) + .def("get_int1", &ir::type::get_int1_ty, ret::reference) + .def("get_int8", &ir::type::get_int8_ty, ret::reference) + .def("get_int16", &ir::type::get_int16_ty, ret::reference) + .def("get_int32", &ir::type::get_int32_ty, ret::reference) + .def("get_int64", &ir::type::get_int64_ty, ret::reference) + + .def("is_void", &ir::type::is_void_ty) + .def("is_fp16", &ir::type::is_half_ty) + .def("is_fp32", &ir::type::is_float_ty) + .def("is_fp64", &ir::type::is_double_ty) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + + .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) + .def_property_readonly("scalar", &ir::type::get_scalar_ty) + .def_property_readonly("context", &ir::type::get_context, ret::reference); + + py::class_(m, "pointer_type") + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); + + py::class_(m, "function_type"); + py::class_(m, "integer_type"); + py::class_(m, "block_type") + .def_property_readonly("shape", &ir::block_type::get_shapes) + .def_property_readonly("numel", &ir::type::get_tile_num_elements); + + py::class_(m, "scope") + .def(py::init<>()) + .def_property_readonly("values", &ir::scope::get_values) + .def("set_type", &ir::scope::set_type); + + py::class_(m, "module") + .def(py::init()) + .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) + .def("add_new_scope", &ir::module::add_new_scope, ret::reference) + .def("seal_block", &ir::module::seal_block) + .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) + .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) + .def("pop_scope", &ir::module::pop_scope) + .def_property_readonly("scope", &ir::module::get_scope, ret::reference) + .def_property_readonly("builder", &ir::module::get_builder, ret::reference); + + using eattr = ir::attribute_kind_t; + py::enum_(m, "attribute_kind") + .value("readonly", eattr::readonly) + .value("writeonly", eattr::writeonly) + .value("noalias", eattr::noalias) + .value("aligned", eattr::aligned) + .value("multiple_of", eattr::multiple_of) + .value("retune", eattr::retune) + .value("not_implemented", eattr::not_implemented); + + py::class_(m, "attribute") + .def(py::init()); + + py::class_(m, "function") + .def_property_readonly("args", &ir::function::args) + .def_property_readonly("attrs", &ir::function::attrs) + .def("add_attr", &ir::function::add_attr); + + py::class_(m, "argument"); + + py::class_(m, "basic_block") + .def("create", &ir::basic_block::create, ret::reference) + .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); + + py::class_(m, "builder", py::dynamic_attr()) + .def(py::init()) + // getters + .def_property_readonly("context", &ir::builder::get_context, ret::reference) + // control flow + .def("br", &ir::builder::create_br, ret::reference) + .def("cond_br", &ir::builder::create_cond_br, ret::reference) + .def("ret_void", &ir::builder::create_ret_void, ret::reference) + .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) + .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + // constants + .def("get_int1", &ir::builder::get_int1, ret::reference) + .def("get_int32", &ir::builder::get_int32, ret::reference) + .def("get_float16", &ir::builder::get_float16, ret::reference) + .def("get_float32", &ir::builder::get_float32, ret::reference) + .def("get_range", &ir::builder::get_range, ret::reference); } void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); + init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_driver(std::move(subm.def_submodule("driver"))); - init_triton_runtime(std::move(subm.def_submodule("runtime"))); - init_triton_tools(std::move(subm.def_submodule("tools"))); + init_triton_ir(std::move(subm.def_submodule("ir"))); + init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/test_code_gen.py b/python/test/test_code_gen.py new file mode 100644 index 000000000..140105c27 --- /dev/null +++ b/python/test/test_code_gen.py @@ -0,0 +1,209 @@ +import torch +import triton +import copy +import pytest +import ast + +torch.manual_seed(0) + +# convert from string to torch.dtype +# Necessary because doesn't print torch.dtype properly +cvt = { + 'bool': torch.bool, + 'int8': torch.int8, + 'int16': torch.int16, + 'int32': torch.int32, + 'int64': torch.int64, + 'float16': torch.float16, + 'float32': torch.float32, + 'float64': torch.float64, +} + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + float_dtypes + + +def patch_kernel(template, to_replace): + kernel = copy.deepcopy(template) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +# generic test functions +def _test_unary(dtype_x, expr, device='cuda'): + SIZE = 128 + # define the kernel / launch-grid + @triton.jit + def kernel(Z, X, **meta): + off = triton.arange(0, meta['SIZE']) + x = triton.load(X + off) + z = GENERATE_TEST_HERE + triton.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) + # reference result + z_ref = eval(expr) + # triton result + z_tri = torch.empty_like(z_ref) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + # compare + triton.testing.assert_allclose(z_ref, z_tri) + + +def _test_binary(dtype_x, dtype_y, expr, device='cuda'): + SIZE = 128 + # define the kernel / launch-grid + @triton.jit + def kernel(Z, X, Y, **meta): + off = triton.arange(0, meta['SIZE']) + x = triton.load(X + off) + y = triton.load(Y + off) + z = GENERATE_TEST_HERE + triton.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) + y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device) + # reference result + z_ref = eval(expr) + # triton result + z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) + kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) + # compare + triton.testing.assert_allclose(z_ref, z_tri) + + +# --------------- +# test binary ops +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ + (dtype_x, dtype_y, f' x {op} y') \ + for op in ['+', '-', '*', '/', '%'] \ + for dtype_x in dtypes \ + for dtype_y in dtypes +]) +def test_bin_op(dtype_x, dtype_y, expr, device='cuda'): + _test_binary(dtype_x, dtype_y, expr, device=device) + + +# --------------- +# test bitwise ops +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ + (dtype_x, dtype_y, f' x {op} y') \ + for op in ['&', '|', '^'] \ + for dtype_x in dtypes \ + for dtype_y in dtypes +]) +def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): + if 'float' in dtype_x + dtype_y: + with pytest.raises(RuntimeError): + _test_binary(dtype_x, dtype_y, expr, device=device) + else: + _test_binary(dtype_x, dtype_y, expr, device=device) + + +# --------------- +# test compare ops +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ + (dtype_x, dtype_y, f' x {op} y') \ + for op in ['==', '!=', '>', '<', '>=', '<='] \ + for dtype_x in dtypes \ + for dtype_y in dtypes +]) +def test_compare_op(dtype_x, dtype_y, expr, device='cuda'): + _test_binary(dtype_x, dtype_y, expr, device=device) + + +# --------------- +# test unary ops +# --------------- +@pytest.mark.parametrize("dtype_x, expr", [ + (dtype_x, f' -x') for dtype_x in float_dtypes +] + [\ + (dtype_x, f' ~x') for dtype_x in int_dtypes + ]) +def test_unary_op(dtype_x, expr, device='cuda'): + _test_unary(dtype_x, expr, device=device) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +@pytest.mark.parametrize("expr", [f'x[{s}]' for s in + ['None, :', ':, None',\ + 'None, :, :', ':, :, None']\ +]) +def test_index1d(expr, device='cuda'): + dtype = torch.int32 + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, **meta): + SIZE = meta['SIZE'] + m = triton.arange(0, SIZE) + n = triton.arange(0, SIZE) + x = triton.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + triton.store(Z_PTR_EXPR, z) + + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + kernel = patch_kernel(kernel, to_replace) + + # torch result + x = triton.testing.random(shape_x, dtype=dtype, device=device) + y = torch.zeros(shape_z, dtype=dtype, device=device) + z_ref = eval(expr) + y + # triton result + z_tri = torch.empty_like(z_ref) + kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0]) + # compare + triton.testing.assert_allclose(z_ref, z_tri) + + +# --------------- +# test load +# --------------- + +# --------------- +# test store +# --------------- + +# --------------- +# test if +# --------------- + +# --------------- +# test for +# --------------- + +# --------------- +# test while +# --------------- diff --git a/python/test/test_conv.py b/python/test/test_conv.py deleted file mode 100644 index 46cabd3d4..000000000 --- a/python/test/test_conv.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import triton - - -def test_op(): - torch.manual_seed(0) - DTYPE = torch.float16 - N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3 - pad, stride, = (1, 1), (1, 1) - dilation = (1, 1) - a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5 - b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5 - th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation) - tt_c = triton.ops.conv(a, b, pad, stride) - rtol, atol = {torch.float32: (1e-4, 1e-5), - torch.float16: (1e-2, 1e-3)}[DTYPE] - assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index a621d347d..163269dc3 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -3,66 +3,74 @@ import itertools import triton import torch + @pytest.mark.parametrize( - "TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE", - itertools.chain(*[ - [ - # 1 warp - (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), - # # 2 warp - (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), - # # 4 warp - (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), - # 8 warp - # (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), - # (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), - # (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), - # split-k - (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), - # variable input - (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE), - ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] - ]), + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE", + itertools.chain( + *[ + [ + # 1 warp + (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), + # 2 warp + (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), + # 4 warp + (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), + # 8 warp + (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), + # # split-k + (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), + # # variable input + (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE), + ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] + ] + ), ) -def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE): - DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE): torch.manual_seed(0) - defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)} - triton.ops._matmul._kernels = dict() - triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)] - if M is None: - M = TM - if N is None: - N = TN - if K is None: - K = TK * SPLITK + # nuke kernel decorators -- will set meta-parameters manually + META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8} + configs = [triton.Config(meta=META, num_warps=NWARP)] + kernel = triton.ops._matmul.kernel + decorators = kernel.kernel_decorators + kernel.kernel_decorators = [] + triton.autotune(configs, [])(kernel) + kernel.kernel_decorators += decorators[1:] + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + # allocate/transpose inputs + DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a b = b.t() if BT else b + # run test th_c = torch.matmul(a, b) tt_c = triton.ops.matmul(a, b) assert triton.testing.allclose(th_c, tt_c) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 7841be5c0..ebea7548b 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -2,9 +2,10 @@ # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from . import testing -from .kernel import * -from . import ops +from .code_gen import jit, autotune, heuristics, Config, Autotuner +from .core import * +from . import testing +from . import ops # version __version__ = '1.0.0' \ No newline at end of file diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py new file mode 100644 index 000000000..b1a2eff11 --- /dev/null +++ b/python/triton/code_gen.py @@ -0,0 +1,648 @@ +import inspect +import struct +import enum +import types +import torch +import ast +import builtins +import triton._C.libtriton.triton as _triton +import triton +import sys +import textwrap +from abc import ABC, abstractmethod + + +class CodeGenerator(ast.NodeVisitor): + def get_value(self, name): + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + if isinstance(ret, triton.block): + handle = self.module.get_value(name) + return triton.block(handle) + return ret + + def set_value(self, name, value): + if isinstance(value, _triton.ir.value): + value = triton.block(value) + if isinstance(value, triton.block): + self.module.set_value(name, value.handle) + self.module.scope.set_type(name, value.handle.type) + self.lscope[name] = value + + def is_triton_object(self, value): + return isinstance(value, triton.block) + + def visit_compound_statement(self, stmts, add_scope=False): + if add_scope: + self.module.add_new_scope() + for stmt in stmts: + self.last_ret = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + if add_scope: + self.module.pop_scope() + return self.last_ret + + def __init__(self, context, prototype, gscope, attributes, constants, kwargs): + self.builder = _triton.ir.builder(context) + self.module = _triton.ir.module('', self.builder) + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.kwargs = kwargs + self.last_node = None + self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr} + + def visit_Module(self, node): + self.module.add_new_scope() + ast.NodeVisitor.generic_visit(self, node) + self.module.pop_scope() + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + return self.visit(node.value) + + def visit_FunctionDef(self, node, inline=False, arg_values=None): + arg_names, kwarg_names = self.visit(node.args) + # store keyword arguments in local scope + self.lscope[kwarg_names] = self.kwargs + # initialize function + if inline: + pass + else: + fn = self.module.get_or_insert_function(node.name, self.prototype) + arg_values = [] + for i, arg_name in enumerate(arg_names): + if i in self.constants: + arg_values.append(self.constants[i]) + else: + if i in self.attributes: + is_ptr = fn.args[i].type.is_ptr() + attr = 'aligned' if is_ptr else 'multiple_of' + attr = getattr(_triton.ir.attribute_kind, attr) + attr = _triton.ir.attribute(attr, self.attributes[i]) + fn.add_attr(i + 1, attr) + fn.args[i].name = arg_name + arg_values.append(fn.args[i]) + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + if inline: + return self.visit_compound_statement(node.body, add_scope=True) + else: + entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) + self.module.seal_block(entry) + self.builder.set_insert_block(entry) + # visit function body + self.visit_compound_statement(node.body, add_scope=True) + # finalize function + self.builder.ret_void() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + assert len(names) == 1 + name = names[0] + value = self.visit(node.value) + self.set_value(names[0], value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.get_value(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.get_value(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + fn = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + }[type(node.op)] + kws = dict() + + if self.is_triton_object(lhs): + kws['builder'] = self.builder + ret = getattr(lhs, fn)(rhs, **kws) + if ret is NotImplemented: + if self.is_triton_object(rhs): + kws['builder'] = self.builder + fn = fn[:2] + 'r' + fn[2:] + ret = getattr(rhs, fn)(lhs, **kws) + return ret + + def visit_If(self, node): + cond = self.visit(node.test) + if self.is_triton_object(cond): + current_bb = self.builder.get_insert_block() + then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) + else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None + endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) + self.module.seal_block(then_bb) + if else_bb: + self.module.seal_block(else_bb) + self.builder.cond_br(cond.handle, then_bb, else_bb) + else: + self.builder.cond_br(cond.handle, then_bb, endif_bb) + self.builder.set_insert_block(then_bb) + self.visit_compound_statement(node.body, add_scope=True) + # TODO: last statement is a terminator? + self.builder.br(endif_bb) + if else_bb: + self.builder.set_insert_block(else_bb) + self.visit_compound_statement(node.orelse, add_scope=True) + #TODO: last statement is a terminator? + self.builder.br(endif_bb) + self.module.seal_block(endif_bb) + self.builder.set_insert_block(endif_bb) + else: + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + assert len(node.comparators) == 1 + assert len(node.ops) == 1 + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + fn = { + ast.Eq: '__eq__', + ast.NotEq: '__ne__', + ast.Lt: '__lt__', + ast.LtE: '__le__', + ast.Gt: '__gt__', + ast.GtE: '__ge__', + ast.Is: '__eq__', + ast.IsNot: '__ne__', + }[type(node.ops[0])] + if self.is_triton_object(lhs) or self.is_triton_object(rhs): + return getattr(lhs, fn)(rhs, builder=self.builder) + return getattr(lhs, fn)(rhs) + + def visit_UnaryOp(self, node): + op = self.visit(node.operand) + fn = { + ast.USub: '__neg__', + ast.UAdd: '__pos__', + ast.Invert: '__invert__', + }[type(node.op)] + if self.is_triton_object(op): + return getattr(op, fn)(builder=self.builder) + return getattr(op, fn)() + + def visit_While(self, node): + current_bb = self.builder.get_insert_block() + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + + def continue_fn(): + cond = self.visit(node.test) + return self.builder.cond_br(cond.handle, loop_bb, next_bb) + + continue_fn() + self.builder.set_insert_block(loop_bb) + self.visit_compound_statement(node.body, add_scope=True) + continue_fn() + stop_bb = self.builder.get_insert_block() + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) + self.builder.set_insert_block(next_bb) + + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Str(self, node): + return ast.literal_eval(node) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if self.is_triton_object(lhs): + return lhs.__getitem__(slices, builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + iterator = self.visit(node.iter.func) + assert iterator == self.builtins['range'] + # create nodes + st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) + init_node = ast.Assign(targets=[st_target], value=node.iter.args[0]) + pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]]) + neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]]) + pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)]) + build_cond = lambda: triton.where(self.visit(pos_step_node),\ + self.visit(pos_cond_node),\ + self.visit(neg_cond_node),\ + builder=self.builder) + #cond_node = neg_cond_node + step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2]) + # code generation + current_bb = self.builder.get_insert_block() + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + + def continue_fn(): + self.visit(step_node) + cond = build_cond() + return self.builder.cond_br(cond.handle, loop_bb, next_bb) + + self.visit(init_node) + cond = build_cond() + self.builder.cond_br(cond.handle, loop_bb, next_bb) + self.builder.set_insert_block(loop_bb) + self.visit_compound_statement(node.body, add_scope=True) + # TODO: handle case where body breaks control flow + continue_fn() + stop_bb = self.builder.get_insert_block() + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) + self.builder.set_insert_block(next_bb) + + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_NameConstant(self, node): + return node.value + + def visit_keyword(self, node): + return {node.arg: self.visit(node.value)} + + def visit_Call(self, node): + fn = self.visit(node.func) + kws = dict() + for keyword in node.keywords: + kws.update(self.visit(keyword)) + args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): + return fn(*args, generator=self, **kws) + if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ + sys.modules[fn.__module__] is triton.core: + return fn(*args, builder=self.builder, **kws) + return fn(*args, **kws) + + def visit_Num(self, node): + return node.n + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit(self, node): + if node is not None: + self.last_node = node + return super().visit(node) + + def generic_visit(self, node): + typename = type(node).__name__ + raise NotImplementedError("Unsupported node: {}".format(typename)) + + +class Binary: + def __init__(self, module, kernel, num_warps, shared_mem): + self.module = module + self.kernel = kernel + self.shared_mem = shared_mem + self.num_warps = num_warps + + def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): + stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem) + + +class CompilationError(Exception): + def __init__(self, src, node, err): + self.message = '\n'.join(src.split('\n')[:node.lineno]) + self.message += '\n' + ' ' * node.col_offset + '^' + self.message += '\n Error: ' + str(err) + super().__init__(self.message) + + +class Kernel: + + type_names = { + int: 'I', + float: 'f', + bool: 'B', + torch.float16: 'f16', + torch.float32: 'f32', + torch.float64: 'f64', + torch.bool: 'i1', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + } + + @staticmethod + def _to_triton_ir(context, obj): + type_map = { + 'I': _triton.ir.type.get_int32, + 'f': _triton.ir.type.get_fp32, + 'B': _triton.ir.type.get_int1, + 'f16': _triton.ir.type.get_fp16, + 'f32': _triton.ir.type.get_fp32, + 'f64': _triton.ir.type.get_fp64, + 'i1': _triton.ir.type.get_int1, + 'i8': _triton.ir.type.get_int8, + 'i16': _triton.ir.type.get_int16, + 'i32': _triton.ir.type.get_int32, + 'i64': _triton.ir.type.get_int64, + } + # convert torch.Tensor to Triton IR pointers + if isinstance(obj, torch.Tensor): + name = Kernel.type_names[obj.dtype] + elt_ty = type_map[name](context) + return _triton.ir.type.make_ptr(elt_ty, 1) + # default path returns triton.ir.type directly + name = Kernel.type_names[obj.__class__] + return type_map[name](context) + + @staticmethod + def _types_key(*wargs, tensor_idxs): + # type inference + types_key = [None] * len(wargs) + for i, arg in enumerate(wargs): + prefix = 'P' if i in tensor_idxs else '' + suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__] + types_key[i] = prefix + suffix + return tuple(types_key) + + @staticmethod + def pow2_divisor(N): + if N % 16 == 0: return 16 + if N % 8 == 0: return 8 + if N % 4 == 0: return 4 + if N % 2 == 0: return 2 + return 1 + + def __init__(self, fn): + self.fn = fn + + def _compile(self, *wargs, device, attributes, constants, num_warps, **meta): + # explicitly set device + torch.cuda.set_device(device.index) + # create IR module + context = _triton.ir.context() + # get just-in-time proto-type of kernel + arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] + ret_type = _triton.ir.type.get_void(context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) + # generate Triton-IR + # export symbols visible from self.fn into code-generator object + gscope = sys.modules[self.fn.module].__dict__ + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta) + try: + generator.visit(self.fn.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.fn.src, node, e) + tt_device = _triton.driver.cu_device(device.index, False) + # Compile to machine code + mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps) + return Binary(mod, ker, num_warps, shared_mem) + + def __call__(self, *wargs, grid, num_warps=4, **meta): + # device inference + tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)] + if len(tensor_idxs) == 0: + raise ValueError("No Tensor argument found.") + device = wargs[tensor_idxs[0]].device + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + # determine if we need to re-compile + types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) + attr_key = frozenset(attributes.items()) + meta_key = frozenset(meta.items()) + const_key = frozenset(constants.items()) + key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key) + cache = self.fn.cache + if key not in cache: + # compile and cache configuration if necessary + cache[key] = self._compile( + *wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta + ) + # pack arguments + fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)]) + params = struct.pack(fmt, *args) + # enqueue cached function into stream + binary = cache[key] + cu_stream = torch.cuda.current_stream(device.index).cuda_stream + stream = _triton.driver.cu_stream(cu_stream, False) + grid = grid(meta) if hasattr(grid, '__call__') else grid + binary(stream, params, *grid) + + +class Launcher: + def __init__(self, kernel, grid): + self.kernel = kernel + self.grid = grid + + def __call__(self, *wargs, **kwargs): + self.kernel(*wargs, **kwargs, grid=self.grid) + + +class Autotuner: + def __init__(self, kernel, arg_names, configs, key): + if not configs: + self.configs = [Config(dict(), num_warps=4)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = dict() + self.kernel = kernel + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.meta.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.meta) + kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current) + return triton.testing.do_bench(kernel_call) + + def __call__(self, *args, **meta): + if len(self.configs) > 1: + key = tuple([args[i] for i in self.key_idx]) + if key not in self.cache: + timings = {config: self._bench(*args, config=config, **meta) \ + for config in self.configs} + self.cache[key] = builtins.min(timings, key=timings.get) + config = self.cache[key] + else: + config = self.configs[0] + self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta) + + +class JITFunction: + def __init__(self, fn): + self.module = fn.__module__ + self.arg_names = inspect.getfullargspec(fn).args + self.cache = dict() + self.kernel_decorators = [] + self.src = textwrap.dedent(inspect.getsource(fn)) + self.kernel = None + + # we do not parse in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Some unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, generator: CodeGenerator, **meta): + try: + return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.src, node, e) + + def _init_kernel(self): + if self.kernel is None: + self.kernel = Kernel(self) + for decorator in reversed(self.kernel_decorators): + self.kernel = decorator(self.kernel) + return self.kernel + + def __getitem__(self, grid): + return Launcher(self._init_kernel(), grid) + + +class Config: + def __init__(self, meta, num_warps=4): + self.meta = meta + self.num_warps = num_warps + + +def autotune(configs, key): + def decorator(fn): + def wrapper(kernel): + return Autotuner(kernel, fn.arg_names, configs, key) + + fn.kernel_decorators.append(wrapper) + return fn + + return decorator + + +def heuristics(values): + def decorator(fn): + def wrapper(kernel): + def fun(*args, **meta): + for v, heur in values.items(): + assert v not in meta + meta[v] = heur(*args, **meta) + return kernel(*args, **meta) + + return fun + + fn.kernel_decorators.append(wrapper) + return fn + + return decorator + + +def jit(fn): + return JITFunction(fn) diff --git a/python/triton/core.py b/python/triton/core.py new file mode 100644 index 000000000..8325990c1 --- /dev/null +++ b/python/triton/core.py @@ -0,0 +1,499 @@ +from triton._C.libtriton.triton import ir +from triton._C.libtriton.triton import frontend +import triton +from functools import wraps + + +def _patch(fn): + + # convert block/dtype to ir values + def _to_ir(x, builder): + if isinstance(x, bool): + return builder.get_int1(x) + elif isinstance(x, int): + return builder.get_int32(x) + elif isinstance(x, float): + return builder.get_float32(x) + if isinstance(x, block): + return x.handle + if isinstance(x, dtype): + return x.handle(builder) + return x + + def _from_ir(x): + if isinstance(x, ir.value): + if x.type.is_void(): + return None + return block(x) + return x + + def wrapper(*args, **kwargs): + builder = args[-1] + assert isinstance(builder, ir.builder) + args = [_to_ir(x, builder) for x in args] + kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} + ret = fn(*args, **kwargs) + if isinstance(ret, tuple): + return map(_from_ir, ret) + return _from_ir(ret) + + return wrapper + + +for name in dir(frontend): + fn = getattr(frontend, name) + if callable(fn): + setattr(frontend, name, _patch(fn)) + + +def builtin(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if 'builder' not in kwargs or \ + kwargs['builder'] is None: + raise ValueError("Builder argument must be provided outside of JIT functions") + return fn(*args, **kwargs) + + if wrapper.__doc__: + wrapper.__doc__ += """\ +:param builder: IR builder to generate code into, optional from within @triton.jit functions + :type builder: triton.ir.builder +""" + return wrapper + + +class dtype: + def __init__(self, init): + self.init = init + + def handle(self, builder): + ctx = builder.context + return self.init(ctx) + + +class pointer_dtype: + def __init__(self, element_ty): + self.element_ty = element_ty + + def handle(self, builder): + return ir.type.make_ptr(self.element_ty, 1) + + +int1 = dtype(ir.type.get_int1) +int8 = dtype(ir.type.get_int8) +int16 = dtype(ir.type.get_int16) +int32 = dtype(ir.type.get_int32) +int64 = dtype(ir.type.get_int64) +float16 = dtype(ir.type.get_fp16) +float32 = dtype(ir.type.get_fp32) +float64 = dtype(ir.type.get_fp64) + + +class block: + @staticmethod + def _init_dtype(ir_type): + # primitive type + if ir_type.is_int1(): return int1 + if ir_type.is_int8(): return int8 + if ir_type.is_int16(): return int16 + if ir_type.is_int32(): return int32 + if ir_type.is_int64(): return int64 + if ir_type.is_fp16(): return float16 + if ir_type.is_fp32(): return float32 + if ir_type.is_fp64(): return float64 + # pointer type + if ir_type.is_ptr(): + element_ty = block._init_dtype(ir_type.element) + return pointer_dtype(element_ty) + raise ValueError(f"Unsupported type {ir_type}") + + def __init__(self, handle): + # IR handle + self.handle = handle + # Block shape + self.shape = (1, ) + if self.handle.type.is_block(): + self.shape = self.handle.type.shape + # Data-type wrapper + self.dtype = block._init_dtype(self.handle.type.scalar) + + @builtin + def __add__(self, other, builder=None): + return frontend.add(self, other, builder) + + def __radd__(self, other, builder=None): + return self.__add__(other, builder=builder) + + @builtin + def __sub__(self, other, builder=None): + return frontend.sub(self, other, builder) + + @builtin + def __mul__(self, other, builder=None): + return frontend.mul(self, other, builder) + + def __rmul__(self, other, builder=None): + return self.__mul__(other, builder=builder) + + @builtin + def __truediv__(self, other, builder=None): + return frontend.truediv(self, other, builder) + + def __rtruediv__(self, other, builder=None): + return frontend.truediv(other, self, builder) + + @builtin + def __floordiv__(self, other, builder=None): + return frontend.floordiv(self, other, builder) + + @builtin + def __mod__(self, other, builder=None): + return frontend.mod(self, other, builder) + + # unary operators + @builtin + def __neg__(self, builder=None): + return frontend.minus(self, builder) + + @builtin + def __invert__(self, builder=None): + return frontend.invert(self, builder) + + # bitwise operators + + @builtin + def __and__(self, other, builder=None): + return frontend.and_(self, other, builder) + + @builtin + def __or__(self, other, builder=None): + return frontend.or_(self, other, builder) + + @builtin + def __xor__(self, other, builder=None): + return frontend.xor_(self, other, builder) + + @builtin + def __lshift__(self, other, builder=None): + return frontend.shl(self, other, builder) + + @builtin + def __rshift__(self, other, builder=None): + return frontend.lshr(self, other, builder) + + # comparison operators + + @builtin + def __gt__(self, other, builder=None): + return frontend.greater_than(self, other, builder) + + @builtin + def __ge__(self, other, builder=None): + return frontend.greater_equal(self, other, builder) + + @builtin + def __lt__(self, other, builder=None): + return frontend.less_than(self, other, builder) + + @builtin + def __le__(self, other, builder=None): + return frontend.less_equal(self, other, builder) + + @builtin + def __eq__(self, other, builder=None): + return frontend.equal(self, other, builder) + + @builtin + def __ne__(self, other, builder=None): + return frontend.not_equal(self, other, builder) + + @builtin + def __getitem__(self, slices, builder=None): + if isinstance(slices, slice): + slices = [slices] + src_shape = self.shape + dst_shape = [] + curr = 0 + for sl in slices: + if sl == None: + dst_shape.append(1) + elif sl == slice(None, None, None): + dst_shape.append(src_shape[curr]) + curr += 1 + ret = frontend.reshape(self, dst_shape, builder) + return ret + + @builtin + def to(self, dtype, builder=None): + return frontend.cast(self, dtype.handle(builder), builder) + + +# ----------------------- +# SPMD Programming Model +# ----------------------- + + +@builtin +def program_id(axis, builder=None): + """ + Returns the id of the current program instance along the given `axis`. + Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + """ + return frontend.program_id(axis, builder) + + +@builtin +def num_programs(axis, builder=None): + """ + Returns the number of program instances launched along the given `axis`. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + """ + return frontend.num_programs(axis, builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, builder=None): + """ + Returns contiguous values within the open interval [start, end). + + :param start: Start of the interval. + :type start: int + :param stop: End of the interval. + :type stop: int + """ + return frontend.arange(start, end, builder) + + +@builtin +def zeros(shape, dtype, builder=None): + """ + Returns a block filled with the scalar value 0 and the given shape. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., triton.float16 + :type dtype: triton.ir.dtype + """ + return frontend.zeros(shape, dtype, builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, builder=None): + """ + Tries to broadcast two blocks to a common compatible shape. + + :param input: The first input block. + :type input: triton.ir.value + :param other: The second input block. + :type other: triton.ir.value + """ + return frontend.broadcast(input, other, builder) + + +@builtin +def broadcast_to(input, shape, builder=None): + """ + Tries to broadcast a block to a new shape. + + :param input: The input block. + :type input: triton.value + :param shape: The new shape. + :type shape: tuple of int + """ + return frontend.broadcast_to(input, shape, builder) + + +@builtin +def reshape(input, shape, builder=None): + """ + Reshapes a block to a new shape. + """ + return frontend.reshape(input, shape, builder) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, builder=None): + """ + Returns the matrix product of two blocks. + The two blocks must be two dimensionals and have compatible inner dimensions. + + :param input: The first block to be multiplied. + :type input: 2D block of scalar-type in {`float16`, `float32`} + :param other: The second block to be multiplied. + :type other: 2D block of scalar-type in {`float16`, `float32`} + """ + return frontend.dot(input, other, builder) + + +# ----------------------- +# Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, builder=None): + """ + Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`. + + :param pointer: Pointer to the data to be loaded. + :type pointer: Block of triton.pointer + :param mask: if mask[idx] is false, do not load the data at `pointer[idx]`. + :type mask: Block of triton.bool, optional + :param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]` + :type other: Block of triton.value, optional + """ + return frontend.load(pointer, mask, other, builder) + + +@builtin +def store(pointer, value, mask=None, builder=None): + """ + Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`. + + :param pointer: The memory locations where the elements of `value` are stored. + :type pointer: Block of triton.pointer + :param value: The block of elements to be stored. + :type value: Block of triton.value + :param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`. + :type mask: Block of triton.bool, optional + """ + return frontend.store(pointer, value, mask, builder) + + +@builtin +def atomic_cas(ptr, cmp, val, builder=None): + return frontend.atomic_cas(ptr, cmp, val, builder) + + +@builtin +def atomic_xchg(ptr, val, builder=None): + return frontend.atomic_xchg(ptr, val, builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, builder=None): + """ + Returns a block of elements from either `x` or `y`, depending on `condition`. + Note that `x` and `y` are always evaluated regardless of the value of `condition`. + If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. + The shape of `x` and `y` are both broadcast to the shape of `condition`. + `x` and `y` must have the data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + return frontend.where(condition, x, y, builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def exp(x, builder=None): + return frontend.exp(x, builder) + + +@builtin +def log(x, builder=None): + return frontend.log(x, builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +@builtin +def max(input, axis, builder=None): + return frontend.max(input, axis, builder) + + +@builtin +def min(input, axis, builder=None): + return frontend.min(input, axis, builder) + + +@builtin +def sum(input, axis, builder=None): + return frontend.sum(input, axis, builder) + + +# ----------------------- +# Internal for debugging +# ----------------------- + + +@builtin +def debug_barrier(builder=None): + return frontend.debug_barrier(builder) + + +@builtin +def multiple_of(x, value, builder=None): + return frontend.multiple_of(x, value, builder) + + +# ----------------------- +# Standard library +# ----------------------- + + +@triton.jit +def minimum(x, y): + return triton.where(x < y, x, y) + + +@triton.jit +def maximum(x, y): + return triton.where(x > y, x, y) + + +@triton.jit +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +@triton.jit +def ravel(x): + return triton.reshape(x, [x.type.numel]) + + +@triton.jit +def softmax(x): + z = x - triton.max(x, 0) + num = triton.exp(z) + den = triton.sum(num, 0) + return num / den + + +def cdiv(x, y): + return (x + y - 1) // y diff --git a/python/triton/kernel.py b/python/triton/kernel.py deleted file mode 100644 index 04237b902..000000000 --- a/python/triton/kernel.py +++ /dev/null @@ -1,119 +0,0 @@ -import os -import struct -from typing import Optional, Dict, List, Callable -import torch -import triton._C.libtriton.triton as _triton - -codes = { - _triton.runtime.arg_type.int1: 'B', - _triton.runtime.arg_type.int8: 'B', - _triton.runtime.arg_type.int32: 'I', - _triton.runtime.arg_type.int64: 'Q', - _triton.runtime.arg_type.half: 'H', - _triton.runtime.arg_type.float: 'f', - _triton.runtime.arg_type.double: 'd', - _triton.runtime.arg_type.buffer: 'P', -} - - -def th_to_triton(obj): - """ Convert a `torch.dtype` to a Triton-C type string. """ - tys = { - torch.int8: 'char', - torch.int16: 'short', - torch.int32: 'int', - torch.int64: 'long', - torch.float16: 'half', - torch.float32: 'float', - torch.float64: 'double', - } - if isinstance(obj, torch.dtype): - return tys[obj] - return str(obj) - - -def cdiv(a: int, b: int) -> int: - """ Ceil division (a + b - 1) // b""" - return (a + b - 1) // b - - -def read(path: str, kernel_names: Optional[List] = None) -> str: - """ Extracts the source code for `kernel_names` from the given `path` file.""" - if kernel_names is None: - kernel_names = [] - with open(path, 'r') as f: - source = f.read() - source = _triton.tools.extract_kernels(source, kernel_names) - return source - - -config = _triton.runtime.config - - -class kernel: - """ - A class used to represent a Triton kernel. - """ - def __init__( - self, - src: str, - device: torch.device, - defines: Optional[Dict] = None, - num_warps: int = 4, - autotune_configs: Optional[List] = None, - autotune_key: Optional[List] = None - ): - """ - :param src: The source code of the kernel. - :param device: The device to compile the kernel for. - :param defines: A dictionary of preprocessor #define for the compiler. - :param num_warps: Optimization flag for the compiler's internal auto-parallelization engine. - :param autotune_configs: A list of triton.config objects for the autotuner to try. - :param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run. - """ - - if defines is None: - defines = {} - if autotune_configs is None: - autotune_configs = [] - if autotune_key is None: - autotune_key = [] - # check if src is empty - if src == '': - raise ValueError('Kernel source code is empty') - self.src = src - # device - assert device.type in ['cuda', 'cpu'] - if device.type == 'cuda': - self.device_id = torch.cuda.current_device() if device.index is None else device.index - self.device = _triton.driver.cu_device(self.device_id, False) - cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream - self.stream = _triton.driver.cu_stream(cu_stream, False) - if device.type == 'cpu': - self.device_id = -1 - self.device = _triton.driver.host_device() - self.device = _triton.driver.host_stream() - torch.cuda.set_device(self.device_id) - # function - self.opt = _triton.runtime.options() - self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} - self.opt.num_warps = num_warps - # autotune_configs = [({}, 4)] - self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key) - self.tys = ''.join([codes[x] for x in self.fn.signature()]) - - def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]): - """ - Runs the kernel on the given arguments and launch grid. - :param args: The arguments to the kernel in the orders that they appear in the Triton-C source. - :param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers. - :return: None - """ - # make sure that the executing thread is on the right device - torch.cuda.set_device(self.device_id) - # pack parameters into a byte buffer - params = struct.pack(self.tys, *args) - kernel = self.fn.autotune(params, grid, self.stream) - # run kernel - grid = grid(kernel.opt) - kernel(params, self.stream, grid) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index 425c42d8f..ca6ca61f8 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,4 +1,4 @@ -from .conv import _conv, conv +#from .conv import _conv, conv from .matmul import _matmul, matmul from .cross_entropy import _cross_entropy, cross_entropy from . import blocksparse \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.c b/python/triton/ops/blocksparse/matmul.c deleted file mode 100644 index e3522ec29..000000000 --- a/python/triton/ops/blocksparse/matmul.c +++ /dev/null @@ -1,199 +0,0 @@ -__global__ void NAME(TYPE *A __readonly __noalias, - TYPE *B __readonly __noalias, - TYPE *C __noalias, - int lda, - int ldb, - int ldc, - long stride_za, - long stride_zb, - long stride_zc, - long stride_ha, - long stride_hb, - long stride_hc, - int DS0, int DS1, - int SDD_K, - int SDD_off_width, - int *lut, int *locks, int nlocks) { - /* ---------------- */ - /* Prologue */ - /* ---------------- */ - // program ids - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int pidz = get_program_id(2); -#ifdef SDD - // load LUT header - pid1 = pid1 + SDD_off_width; - int blockidm[TM] = (0 ... TM) / BLOCK; - int blockidn[TN] = (0 ... TN) / BLOCK; - int offlutm[TM] = blockidm * (TN / BLOCK) * 4; - int offlutn[TN] = blockidn * 4; - int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4; - int z = *(header + 0); - int i[TM] = *(header + 1 + offlutm); - int j[TN] = *(header + 2 + offlutn); - int AS1 = SDD_K / TZ; - int lockid = select(TZ > 1, 1, 0); - int offka = pid0 * AS1; - int offkb = pid0 * AS1; - int offmc = 0; - int offnc = 0; - int offpa = 0; - int offpb = 0; - int maxid = TZ; - int offhc = 0; - int offha = z; - int offhb = z; - int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK); - int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK); -#else - // load LUT header - int *header = lut + pid0 * 6; - int offset = *(header + 0); - int AS1 = *(header + 1); - int column = *(header + 2); - int depth = *(header + 3); - int lockid = *(header + 4); - int maxid = *(header + 5); - int *pinc = lut + offset; - int offhc = depth; -#ifdef DSD - // output offset - int offnc = pid1 * TN; - int offmc = column * TM; - int offpc = 0; - // dense input offset - int offnb = pid1 * TN; - int offkb __multipleof(8) = *pinc; - int offpb = 0; - // sparse input offset - int offma = 0; - int offka = 0; - long offpa __multipleof(8) = *(pinc + 1); - offpa = offpa * BLOCK * BLOCK; - int offha = 0; - int offhb = depth; -#endif -#ifdef DDS - // output offset - int offmc = pid1 * TM; - int offnc = column * TN; - int offpc = 0; - // dense input offset - int offma = pid1 * TM; - int offka __multipleof(8) = *pinc; - int offpa = 0; - // sparse input offset - int offnb = 0; - int offkb = 0; - long offpb __multipleof(8) = *(pinc + 1); - offpb = offpb * BLOCK * BLOCK; - int offha = depth; - int offhb = 0; -#endif - int ram[TM] = offma + 0 ... TM; - int rbn[TN] = offnb + 0 ... TN; -#endif - // initialize a, b pointers - int rka[TK] = offka + 0 ... TK; - int rkb[TK] = offkb + 0 ... TK; - TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK; - TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; - // pre-fetch -#ifdef DDS - bool checkam[TM, TK] = ram[:, newaxis] < DS0; -#else - bool checkam[TM, TK] = AS1 > 0; -#endif -#ifdef DSD - bool checkbn[TK, TN] = rbn [newaxis, :] < DS0; -#else - bool checkbn[TK, TN] = AS1 > 0; -#endif - TYPE a[TM, TK] = checkam ? *pa : 0; - TYPE b[TK, TN] = checkbn ? *pb : 0; - - /* ---------------- */ - /* Inner Loop */ - /* ---------------- */ - // create result tile - float acc[TM, TN] = 0; - int step = TK; - for (int k = AS1; k > 0; k -= step) { - acc += a @b; - // update pointers -#ifdef SDD - int inc_a = TK * STRIDE_AK; - int inc_b = TK * STRIDE_BK; -#else - pinc += 2; -#ifdef DSD - int inc_b __multipleof(8) = *pinc; - int inc_a __multipleof(8) = *(pinc + 1); - inc_b = inc_b * STRIDE_BK; -#endif -#ifdef DDS - int inc_a __multipleof(8) = *pinc; - int inc_b __multipleof(8) = *(pinc + 1); - inc_a = inc_a * STRIDE_AK; -#endif -#endif - pa += inc_a; - pb += inc_b; - // pre-fetch - bool checkak[TM, TK] = k > TK; - bool checkbk[TK, TN] = k > TK; - bool checka[TM, TK] = checkam && checkak; - bool checkb[TK, TN] = checkbk && checkbn; - a = *? (checka)pa; - b = *? (checkb)pb; - } - TYPE c[TM, TN] = acc; - - /* ---------------- */ - /* Epilogue */ - /* ---------------- */ - // initialize c pointers -#ifdef SDD - bool checkc[TM, TN] = 1; - // rematerialize - int rr_blockidm[TM] = (0 ... TM) / BLOCK; - int rr_blockidn[TN] = (0 ... TN) / BLOCK; - int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4; - int rr_offlutn[TN] = rr_blockidn * 4; - int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :]; - int bkid[TM, TN] = *(header + off_bkid); - long offpc[TM, TN] = bkid * BLOCK * BLOCK; - // range within blocks - int rcm[TM] = (0 ... TM) % BLOCK; - int rcn[TN] = (0 ... TN) % BLOCK; -#else - int rcm[TM] = offmc + 0 ... TM; - int rcn[TN] = offnc + 0 ... TN; -#ifdef DSD - bool checkc[TM, TN] = rcn [newaxis, :] < DS0; -#endif -#ifdef DDS - bool checkc[TM, TN] = rcm[:, newaxis] < DS0; -#endif -#endif - TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN; - // write-back directly - if (lockid == 0) { - *? (checkc)pc = c; - } - // accumulate partial result using spin-locks - else { - int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1; - int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks; - for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) - ; - int count = *pcount; - if (count == 0) - *? (checkc)pc = c; - else - *? (checkc)pc = c + *? (checkc)pc; - atomic_xchg(pcount, (count + 1) % maxid); - atomic_xchg(plock, 0); - } -} \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 3eff88060..50af3c564 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -4,7 +4,183 @@ import torch import os import math -src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) + +@triton.jit +def _kernel( + A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, + stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta +): + TM = meta['TM'] + TN = meta['TN'] + TK = meta['TK'] + TZ = meta['TZ'] + BLOCK = meta['BLOCK'] + #------------# + #- Prologue -# + #------------# + pid0 = triton.program_id(0) + pid1 = triton.program_id(1) + pidz = triton.program_id(2) + if meta['SDD']: + pid1 = pid1 + SDD_off_width + blockidm = triton.arange(0, TM) // BLOCK + blockidn = triton.arange(0, TN) // BLOCK + offlutm = blockidm * (TN // BLOCK) * 4 + offlutn = blockidn * 4 + header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 + z = triton.load(header + 0) + i = triton.load(header + 1 + offlutm) + j = triton.load(header + 2 + offlutn) + AS1 = SDD_K // TZ + lockid = triton.where(TZ > 1, 1, 0) + offka = pid0 * AS1 + offkb = pid0 * AS1 + offmc = 0 + offnc = 0 + offpa = 0 + offpb = 0 + maxid = TZ + offhc = 0 + offha = z + offhb = z + ram = i * BLOCK + (triton.arange(0, TM) % BLOCK) + rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK) + else: + header = lut + pid0 * 6 + offset = triton.load(header + 0) + AS1 = triton.load(header + 1) + column = triton.load(header + 2) + depth = triton.load(header + 3) + lockid = triton.load(header + 4) + maxid = triton.load(header + 5) + pinc = lut + offset + offhc = depth + if meta['DSD']: + # output offset + offnc = pid1 * TN + offmc = column * TM + offpc = 0 + # dense input offset + offnb = pid1 * TN + offkb = triton.load(pinc) + offkb = triton.multiple_of(offkb, 8) # compiler hint + offpb = 0 + # sparse input offset + offma = 0 + offka = 0 + offpa = triton.load(pinc + 1) + offpa = triton.multiple_of(offpa, 8) # compiler hint + offpa = offpa * BLOCK * BLOCK + offha = 0 + offhb = depth + else: + # output offset + offmc = pid1 * TM + offnc = column * TN + offpc = 0 + # dense input offset + offma = pid1 * TM + offka = triton.load(pinc) + offka = triton.multiple_of(offka, 8) # compiler hint + offpa = 0 + # sparse input offset + offnb = 0 + offkb = 0 + offpb = triton.load(pinc + 1) + offpb = triton.multiple_of(offpb, 8) # compiler hint + offpb = offpb * BLOCK * BLOCK + offha = depth + offhb = 0 + ram = offma + triton.arange(0, TM) + rbn = offnb + triton.arange(0, TN) + + # initialize a, b pointers + rka = offka + triton.arange(0, TK) + rkb = offkb + triton.arange(0, TK) + pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka + pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb + if meta['DDS']: + checkam = ram[:, None] < DS0 + else: + checkam = AS1 > 0 + if meta['DSD']: + checkbn = rbn[None, :] < DS0 + else: + checkbn = AS1 > 0 + a = triton.load(pa, mask=checkam, other=0.) + b = triton.load(pb, mask=checkbn, other=0.) + + ## ---------------- ## + ## Inner Loop ## + ## ---------------- ## + acc = triton.zeros((TM, TN), dtype=triton.float32) + for k in range(AS1, 0, -TK): + acc += triton.dot(a, b) + if meta['SDD']: + inc_a = TK * stride_ka + inc_b = TK * stride_kb + else: + pinc += 2 + if meta['DSD']: + inc_b = triton.load(pinc) + inc_a = triton.load(pinc + 1) + inc_b = triton.multiple_of(inc_b, 8) + inc_a = triton.multiple_of(inc_a, 8) + inc_b = inc_b * stride_kb + if meta['DDS']: + inc_a = triton.load(pinc) + inc_b = triton.load(pinc + 1) + inc_a = triton.multiple_of(inc_a, 8) + inc_b = triton.multiple_of(inc_b, 8) + inc_a = inc_a * stride_ka + pa += inc_a + pb += inc_b + # pre-fetch + checkak = k > TK + checkbk = k > TK + checka = checkam & checkak + checkb = checkbn & checkbk + a = triton.load(pa, mask=checka) + b = triton.load(pb, mask=checkb) + c = acc.to(C.dtype.element_ty) + + if meta['SDD']: + checkc = True + rr_blockidm = triton.arange(0, TM) // BLOCK + rr_blockidn = triton.arange(0, TN) // BLOCK + rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 + rr_offlutn = rr_blockidn * 4 + off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] + bkid = triton.load(header + off_bkid) + offpc = bkid * BLOCK * BLOCK + rcm = triton.arange(0, TM) % BLOCK + rcn = triton.arange(0, TN) % BLOCK + else: + rcm = offmc + triton.arange(0, TM) + rcn = offnc + triton.arange(0, TN) + if meta['DSD']: + checkc = rcn[None, :] < DS0 + if meta['DDS']: + checkc = rcm[:, None] < DS0 + + pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc + # write-back directly + if lockid == 0: + triton.store(pc, c, mask=checkc) + # accumulate partial results using spin-locks + else: + plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1 + pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks + while triton.atomic_cas(plock, 0, 1) == 1: + pass + count = triton.load(pcount) + if count == 0: + triton.store(pc, c, mask=checkc) + else: + d = triton.load(pc, mask=checkc) + triton.store(pc, d + c, mask=checkc) + triton.atomic_xchg(pcount, (count + 1) % maxid) + triton.atomic_xchg(plock, 0) ############## @@ -118,31 +294,11 @@ class _matmul(torch.autograd.Function): raise ValueError('Reduction size for SDD must be a multiple of 16') # create kernel total_width = sum([width * pack * pack for width, pack in zip(widths, packs)]) - c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device) + c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device) for lut, width, pack in zip(luts, widths, packs): num_lock = 1 - key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) - if key not in _matmul.sdd_cache: - defines = { - 'TM': block * pack, - 'TN': block * pack, - 'TMN': block * block * pack * pack, - 'BLOCK': block, - 'TK': 32, - 'TYPE': dtype, - 'STRIDE_AM': '1' if trans_a else 'lda', - 'STRIDE_AK': 'lda' if trans_a else '1', - 'STRIDE_BN': 'ldb' if trans_b else '1', - 'STRIDE_BK': '1' if trans_b else 'ldb', - 'STRIDE_CM': 'ldc', - 'STRIDE_CN': '1', - 'SDD': True, - 'TZ': 1, - 'NAME': 'sdd_kernel' - } - _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) - - kernel = _matmul.sdd_cache[key] + meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \ + 'SDD': True, 'DSD': False, 'DDS': False} # create output locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device) # maximum grid size is 65535 @@ -150,27 +306,32 @@ class _matmul(torch.autograd.Function): # kernel calls max_width = 49152 for off_width in range(0, width, max_width): - kernel( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - a.stride(2), - b.stride(2), - block, + grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0] + _kernel[grid]( + a, + b, + c, a.stride(0), - b.stride(0), - c.stride(0), a.stride(1), + a.stride(3 if trans_a else 2), + a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), + b.stride(3 if trans_b else 2), + b.stride(2 if trans_b else 3), c.stride(0), + c.stride(0), + c.stride(2), + c.stride(3), AS2, AS2, AS3, off_width, - lut.data_ptr(), - locks.data_ptr(), + lut, + locks, num_lock, - grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0] + num_warps=4, + **meta ) # save for backward pass return c @@ -282,25 +443,8 @@ class _matmul(torch.autograd.Function): BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel - key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) - if key not in _matmul.dds_cache: - defines = { - 'TM': 128, - 'TN': block, - 'TK': 16, - 'BLOCK': block, - 'TYPE': dtype, - 'STRIDE_AM': 1 if trans_a else 'lda', - 'STRIDE_AK': 'lda' if trans_a else 1, - 'STRIDE_BN': block if trans_b else 1, - 'STRIDE_BK': 1 if trans_b else block, - 'STRIDE_CM': '1' if trans_c else 'ldc', - 'STRIDE_CN': 'ldc' if trans_c else '1', - 'NAME': 'dds_kernel', - 'DDS': True - } - _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) - kernel = _matmul.dds_cache[key] + meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\ + 'SDD': False, 'DSD': False, 'DDS': True} # output CS0 = AS0 CS1 = AS1 @@ -308,27 +452,32 @@ class _matmul(torch.autograd.Function): CS3 = AS2 if trans_c else BS2 locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - a.stride(2), - block, - c.stride(2), + grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0] + _kernel[grid]( + a, + b, + c, a.stride(0), - b.stride(0), - c.stride(0), a.stride(1), + a.stride(3 if trans_a else 2), + a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), + b.stride(3 if trans_b else 2), + b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), + c.stride(3 if trans_c else 2), + c.stride(2 if trans_c else 3), AS2, BS2, 0, 0, - lut.data_ptr(), - locks.data_ptr(), + lut, + locks, num_locks, - grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0] + num_warps=4, + **meta ) return c @@ -344,25 +493,8 @@ class _matmul(torch.autograd.Function): BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel - key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) - if key not in _matmul.dsd_cache: - defines = { - 'TM': block, - 'TN': 128, - 'TK': 16, - 'BLOCK': block, - 'TYPE': dtype, - 'STRIDE_AM': 1 if trans_a else block, - 'STRIDE_AK': block if trans_a else 1, - 'STRIDE_BN': 'ldb' if trans_b else '1', - 'STRIDE_BK': '1' if trans_b else 'ldb', - 'STRIDE_CM': '1' if trans_c else 'ldc', - 'STRIDE_CN': 'ldc' if trans_c else '1', - 'NAME': 'dsd_kernel', - 'DSD': True - } - _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) - kernel = _matmul.dsd_cache[key] + meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\ + 'SDD': False, 'DSD': True, 'DDS': False} # output CS0 = BS0 CS1 = BS1 @@ -370,27 +502,32 @@ class _matmul(torch.autograd.Function): CS3 = AS1 if trans_c else BS3 locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - block, - b.stride(2), - c.stride(2), + grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0] + _kernel[grid]( + a, + b, + c, a.stride(0), - b.stride(0), - c.stride(0), a.stride(1), + a.stride(3 if trans_a else 2), + a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), + b.stride(3 if trans_b else 2), + b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), + c.stride(2), + c.stride(3), BS3, AS1, 0, 0, - lut.data_ptr(), - locks.data_ptr(), + lut, + locks, num_locks, - grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0] + num_warps=4, + **meta ) return c diff --git a/python/triton/ops/blocksparse/softmax.c b/python/triton/ops/blocksparse/softmax.c deleted file mode 100644 index 8b8c9506a..000000000 --- a/python/triton/ops/blocksparse/softmax.c +++ /dev/null @@ -1,135 +0,0 @@ -__global__ void forward(TYPE *X __readonly __noalias, - float scale, - int *LUT __readonly __noalias, - TYPE *RPE __readonly __noalias, - TYPE *KP_M __readonly __noalias, - TYPE *ATTN_M __readonly __noalias, - int sizemax, - long stride_zx, - long stride_zrpe, - int stride_hrpe, - int stride_srpe, - int stride_zkpm, - int stride_zattnm) { - int pidhm = get_program_id(0); - int pidz = get_program_id(1); - // create index ranges - int rxm = pidhm % BLOCK; - int rbm = pidhm / BLOCK; - int rxn[TN] = (0 ... TN) % BLOCK; - int rbn[TN] = (0 ... TN) / BLOCK; - // extract information from look-up table - int *header = LUT + rbm * 2; - int size = *(header + 0); - int offset = *(header + 1); - bool check[TN] = rbn < size; - int rbmn[TN] = check ? rbn : size - 1; - // block id and column id - long blockid[TN] = *(LUT + offset + rbmn * 4 + 0); - long columnid[TN] = *(LUT + offset + rbmn * 4 + 1); - long rowid[TN] = *(LUT + offset + rbmn * 4 + 2); - long headid[TN] = *(LUT + offset + rbmn * 4 + 3); - // pointers to X - TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; -#ifdef APPLY_RPE - // pointers to relative position embedding - TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn; -#endif -#ifdef APPLY_KP_MASK - // pointers to key padding mask - TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn; -#endif -#ifdef APPLY_ATTN_MASK - // pointers to attention mask - TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn; -#endif - - // load input - TYPE x[TN] = check ? *px : -INFINITY; -#ifdef APPLY_RPE - // load relative position embedding - TYPE rpe[TN] = check ? *prpe : 0; -#endif -#ifdef APPLY_KP_MASK - // load key-padding mask - TYPE kp_m[TN] = check ? *pkp_m : -INFINITY; -#endif -#ifdef APPLY_ATTN_MASK - // load attention mask - TYPE attn_m[TN] = check ? *pattn_m : -INFINITY; -#endif - // compute softmax in float -#ifdef APPLY_RPE - float Frpe[TN] = rpe; -#endif -#ifdef APPLY_KP_MASK - float Fkp_m[TN] = kp_m; -#endif -#ifdef APPLY_ATTN_MASK - float Fattn_m[TN] = attn_m; -#endif -#ifdef KP_MASK_MUL - Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0; -#endif -#ifdef ATTN_MASK_MUL - Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0; -#endif - float Fx[TN] = x; -#ifdef APPLY_SCALE - Fx = Fx * scale; // apply scale -#endif -#ifdef APPLY_RPE - Fx = Fx + Frpe; // apply relative position embedding -#endif -#ifdef APPLY_KP_MASK - Fx = Fx + Fkp_m; // apply key padding mask -#endif -#ifdef APPLY_ATTN_MASK - Fx = Fx + Fattn_m; // apply attention mask -#endif - float Fxmax = Fx[max]; - float Fy[TN] = exp(Fx - Fxmax); - float Fysum = (check ? Fy : 0)[+]; - // write-back in half/float - TYPE y[TN] = Fy; - TYPE ysum = Fysum; - *? (check)px = y / ysum; -} - -__global__ void backward(TYPE *X __readonly __noalias, - float scale, - TYPE *DX __readonly __noalias, - int *LUT, - int sizemax, - long stride_zx, - long stride_zdx) { - int pidhm = get_program_id(0); - int pidz = get_program_id(1); - // create index ranges - int rxm = pidhm % BLOCK; - int rbm = pidhm / BLOCK; - int rxn[TN] = (0 ... TN) % BLOCK; - int rbn[TN] = (0 ... TN) / BLOCK; - // extract information from look-up table - int *header = LUT + rbm * 2; - int size = *(header + 0); - int offset = *(header + 1); - // bounds checking on lut - bool check[TN] = rbn < size; - int rbmn[TN] = check ? rbn : size - 1; - // initialize pointers to block-sparse input - long blockid[TN] = *(LUT + offset + rbmn * 4); - TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; - TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; - // compute fused softmax backward - TYPE x[TN] = check ? *px : 0; - TYPE dx[TN] = check ? *pdx : 0; - float Fdx[TN] = dx; - float Fx[TN] = x; - float Fxdx[TN] = Fdx * Fx; - float Fxdxsum = Fxdx[+]; - float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale; - TYPE y[TN] = Fy; - // write-back - *? (check)pdx = y; -} diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 2b0d904fa..55d86bbc0 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -2,24 +2,118 @@ import triton import torch import os -fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) -fwd_kernels = dict() -bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward']) -bwd_kernels = dict() +def next_power_of_2(n): + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + + +def num_warps(n): + if n < 512: + return 4 + if n < 2048: + return 8 + return 16 + + +@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])}) +@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])}) +@triton.jit +def _forward( + X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + **meta +): + TN = meta['TN'] + BLOCK = meta['BLOCK'] + pidhm = triton.program_id(0) + pidz = triton.program_id(1) + # create index ranges + rxm = pidhm % BLOCK + rbm = pidhm // BLOCK + rxn = triton.arange(0, TN) % BLOCK + rbn = triton.arange(0, TN) // BLOCK + # extract information from LUT + header = LUT + rbm * 2 + size = triton.load(header + 0) + offset = triton.load(header + 1) + check = rbn < size + rbmn = triton.where(check, rbn, size - 1) + # block id and column id + blockid = triton.load(LUT + offset + rbmn * 4 + 0) + columnid = triton.load(LUT + offset + rbmn * 4 + 1) + rowid = triton.load(LUT + offset + rbmn * 4 + 2) + headid = triton.load(LUT + offset + rbmn * 4 + 3) + # pointers to X + px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn + x = triton.load(px, mask=check, other=-float('inf')) + x = x.to(triton.float32) + # apply scale + if meta['APPLY_SCALE']: + x = x * scale + # apply RPE + if meta['APPLY_RPE']: + prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn + rpe = triton.load(prpe, mask=check, other=0) + x = x + rpe + # apply key-padding mask + if meta['APPLY_KP_MASK']: + pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn + kp_m = triton.load(pkp_m, mask=check, other=-float('inf')) + if meta['KP_MASK_MUL']: + kp_m = triton.where(kp_m == 0, -float('inf'), 0.) + x = x + kp_m + # apply attention mask + if meta['APPLY_ATTN_MASK']: + pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn + attn_m = triton.load(pattn_m, mask=check, other=-float('inf')) + if meta['ATTN_MASK_MUL']: + attn_m = triton.where(attn_m == 0, -float('inf'), 0.) + x = x + attn_m + # computation + x = triton.softmax(x) + triton.store(px, x, mask=check) + + +@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) +@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']}) +@triton.jit +def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): + pidhm = triton.program_id(0) + pidz = triton.program_id(1) + TN = meta['TN'] + BLOCK = meta['BLOCK'] + # create index ranges + rxm = pidhm % BLOCK + rbm = pidhm // BLOCK + rxn = triton.arange(0, TN) % BLOCK + rbn = triton.arange(0, TN) // BLOCK + # extract information from look-up table + header = LUT + rbm * 2 + size = triton.load(header + 0) + offset = triton.load(header + 1) + # bounds checking on lut + check = rbn < size + rbmn = triton.where(check, rbn, size - 1) + # initialize pointers to block-sparse input + blockid = triton.load(LUT + offset + rbmn * 4) + X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn + DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn + # compute fused softmax backward + x = triton.load(X, mask=check, other=0) + dx = triton.load(DX, mask=check, other=0) + x = x.to(triton.float32) + dx = dx.to(triton.float32) + y = x * (dx - triton.sum(x * dx, 0)) * scale + triton.store(DX, y, mask=check) + class _softmax(torch.autograd.Function): - @staticmethod - def next_power_of_2(n): - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) @@ -43,40 +137,9 @@ class _softmax(torch.autograd.Function): return lut, int(sizes.max()) @staticmethod - def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, - kp_mask_mode, attn_mask_mode): - if max_k >= 32768: - raise NotImplementedError('Reductions larger than 32768 elements '\ - 'are not yet implemented') - num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) - TN = _softmax.next_power_of_2(max_k) - # just-in-time compile kernel - key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, - kp_mask_mode, attn_mask_mode) - if key not in cache: - defines = { - 'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY': - {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype] - } - if apply_scale: - defines['APPLY_SCALE'] = True - if apply_rpe: - defines['APPLY_RPE'] = True - if apply_kp_mask: - defines['APPLY_KP_MASK'] = True - if kp_mask_mode == 'mul': - defines['KP_MASK_MUL'] = True - if apply_attn_mask: - defines['APPLY_ATTN_MASK'] = True - if attn_mask_mode == 'mul': - defines['ATTN_MASK_MUL'] = True - kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps) - cache[key] = kernel - return cache[key] - - @staticmethod - def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, - maxlut, bench, time): + def forward( + ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time + ): apply_scale = False if scale == 1.0 else True # handle None rpe @@ -107,26 +170,20 @@ class _softmax(torch.autograd.Function): stride_zattnm = attn_mask.stride(0) # run kernel - kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale, - apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) M = x.shape[0] + meta = { + 'BLOCK': block, + 'APPLY_SCALE': apply_scale, + 'APPLY_RPE': apply_rpe, + 'APPLY_KP_MASK': apply_kp_mask, + 'APPLY_ATTN_MASK': apply_attn_mask, + 'KP_MASK_MUL': kp_mask_mode == 'mul', + 'ATTN_MASK_MUL': attn_mask_mode == 'mul', + } grid = lambda opt: [spdims[0] * spdims[1] * block, M] + _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ + stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta) - # run kernel - kernel(x.data_ptr(), - scale, - lut.data_ptr(), - rpe.data_ptr(), - key_padding_mask.data_ptr(), - attn_mask.data_ptr(), - maxlut, - x.stride(0), - stride_zrpe, - stride_hrpe, - stride_srpe, - stride_zkpm, - stride_zattnm, - grid=grid) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -147,14 +204,12 @@ class _softmax(torch.autograd.Function): # retrieve from context x, lut = ctx.saved_tensors # run kernel - kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block, - ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, - ctx.kp_mask_mode, ctx.attn_mask_mode) M = x.shape[0] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid) + _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None + class softmax: apply_softmax = _softmax.apply @@ -172,14 +227,9 @@ class softmax: self.bench = bench self.lut_cache = dict() - def __call__(self, - x, - scale=1., - rpe=None, - key_padding_mask=None, - attn_mask=None, - key_padding_mask_mode='add', - attn_mask_mode='add'): + def __call__( + self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add' + ): time_y = [None] if rpe is not None and rpe.dtype != x.dtype: raise ValueError('relative position embedding must be %s' % x.dtype) @@ -188,6 +238,8 @@ class softmax: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: raise ValueError('Key padding mask must be %s' % x.dtype) lut, maxlut = self.make_lut(x.device) - x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, - self.spdims, self.block, lut, maxlut, self.bench, time_y) + x = softmax.apply_softmax( + x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut, + maxlut, self.bench, time_y + ) return x \ No newline at end of file diff --git a/python/triton/ops/conv.c b/python/triton/ops/conv.c deleted file mode 100644 index 6a8877895..000000000 --- a/python/triton/ops/conv.c +++ /dev/null @@ -1,123 +0,0 @@ -__global__ void conv(TYPE *A __noalias __readonly, - TYPE *B __noalias __readonly, - TYPE *C __noalias, - float alpha, - // equivalent matmul - int M, int N, int K, - // convolution properties - int pad_h, int pad_w, int stride_h, int stride_w, - // pointer increment - int *ADELTA, - // memory strides - int lda_z, int lda_ci, int lda_h, int lda_w, - int ldb_ci, int ldb_r, int ldb_s, int ldb_co, - int ldc_z, int ldc_co, int ldc_p, int ldc_q) -{ - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int gridx = M / TM; - int gridy = N / TN; - int rid = ridx + ridy * gridx; - ridx = rid / gridy; - ridy = rid % gridy; - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - // reduction splitting - K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; - - // unpack aggregate rows - // m = (z, p, q) - int rq[TM] = rm % QQ; - int rzp[TM] = rm / QQ; - int rp[TM] = rzp % PP; - int rz[TM] = rzp / PP; - // unpack aggregate reduction - // k = (ci, r, s) - int rs[TK] = rk % SS; - int rcir[TK] = rk / SS; - int rr[TK] = rcir % RR; - int rci[TK] = rcir / RR; - - // padding / striding - int rh_0[TM] = rp * stride_h - pad_h; - int rw_0[TM] = rq * stride_w - pad_w; - int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :]; - int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :]; - - // pointers to lhs - int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci + - rh * lda_h + rw * 1; - TYPE *pa[TM, TK] = A + offa; - int *padelta[TK] = ADELTA + rk; - // pointers to rhs - int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r + - rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1; - TYPE *pb[TK, TN] = B + offb; - - // prefetches operands - bool checkam[TM, TK] = rm[:, newaxis] < M; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - int total = 0; - - // reduction loop - float acc[TM, TN] = 0; - for (int k = K; k > 0; k -= TK) - { - acc += a @b; - // increment A - int adelta[TK] = *padelta; - padelta += TK; - pa += adelta [newaxis, :]; - // bounds-checking A - rk += TK; - rs = rk % SS; - rcir = rk / SS; - rr = rcir % RR; - rh = rh_0[:, newaxis] + rr [newaxis, :]; - rw = rw_0[:, newaxis] + rs [newaxis, :]; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - // increment B - pb += TK * ldb_s; - // bounds-checking B - bool checkb[TK, TN] = k > TK; - a = checka ? *pa : 0; - b = *? (checkb)pb; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; - - // epilogue - rm = ridx * TM + 0 ... TM; - rn = ridy * TN + 0 ... TN; - rq = rm % QQ; - rzp = rm / QQ; - rp = rzp % PP; - rz = rzp / PP; - int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co + - rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1; - TYPE *pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N; - -#if (TZ == 1) - *? (checkc)pc = c; -#else - // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) - ; - int count = *pcount; - if (count == 0) - *? (checkc)pc = c; - else - *? (checkc)pc = c + *? (checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); -#endif -} \ No newline at end of file diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py deleted file mode 100644 index 1725d9ca0..000000000 --- a/python/triton/ops/conv.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import triton -import os - -class _conv(torch.autograd.Function): - src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c')) - kernel = dict() - - @staticmethod - def unpack(IDX, CI, R, S): - s = IDX % S - cr = IDX // S - r = cr % R - ci = cr // R - return ci, r, s - - @staticmethod - def forward(ctx, a, b, pad, stride): - # create kernel if necessary - dtype = a.dtype - device = a.device - # shapes - Z, CI, H, W = a.shape - _, R, S, CO = b.shape - P = (H + 2 * pad[0] - R) // stride[0] + 1 - Q = (W + 2 * pad[1] - S) // stride[1] + 1 - # compile kernel - if (dtype, device) not in _conv.kernel: - TK = 16 - defines = { - 'TYPE': dtype, - 'TM': 64, - 'TN': 64, - 'TK': TK, - 'TZ': 1, - 'HH': H, - 'WW': W, - 'PP': P, - 'QQ': Q, - 'SS': S, - 'RR': R, - } - idx = torch.arange(CI * R * S) - ci, r, s = _conv.unpack(idx, CI, R, S) - nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) - delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3) - delta = delta.type(torch.int32).cuda() - _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) - delta, kernel = _conv.kernel[dtype] - # allocate output - c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) - # enqueue - kernel( - a.data_ptr(), - b.data_ptr(), - c.data_ptr(), - 1., - Z * P * Q, - CO, - CI * R * S, - pad[0], - pad[1], - stride[0], - stride[1], - delta.data_ptr(), - a.stride(0), - a.stride(1), - a.stride(2), - a.stride(3), - b.stride(0), - b.stride(1), - b.stride(2), - b.stride(3), - c.stride(0), - c.stride(1), - c.stride(2), - c.stride(3), - grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)]) - return c - -conv = _conv.apply \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c deleted file mode 100644 index b906c8a05..000000000 --- a/python/triton/ops/cross_entropy.c +++ /dev/null @@ -1,35 +0,0 @@ -__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) { - int row = get_program_id(0); - - bool check[TILE] = ((0 ... TILE) < n_cols); - int offset[TILE] = row * n_cols + 0 ... TILE; - TYPE *px[TILE] = logit + offset; - TYPE *pmodified[TILE] = modified_logit + offset; - long local_ind = *(indices + row); - - TYPE F16[TILE] = check ? *px : -INFINITY; - float shifted_logit[TILE] = F16 - F16[max]; - float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit; - *? (check)pmodified = neg_logprob; - __debug_barrier(); - *(result + row) = *(modified_logit + (local_ind + n_cols * row)); -} - -__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) { - - int row = get_program_id(0); - // pointer arithmetic - bool check[TILE] = ((0 ... TILE) < n_cols); - int offset[TILE] = row * n_cols + 0 ... TILE; - TYPE *px[TILE] = neg_logprobs + offset; - long local_ind = *(indices + row); - TYPE local_dn = *(dneg_logprobs + row); - // We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] - // and we have -log(p[k]) stored, so this is easy - TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0; - // selected_logit_idx is selected logit index for our token - bool find_one[TILE] = ((0 ... TILE) == local_ind); - intermediate = intermediate - ((TYPE[TILE])find_one); - // multiply by dneg_logprobs - *? (check)px = intermediate * local_dn; -} \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 75ee03e59..e69ad2038 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -2,6 +2,7 @@ import os import triton import torch + def next_power_of_2(n): n -= 1 n |= n >> 1 @@ -12,34 +13,61 @@ def next_power_of_2(n): n += 1 return n -def largest_pow2_divisor(N): - if N % 8 == 0: return 8 - if N % 4 == 0: return 4 - if N % 2 == 0: return 2 - return 1 -def make_kernel(device, dtype, n_cols, cache, name): - rounded = next_power_of_2(n_cols) - div = largest_pow2_divisor(n_cols) - key = (dtype, rounded, div) - if key not in cache: - fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") - src = triton.read(fname, kernel_names=[name]) - infinities = { - torch.float16: "F16_INFINITY", - torch.float32: "F32_INFINITY", - } - defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div} - cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4) - return cache[key] +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 -# forward kernel -fwd_kernels = dict() -make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward") -# backward kernel -bwd_kernels = dict() -make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward") +@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) +@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) +@triton.jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): + BLOCK = meta['BLOCK'] + row = triton.program_id(0) + cols = triton.arange(0, BLOCK) + idx = triton.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = triton.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(triton.float32) + logits = logits - triton.max(logits, 0) + probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits + triton.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + triton.debug_barrier() + # write-back loss + probs = triton.load(READ_PROBS) + triton.store(LOSS + row, probs) + + +@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) +@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) +@triton.jit +def _backward(PROBS, IDX, DPROBS, N, **meta): + BLOCK = meta['BLOCK'] + row = triton.program_id(0) + cols = triton.arange(0, BLOCK) + idx = triton.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -triton.load(PROBS, mask=cols < N, other=float('inf')) + probs = triton.exp(probs.to(triton.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = triton.load(DPROBS + row) + din = (probs - delta) * dout + triton.store(PROBS, din.to(triton.float16), mask=cols < N) + class _cross_entropy(torch.autograd.Function): @classmethod @@ -49,16 +77,11 @@ class _cross_entropy(torch.autograd.Function): # make kernel device, dtype = logits.device, logits.dtype n_cols = logits.shape[-1] - kernel = make_fwd_kernel(device, dtype, n_cols) # run the kernel result = torch.empty_like(indices, dtype=dtype, device=device) neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) - kernel(logits.data_ptr(), - neg_logprobs.data_ptr(), - indices.data_ptr(), - result.data_ptr(), - n_cols, - grid=lambda opt: (logits.numel() // n_cols, )) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) # save for backward ctx.save_for_backward(neg_logprobs, indices) return result @@ -75,14 +98,11 @@ class _cross_entropy(torch.autograd.Function): # make kernel device, dtype = neg_logprobs.device, neg_logprobs.dtype n_cols = neg_logprobs.shape[-1] - kernel = make_bwd_kernel(device, dtype, n_cols) # run the kernel # neg_logprobs will be modified in place to become our gradient: - kernel(neg_logprobs.data_ptr(), - indices.data_ptr(), - dneg_logprobs.data_ptr(), - n_cols, - grid=lambda opt: (neg_logprobs.numel() // n_cols, )) + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) return neg_logprobs, None + cross_entropy = _cross_entropy.apply \ No newline at end of file diff --git a/python/triton/ops/matmul.c b/python/triton/ops/matmul.c deleted file mode 100644 index 92ea9f2ad..000000000 --- a/python/triton/ops/matmul.c +++ /dev/null @@ -1,94 +0,0 @@ -#define STM 1 -#define STN 1 - -__global__ void matmul(TYPE *A __noalias __readonly, - TYPE *B __noalias __readonly, - TYPE *C __noalias, - float alpha, - int M, int N, int K, - int lda, int ldb, int ldc, - int *locks) { - // prologue - int pid = get_program_id(0); - int pidz = get_program_id(2); - int gridm = (M + TM - 1) / TM; - int gridn = (N + TN - 1) / TN; - - // swizzle for better L2 performance - int width = STM * gridn; - int stm = pid / width; - int RSTM = min(gridm - stm * STM, STM); - int stn = (pid % width) / (RSTM * STN); - int RSTN = min(gridn - stn * STN, STN); - int laneid = pid % (RSTM * RSTN); - int lanem = laneid / RSTN; - int lanen = laneid % RSTN; - int pidm = stm * STM + lanem; - int pidn = stn * STN + lanen; - int rm[TM] = pidm * TM + 0 ... TM; - int rn[TN] = pidn * TN + 0 ... TN; - - // split-k for better parrallelism - K = K / SPLITK; - int rk[TK] = 0 ... TK; - // pointers to operands - int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; - int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN; - TYPE *pa[TM, TK] = A + offa; - TYPE *pb[TK, TN] = B + offb; - - // prefetches operands - bool checka[TM, TK] = rk [newaxis, :] < K; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - - // reduction loop - float acc[TM, TN] = 0; - for (int k = K; k > 0; k -= TK) { -#if (IS_TK_DIV_K == 1) - bool checkk[TK] = k > TK; -#else - bool checkk[TK] = rk < k - TK; -#endif - bool checka[TM, TK] = checkk [newaxis, :]; - bool checkb[TK, TN] = checkk[:, newaxis]; - acc += a @b; -#if (IS_TK_DIV_K == 1) - a = *? (checka)pa; - b = *? (checkb)pb; -#else - a = checka ? *pa : 0; - b = checkb ? *pb : 0; -#endif - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; - - // epilogue - int rcm[TM] = pidm * TM + 0 ... TM; - int rcn[TN] = pidn * TN + 0 ... TN; - int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :]; - TYPE *pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N; -#if (SPLITK == 1) - *? (checkc)pc = c; -#else - // accumulate partial result using spin-locks - int *plock = locks + pid; - int *pcount = plock + get_num_programs(0); - for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) - ; - int count = *pcount; - if (count == 0) - *? (checkc)pc = c; - else - *? (checkc)pc = c + *? (checkc)pc; - atomic_xchg(pcount, (count + 1) % SPLITK); - atomic_xchg(plock, 0); -#endif -} \ No newline at end of file diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 29c7a8ab2..9af671a88 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,108 +1,117 @@ import torch import triton -import os + + +@triton.heuristics({ + 'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0, +}) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4), + # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ + # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ + # triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ + # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\ + # triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2), + ], + key=['M', 'N', 'K'] +) +@triton.jit +def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META): + # extract meta-parameters + BLOCK_M = META['BLOCK_M'] + BLOCK_N = META['BLOCK_N'] + BLOCK_K = META['BLOCK_K'] + GROUP_M = META['GROUP_M'] + SPLIT_K = META['SPLIT_K'] + # matrix multiplication + pid = triton.program_id(0) + pid_z = triton.program_id(1) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + rk = triton.arange(0, BLOCK_K) + # pointers + K = K // SPLIT_K + A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) + acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) + for k in range(K, 0, -BLOCK_K): + if META['EVEN_K']: + a = triton.load(A) + b = triton.load(B) + else: + a = triton.load(A, mask=rk[None, :] < k, other=0.) + b = triton.load(B, mask=rk[:, None] < k, other=0.) + acc += triton.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(triton.float16) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + triton.store(C, acc, mask=mask) + else: + LOCKS = LOCKS + triton.program_id(0) + COUNT = LOCKS + triton.num_programs(0) + while triton.atomic_cas(LOCKS, 0, 1) == 1: + pass + count = triton.load(COUNT) + if count == 0: + triton.store(C, acc, mask=mask) + else: + curr = triton.load(C, mask=mask, other=0.) + triton.store(C, acc + curr, mask=mask) + triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K) + triton.atomic_xchg(LOCKS, 0) class _matmul(torch.autograd.Function): - src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c")) - - _DEFAULT_CONFIGS = [ - triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4), - triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4), - triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4), - triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4), - triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4), - triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4), - triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2), - triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2), - triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), - triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), - triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), - triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), - ] - _CONFIGS = _DEFAULT_CONFIGS - - @staticmethod - def largest_pow2_divisor(N): - if N % 8 == 0: - return 8 - if N % 4 == 0: - return 4 - if N % 2 == 0: - return 2 - return 1 + kernel = _kernel _locks = dict() - _kernels = dict() @staticmethod def _call(a, b): - dtype = a.dtype device = a.device - # allocate output - M, K = a.shape - K, N = b.shape - c = torch.empty((M, N), dtype=dtype, device=device) # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() - # kernel hash - is_a_row = a.stride(1) == 1 - is_b_row = b.stride(1) == 1 - lda = a.stride(0) if is_a_row else a.stride(1) - ldb = b.stride(0) if is_b_row else b.stride(1) - ldc = c.stride(0) - lda_pow2_div = _matmul.largest_pow2_divisor(lda) - ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) - ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) - is_tk_div_k = K % 64 == 0 - key = ( - device, - dtype, - is_a_row, - is_b_row, - lda_pow2_div, - ldb_pow2_div, - ldc_pow2_div, - is_tk_div_k, - ) - if key not in _matmul._kernels: - defines = { - "TYPE": dtype, - "STRIDE_AM": "lda" if is_a_row else "1", - "STRIDE_AK": "1" if is_a_row else "lda", - "STRIDE_BK": "ldb" if is_b_row else "1", - "STRIDE_BN": "1" if is_b_row else "ldb", - "LDA_POW2_DIV": lda_pow2_div, - "LDB_POW2_DIV": ldb_pow2_div, - "LDC_POW2_DIV": ldc_pow2_div, - "IS_TK_DIV_K": int(is_tk_div_k), - } - _matmul._kernels[key] = triton.kernel( - _matmul.src, - device, - defines=defines, - autotune_configs=_matmul._CONFIGS, - autotune_key=["M", "N", "K"], - ) - kernel = _matmul._kernels[key] - # # locks for split-k - if device not in _matmul._locks: + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=a.dtype) + # allocate locks for split-k + if a.device not in _matmul._locks: _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] - # enqueue - alpha = 1.0 - args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] - grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK] - kernel(*args, grid=grid) + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks) + # done return c @staticmethod def forward(ctx, a, b): - c = _matmul._call(a, b) - return c + return _matmul._call(a, b) matmul = _matmul.apply diff --git a/python/triton/testing.py b/python/triton/testing.py index ad57fa7f3..317730781 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -47,13 +47,37 @@ def mask_tensor(x, mask, block, value=0): def allclose(x, y, tol=1e-2): - assert x.dtype == y.dtype + if x.dtype != y.dtype: + raise RuntimeError(f'{x.dtype} did not match with {x.dtype}') + if x.shape != y.shape: + raise RuntimeError(f'{x.shape} did not match with {y.shape}') + if x.dtype == torch.bool: + return torch.sum(x ^ y) == 0 + if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + tol = 0 diff = abs(x - y) x_max = torch.max(x) y_max = torch.max(y) tol = 1e-2 err = torch.max(diff) / torch.max(x_max, y_max) - return err < tol + return err <= tol + + +def assert_allclose(x, y, tol=1e-2): + assert x.dtype == y.dtype + assert allclose(x, y, tol) + + +def random(shape, dtype, device): + if isinstance(shape, int): + shape = (shape, ) + if dtype == torch.bool: + return torch.randint(0, 2, shape, dtype=dtype, device=device) + if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + return torch.randint(1, 32, shape, dtype=dtype, device=device) + if dtype in [torch.float16, torch.float32, torch.float64]: + return torch.randn(shape, dtype=dtype, device=device) + raise RuntimeError(f'Unknown dtype {dtype}') def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]): diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 923c89038..b1c82a34a 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -1,139 +1,71 @@ +import torch +import triton """ Vector Addition ================= In this tutorial, you will write a simple vector addition using Triton and learn about: -- The basic syntax of the Triton programming language -- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API +- The basic programming model used by Triton +- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels. - The best practices for validating and benchmarking custom ops against native reference implementations """ # %% # Compute Kernel # -------------------------- -# -# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel -# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_) -# programming model for more details). -# -# .. code-block:: C -# -# __global__ void add(float* z, float* x, float* y, int N){ -# // The `get_program_id(i)` returns the i-th coordinate -# // of the program in the overaching SPMD context -# // (a.k.a launch grid). This is what allows us to process -# // different chunks of data in parallel. -# // For those similar with CUDA, `get_program_id({0,1,2})` -# // is similar to blockIdx.{x,y,z} -# int pid = get_program_id(0); -# // In Triton, arrays are first-class citizen. In other words, -# // they are primitives data-types and are -- contrary to C and -# // CUDA -- not implemented as pointers to contiguous chunks of -# // memory. -# // In the few lines below, we create an array of `BLOCK` pointers -# // whose memory values are, e.g.: -# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1] -# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time -# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; -# float* pz [BLOCK] = z + offset; -# float* px [BLOCK] = x + offset; -# float* py [BLOCK] = y + offset; -# // Simple element-wise control-flow for load/store operations can -# // be achieved using the the ternary operator `cond ? val_true : val_false` -# // or the conditional dereferencing operator `*?(cond)ptr -# // Here, we make sure that we do not access memory out-of-bounds when we -# // write-back `z` -# bool check[BLOCK] = offset < N; -# *?(check)pz = *?(check)px + *?(check)py; -# } -# -# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper `_. + + +@triton.jit +def _add( + X, # *Pointer* to first input vector + Y, # *Pointer* to second input vector + Z, # *Pointer* to output vector + N, # Size of the vector + **meta # Optional meta-parameters for the kernel +): + pid = triton.program_id(0) + # Create an offset for the blocks of pointers to be + # processed by this program instance + offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK']) + # Create a mask to guard memory operations against + # out-of-bounds accesses + mask = offsets < N + # Load x + x = triton.load(X + offsets, mask=mask) + y = triton.load(Y + offsets, mask=mask) + # Write back x + y + z = x + y + triton.store(Z + offsets, z) + # %% -# Torch Bindings -# -------------------------- -# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things: -# -# - :code:`source: string`: the source-code of the kernel you want to create -# - :code:`device: torch.device`: the device you want to compile this code for -# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you - -import torch -import triton - -# source-code for Triton compute kernel -# here we just copy-paste the above code without the extensive comments. -# you may prefer to store it in a .c file and load it from there instead. -_src = """ -__global__ void add(float* z, float* x, float* y, int N){ - // program id - int pid = get_program_id(0); - // create arrays of pointers - int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; - float* pz[BLOCK] = z + offset; - float* px[BLOCK] = x + offset; - float* py[BLOCK] = y + offset; - // bounds checking - bool check[BLOCK] = offset < N; - // write-back - *?(check)pz = *?(check)px + *?(check)py; -} - """ +# We can also declara a helper function that handles allocating the output vector +# and enqueueing the kernel. -# This function returns a callable `triton.kernel` object created from the above source code. -# For portability, we maintain a cache of kernels for different `torch.device` -# We compile the kernel with -DBLOCK=1024 -def make_add_kernel(device): - cache = make_add_kernel.cache - if device not in cache: - defines = {'BLOCK': 1024} - cache[device] = triton.kernel(_src, device=device, defines=defines) - return cache[device] +def add(x, y): + z = torch.empty_like(x) + N = z.shape[0] + # The SPMD launch grid denotes the number of kernel instances that should execute in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] + grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), ) + # NOTE: + # - torch.tensor objects are implicitly converted to pointers to their first element. + # - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel + # - don't forget to pass meta-parameters as keywords arguments + _add[grid](x, y, z, N, BLOCK=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously. + return z -make_add_kernel.cache = dict() - - -# This is a standard torch custom autograd Function; -# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.` -class _add(torch.autograd.Function): - @staticmethod - def forward(ctx, x, y): - # constraints of the op - assert x.dtype == torch.float32 - # *allocate output* - z = torch.empty_like(x) - # *create launch grid*: - # this is a function which takes compilation parameters `opt` - # as input and returns a tuple of int (i.e., launch grid) for the kernel. - # triton.cdiv is a shortcut for ceil division: - # triton.cdiv(a, b) = (a + b - 1) // b - N = z.shape[0] - grid = lambda opt: (triton.cdiv(N, opt.BLOCK), ) - # *launch kernel*: - # pointer to the data of torch tensors can be retrieved with - # the `.data_ptr()` method - kernel = make_add_kernel(z.device) - kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid) - return z - - -# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function -add = _add.apply - # %% -# We can now use the above function to compute the sum of two `torch.tensor` objects: - -# %% -# Unit Test -# ----------- -# -# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below: +# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results: torch.manual_seed(0) -x = torch.rand(98432, device='cuda') -y = torch.rand(98432, device='cuda') +size = 98432 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') za = x + y zb = add(x, y) print(za) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index c1206710d..3c5d674c2 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -4,8 +4,7 @@ Fused Softmax In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about: - The benefits of kernel fusion for bandwidth-bound operations. -- The syntax and usage of reduction operators in Triton. -- The automatic vectorization capabilities of the Triton compiler. +- The reduction operators in Triton. """ # %% @@ -36,79 +35,45 @@ def naive_softmax(x): # %% # When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. # This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. -# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). +# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). # In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory. # %% # Compute Kernel # ---------------- -# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y. +# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y. # Note that one important limitation of Triton is that each block must have a power-of-two number of elements, # so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes: -# -# .. code-block:: C -# -# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){ -# // row index -# int m = get_program_id(0); -# // column indices -# int n [BLOCK] = 0 ... BLOCK; -# // the memory address of all the elements -# // that we want to load can be computed as follows -# float* px [BLOCK] = X + m*stride_xm + n; -# // because BLOCK has to be a power of two -# // (per Triton-C specs), it is important -# // to guard each memory operation with predicates -# // or we will read out of bounds -# bool check[BLOCK] = n < N; -# float x [BLOCK] = check ? *px : -F32_INFINITY; -# // syntax for reduction in Triton is: -# // x[:, :, OPERATOR, :, :] -# // ^ -# // index -# // where operator is in {min, max, +} -# // for 1D vectors, this is just x[OPERATOR]. -# float z [BLOCK] = x - x[max]; -# // Note that exponentials in Triton are fast -# // but approximate (i.e., think __expf in CUDA) -# float num [BLOCK] = exp(z); -# float denom = num[+]; -# // The result of the reduction is now stored in y -# float y [BLOCK] = num / denom; -# // We write it back -# float* py [BLOCK] = Y + m*stride_ym + n; -# *?(check)py = y; -# } -# %% -# Torch Bindings -# --------------- -# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial. -# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix. -# This means that different values of BLOCK will result in different kernels - -import torch import triton -# Source code for the Triton kernel -_src = """ -__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){ - int m = get_program_id(0); - int n [BLOCK] = 0 ... BLOCK; - float* px [BLOCK] = X + m*stride_xm + n; - bool check[BLOCK] = n < N; - float x [BLOCK] = check ? *px : -F32_INFINITY; - float z [BLOCK] = x - x[max]; - float num [BLOCK] = exp(z); - float denom = num[+]; - float y [BLOCK] = num / denom; - float* py [BLOCK] = Y + m*stride_ym + n; - *?(check)py = y; -} -""" + +@triton.jit +def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta): + # row index + m = triton.program_id(0) + # col indices + n = triton.arange(0, meta['BLOCK']) + # the memory address of all the elements + # that we want to load can be computed as follows + X = X + m * stride_xm + n + x = triton.load(X, mask=n < N, other=-float('inf')) + # Substract maximum for numerical stability + z = x - triton.max(x, axis=0) + # Note that exponentials in Triton are fast + # but approximate (i.e., think __expf in CUDA) + num = triton.exp(z) + denom = triton.sum(num, axis=0) + y = num / denom + # Write back to Y + Y = Y + m * stride_ym + n + triton.store(Y, y, mask=n < N) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -# helper function to get the smaller power-of-two larger than a given number def next_power_of_2(n): n -= 1 n |= n >> 1 @@ -120,11 +85,9 @@ def next_power_of_2(n): return n -# kernel caching mechanism -def make_kernel(N, device): - cache = make_kernel.cache - # Now are kernels are indexed not only by the provided device but also - # by the rounded number of columns in the input matrix +def softmax(x): + M, N = x.shape + # The block size is the smallest power of two greater than the number of columns in `x` BLOCK = next_power_of_2(N) # Another trick we can use is to ask the compiler to parallelize each # row-normalization more aggressively -- i.e., with more warps -- vectors @@ -134,37 +97,13 @@ def make_kernel(N, device): num_warps = 4 if BLOCK >= 2048: num_warps = 8 if BLOCK >= 4096: num_warps = 16 - # Each (BLOCK, num_warps, device) results in a different kernel - key = (BLOCK, num_warps, device) - if key not in cache: - defines = {'BLOCK': BLOCK} - cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps) - return cache[key] + # Allocate output + y = torch.empty_like(x) + # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix + _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK) + return y -make_kernel.cache = dict() - - -class _softmax(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - # constraints of the op - assert x.dtype == torch.float32 - y = torch.empty_like(x) - # The launch grid is simple: we have one kernel instance per row of the input matrix - M, N = y.shape - grid = lambda opt: (M, ) - # Launch kernel - kernel = make_kernel(N, y.device) - kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid) - return y - - -softmax = _softmax.apply - -# %% -# We can use the above softmax function to compute the row-wise softmax of a given matrix. - # %% # Unit Test # ---------- diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 10534e874..8f65a91b4 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -1,10 +1,10 @@ """ Matrix Multiplication ====================== -In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance. +In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs. You will specifically learn about: -- The block-level matrix multiplication operator `@` +- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-ordering for improved L2 cache hit rate - Automatic performance tuning @@ -15,7 +15,7 @@ You will specifically learn about: # ------------- # Matrix multiplications are a key building block of most modern high-performance computing systems. # They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). -# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.). +# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.). # For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend. # # Roughly speaking, the kernel that we will write will implement the following blocked algorithm: @@ -23,322 +23,212 @@ You will specifically learn about: # .. code-block:: python # # # do in parallel -# for m in range(0, M, MB): +# for m in range(0, M, BLOCK_M): # # do in parallel -# for n in range(0, N, NB): -# acc = zeros((MB, NB), dtype=float32) -# for k in range(0, K, KB): -# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB] -# C[m : m+MB, n : n+NB] = acc; +# for n in range(0, N, BLOCK_N): +# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32) +# for k in range(0, K, BLOCK_K): +# a = A[m : m+BLOCK_M, k : k+BLOCK_K] +# b = B[k : k+BLOCK_K, n : n+BLOCK_N] +# acc += dot(a, b) +# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc; # -# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance. +# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance. # %% # Compute Kernel # ---------------- # -# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication. -# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop. +# The above algorithm is actually fairly straightforward to implement in Triton. +# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop. # # Pointer Arithmetics # ~~~~~~~~~~~~~~~~~~~~ # -# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`. -# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as: +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as: # # .. code-block:: python # -# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :]; -# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :]; +# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]; +# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]; # # Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as: # -# .. code-block:: C +# .. code-block:: python # :force: -# -# int rm[MB] = program_id_m * MB + 0 ... MB; -# int rn[NB] = program_id_n * NB + 0 ... NB; -# int rk[KB] = 0 ... KB; -# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1); -# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1); +# pid_m = triton.program_id(0) +# pid_n = triton.program_id(1) +# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) +# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) +# rk = triton.arange(0, BLOCK_K) +# // pointer for A operand +# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1); +# // pointer for B operand +# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1); # # These pointers can then be updated in the inner loop as: # -# .. code-block:: C +# .. code-block:: python # -# pa += KB * 1; -# pb += KB * ldb; +# pa += BLOCK_K * stride_a_1; +# pb += BLOCK_K * stride_b_0; # # # L2 Cache Optimizations # ~~~~~~~~~~~~~~~~~~~~~~~~ # -# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`. +# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`. # However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program. # This means that a naive row-major ordering: # -# .. code-block:: C +# .. code-block:: Python # -# int program_id = get_program_id(0); -# int grid_m = (M + MB - 1) / MB; -# int grid_n = (N + NB - 1) / NB; -# int program_id_m = program_id / grid_n; -# int program_id_n = program_id % grid_n; +# pid = triton.program_id(0); +# grid_m = (M + BLOCK_M - 1) / BLOCK_M; +# grid_n = (N + BLOCK_N - 1) / BLOCK_N; +# pid_m = pid / grid_n; +# pid_n = pid % grid_n; # # is unlikely to result in optimal performance. # # One possible solution is to launch blocks in an order that promotes data reuse. -# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column: +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column: # # .. code-block:: C # -# int program_id = get_program_id(0); -# int width = GROUP_SIZE * grid_n; -# int group_id = pid / width; -# // we need to handle the case where M % (GROUP_SIZE*BM) != 0 -# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE); -# int pid_m = group_id * GROUP_SIZE + (pid % group_size); -# int pid_n = (pid % width) / (group_size); +# pid = triton.program_id(0); +# width = GROUP_M * grid_n; +# group_id = pid / width; +# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0 +# group_size = min(grid_m - group_id * GROUP_M, GROUP_M); +# pid_m = group_id * GROUP_M + (pid % group_size); +# pid_n = (pid % width) / (group_size); # # In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). # -# Final Result -# ~~~~~~~~~~~~~~ -# -# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication. -# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure. -# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler. -# -# .. code-block:: C -# :force: -# -# #define MAX_GROUP_SIZE 8 -# -# __global__ void dot(TYPE* A, TYPE* B, TYPE* C, -# int M, int N, int K, -# int stride_a_0, int stride_b_0, int stride_c_0) { -# // prologue -# int pid = get_program_id(0); -# int grid_m = (M + MB - 1) / MB; -# int grid_n = (N + NB - 1) / NB; -# // re-order program ID for better L2 performance -# int width = MAX_GROUP_SIZE * grid_n; -# int group_id = pid / width; -# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE); -# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size); -# int pid_n = (pid % width) / (group_size); -# // pointers to operands -# // note the parentheses here; they force the offset -# // computation to happen in typeof(stride_a_0) = int32 rather than -# // typeof(A) = int64 -# int rm[MB] = pid_m * MB + 0 ... MB; -# int rn[NB] = pid_n * NB + 0 ... NB; -# int rk[KB] = 0 ... KB; -# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0); -# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1); -# // reduction loop -# float acc[MB, NB] = 0; -# for (int k = K; k > 0; k -= KB) { -# acc += (*pa) @ (*pb); -# pa += KB * 1; -# pb += KB * stride_b_0; -# } -# // pointers to output -# // here we rematerialize `rm` and `rn` so that they are not live through -# // the above reduction loop. In the future, the compiler should be able to -# // do this automatically. -# rm = pid_m * MB + 0 ... MB; -# rn = pid_n * NB + 0 ... NB; -# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]); -# // we write back using *?() operator. `acc` gets casted to `float32` implicitly. -# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc; -# } -# -# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code. -# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial. -# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16. -# # %% -# Torch Bindings -# ---------------- +# Final Result +# ------------- # -# Auto-Tuning -# ~~~~~~~~~~~~~~ -# -# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows: import torch import triton -autotune_configs = [ - triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4), - triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4), - triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4), - triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4), - triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4), - triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4), - triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2), - triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2) -] +# % +# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try +# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs + + +@triton.jit +def sigmoid(x): + ret_true = 1 / (1 + triton.exp(-x)) + ret_false = triton.exp(x) / (1 + triton.exp(x)) + return triton.where(x >= 0, ret_true, ret_false) + + +@triton.jit +def swish(x): + return x * sigmoid(x) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), + ], + key=['M', 'N', 'K'], +) +# % +# We can now define our kernel as normal, using all the techniques presented above +@triton.jit +def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META): + # extract meta-parameters + BLOCK_M = META['BLOCK_M'] + BLOCK_N = META['BLOCK_N'] + BLOCK_K = META['BLOCK_K'] + GROUP_M = 8 + # matrix multiplication + pid = triton.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + rk = triton.arange(0, BLOCK_K) + A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) + for k in range(K, 0, -BLOCK_K): + a = triton.load(A) + b = triton.load(B) + acc += triton.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + # triton can accept arbitrary activation function + # via metaparameters! + if META['ACTIVATION']: + acc = META['ACTIVATION'](acc) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm[:, None] < M) & (rn[None, :] < N) + triton.store(C, acc, mask=mask) + # %% -# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in. -# Here, we want to re-tune our kernel only when the shape of input matrices changes. - -autotune_key = ["M", "N", "K"] - -# %% -# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class. - -src = """ -#define MAX_GROUP_SIZE 8 - -__global__ void dot(TYPE* A, TYPE* B, TYPE* C, - int M, int N, int K, - int lda, int ldb, int ldc) { - int pid = get_program_id(0); - int grid_m = (M + MB - 1) / MB; - int grid_n = (N + NB - 1) / NB; - int width = MAX_GROUP_SIZE * grid_n; - int group_id = pid / width; - int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE); - int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size); - int pid_n = (pid % width) / (group_size); - int rm[MB] = pid_m * MB + 0 ... MB; - int rn[NB] = pid_n * NB + 0 ... NB; - int rk[KB] = 0 ... KB; - TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda); - TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1); - float acc[MB, NB] = 0; - for (int k = K; k > 0; k -= KB) { - acc += (*pa) @ (*pb); - pa += KB * 1; - pb += KB * ldb; - } - rm = pid_m * MB + 0 ... MB; - rn = pid_n * NB + 0 ... NB; - TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]); - *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc; -} -""" +# We can also create a convenience wrapper function that only takes two input tensors +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel -def make_kernel(device, dtype): - key = (device, dtype) - cache = make_kernel.cache - if key not in cache: - defines = {'TYPE': dtype} - cache[key] = triton.kernel( - src, - device=device, - defines=defines, - autotune_configs=autotune_configs, - autotune_key=autotune_key, - ) - return cache[key] +def matmul(a, b, activation=None): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + assert a.is_contiguous(), "matrix A must be contiguous" + assert b.is_contiguous(), "matrix B must be contiguous" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) + _matmul[grid]( + a, b, c, M, N, K, \ + a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\ + ACTIVATION = activation + ) + # return output + return c -make_kernel.cache = dict() - -# %% -# Autograd Function -# ~~~~~~~~~~~~~~~~~~ -# -# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`. -# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output. - - -class _dot(torch.autograd.Function): - @staticmethod - def forward(ctx, a, b): - M, Ka = a.shape - Kb, N = b.shape - assert Ka == Kb, "incompatible dimensions" - assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous" - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - kernel = make_kernel(a.device, a.dtype) - grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), ) - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ - M, N, Ka, \ - a.stride(0), b.stride(0), c.stride(0), \ - grid=grid) - return c - - -dot = _dot.apply - # %% # Unit Test # ----------- # -# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`). -# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors. +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel) -a = torch.rand((512, 768), device='cuda', dtype=torch.float16) -b = torch.rand((768, 896), device='cuda', dtype=torch.float16) -c_0 = dot(a, b) -c_1 = torch.matmul(a, b) +#torch.manual_seed(0) +a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +c_0 = matmul(a, b, activation=swish) +c_1 = torch.nn.SiLU()(torch.matmul(a, b)) print(c_0) print(c_1) -print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3)) +print(triton.testing.allclose(c_0, c_1)) # %% # Benchmark # -------------- # -# Installing The CUTLASS Bindings -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# -# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools. -# For this reason, we will instead compare the performance of our kernel against `CUTLASS `_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._ -# To install CUTLASS, you need a recent version of cmake: -# -# .. code-block:: bash -# -# cd /path/to/cutlass/ -# git clone https://github.com/NVIDIA/cutlass.git -# cd cutlass -# mkdir build -# cd build -# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz -# tar xzvf *.tar.gz -# -# You can then install CUTLASS as follows for V100 -# -# .. code-block:: bash -# -# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8 -# make -j8 install -# -# Or as follows for A100: -# -# .. code-block:: bash -# -# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8 -# make -j8 install -# -# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance. -# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process. -# To re-install Triton with the updated CUTLASS bindings, run the following command: -# -# .. code-block:: bash -# -# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/ -# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/ -# pip uninstall -y triton -# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python" -# -# Which we can test as follows: - -import triton -c_2 = triton.testing.cutlass_matmul(a, b) -print(c_2) -print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3)) - -# %% -# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready. -# # Square Matrix Performance # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.# @@ -347,29 +237,25 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot - x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name` + x_vals=[8192], # different possible values for `x_name` y_name='provider', # argument name whose value corresponds to a different line in the plot - y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name` - y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines + y_vals=['cublas', 'triton'], # possible keys for `y_name` + y_lines=["cuBLAS", "Triton"], # label name for the lines ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={} ) ) def benchmark(M, N, K, provider): + silu = torch.nn.SiLU() a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b)) - if provider == 'cutlass': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) -benchmark.run(show_plots=True) - -# %% -# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write. \ No newline at end of file +benchmark.run(print_data=True) \ No newline at end of file