Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -22,7 +22,7 @@ endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++17")

View File

@@ -35,6 +35,13 @@ extensions = []
# Math Jax
extensions += ['sphinx.ext.mathjax']
# Auto Doc
import sys
import os
sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
autosummary_generate = True
# Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey

View File

@@ -17,15 +17,27 @@ Getting Started
getting-started/installation
getting-started/tutorials/index
Programming Guide
Language Reference
-------------------
- Checkout the :doc:`Python API Documentation <language-reference/python-api/index>`
.. toctree::
:maxdepth: 1
:caption: Language Reference
:hidden:
language-reference/python-api/index
Going Further
------------------
Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs:
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
- Chapter 3: :doc:`The Triton-C Language <programming-guide/chapter-3/triton-c>`
- Chapter 4: :doc:`The Triton-IR Intermediate Representation <programming-guide/chapter-4/triton-ir>`
.. toctree::
:maxdepth: 1
@@ -34,5 +46,3 @@ Check out the following documents to learn more about Triton and how it compares
programming-guide/chapter-1/introduction
programming-guide/chapter-2/related-work
programming-guide/chapter-3/triton-c
programming-guide/chapter-4/triton-ir

View File

@@ -0,0 +1,117 @@
Python API
===========
.. currentmodule:: triton
Programming Model
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
program_id
num_programs
Creation Ops
-------------
.. autosummary::
:toctree: generated
:nosignatures:
arange
zeros
Shape Manipulation Ops
-----------------------
.. autosummary::
:toctree: generated
:nosignatures:
broadcast_to
reshape
ravel
Linear Algebra Ops
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
dot
Memory Ops
--------------------
.. autosummary::
:toctree: generated
:nosignatures:
load
store
atomic_cas
atomic_xchg
Indexing Ops
--------------
.. autosummary::
:toctree: generated
:nosignatures:
where
Math Ops
----------
.. autosummary::
:toctree: generated
:nosignatures:
exp
log
sigmoid
softmax
Reduction Ops
---------------
.. autosummary::
:toctree: generated
:nosignatures:
max
min
sum
Comparison ops
---------------
.. autosummary::
:toctree: generated
:nosignatures:
minimum
maximum
Compiler Hint Ops
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
multiple_of

View File

@@ -1,84 +0,0 @@
=======================
The Triton-C Language
=======================
In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model.
-------------------
Differences with C
-------------------
The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories:
+++++++++++
Extensions
+++++++++++
**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`).
**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`.
**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`).
**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`.
**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below:
.. code-block:: C
:force:
// create mask
bool mask[16, 16] = ...;
// conditional addition
float x[16, 16] = mask ? a + b : 0;
// conditional load
float y[16] 16] = mask ? *ptr : 0;
// conditional store
*?(mask)ptr = y;
\end{lstlisting}
+++++++++++++
Restrictions
+++++++++++++
The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported:
**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments.
**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes.
**While loops**: We just haven't had time to implement those yet.
----------------
Semantics
----------------
The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics.
+++++++++++++++++++++++
Broadcasting Semantics
+++++++++++++++++++++++
Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done.
.. code-block:: C
int a[16], b[32, 16], c[16, 1];
// a is first reshaped to [1, 16]
// and then broadcast to [32, 16]
int x_1[32, 16] = a[newaxis, :] + b;
// Same as above but implicitly
int x_2[32, 16] = a + b;
// a is first reshaped to [1, 16]
// a is broadcast to [16, 16]
// c is broadcast to [16, 16]
int y[16, 16] = a + c;
------------------
Programming Model
------------------
As discussed in the `CUDA documentation <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_, The execution of CUDA code on GPUs is supported by an `SPMD <https://en.wikipedia.org/wiki/SPMD>`_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

View File

@@ -1,82 +0,0 @@
==========================================
The Triton-IR Intermediate Representation
==========================================
Triton-IR is an LLVM-based Intermediate Representation (IR) whose purpose is to provide an environment suitable for block-level program analysis, transformation and optimization.
In our implementation, Triton-IR programs are constructed directly from Triton-C after parsing, but they could also be formed directly by higher-level DSLs in the future.
Triton-IR and LLVM-IR programs share the same high-level structure, but the former also includes a number of extensions necessary for block-level data-flow analysis.
These extensions are crucial for carrying out the optimizations outlined in the next chapter of this document.
---------------------------------
Structure of a Triton-IR Program
---------------------------------
++++++++
Modules
++++++++
At the highest level, Triton-IR programs consist of one or multiple basic units of compilation known as *modules*. These modules are compiled independently from one another, and eventually aggregated by a linker whose role is to resolve forward declarations and adequately merge global definitions. Each module itself is composed of functions, global variables, constants and other miscellaneous symbols such as metadata and attributes.
++++++++++
Functions
++++++++++
Triton-IR function definitions consist of a return type, a name and a potentially empty arguments list. Additional visibility, alignment and linkage specifiers can be added if desired. Function attributes (such as inlining hints) and parameter attributes (such as "readonly", aliasing hints) can also be specified, allowing compiler backends to perform more aggressive optimizations by, for instance, making better use of non-coherent caches found on NVIDIA GPUs. This header is followed by a body composed of a list of basic blocks whose interdependencies form the Control Flow Graph (CFG) of the function.
+++++++++++++
Basic Blocks
+++++++++++++
Basic blocks are straight-line code sequences that may only contain so-called *terminator* instructions (i.e., branching, return) at their end. To simplify program analysis, Triton-IR uses the Static Single Assignment (SSA) form, meaning that each variable in each basic block must be (1) assigned to only once and (2) defined before being used. In so doing, each basic block implicitly defines a Data-Flow Graph (DFG). In our case, the SSA form is created directly from Triton-C's Abstract Syntax Trees (ASTs) using an algorithm from the literature [BRAUN13]_.
---------------------------------
Block-Level Dataflow Analysis
---------------------------------
+++++++
Types
+++++++
Multi-dimensional blocks are at the center of data-flow analysis in Triton-JIT. They can be declared using syntax similar to vector declarations in LLVM-IR. For example, :code:`i32<8, 8>` is the type corresponding to :math:`8 \times 8` blocks of 32-bit integers. Note that there is no preprocessor in Triton-IR, hence parametric shape values must be resolved before programs are generated. In our case, this is done by Triton-JIT's auto-tuner.
+++++++++++++
Instructions
+++++++++++++
Triton-IR introduces a set of *reblocking* instructions whose purpose is to support broadcasting semantics as described in the previous chapter. The :code:`reshape` instruction creates a block of the specified shape using the raw data from its input argument. This is particularly useful to re-interpret variables as higher-dimensional arrays by padding their input shapes with ones in preparation for broadcasting. The :code:`broadcast` instruction creates a block of the specified shapes by replicating its input argument as many times as necessary along dimensions of size 1 -- as shown below for the :code:`broadcast<3,3>` instruction.
|pic1| and |pic2|
.. |pic1| image:: broadcast-1.png
:width: 40%
.. |pic2| image:: broadcast-2.png
:width: 40%
Usual scalar instructions (:code:`cmp`, :code:`getelementptr`, :code:`add`, :code:`load`...) were preserved and extended to signify element-wise operations when applicable. Finally, Triton-IR also exposes specialized arithmetic instructions for reductions (:code:`reduce`) and matrix multiplications (:code:`dot`).
----------------------------------
Block-Level Control Flow Analysis
----------------------------------
In Triton-IR, operations on block variables are atomic: they execute either in full or not at all. As a result, traditional control flow structures (e.g., conditional, loops) are not applicable to individual block elements. This is problematic, since a program may need to e.g., partially guard blocked loads against memory access violations.
This could be potentially solved through the use of the Predicated SSA (PSSA) [CARTER99]_ [STOUTCHININ01]_ form for Triton-IR. However, this would create a lot of unnecessary complexity for GPUs, where the benefits of PSSA are close to none as divergent program paths within warps are serialized anyway. Therefore, recent versions of Triton handle intra-block control flow in a much simpler way, using conditional instructions such as :code:`select`, :code:`masked_load` and :code:`masked_store`:
.. code-block:: C
// For all indices [idx], return cond[idx] ? true_value[idx] : false_value[idx];
select TYPE<TS1, ..., TSN> cond, true_value, false_value;
// For all indices [idx], return cond[idx] ? *true_addr[idx] : false_value[idx];
masked_load TYPE<TS1, ..., TSN> cond, true_addr, false_value;
// For all indices [idx], execute *true_addr[idx] = true_value[idx] if cond[idx]
masked_store TYPE<TS1, ..., TSN> cond, true_addr, true_value;
------------
References
------------
.. [BRAUN13] M. Braun et al., "Simple and Efficient Construction of Static Single Assignment Form", CC 2013
.. [CARTER99] L. Carter et al., "Predicated Static Single Assignment", PACT 1999
.. [STOUTCHININ01] A. Stoutchinin et al., "Efficient Static Single Assignment Form for Predication", MICRO 2001

View File

@@ -1,30 +1,31 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <list>
#include <memory>
namespace triton{
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
class pass {
public:
virtual void run(ir::module& m);
};
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
class pass_manager {
public:
void add(pass* p);
void run(ir::module& m);
private:
std::list<pass*> passes;
};
}
}
#endif

View File

@@ -119,7 +119,7 @@ public:
void visit_exp_inst(ir::exp_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);

View File

@@ -59,7 +59,7 @@ public:
// CUDA
class cu_module: public module {
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
void init_from_ptx(const std::string& ptx);
public:

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
@@ -27,6 +27,8 @@ class builder{
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
@@ -38,6 +40,9 @@ public:
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();
type *get_int1_ty();
@@ -50,11 +55,10 @@ public:
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst, const std::string &name = ""){
InstTy* insert(InstTy *inst){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
inst->set_name(name);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
@@ -64,91 +68,87 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_ui(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
value *create_downcast(value *arg, const std::string &name = "");
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
value* create_fp_to_si(value *src, type *dst_ty);
value* create_fp_to_ui(value *src, type *dst_ty);
value* create_fp_ext(value *src, type *dst_ty);
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
value *create_fadd(value *lhs, value *rhs, const std::string &name = "");
value *create_fsub(value *lhs, value *rhs, const std::string &name = "");
value *create_mul(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_udiv(value *lhs, value *rhs, const std::string &name = "");
value *create_srem(value *lhs, value *rhs, const std::string &name = "");
value *create_urem(value *lhs, value *rhs, const std::string &name = "");
value *create_add(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs);
value *create_fdiv(value *lhs, value *rhs);
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_icmpSLE(value *lhs, value *rhs);
value *create_icmpSLT(value *lhs, value *rhs);
value *create_icmpSGE(value *lhs, value *rhs);
value *create_icmpSGT(value *lhs, value *rhs);
value *create_icmpULE(value *lhs, value *rhs);
value *create_icmpULT(value *lhs, value *rhs);
value *create_icmpUGE(value *lhs, value *rhs);
value *create_icmpUGT(value *lhs, value *rhs);
value *create_icmpEQ(value *lhs, value *rhs);
value *create_icmpNE(value *lhs, value *rhs);
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpONE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_fcmpOLT(value *lhs, value *rhs);
value *create_fcmpOGT(value *lhs, value *rhs);
value *create_fcmpOLE(value *lhs, value *rhs);
value *create_fcmpOGE(value *lhs, value *rhs);
value *create_fcmpOEQ(value *lhs, value *rhs);
value *create_fcmpONE(value *lhs, value *rhs);
// Logical
value *create_and(value *lhs, value *rhs, const std::string &name = "");
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
value *create_or(value *lhs, value *rhs, const std::string &name = "");
// Unary
// value *create_fneg(value *arg, const std::string &name = "");
// value *create_neg(value *arg, const std::string &name = "");
// value *create_not(value *arg, const std::string &name = "");
value *create_and(value *lhs, value *rhs);
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = "");
value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = "");
// Tile instruction
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_load(value *arg);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Built-in instruction
value *create_get_program_id(unsigned axis, const std::string &name = "");
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
value *create_exp(value* arg, const std::string &name = "");
value *create_log(value* arg, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_exch(value *ptr, value *val);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
@@ -158,6 +158,7 @@ private:
iterator insert_point_;
};
}
}

View File

@@ -9,6 +9,7 @@
namespace triton{
namespace ir{
class builder;
class type;
class context_impl;
@@ -16,8 +17,11 @@ class context_impl;
class context {
public:
context();
context(const context&) = delete;
context& operator=(const context&) = delete;
public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl;
};

View File

@@ -28,13 +28,16 @@ public:
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
};
}

View File

@@ -0,0 +1,97 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -35,7 +35,7 @@ private:
/* Attribute */
enum attribute_kind_t {
readonly,
readonly = 0,
writeonly,
noalias,
aligned,
@@ -71,7 +71,7 @@ public:
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly";
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
case retune: return ".retunr";
default: break;
}
@@ -102,7 +102,7 @@ private:
public:
// accessors
const args_t &args() { return args_; }
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
// factory methods

View File

@@ -514,7 +514,7 @@ public:
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
@@ -525,7 +525,7 @@ private:
std::string repr_impl() const { return "reshape"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
_TRITON_DEFINE_ACCEPT(reshape_inst)
@@ -539,7 +539,7 @@ private:
std::string repr_impl() const { return "splat"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
_TRITON_DEFINE_ACCEPT(splat_inst)
@@ -553,7 +553,7 @@ private:
std::string repr_impl() const { return "broadcast"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
_TRITON_DEFINE_ACCEPT(broadcast_inst)
@@ -597,16 +597,16 @@ private:
unsigned axis_;
};
class get_num_program_inst: public builtin_inst {
class get_num_programs_inst: public builtin_inst {
private:
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_program_inst)
_TRITON_DEFINE_ACCEPT(get_num_program_inst)
_TRITON_DEFINE_CLONE(get_num_programs_inst)
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
private:
unsigned axis_;

View File

@@ -36,6 +36,11 @@ class alloc_const;
/* Module */
struct scope {
public:
const std::map<std::string, ir::value*>& get_values() { return values; }
void set_type(const std::string& name, ir::type* ty) { types[name] = ty; }
ir::type* get_type(const std::string& name) { return types.at(name); }
private:
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
@@ -61,8 +66,7 @@ private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name);
context& get_context();
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
@@ -95,8 +99,7 @@ public:
private:
std::string name_;
context context_;
builder builder_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;
std::set<std::string> const_;

View File

@@ -18,7 +18,7 @@ class constant_int;
/* Type */
class type {
public:
typedef std::vector<unsigned> tile_shapes_t;
typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
@@ -43,7 +43,7 @@ public:
FunctionTyID, ///< 11: Functions
PointerTyID, ///< 12: Pointers
StructTyID, ///< 13: Struct
TileTyID, ///< 14: Tile
BlockTyID, ///< 14: Tile
};
public:
@@ -62,7 +62,7 @@ public:
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
const tile_shapes_t& get_tile_shapes() const;
block_shapes_t get_block_shapes() const;
const size_t get_tile_rank() const;
const size_t get_tile_ranks1() const;
unsigned get_tile_num_elements() const;
@@ -83,7 +83,7 @@ public:
get_integer_bitwidth() == bitwidth;}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
@@ -110,7 +110,7 @@ public:
// repr
std::string tile_repr() const {
std::string res = get_tile_element_ty()->repr();
auto shapes = get_tile_shapes();
auto shapes = get_block_shapes();
res += "<";
for(size_t i = 0; i < shapes.size(); i++){
if(i > 0)
@@ -137,7 +137,7 @@ public:
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case TileTyID: return tile_repr();
case BlockTyID: return tile_repr();
default: break;
}
assert(false);
@@ -180,23 +180,23 @@ public:
type* get_type_at_index(value *idx) const;
};
class tile_type: public composite_type {
class block_type: public composite_type {
private:
tile_type(type *ty, const tile_shapes_t &shapes);
block_type(type *ty, const block_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const tile_shapes_t& get_shapes() const { return shapes_; }
const block_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const;
unsigned get_bitwidth() const;
// factory methods
static tile_type* get(type *ty, const tile_shapes_t &shapes);
static tile_type* get_same_shapes(type *ty, type *ref);
static block_type* get(type *ty, const block_shapes_t &shapes);
static block_type* get_same_shapes(type *ty, type *ref);
private:
tile_shapes_t shapes_;
block_shapes_t shapes_;
};
class pointer_type: public type {

View File

@@ -52,7 +52,7 @@ class exp_inst;
class log_inst;
class get_program_id_inst;
class get_num_program_inst;
class get_num_programs_inst;
class atomic_cas_inst;
class atomic_exch_inst;
class atomic_add_inst;
@@ -128,7 +128,7 @@ public:
virtual void visit_downcast_inst(downcast_inst*) = 0;
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
virtual void visit_get_num_program_inst(get_num_program_inst*) = 0;
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;

View File

@@ -1,823 +0,0 @@
#pragma once
#ifndef _WGTCC_AST_H_
#define _WGTCC_AST_H_
#include "error.h"
#include "token.h"
#include "type.h"
#include <cassert>
#include <list>
#include <memory>
#include <string>
class Visitor;
template<typename T> class Evaluator;
class AddrEvaluator;
class Generator;
class Scope;
class Parser;
class ASTNode;
class Token;
class TokenSequence;
// Expressions
class Expr;
class BinaryOp;
class UnaryOp;
class ConditionalOp;
class FuncCall;
class TempVar;
class Constant;
class Identifier;
class Object;
struct Initializer;
class Declaration;
class Enumerator;
// Statements
class Stmt;
class IfStmt;
class ForStmt;
class JumpStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
/*
* AST Node
*/
class ASTNode {
public:
struct Attr{
enum KindT{
MULTIPLEOF,
ALIGNED,
NOALIAS,
READONLY,
WRITEONLY,
RETUNE,
};
KindT kind;
std::vector<Expr*> vals;
};
using AttrList = std::vector<Attr>;
public:
virtual ~ASTNode() {}
virtual void Accept(Visitor* v) = 0;
protected:
ASTNode() {}
MemPool* pool_ {nullptr};
};
using ExtDecl = ASTNode;
/*
* Statements
*/
class Stmt : public ASTNode {
public:
virtual ~Stmt() {}
protected:
Stmt() {}
};
class EmptyStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static EmptyStmt* New();
virtual ~EmptyStmt() {}
virtual void Accept(Visitor* v);
protected:
EmptyStmt() {}
};
class LabelStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static LabelStmt* New();
~LabelStmt() {}
virtual void Accept(Visitor* v);
std::string Repr() const { return ".L" + std::to_string(tag_); }
protected:
LabelStmt(): tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_; // 使用整型的tag值而不直接用字符串
};
class IfStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr);
virtual ~IfStmt() {}
virtual void Accept(Visitor* v);
protected:
IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr)
: cond_(cond), then_(then), else_(els) {}
private:
Expr* cond_;
Stmt* then_;
Stmt* else_;
};
class ForStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
virtual ~ForStmt() {}
virtual void Accept(Visitor* v);
protected:
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
: body_(body), init_(init), cond_(cond), step_(step) {}
private:
Stmt* body_;
Stmt* init_;
Expr* cond_;
Expr* step_;
};
class JumpStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static JumpStmt* New(LabelStmt* label);
virtual ~JumpStmt() {}
virtual void Accept(Visitor* v);
void SetLabel(LabelStmt* label) { label_ = label; }
protected:
JumpStmt(LabelStmt* label): label_(label) {}
private:
LabelStmt* label_;
};
class ReturnStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ReturnStmt* New(Expr* expr);
virtual ~ReturnStmt() {}
virtual void Accept(Visitor* v);
protected:
ReturnStmt(::Expr* expr): expr_(expr) {}
private:
::Expr* expr_;
};
using StmtList = std::list<Stmt*>;
class CompoundStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr);
virtual ~CompoundStmt() {}
virtual void Accept(Visitor* v);
StmtList& Stmts() { return stmts_; }
::Scope* Scope() { return scope_; }
protected:
CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr)
: stmts_(stmts), scope_(scope) {}
private:
StmtList stmts_;
::Scope* scope_;
};
struct Initializer {
Initializer(Type* type,
int offset,
Expr* expr,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0)
: type_(type),
offset_(offset),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
expr_(expr) {}
bool operator<(const Initializer& rhs) const;
// It could be the object it self or, it will be the member
// that was initialized
Type* type_;
int offset_;
unsigned char bitFieldBegin_;
unsigned char bitFieldWidth_;
Expr* expr_;
};
using InitList = std::set<Initializer>;
class Declaration: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Declaration* New(Object* obj);
virtual ~Declaration() {}
virtual void Accept(Visitor* v);
InitList& Inits() { return inits_; }
Object* Obj() { return obj_; }
void AddInit(Initializer init);
protected:
Declaration(Object* obj): obj_(obj) {}
Object* obj_;
InitList inits_;
};
/*
* Expr
* BinaryOp
* UnaryOp
* ConditionalOp
* FuncCall
* Constant
* Identifier
* Object
* TempVar
*/
class Expr : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
virtual ~Expr() {}
::Type* Type() { return type_.GetPtr(); }
virtual bool IsLVal() = 0;
virtual void TypeChecking() = 0;
void EnsureCompatible(const QualType lhs, const QualType rhs) const;
void EnsureCompatibleOrVoidPointer(const QualType lhs,
const QualType rhs) const;
const Token* Tok() const { return tok_; }
void SetTok(const Token* tok) { tok_ = tok; }
static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType);
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
bool IsVolatileQualified() const { return type_.IsVolatileQualified(); }
protected:
// You can construct a expression without specifying a type,
// then the type should be evaluated in TypeChecking()
Expr(const Token* tok, QualType type): tok_(tok), type_(type) {}
const Token* tok_;
QualType type_;
};
/*
* '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^'
* '=',(复合赋值运算符被拆分为两个运算)
* '==', '!=', '<=', '>=',
* '&&', '||'
* '['(下标运算符), '.'(成员运算符)
* ','(逗号运算符),
*/
class BinaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
friend class Declaration;
public:
static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs);
static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs);
virtual ~BinaryOp() {}
virtual void Accept(Visitor* v);
// Member ref operator is a lvalue
virtual bool IsLVal() {
switch (op_) {
case '.': return !Type()->ToArray() && lhs_->IsLVal();
case ']': return !Type()->ToArray();
case Token::MASKED_DEREF: return true;
default: return false;
}
}
ArithmType* Convert();
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
virtual void TypeChecking();
void SubScriptingOpTypeChecking();
void MemberRefOpTypeChecking();
void MultiOpTypeChecking();
void AdditiveOpTypeChecking();
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void MaskedDerefOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();
void LogicalOpTypeChecking();
void AssignOpTypeChecking();
void CommaOpTypeChecking();
protected:
BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs)
: Expr(tok, nullptr), op_(op) {
lhs_ = lhs, rhs_ = rhs;
if (op != '.') {
lhs_ = MayCast(lhs);
rhs_ = MayCast(rhs);
}
}
int op_;
Expr* lhs_;
Expr* rhs_;
};
/*
* Unary Operator:
* '++' (prefix/postfix)
* '--' (prefix/postfix)
* '&' (ADDR)
* '*' (DEREF)
* '+' (PLUS)
* '-' (MINUS)
* '~'
* '!'
* CAST // like (int)3
*/
class UnaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void UnaryArithmOpTypeChecking();
void BitcastOpTypeChecking();
void CastOpTypeChecking();
void IntrinsicOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
}
}
int op_;
int info_;
Expr* operand_;
};
class TransOp: public Expr {
friend class Generator;
public:
using PermInt = std::vector<int>;
public:
static TransOp* New(const PermInt& perm, Expr* operand);
const PermInt& getPerm() const { return perm_; }
void Accept(Visitor* v);
bool IsLVal() { return false; }
void TypeChecking();
protected:
TransOp(const PermInt& perm, Expr* operand)
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
private:
Expr* operand_;
PermInt perm_;
};
// cond ? true false
class ConditionalOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ConditionalOp* New(const Token* tok,
Expr* cond, Expr* exprTrue, Expr* exprFalse);
virtual ~ConditionalOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
ArithmType* Convert();
virtual void TypeChecking();
protected:
ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse)
: Expr(cond->Tok(), nullptr), cond_(MayCast(cond)),
exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {}
private:
Expr* cond_;
Expr* exprTrue_;
Expr* exprFalse_;
};
class FuncCall : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ArgList = std::vector<Expr*>;
public:
static FuncCall* New(Expr* designator, const ArgList& args);
~FuncCall() {}
virtual void Accept(Visitor* v);
// A function call is ofcourse not lvalue
virtual bool IsLVal() { return false; }
ArgList* Args() { return &args_; }
Expr* Designator() { return designator_; }
const std::string& Name() const { return tok_->str_; }
::FuncType* FuncType() { return designator_->Type()->ToFunc(); }
virtual void TypeChecking();
protected:
FuncCall(Expr* designator, const ArgList& args)
: Expr(designator->Tok(), nullptr),
designator_(designator), args_(args) {}
Expr* designator_;
ArgList args_;
};
class Constant: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Constant* New(const Token* tok, int tag, long val);
static Constant* New(const Token* tok, int tag, double val);
static Constant* New(const Token* tok, int tag, const std::string* val);
~Constant() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual void TypeChecking() {}
long IVal() const { return ival_; }
double FVal() const { return fval_; }
const std::string* SVal() const { return sval_; }
std::string SValRepr() const;
std::string Repr() const { return std::string(".LC") + std::to_string(id_); }
protected:
Constant(const Token* tok, QualType type, long val)
: Expr(tok, type), ival_(val) {}
Constant(const Token* tok, QualType type, double val)
: Expr(tok, type), fval_(val) {}
Constant(const Token* tok, QualType type, const std::string* val)
: Expr(tok, type), sval_(val) {}
union {
long ival_;
double fval_;
struct {
long id_;
const std::string* sval_;
};
};
};
class TempVar : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TempVar* New(QualType type);
virtual ~TempVar() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return true; }
virtual void TypeChecking() {}
protected:
TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_;
};
enum Linkage {
L_NONE,
L_EXTERNAL,
L_INTERNAL,
};
class Identifier: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Identifier* New(const Token* tok, QualType type, Linkage linkage, const AttrList& attrList={});
virtual ~Identifier() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual Object* ToObject() { return nullptr; }
virtual Enumerator* ToEnumerator() { return nullptr; }
// An identifer can be:
// object, sturct/union/enum tag, typedef name, function, label.
Identifier* ToTypeName() {
// A typename has no linkage
// And a function has external or internal linkage
if (ToObject() || ToEnumerator() || linkage_ != L_NONE)
return nullptr;
return this;
}
virtual const std::string Name() const { return tok_->str_; }
enum Linkage Linkage() const { return linkage_; }
void SetLinkage(enum Linkage linkage) { linkage_ = linkage; }
virtual void TypeChecking() {}
protected:
Identifier(const Token* tok, QualType type, enum Linkage linkage, const AttrList& attrList={})
: Expr(tok, type), linkage_(linkage), attrList_(attrList) {}
// An identifier has property linkage
enum Linkage linkage_;
AttrList attrList_;
};
class Enumerator: public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Enumerator* New(const Token* tok, int val);
virtual ~Enumerator() {}
virtual void Accept(Visitor* v);
virtual Enumerator* ToEnumerator() { return this; }
int Val() const { return cons_->IVal(); }
protected:
Enumerator(const Token* tok, int val)
: Identifier(tok, ArithmType::New(T_INT), L_NONE),
cons_(Constant::New(tok, T_INT, (long)val)) {}
Constant* cons_;
};
class Object : public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Object* New(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
static Object* NewAnony(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
~Object() {}
virtual void Accept(Visitor* v);
virtual Object* ToObject() { return this; }
virtual bool IsLVal() {
// TODO(wgtdkp): not all object is lval?
return true;
}
bool IsStatic() const {
return (Storage() & S_STATIC) || (Linkage() != L_NONE);
}
int Storage() const { return storage_; }
void SetStorage(int storage) { storage_ = storage; }
int Align() const { return align_; }
void SetAlign(int align) {
assert(align > 0);
// Allowing reduce alignment to implement __attribute__((packed))
//if (align < align_)
// Error(this, "alignment specifier cannot reduce alignment");
align_ = align;
}
int Offset() const { return offset_; }
void SetOffset(int offset) { offset_ = offset; }
Declaration* Decl() { return decl_; }
void SetDecl(Declaration* decl) { decl_ = decl; }
const AttrList& GetAttrList() const { return attrList_; }
unsigned char BitFieldBegin() const { return bitFieldBegin_; }
unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; }
unsigned char BitFieldWidth() const { return bitFieldWidth_; }
static unsigned long BitFieldMask(Object* bitField) {
return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_);
}
static unsigned long BitFieldMask(unsigned char begin, unsigned char width) {
auto end = begin + width;
return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin;
}
bool HasInit() const { return decl_ && decl_->Inits().size(); }
bool Anonymous() const { return anonymous_; }
virtual const std::string Name() const { return Identifier::Name(); }
std::string Repr() const {
assert(IsStatic() || anonymous_);
if (anonymous_)
return "anonymous." + std::to_string(id_);
if (linkage_ == L_NONE)
return Name() + "." + std::to_string(id_);
return Name();
}
protected:
Object(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={})
: Identifier(tok, type, linkage),
storage_(storage),
offset_(0),
align_(type->Align()),
decl_(nullptr),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
anonymous_(false),
attrList_(attrList){}
private:
int storage_;
int offset_;
int align_;
Declaration* decl_;
unsigned char bitFieldBegin_;
// 0 means it's not a bitfield
unsigned char bitFieldWidth_;
bool anonymous_;
long id_ {0};
AttrList attrList_;
};
/*
* Declaration
*/
class FuncDef : public ExtDecl {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ParamList = std::vector<Object*>;
public:
static FuncDef* New(Identifier* ident, LabelStmt* retLabel);
virtual ~FuncDef() {}
virtual void Accept(Visitor* v);
::FuncType* FuncType() { return ident_->Type()->ToFunc(); }
CompoundStmt* Body() { return body_; }
void SetBody(CompoundStmt* body) { body_ = body; }
std::string Name() const { return ident_->Name(); }
enum Linkage Linkage() { return ident_->Linkage(); }
protected:
FuncDef(Identifier* ident, LabelStmt* retLabel)
: ident_(ident), retLabel_(retLabel) {}
private:
Identifier* ident_;
LabelStmt* retLabel_;
CompoundStmt* body_;
};
using ExtDeclList = std::list<ExtDecl*>;
class TranslationUnit : public ASTNode {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TranslationUnit* New() { return new TranslationUnit();}
virtual ~TranslationUnit() {}
virtual void Accept(Visitor* v);
void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); }
ExtDeclList& ExtDecls() { return extDecls_; }
const ExtDeclList& ExtDecls() const { return extDecls_; }
private:
TranslationUnit() {}
ExtDeclList extDecls_;
};
#endif

View File

@@ -1,167 +0,0 @@
#pragma once
#ifndef _WGTCC_CODE_GEN_H_
#define _WGTCC_CODE_GEN_H_
#include "ast.h"
#include "visitor.h"
#include <stack>
namespace triton{
namespace ir{
class value;
class module;
class type;
class context;
class builder;
class attribute;
}
}
using namespace triton;
class Parser;
struct Addr;
template<> class Evaluator<Addr>;
struct StaticInitializer;
class LValAssigner;
using TypeList = std::vector<Type*>;
using LocationList = std::vector<std::string>;
using StaticInitList = std::vector<StaticInitializer>;
// Error
inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); }
inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); }
class Generator: public Visitor {
friend class Evaluator<Addr>;
friend class LValAssigner;
protected:
struct scope {
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
void set_ret(ir::value* value);
ir::value *GenUnaryMinus(ir::value* arg);
ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc);
public:
Generator(Parser* parser) : parser_(parser) {}
void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitTransOp(TransOp* transOp);
void VisitConditionalOp(ConditionalOp* condOp);
void VisitFuncCall(FuncCall* funcCall);
void VisitObject(Object* obj);
void VisitEnumerator(Enumerator* enumer);
void VisitIdentifier(Identifier* ident);
void VisitConstant(Constant* cons);
void VisitTempVar(TempVar* tempVar);
// Statement
void VisitDeclaration(Declaration* init);
void VisitEmptyStmt(EmptyStmt* emptyStmt);
void VisitIfStmt(IfStmt* ifStmt);
void VisitForStmt(ForStmt* ifStmt);
void VisitJumpStmt(JumpStmt* jumpStmt);
void VisitReturnStmt(ReturnStmt* returnStmt);
void VisitLabelStmt(LabelStmt* labelStmt);
void VisitCompoundStmt(CompoundStmt* compoundStmt);
void VisitFuncDef(FuncDef* funcDef);
void VisitTranslationUnit(TranslationUnit* unit);
void Gen(ir::module *mod);
protected:
// Triton-IR attributes
ir::attribute GenIRAttr(ASTNode::Attr attr);
// Triton-IR metadata
void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs);
// Triton-IR values
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
ir::value* GenSemCastOp(ir::value* op, ir::type* type);
ir::value* GenBitCastOp(ir::value* src, ir::type* dst_ty);
// Triton-IR types
static ir::type* GenIRType(::Type* type, ir::context &ctx);
static ir::type* GenIRArithmType(ArithmType* type, ir::context& ctx);
static ir::type* GenIRArrayType(ArrayType* type, ir::context& ctx);
static ir::type* GenIRTileType(TileType* type, ir::context& ctx);
static ir::type* GenIRFuncType(FuncType* type, ir::context& ctx);
static ir::type* GenIRPointerType(PointerType* type, ir::context& ctx);
static ir::type* GenIRStructType(StructType* type, ir::context& ctx);
void AllocObjects(Scope* scope, const FuncDef::ParamList& params=FuncDef::ParamList());
// SSA
void pushScope();
void popScope();
private:
Parser* parser_;
ir::value* ret_;
ir::builder* bld_;
ir::context* ctx_;
ir::module* mod_;
private:
// std::stack<scope> scopes_;
LValAssigner* assign_;
};
class LValAssigner: public Visitor {
public:
LValAssigner(Generator* gen): gen_(gen) {}
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitObject(Object* obj);
void VisitIdentifier(Identifier* ident);
void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); }
void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); }
void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); }
void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); }
void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); }
void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); }
void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); }
void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); }
void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); }
void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); }
void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); }
void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); }
void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); }
void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); }
void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); }
void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); }
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
rhs_ = rhs;
expr->Accept(this);
return ret_;
}
private:
ir::value* ret_;
ir::value* rhs_;
Generator* gen_;
};
#endif

View File

@@ -1,164 +0,0 @@
#pragma once
#ifndef _WGTCC_CPP_H_
#define _WGTCC_CPP_H_
#include "scanner.h"
#include <cstdio>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
class Macro;
struct CondDirective;
using MacroMap = std::map<std::string, Macro>;
using ParamList = std::list<std::string>;
using ParamMap = std::map<std::string, TokenSequence>;
using PPCondStack = std::stack<CondDirective>;
using PathList = std::list<std::string>;
class Macro {
public:
Macro(const TokenSequence& repSeq, bool preDef=false)
: funcLike_(false), variadic_(false),
preDef_(preDef), repSeq_(repSeq) {}
Macro(bool variadic, ParamList& params,
TokenSequence& repSeq, bool preDef=false)
: funcLike_(true), variadic_(variadic), preDef_(preDef),
params_(params), repSeq_(repSeq) {}
~Macro() {}
bool FuncLike() { return funcLike_; }
bool ObjLike() { return !FuncLike(); }
bool Variadic() { return variadic_; }
bool PreDef() { return preDef_; }
ParamList& Params() { return params_; }
TokenSequence RepSeq(const std::string* filename, unsigned line);
private:
bool funcLike_;
bool variadic_;
bool preDef_;
ParamList params_;
TokenSequence repSeq_;
};
struct CondDirective {
int tag_;
bool enabled_;
bool cond_;
};
class Preprocessor {
public:
Preprocessor(const std::string* str, bool isSrc = true)
: curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) {
if(isSrc)
fSrc_ = str;
else
fName_ = str;
// Add predefined
Init();
}
~Preprocessor() {}
void Finalize(TokenSequence os);
void Process(TokenSequence& os);
void Expand(TokenSequence& os, TokenSequence is, bool inCond=false);
void Subst(TokenSequence& os, TokenSequence is,
bool leadingWS, const HideSet& hs, ParamMap& params);
void Glue(TokenSequence& os, TokenSequence is);
void Glue(TokenSequence& os, const Token* tok);
const Token* Stringize(TokenSequence is);
void Stringize(std::string& str, TokenSequence is);
const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap);
int GetDirective(TokenSequence& is);
const Token* EvalDefOp(TokenSequence& is);
void ReplaceIdent(TokenSequence& is);
void ParseDirective(TokenSequence& os, TokenSequence& is, int directive);
void ParseIf(TokenSequence ls);
void ParseIfdef(TokenSequence ls);
void ParseIfndef(TokenSequence ls);
void ParseElif(TokenSequence ls);
void ParseElse(TokenSequence ls);
void ParseEndif(TokenSequence ls);
void ParseInclude(TokenSequence& is, TokenSequence ls);
void ParseDef(TokenSequence ls);
void ParseUndef(TokenSequence ls);
void ParseLine(TokenSequence ls);
void ParseError(TokenSequence ls);
void ParsePragma(TokenSequence ls);
void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename);
void IncludeFile(TokenSequence& is, const std::string* filename);
bool ParseIdentList(ParamList& params, TokenSequence& is);
Macro* FindMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return nullptr;
return &res->second;
}
void AddMacro(const std::string& name,
std::string* text, bool preDef=false);
void AddMacro(const std::string& name, const Macro& macro) {
auto res = macroMap_.find(name);
if (res != macroMap_.end()) {
// TODO(wgtdkp): give warning
macroMap_.erase(res);
}
macroMap_.insert(std::make_pair(name, macro));
}
void RemoveMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return;
if(res->second.PreDef()) // Cannot undef predefined macro
return;
macroMap_.erase(res);
}
std::string* SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath);
void AddSearchPath(std::string path);
void HandleTheFileMacro(TokenSequence& os, const Token* macro);
void HandleTheLineMacro(TokenSequence& os, const Token* macro);
void UpdateFirstTokenLine(TokenSequence ts);
bool NeedExpand() const {
if (ppCondStack_.empty())
return true;
auto top = ppCondStack_.top();
return top.enabled_ && top.cond_;
}
private:
void Init();
PPCondStack ppCondStack_;
unsigned curLine_;
unsigned lineLine_;
bool curCond_;
MacroMap macroMap_;
PathList searchPaths_;
const std::string* fName_;
const std::string* fSrc_;
};
#endif

View File

@@ -1,22 +0,0 @@
#pragma once
#ifndef _WGTCC_ENCODING_H_
#define _WGTCC_ENCODING_H_
#include <string>
enum class Encoding {
NONE,
CHAR16,
CHAR32,
UTF8,
WCHAR
};
void ConvertToUTF16(std::string& str);
void ConvertToUTF32(std::string& str);
void AppendUCN(std::string& str, int c);
#endif

View File

@@ -1,17 +0,0 @@
#pragma once
#ifndef _WGTCC_ERROR_H_
#define _WGTCC_ERROR_H_
struct SourceLocation;
class Token;
class Expr;
[[noreturn]] void Error(const char* format, ...);
[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...);
[[noreturn]] void Error(const Token* tok, const char* format, ...);
[[noreturn]] void Error(const Expr* expr, const char* format, ...);
#endif

View File

@@ -1,130 +0,0 @@
#pragma once
#ifndef _WGTCC_EVALUATOR_H_
#define _WGTCC_EVALUATOR_H_
#include "ast.h"
#include "error.h"
#include "visitor.h"
class Expr;
template<typename T>
class Evaluator: public Visitor {
public:
Evaluator() {}
virtual ~Evaluator() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
val_ = static_cast<T>(enumer->Val());
}
virtual void VisitIdentifier(Identifier* ident) {
Error(ident, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitObject(Object* obj) {
Error(obj, "expect constant expression");
}
virtual void VisitConstant(Constant* cons) {
if (cons->Type()->IsFloat()) {
val_ = static_cast<T>(cons->FVal());
} else if (cons->Type()->IsInteger()) {
val_ = static_cast<T>(cons->IVal());
} else {
assert(false);
}
}
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
T Eval(Expr* expr) {
expr->Accept(this);
return val_;
}
private:
T val_;
};
struct Addr {
std::string label_;
int offset_;
};
template<>
class Evaluator<Addr>: public Visitor {
public:
Evaluator<Addr>() {}
virtual ~Evaluator<Addr>() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
addr_.offset_ = enumer->Val();
}
virtual void VisitIdentifier(Identifier* ident) {
addr_.label_ = ident->Name();
addr_.offset_ = 0;
}
virtual void VisitObject(Object* obj) {
if (!obj->IsStatic()) {
Error(obj, "expect static object");
}
addr_.label_ = obj->Repr();
addr_.offset_ = 0;
}
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
Addr Eval(Expr* expr) {
expr->Accept(this);
return addr_;
}
private:
Addr addr_;
};
#endif

View File

@@ -1,103 +0,0 @@
#pragma once
#ifndef _WGTCC_MEM_POOL_H_
#define _WGTCC_MEM_POOL_H_
#include <cstddef>
#include <vector>
class MemPool {
public:
MemPool(): allocated_(0) {}
virtual ~MemPool() {}
MemPool(const MemPool& other) = delete;
MemPool& operator=(const MemPool& other) = delete;
virtual void* Alloc() = 0;
virtual void Free(void* addr) = 0;
virtual void Clear() = 0;
protected:
size_t allocated_;
};
template <class T>
class MemPoolImp: public MemPool {
public:
MemPoolImp() : root_(nullptr) {}
virtual ~MemPoolImp() {}
MemPoolImp(const MemPool& other) = delete;
MemPoolImp& operator=(MemPool& other) = delete;
virtual void* Alloc();
virtual void Free(void* addr);
virtual void Clear();
private:
enum {
COUNT = (4 * 1024) / sizeof(T)
};
union Chunk {
Chunk* next_;
char mem_[sizeof(T)];
};
struct Block {
Block() {
for (size_t i = 0; i < COUNT - 1; ++i)
chunks_[i].next_ = &chunks_[i+1];
chunks_[COUNT-1].next_ = nullptr;
}
Chunk chunks_[COUNT];
};
std::vector<Block*> blocks_;
Chunk* root_;
};
template <class T>
void* MemPoolImp<T>::Alloc() {
if (nullptr == root_) { // 空间不够,需要分配空间
auto block = new Block();
root_ = block->chunks_;
// 如果blocks实现为std::list, 那么push_back实际的overhead更大
// 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead)
// 仍然倾向于使用std::vector
// 当然std::vector的指数级capacity增长会造成内存浪费。
blocks_.push_back(block);
}
auto ret = root_;
root_ = root_->next_;
++allocated_;
return ret;
}
template <class T>
void MemPoolImp<T>::Free(void* addr) {
if (nullptr == addr)
return;
auto chunk = static_cast<Chunk*>(addr);
chunk->next_ = root_;
root_ = chunk;
--allocated_;
}
template <class T>
void MemPoolImp<T>::Clear() {
for (auto block: blocks_)
delete block;
blocks_.resize(0);
root_ = nullptr;
allocated_ = 0;
}
#endif

View File

@@ -1,260 +0,0 @@
#pragma once
#ifndef _PARSER_H_
#define _PARSER_H_
#include "ast.h"
#include "encoding.h"
#include "error.h"
#include "mem_pool.h"
#include "scope.h"
#include "token.h"
#include <cassert>
#include <memory>
#include <stack>
class Preprocessor;
struct DeclInfo {
DeclInfo(const Token* _tok,
QualType _type,
ASTNode::AttrList _attrs = {})
: tok(_tok), type(_type), attrs(_attrs) {}
const Token* tok;
QualType type;
ASTNode::AttrList attrs;
};
class Parser {
using LiteralList = std::vector<Constant*>;
using StaticObjectList = std::vector<Object*>;
using CaseLabelList = std::vector<std::pair<Constant*, LabelStmt*>>;
using LabelJumpList = std::list<std::pair<const Token*, JumpStmt*>>;
using LabelMap = std::map<std::string, LabelStmt*>;
friend class Generator;
public:
explicit Parser(TokenSequence& ts)
: unit_(TranslationUnit::New()),
ts_(ts),
externalSymbols_(new Scope(nullptr, S_BLOCK)),
errTok_(nullptr),
curScope_(new Scope(nullptr, S_FILE)),
curFunc_(nullptr),
breakDest_(nullptr),
continueDest_(nullptr),
caseLabels_(nullptr),
defaultLabel_(nullptr) {
ts_.SetParser(this);
}
~Parser() {}
Constant* ParseConstant(const Token* tok);
Constant* ParseFloat(const Token* tok);
Constant* ParseInteger(const Token* tok);
Constant* ParseCharacter(const Token* tok);
Encoding ParseLiteral(std::string& str, const Token* tok);
Constant* ConcatLiterals(const Token* tok);
Expr* ParseGeneric();
void Parse();
void ParseTranslationUnit();
FuncDef* ParseFuncDef(Identifier* ident);
// Expressions
Expr* ParseExpr();
Expr* ParsePrimaryExpr();
QualType TryCompoundLiteral();
Object* ParseCompoundLiteral(QualType type);
Expr* ParsePostfixExpr();
Expr* ParsePostfixExprTail(Expr* primExpr);
Expr* ParseSubScripting(Expr* pointer);
BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs);
UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand);
FuncCall* ParseFuncCall(Expr* caller);
Expr* ParseUnaryExpr();
Constant* ParseSizeof();
Constant* ParseAlignof();
UnaryOp* ParsePrefixIncDec(const Token* tok);
UnaryOp* ParseUnaryIntrinsicOp(int op);
UnaryOp* ParseUnaryOp(const Token* tok, int op);
Expr* ParseDerefOp(const Token* tok);
QualType ParseTypeName();
Expr* ParseCastExpr();
Expr* ParseRangeExpr();
Expr* ParseMatmulExpr();
Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr();
Expr* ParseRelationalExpr();
Expr* ParseEqualityExpr();
Expr* ParseBitiwiseAndExpr();
Expr* ParseBitwiseXorExpr();
Expr* ParseBitwiseOrExpr();
Expr* ParseLogicalAndExpr();
Expr* ParseLogicalOrExpr();
Expr* ParseConditionalExpr();
Expr* ParseCommaExpr();
Expr* ParseAssignExpr();
// Declarations
CompoundStmt* ParseDecl();
void ParseStaticAssert();
QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec);
QualType ParseSpecQual();
int ParseAlignas();
Type* ParseStructUnionSpec(bool isStruct);
StructType* ParseStructUnionDecl(StructType* type);
void ParseBitField(StructType* structType, const Token* tok, QualType type);
Type* ParseEnumSpec();
Type* ParseEnumerator(ArithmType* type);
int ParseQual();
QualType ParsePointer(QualType typePointedTo);
DeclInfo ParseDeclarator(QualType type);
QualType ParseArrayFuncDeclarator(const Token* ident, QualType base);
int ParseArrayLength();
TileType::ShapeInt ParseTileShape();
bool ParseParamList(FuncType::ParamList& params);
Object* ParseParamDecl();
QualType ParseAbstractDeclarator(QualType type);
Identifier* ParseDirectDeclarator(QualType type,
int storageSpec,
int funcSpec,
int align);
// Initializer
void ParseInitializer(Declaration* decl,
QualType type,
int offset,
bool designated=false,
bool forceBrace=false,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
void ParseArrayInitializer(Declaration* decl,
ArrayType* type,
int offset,
bool designated);
StructType::Iterator ParseStructDesignator(StructType* type,
const std::string& name);
void ParseStructInitializer(Declaration* decl,
StructType* type,
int offset,
bool designated);
bool ParseLiteralInitializer(Declaration* init,
ArrayType* type,
int offset);
Declaration* ParseInitDeclarator(Identifier* ident);
Declaration* ParseInitDeclaratorSub(Object* obj);
// Statements
Stmt* ParseStmt();
CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr);
IfStmt* ParseIfStmt();
CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt();
ForStmt *ParseForStmt();
JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt();
ReturnStmt* ParseReturnStmt();
CompoundStmt* ParseLabelStmt(const Token* label);
CompoundStmt* ParseCaseStmt();
CompoundStmt* ParseDefaultStmt();
Identifier* ProcessDeclarator(const Token* tok,
QualType type, const ASTNode::AttrList &attrs,
int storageSpec,
int funcSpec,
int align);
// GNU extensions
ASTNode::AttrList TryAttributeSpecList();
void ParseAttributeSpec(ASTNode::AttrList &attrList);
ASTNode::Attr ParseAttribute();
bool IsTypeName(const Token* tok) const{
if (tok->IsTypeSpecQual())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
if (ident && ident->ToTypeName())
return true;
}
return false;
}
bool IsType(const Token* tok) const{
if (tok->IsDecl())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
return (ident && ident->ToTypeName());
}
return false;
}
void EnsureInteger(Expr* expr) {
if (!expr->Type()->IsInteger()) {
Error(expr, "expect integer expression");
}
}
void EnterBlock(FuncType* funcType=nullptr);
void ExitBlock() { curScope_ = curScope_->Parent(); }
void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); }
void ExitProto() { curScope_ = curScope_->Parent(); }
FuncDef* EnterFunc(Identifier* ident);
void ExitFunc();
LabelStmt* FindLabel(const std::string& label) {
auto ret = curLabels_.find(label);
if (curLabels_.end() == ret)
return nullptr;
return ret->second;
}
void AddLabel(const std::string& label, LabelStmt* labelStmt) {
assert(nullptr == FindLabel(label));
curLabels_[label] = labelStmt;
}
TranslationUnit* Unit() { return unit_; }
FuncDef* CurFunc() { return curFunc_; }
const TokenSequence& ts() const { return ts_; }
protected:
static bool IsBuiltin(FuncType* type);
static bool IsBuiltin(const std::string& name);
static Identifier* GetBuiltin(const Token* tok);
static void DefineBuiltins();
static FuncType* vaStartType_;
static FuncType* vaArgType_;
// The root of the AST
TranslationUnit* unit_;
TokenSequence& ts_;
// It is not the real scope,
// It contains all external symbols(resolved and not resolved)
Scope* externalSymbols_;
const Token* errTok_;
Scope* curScope_;
FuncDef* curFunc_;
LabelMap curLabels_;
LabelJumpList unresolvedJumps_;
LabelStmt* breakDest_;
LabelStmt* continueDest_;
CaseLabelList* caseLabels_;
LabelStmt* defaultLabel_;
};
#endif

View File

@@ -1,86 +0,0 @@
#pragma once
#ifndef _WGTCC_SCANNER_H_
#define _WGTCC_SCANNER_H_
#include "error.h"
#include "encoding.h"
#include "token.h"
#include <string>
#include <cassert>
class Scanner {
public:
explicit Scanner(const Token* tok)
: Scanner(&tok->str_, tok->loc_) {}
Scanner(const std::string* text, const SourceLocation& loc)
: Scanner(text, loc.filename_, loc.line_, loc.column_) {}
explicit Scanner(const std::string* text,
const std::string* filename=nullptr,
unsigned line=1, unsigned column=1)
: text_(text), tok_(Token::END) {
// TODO(wgtdkp): initialization
p_ = &(*text_)[0];
loc_ = {filename, p_, line, 1};
}
virtual ~Scanner() {}
Scanner(const Scanner& other) = delete;
Scanner& operator=(const Scanner& other) = delete;
// Scan plain text and generate tokens in ts.
// The param 'ts' need not be empty, if so, the tokens
// are inserted at the *header* of 'ts'.
// The param 'ws' tells if there is leading white space
// before this token, it is only SkipComment() that will
// set this param.
Token* Scan(bool ws=false);
void Tokenize(TokenSequence& ts);
static std::string ScanHeadName(const Token* lhs, const Token* rhs);
Encoding ScanCharacter(int& val);
Encoding ScanLiteral(std::string& val);
std::string ScanIdentifier();
private:
Token* SkipIdentifier();
Token* SkipNumber();
Token* SkipLiteral();
Token* SkipCharacter();
Token* MakeToken(int tag);
Token* MakeNewLine();
Encoding ScanEncoding(int c);
int ScanEscaped();
int ScanHexEscaped();
int ScanOctEscaped(int c);
int ScanUCN(int len);
void SkipWhiteSpace();
void SkipComment();
bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); }
bool IsOctal(int c) { return '0' <= c && c <= '7'; }
int XDigit(int c);
bool Empty() const { return *p_ == 0; }
int Peek();
bool Test(int c) { return Peek() == c; };
int Next();
void PutBack();
bool Try(int c) {
if (Peek() == c) {
Next();
return true;
}
return false;
};
void Mark() { tok_.loc_ = loc_; };
const std::string* text_;
SourceLocation loc_;
Token tok_;
const char* p_;
};
std::string* ReadFile(const std::string& filename);
#endif

View File

@@ -1,72 +0,0 @@
#pragma once
#ifndef _WGTCC_SCOPE_H_
#define _WGTCC_SCOPE_H_
#include <iostream>
#include <map>
#include <string>
#include <vector>
class Identifier;
class Token;
enum ScopeType {
S_FILE,
S_PROTO,
S_BLOCK,
S_FUNC,
};
class Scope {
friend class StructType;
using TagList = std::vector<Identifier*>;
using IdentMap = std::map<std::string, Identifier*>;
public:
explicit Scope(Scope* parent, enum ScopeType type)
: parent_(parent), type_(type) {}
~Scope() {}
Scope* Parent() { return parent_; }
void SetParent(Scope* parent) { parent_ = parent; }
enum ScopeType Type() const { return type_; }
Identifier* Find(const Token* tok);
Identifier* FindInCurScope(const Token* tok);
Identifier* FindTag(const Token* tok);
Identifier* FindTagInCurScope(const Token* tok);
TagList AllTagsInCurScope() const;
void Insert(Identifier* ident);
void Insert(const std::string& name, Identifier* ident);
void InsertTag(Identifier* ident);
void Print();
bool operator==(const Scope& other) const { return type_ == other.type_; }
IdentMap::iterator begin() { return identMap_.begin(); }
IdentMap::iterator end() { return identMap_.end(); }
size_t size() const { return identMap_.size(); }
private:
Identifier* Find(const std::string& name);
Identifier* FindInCurScope(const std::string& name);
Identifier* FindTag(const std::string& name);
Identifier* FindTagInCurScope(const std::string& name);
std::string TagName(const std::string& name) {
return name + "@:tag";
}
static bool IsTagName(const std::string& name) {
return name.size() > 5 && name[name.size() - 5] == '@';
}
const Scope& operator=(const Scope& other);
Scope(const Scope& scope);
Scope* parent_;
enum ScopeType type_;
IdentMap identMap_;
};
#endif

View File

@@ -1,434 +0,0 @@
#pragma once
#ifndef _WGTCC_TOKEN_H_
#define _WGTCC_TOKEN_H_
#include "error.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <unordered_map>
class Generator;
class Parser;
class Scanner;
class Token;
class TokenSequence;
using HideSet = std::set<std::string>;
using TokenList = std::list<const Token*>;
struct SourceLocation {
const std::string* filename_;
const char* lineBegin_;
unsigned line_;
unsigned column_;
const char* Begin() const {
return lineBegin_ + column_ - 1;
}
};
class Token {
friend class Scanner;
public:
enum {
// Punctuators
LPAR = '(',
RPAR = ')',
LSQB = '[',
RSQB = ']',
COLON = ':',
COMMA = ',',
SEMI = ';',
ADD = '+',
SUB = '-',
MUL = '*',
DIV = '/',
OR = '|',
AND = '&',
XOR = '^',
LESS = '<',
GREATER = '>',
EQUAL = '=',
DOT = '.',
MOD = '%',
LBRACE = '{',
RBRACE = '}',
TILDE = '~',
NOT = '!',
COND = '?',
SHARP = '#',
MATMUL = '@',
NEW_LINE = '\n',
DSHARP = 128, // '##'
PTR,
INC,
DEC,
LEFT,
RIGHT,
LE,
GE,
EQ,
NE,
LOGICAL_AND,
LOGICAL_OR,
MUL_ASSIGN,
DIV_ASSIGN,
MOD_ASSIGN,
ADD_ASSIGN,
SUB_ASSIGN,
LEFT_ASSIGN,
RIGHT_ASSIGN,
AND_ASSIGN,
XOR_ASSIGN,
OR_ASSIGN,
ELLIPSIS,
MASKED_DEREF,
// Punctuators end
// KEYWORD BEGIN
// TYPE QUALIFIER BEGIN
CONST,
RESTRICT,
VOLATILE,
ATOMIC,
// TYPE QUALIFIER END
// TYPE SPECIFIER BEGIN
VOID,
CHAR,
SHORT,
INT,
LONG,
HALF,
FLOAT,
DOUBLE,
SIGNED,
UNSIGNED,
BOOL, // _Bool
COMPLEX, // _Complex
STRUCT,
UNION,
ENUM,
// TYPE SPECIFIER END
ATTRIBUTE, // GNU extension __attribute__
// FUNCTION SPECIFIER BEGIN
INLINE,
NORETURN, // _Noreturn
// FUNCTION SPECIFIER END
// TILE ARITHMETICS BEGIN
NEWAXIS,
MAX,
MIN,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas
// For syntactic convenience
STATIC_ASSERT, // _Static_assert
// STORAGE CLASS SPECIFIER BEGIN
TYPEDEF,
EXTERN,
STATIC,
THREAD, // _Thread_local
AUTO,
GLOBAL,
CMEM, // constant memory
// STORAGE CLASS SPECIFIER END
BREAK,
CASE,
CONTINUE,
DEFAULT,
DO,
ELSE,
FOR,
GOTO,
IF,
RETURN,
SIZEOF,
SWITCH,
WHILE,
ALIGNOF, // _Alignof
GENERIC, // _Generic
IMAGINARY, // _Imaginary
// function keywords
BITCAST,
EXP,
LOG,
SQRTF,
// KEYWORD END
IDENTIFIER,
CONSTANT,
I_CONSTANT,
C_CONSTANT,
F_CONSTANT,
LITERAL,
// For the parser, a identifier is a typedef name or user defined type
POSTFIX_INC,
POSTFIX_DEC,
PREFIX_INC,
PREFIX_DEC,
ADDR, // '&'
DEREF, // '*'
PLUS,
MINUS,
CAST,
REDUCE,
// For preprocessor
PP_IF,
PP_IFDEF,
PP_IFNDEF,
PP_ELIF,
PP_ELSE,
PP_ENDIF,
PP_INCLUDE,
PP_DEFINE,
PP_UNDEF,
PP_LINE,
PP_ERROR,
PP_PRAGMA,
PP_NONE,
PP_EMPTY,
IGNORE,
INVALID,
END,
NOTOK = -1,
};
static Token* New(int tag);
static Token* New(const Token& other);
static Token* New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws=false);
Token& operator=(const Token& other) {
tag_ = other.tag_;
ws_ = other.ws_;
loc_ = other.loc_;
str_ = other.str_;
hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr;
return *this;
}
virtual ~Token() {}
// Token::NOTOK represents not a kw.
static int KeyWordTag(const std::string& key) {
auto kwIter = kwTypeMap_.find(key);
if (kwTypeMap_.end() == kwIter)
return Token::NOTOK; // Not a key word type
return kwIter->second;
}
static bool IsKeyWord(const std::string& name);
static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; }
bool IsKeyWord() const { return IsKeyWord(tag_); }
bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; }
bool IsLiteral() const { return tag_ == LITERAL; }
bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; }
bool IsIdentifier() const { return IDENTIFIER == tag_; }
bool IsEOF() const { return tag_ == Token::END; }
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; }
static const char* Lexeme(int tag) {
auto iter = tagLexemeMap_.find(tag);
if (iter == tagLexemeMap_.end())
return nullptr;
return iter->second;
}
int tag_;
// 'ws_' standards for weither there is preceding white space
// This is to simplify the '#' operator(stringize) in macro expansion
bool ws_ { false };
SourceLocation loc_;
std::string str_;
HideSet* hs_ { nullptr };
private:
explicit Token(int tag): tag_(tag) {}
Token(int tag, const SourceLocation& loc,
const std::string& str, bool ws=false)
: tag_(tag), ws_(ws), loc_(loc), str_(str) {}
Token(const Token& other) {
*this = other;
}
static const std::unordered_map<std::string, int> kwTypeMap_;
static const std::unordered_map<int, const char*> tagLexemeMap_;
};
class TokenSequence {
friend class Preprocessor;
public:
TokenSequence(): tokList_(new TokenList()),
begin_(tokList_->begin()), end_(tokList_->end()) {}
explicit TokenSequence(Token* tok) {
TokenSequence();
InsertBack(tok);
}
explicit TokenSequence(TokenList* tokList)
: tokList_(tokList),
begin_(tokList->begin()),
end_(tokList->end()) {}
TokenSequence(TokenList* tokList,
TokenList::iterator begin,
TokenList::iterator end)
: tokList_(tokList), begin_(begin), end_(end) {}
~TokenSequence() {}
TokenSequence(const TokenSequence& other) { *this = other; }
const TokenSequence& operator=(const TokenSequence& other) {
tokList_ = other.tokList_;
begin_ = other.begin_;
end_ = other.end_;
return *this;
}
void Copy(const TokenSequence& other) {
tokList_ = new TokenList(other.begin_, other.end_);
begin_ = tokList_->begin();
end_ = tokList_->end();
for (auto iter = begin_; iter != end_; ++iter)
*iter = Token::New(**iter);
}
void UpdateHeadLocation(const SourceLocation& loc) {
assert(!Empty());
auto tok = const_cast<Token*>(Peek());
tok->loc_ = loc;
}
void FinalizeSubst(bool leadingWS, const HideSet& hs) {
auto ts = *this;
while (!ts.Empty()) {
auto tok = const_cast<Token*>(ts.Next());
if (!tok->hs_)
tok->hs_ = new HideSet(hs);
else
tok->hs_->insert(hs.begin(), hs.end());
}
// Even if the token sequence is empty
const_cast<Token*>(Peek())->ws_ = leadingWS;
}
const Token* Expect(int expect);
bool Try(int tag) {
if (Peek()->tag_ == tag) {
Next();
return true;
}
return false;
}
bool Test(int tag) { return Peek()->tag_ == tag; }
const Token* Next() {
auto ret = Peek();
if (!ret->IsEOF()) {
++begin_;
Peek(); // May skip newline token, but why ?
} else {
++exceed_end;
}
return ret;
}
void PutBack() {
assert(begin_ != tokList_->begin());
if (exceed_end > 0) {
--exceed_end;
} else {
--begin_;
if ((*begin_)->tag_ == Token::NEW_LINE)
PutBack();
}
}
const Token* Peek() const;
const Token* Peek2() {
if (Empty())
return Peek(); // Return the Token::END
Next();
auto ret = Peek();
PutBack();
return ret;
}
const Token* Back() const {
auto back = end_;
return *--back;
}
void PopBack() {
assert(!Empty());
assert(end_ == tokList_->end());
auto size_eq1 = tokList_->back() == *begin_;
tokList_->pop_back();
end_ = tokList_->end();
if (size_eq1)
begin_ = end_;
}
TokenList::iterator Mark() { return begin_; }
void ResetTo(TokenList::iterator mark) { begin_ = mark; }
bool Empty() const { return Peek()->tag_ == Token::END; }
void InsertBack(TokenSequence& ts) {
auto pos = tokList_->insert(end_, ts.begin_, ts.end_);
if (begin_ == end_) {
begin_ = pos;
}
}
void InsertBack(const Token* tok) {
auto pos = tokList_->insert(end_, tok);
if (begin_ == end_) {
begin_ = pos;
}
}
// If there is preceding newline
void InsertFront(TokenSequence& ts) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, ts.begin_, ts.end_);
}
void InsertFront(const Token* tok) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, tok);
}
bool IsBeginOfLine() const;
TokenSequence GetLine();
void SetParser(Parser* parser) { parser_ = parser; }
void Print(FILE* fp=stdout) const;
void Print(std::string *str) const;
private:
// Find a insert position with no preceding newline
TokenList::iterator GetInsertFrontPos() {
auto pos = begin_;
if (pos == tokList_->begin())
return pos;
--pos;
while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE)
--pos;
return ++pos;
}
TokenList* tokList_;
mutable TokenList::iterator begin_;
TokenList::iterator end_;
Parser* parser_ {nullptr};
int exceed_end {0};
};
#endif

View File

@@ -1,453 +0,0 @@
#pragma once
#ifndef _WGTCC_TYPE_H_
#define _WGTCC_TYPE_H_
#include "mem_pool.h"
#include "scope.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <list>
class Scope;
class Token;
class Expr;
class Type;
class QualType;
class VoidType;
class Identifier;
class Object;
class Constant;
class ArithmType;
class DerivedType;
class ArrayType;
class TileType;
class FuncType;
class PointerType;
class StructType;
class EnumType;
enum {
// Storage class specifiers
S_TYPEDEF = 0x01,
S_EXTERN = 0x02,
S_STATIC = 0x04,
S_THREAD = 0x08,
S_CONSTANT = 0x10,
S_GLOBAL = 0x20,
// Type specifier
T_SIGNED = 0x40,
T_UNSIGNED = 0x80,
T_CHAR = 0x100,
T_SHORT = 0x200,
T_INT = 0x400,
T_LONG = 0x800,
T_VOID = 0x1000,
T_HALF = 0x2000,
T_FLOAT = 0x4000,
T_DOUBLE = 0x8000,
T_BOOL = 0x10000,
T_COMPLEX = 0x20000,
// T_ATOMIC = 0x40000,
T_STRUCT_UNION = 0x80000,
T_ENUM = 0x100000,
T_TYPEDEF_NAME = 0x200000,
T_LLONG = 0x4000000,
// Function specifier
F_INLINE = 0x8000000,
F_NORETURN = 0x10000000,
};
struct Qualifier {
enum {
CONST = 0x01,
RESTRICT = 0x02,
VOLATILE = 0x04,
CMEM = 0x08,
MASK = CONST | RESTRICT | VOLATILE | CMEM
};
};
class QualType {
public:
QualType(Type* ptr, int quals=0x00)
: ptr_(reinterpret_cast<intptr_t>(ptr)) {
assert((quals & ~Qualifier::MASK) == 0);
ptr_ |= quals;
}
operator bool() const { return !IsNull(); }
bool IsNull() const { return GetPtr() == nullptr; }
const Type* GetPtr() const {
return reinterpret_cast<const Type*>(ptr_ & ~Qualifier::MASK);
}
Type* GetPtr() {
return reinterpret_cast<Type*>(ptr_ & ~Qualifier::MASK);
}
Type& operator*() { return *GetPtr(); }
const Type& operator*() const { return *GetPtr(); }
Type* operator->() { return GetPtr(); }
const Type* operator->() const { return GetPtr(); }
// Indicate whether the specified types are identical(exclude qualifiers).
friend bool operator==(QualType lhs, QualType rhs) {
return lhs.operator->() == rhs.operator->();
}
friend bool operator!=(QualType lhs, QualType rhs) {
return !(lhs == rhs);
}
int Qual() const { return ptr_ & 0x07; }
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
private:
intptr_t ptr_;
};
class Type {
public:
static const int intWidth_ = 4;
static const int machineWidth_ = 8;
bool operator!=(const Type& other) const = delete;
bool operator==(const Type& other) const = delete;
virtual bool Compatible(const Type& other) const {
return complete_ == other.complete_;
}
virtual ~Type() {}
// For Debugging
virtual std::string Str() const = 0;
virtual int Width() const = 0;
virtual int Align() const { return Width(); }
static int MakeAlign(int offset, int align) {
if ((offset % align) == 0)
return offset;
if (offset >= 0)
return offset + align - (offset % align);
else
return offset - align - (offset % align);
}
static QualType MayCast(QualType type, bool inProtoScope=false);
bool Complete() const { return complete_; }
void SetComplete(bool complete) const { complete_ = complete; }
bool IsReal() const { return IsInteger() || IsFloat(); };
virtual bool IsScalar() const { return false; }
virtual bool IsFloat() const { return false; }
virtual bool IsInteger() const { return false; }
virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; }
virtual bool IsTile() const { return ToTile() != nullptr; }
const Type* ScalarType() const;
Type* ScalarType();
virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; }
virtual ArithmType* ToArithm() { return nullptr; }
virtual const ArithmType* ToArithm() const { return nullptr; }
virtual ArrayType* ToArray() { return nullptr; }
virtual const ArrayType* ToArray() const { return nullptr; }
virtual TileType* ToTile() { return nullptr; }
virtual const TileType* ToTile() const { return nullptr; }
virtual FuncType* ToFunc() { return nullptr; }
virtual const FuncType* ToFunc() const { return nullptr; }
virtual PointerType* ToPointer() { return nullptr; }
virtual const PointerType* ToPointer() const { return nullptr; }
virtual DerivedType* ToDerived() { return nullptr; }
virtual const DerivedType* ToDerived() const { return nullptr; }
virtual StructType* ToStruct() { return nullptr; }
virtual const StructType* ToStruct() const { return nullptr; }
protected:
Type(MemPool* pool, bool complete)
: complete_(complete), pool_(pool) {}
mutable bool complete_;
MemPool* pool_;
};
class VoidType : public Type {
public:
static VoidType* New();
virtual ~VoidType() {}
virtual VoidType* ToVoid() { return this; }
virtual const VoidType* ToVoid() const { return this; }
virtual bool Compatible(const Type& other) const { return other.ToVoid(); }
virtual int Width() const {
// Non-standard GNU extension
return 1;
}
virtual std::string Str() const { return "void:1"; }
protected:
explicit VoidType(MemPool* pool): Type(pool, false) {}
};
class ArithmType : public Type {
public:
static ArithmType* New(int typeSpec);
virtual ~ArithmType() {}
virtual ArithmType* ToArithm() { return this; }
virtual const ArithmType* ToArithm() const { return this; }
virtual bool Compatible(const Type& other) const {
// C11 6.2.7 [1]: Two types have compatible type if their types are the same
// But I would to loose this constraints: integer and pointer are compatible
// if (IsInteger() && other.ToPointer())
// return other.Compatible(*this);
return this == &other;
}
virtual int Width() const;
virtual std::string Str() const;
virtual bool IsScalar() const { return true; }
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
virtual bool IsFloat() const {
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
}
virtual bool IsBool() const { return tag_ & T_BOOL; }
bool IsComplex() const { return tag_ & T_COMPLEX; }
int Tag() const { return tag_; }
int Rank() const;
static ArithmType* IntegerPromote(ArithmType* type) {
assert(type->IsInteger());
if (type->Rank() < ArithmType::New(T_INT)->Rank())
return ArithmType::New(T_INT);
return type;
}
static ArithmType* MaxType(ArithmType* lhsType,
ArithmType* rhsType);
protected:
explicit ArithmType(MemPool* pool, int spec)
: Type(pool, true), tag_(Spec2Tag(spec)) {}
private:
static int Spec2Tag(int spec);
int tag_;
};
class DerivedType : public Type {
public:
QualType Derived() const { return derived_; }
void SetDerived(QualType derived) { derived_ = derived; }
virtual DerivedType* ToDerived() { return this; }
virtual const DerivedType* ToDerived() const { return this; }
protected:
DerivedType(MemPool* pool, QualType derived)
: Type(pool, true), derived_(derived) {}
QualType derived_;
};
class PointerType : public DerivedType {
public:
static PointerType* New(QualType derived);
virtual ~PointerType() {}
virtual PointerType* ToPointer() { return this; }
virtual const PointerType* ToPointer() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 8; }
virtual bool IsScalar() const { return true; }
virtual bool IsVoidPointer() const { return derived_->ToVoid(); }
virtual std::string Str() const {
return derived_->Str() + "*:" + std::to_string(Width());
}
protected:
PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {}
};
class ArrayType : public DerivedType {
public:
static ArrayType* New(int len, QualType eleType);
static ArrayType* New(Expr* expr, QualType eleType);
virtual ~ArrayType() { /*delete derived_;*/ }
virtual ArrayType* ToArray() { return this; }
virtual const ArrayType* ToArray() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const {
return Complete() ? (derived_->Width() * len_): 0;
}
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[]:" + std::to_string(Width());
}
int GetElementOffset(int idx) const { return derived_->Width() * idx; }
int Len() const { return len_; }
void SetLen(int len) { len_ = len; }
bool Variadic() const { return lenExpr_ != nullptr; }
protected:
ArrayType(MemPool* pool, Expr* lenExpr, QualType derived)
: DerivedType(pool, derived),
lenExpr_(lenExpr), len_(0) {
SetComplete(false);
}
ArrayType(MemPool* pool, int len, QualType derived)
: DerivedType(pool, derived),
lenExpr_(nullptr), len_(len) {
SetComplete(len_ >= 0);
}
const Expr* lenExpr_;
int len_;
};
class TileType : public DerivedType {
public:
using ShapeExpr = std::vector<Expr*>;
using ShapeInt = std::vector<int>;
public:
static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { }
virtual TileType* ToTile() { return this; }
virtual const TileType* ToTile() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width());
}
ShapeInt Shape() { return shape_; }
int NumEle() const {
int ret = 1;
for(int s: shape_)
ret *= s;
return ret;
}
bool CheckPow2NumEl() const {
int n = NumEle();
return n && !(n & (n - 1));
}
protected:
TileType(MemPool* pool, const ShapeInt& shape, QualType derived);
protected:
ShapeExpr shapeExpr_;
ShapeInt shape_;
};
class FuncType : public DerivedType {
public:
using ParamList = std::vector<Object*>;
public:
static FuncType* New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params);
virtual ~FuncType() {}
virtual FuncType* ToFunc() { return this; }
virtual const FuncType* ToFunc() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 1; }
virtual std::string Str() const;
const ParamList& Params() const { return params_; }
void SetParams(const ParamList& params) { params_ = params; }
bool Variadic() const { return variadic_; }
bool IsInline() const { return inlineNoReturn_ & F_INLINE; }
bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; }
protected:
FuncType(MemPool* pool, QualType derived, int inlineReturn,
bool variadic, const ParamList& params)
: DerivedType(pool, derived), inlineNoReturn_(inlineReturn),
variadic_(variadic), params_(params) {
SetComplete(false);
}
private:
int inlineNoReturn_;
bool variadic_;
ParamList params_;
};
class StructType : public Type {
public:
using MemberList = std::list<Object*>;
using Iterator = std::list<Object*>::iterator;
public:
static StructType* New(bool isStruct,
bool hasTag,
Scope* parent);
virtual ~StructType() {}
virtual StructType* ToStruct() { return this; }
virtual const StructType* ToStruct() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return width_; }
virtual int Align() const { return align_; }
virtual std::string Str() const;
// struct/union
void AddMember(Object* member);
void AddBitField(Object* member, int offset);
bool IsStruct() const { return isStruct_; }
Object* GetMember(const std::string& member);
Scope* MemberMap() { return memberMap_; }
MemberList& Members() { return members_; }
int Offset() const { return offset_; }
bool HasTag() const { return hasTag_; }
void MergeAnony(Object* anony);
void Finalize();
protected:
// Default is incomplete
StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent);
StructType(const StructType& other);
private:
void CalcWidth();
bool isStruct_;
bool hasTag_;
Scope* memberMap_;
MemberList members_;
int offset_;
int width_;
int align_;
int bitFieldAlign_;
};
#endif

View File

@@ -1,56 +0,0 @@
#pragma once
#ifndef _WGTCC_VISITOR_H_
#define _WGTCC_VISITOR_H_
class BinaryOp;
class UnaryOp;
class TransOp;
class ConditionalOp;
class FuncCall;
class Identifier;
class Object;
class Enumerator;
class Constant;
class TempVar;
class Declaration;
class IfStmt;
class ForStmt;
class JumpStmt;
class ReturnStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
class Visitor {
public:
virtual ~Visitor() {}
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
virtual void VisitTransOp(TransOp* trans) = 0;
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
virtual void VisitEnumerator(Enumerator* enumer) = 0;
virtual void VisitIdentifier(Identifier* ident) = 0;
virtual void VisitObject(Object* obj) = 0;
virtual void VisitConstant(Constant* cons) = 0;
virtual void VisitTempVar(TempVar* tempVar) = 0;
virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0;
virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0;
virtual void VisitFuncDef(FuncDef* funcDef) = 0;
virtual void VisitTranslationUnit(TranslationUnit* unit) = 0;
};
#endif

View File

@@ -1,27 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_ARG_H_
#define _TRITON_RUNTIME_ARG_H_
#include <string>
#include <stdexcept>
#include <sstream>
namespace triton{
namespace ir{
class type;
}
namespace driver{
class buffer;
}
namespace runtime {
}
}
#endif

View File

@@ -1,34 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_ERROR_H_
#define _TRITON_RUNTIME_ERROR_H_
#include <exception>
#include <string>
namespace triton {
namespace runtime{
namespace exception {
class base: public std::exception {};
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
class no_valid_configuration: public exception::base {
public:
no_valid_configuration(const std::string& err): err_(err) { }
const char * what() const throw(){ return err_.c_str(); }
private:
std::string err_;
};
#undef TRITON_CREATE_RUNTIME_EXCEPTION
}
}
}
#endif

View File

@@ -1,159 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_FUNCTION_H_
#define _TRITON_RUNTIME_FUNCTION_H_
#include <map>
#include <unordered_map>
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <functional>
// codegen
#include "triton/ir/function.h"
#include "triton/ir/context.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/error.h"
// driver forward declaration
namespace triton {
namespace driver{
class module;
class stream;
class kernel;
class context;
class device;
}
}
// ir forward declaration
namespace triton{
namespace ir {
class module;
class function;
class context;
}
}
namespace triton{
namespace runtime{
/* ------------------------- */
/* Compilation options */
/* ------------------------- */
struct options_t {
template<class T>
T D(const std::string& name) const {
return std::stoi(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
int num_warps;
};
/* ------------------------- */
/* Runtime arguments */
/* ------------------------- */
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T : return 1;
case INT8_T : return 1;
case INT16_T : return 2;
case INT32_T : return 4;
case INT64_T : return 8;
case HALF_T : return 2;
case FLOAT_T : return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
template<class T>
void add_arg(std::stringstream& ss, T arg) {
ss.write((char*)&arg, sizeof(T));
}
/* ------------------------- */
/* ------------------------- */
class kernel{
public:
typedef std::vector<size_t> grid_t;
public:
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
static std::tuple<std::shared_ptr<driver::module>,
std::shared_ptr<driver::kernel>,
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
public:
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
std::string get_asm(const std::string &mode);
public:
const options_t opt;
private:
driver::device* dev_;
// handles
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
};
struct config {
std::map<std::string, std::string> defines;
int num_warps;
};
class function {
public:
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
typedef std::vector<config> autotune_confs_t;
public:
function(const std::string& src, const options_t& opt, driver::device *device,
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
const std::vector<arg_type> get_signature() { return sig_; }
private:
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
std::map<std::vector<uint64_t>, kernel*> cache_;
std::vector<arg_type> sig_;
std::vector<int> align_idxs_;
std::vector<int> int_idxs_;
std::vector<int> key_idxs_;
std::vector<int> arg_size_;
std::vector<int> arg_off_;
std::vector<options_t> opts_;
std::string src_;
driver::device* device_;
};
}
}
#endif

View File

@@ -55,8 +55,8 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
if(ty->is_block_ty())
return ty->get_block_shapes();
else
return {1};
}
@@ -95,7 +95,7 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
@@ -119,7 +119,7 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
@@ -229,7 +229,7 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
@@ -251,7 +251,7 @@ std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_ins
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
@@ -317,9 +317,9 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
if(!v->get_type()->is_block_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
auto shapes = v->get_type()->get_block_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
@@ -450,8 +450,8 @@ std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
if(ty->is_block_ty()) {
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
@@ -462,7 +462,7 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
int nbytes = std::max<int>(nbits / 8, 1);
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}

View File

@@ -15,7 +15,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
auto in_shapes = arg->get_type()->get_block_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
@@ -29,8 +29,8 @@ void axes::update_graph_reshape(ir::instruction *i) {
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_tile_shapes();
auto res_shapes = reshape->get_type()->get_tile_shapes();
auto op_shapes = op->get_type()->get_block_shapes();
auto res_shapes = reshape->get_type()->get_block_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
@@ -58,10 +58,10 @@ void axes::update_graph_trans(ir::instruction *i) {
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_tile_shapes();
auto shapes = broadcast->get_type()->get_block_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
const auto& op_shapes = op_ty->get_block_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
@@ -70,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_tile_shapes();
auto shapes = dot->get_type()->get_block_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
@@ -83,7 +83,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_tile_ty())
if(!op->get_type()->is_block_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
@@ -96,7 +96,7 @@ void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_tile_ty())
if(!i->get_type()->is_block_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)

View File

@@ -325,9 +325,9 @@ layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps,
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
if(!x->get_type()->is_block_ty())
return;
if(!y->get_type()->is_tile_ty())
if(!y->get_type()->is_block_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
@@ -364,7 +364,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
const auto& shapes = largest->get_type()->get_block_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
@@ -411,7 +411,7 @@ void layouts::run(ir::module &mod) {
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_tile_shapes();
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
@@ -425,8 +425,8 @@ void layouts::run(ir::module &mod) {
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes();
ir::type::block_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)

103
lib/codegen/pass.cc Normal file
View File

@@ -0,0 +1,103 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
namespace triton {
namespace codegen {
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
dce.run(ir);
//ir::print(ir, std::cout);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
//ir::print(ir, std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu()) {
reassociate.run(ir);
cts.run(ir);
}
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
barriers.run(ir);
// ir::print(ir, std::cout);
isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str());
shared_mem = allocation.allocated_size();
}
} // namespace codegen
} // namespace triton

View File

@@ -150,7 +150,7 @@ generator::generator(analysis::axes *a_axes,
void generator::visit_value(ir::value* v) {
if(!seen_.insert(v).second)
return;
if(v->get_type()->is_tile_ty()){
if(v->get_type()->is_block_ty()){
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
auto double_buffer = layout->get_double_buffer();
// offset
@@ -384,7 +384,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// compute vector width
size_t vec = 1;
if(op->get_type()->is_tile_ty()){
if(op->get_type()->is_block_ty()){
auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]);
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
@@ -407,7 +407,10 @@ void generator::visit_load_inst(ir::load_inst* x){
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
Instruction *then_term;
Instruction *else_term;
builder_->SetInsertPoint(_ret->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
dummy->removeFromParent();
builder_->SetInsertPoint(then_term);
Value* then_ret = load(ptr);
builder_->SetInsertPoint(else_term);
@@ -441,7 +444,7 @@ void generator::visit_store_inst(ir::store_inst * x){
ir::value *val_op = x->get_value_operand();
// vector size
size_t vec = 1;
if(val_op->get_type()->is_tile_ty()){
if(val_op->get_type()->is_block_ty()){
auto ord = ords_.at(x->get_pointer_operand());
size_t aln = alignment_->get(ptr_op, ord[0]);
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
@@ -461,7 +464,10 @@ void generator::visit_store_inst(ir::store_inst * x){
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
builder_->SetInsertPoint(no_op->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
store(val, ptr);
builder_->SetInsertPoint(no_op);
@@ -501,13 +507,15 @@ void generator::visit_splat_inst(ir::splat_inst* x) {
*/
void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
ir::value* op = x->get_operand(0);
const auto& shape = op->get_type()->get_tile_shapes();
const auto& shape = op->get_type()->get_block_shapes();
for(auto out_idx: idxs_.at(x)){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++)
in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
vals_[x][out_idx] = vals_[op][in_idx];
}
// for(size_t i = 0; i < idxs_.at(x).size(); i++)
// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
}
/**
@@ -527,9 +535,9 @@ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
}
/**
* \brief Code Generation for `get_num_program`
* \brief Code Generation for `get_num_programs`
*/
void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
Module *module = builder_->GetInsertBlock()->getModule();
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
vals_[np][{}] = ret;
@@ -621,7 +629,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
//TODO: clean-up
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){
if(add->get_type()->is_block_ty()){
ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1);
ir::value* msk = add->get_operand(2);
@@ -706,9 +714,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
//TODO: clean-up
void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
// shapes
auto shape_c = C->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
// order
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
@@ -877,7 +885,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
*/
//TODO: clean-up
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes();
const auto& shapes = dot->get_type()->get_block_shapes();
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
@@ -887,8 +895,8 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
fcs[key].push_back(vals_[D][idx]);
};
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
@@ -1059,9 +1067,9 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
* \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
*/
void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
auto shape_c = C->get_type()->get_tile_shapes();
auto shape_a = A->get_type()->get_tile_shapes();
auto shape_b = B->get_type()->get_tile_shapes();
auto shape_c = C->get_type()->get_block_shapes();
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
@@ -1161,7 +1169,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
ir::value *D = dot->get_operand(2);
Type *c_ty = cvt(D->get_type()->get_scalar_ty());
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
auto A_shapes = A->get_type()->get_block_shapes();
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
@@ -1236,7 +1244,10 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
// reduce across warps
Value *cond = icmp_eq(warp, i32(0));
Instruction *barrier = add_barrier();
builder_->SetInsertPoint(barrier->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
Value* ret = load(gep(base, thread));
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
@@ -1359,10 +1370,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
* \brief Code Generation for `select`
*/
void generator::visit_select_inst(ir::select_inst* x) {
for(indices_t idx: idxs_.at(x))
for(indices_t idx: idxs_.at(x)){
vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
vals_[x->get_operand(1)][idx],
vals_[x->get_operand(2)][idx]);
}
}
/**
@@ -1370,7 +1382,7 @@ void generator::visit_select_inst(ir::select_inst* x) {
*/
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
ir::value *op = rc->get_operand(0);
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
// pointer to temporary shared memory
Type *ty = cvt(rc->get_type()->get_scalar_ty());
// layout
@@ -1435,7 +1447,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes();
auto shapes = x->get_type()->get_block_shapes();
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
@@ -1520,7 +1532,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes();
auto shapes = cts->get_type()->get_block_shapes();
// store to shared
Value *current = nullptr;
@@ -1901,13 +1913,13 @@ void generator::visit_argument(ir::argument* arg) {
void generator::init_idx(ir::value *v) {
idxs_[v].clear();
if(!v->get_type()->is_tile_ty()){
if(!v->get_type()->is_block_ty()){
idxs_[v].push_back({});
return;
}
if(layouts_->get(v)->to_shared())
return;
const auto &shapes = v->get_type()->get_tile_shapes();
const auto &shapes = v->get_type()->get_block_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> ord(rank);

View File

@@ -37,7 +37,7 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
if(!a->get_type()->is_block_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
@@ -45,7 +45,7 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
@@ -80,7 +80,7 @@ void membar::transfer(ir::basic_block *block,
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));

View File

@@ -58,7 +58,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, 0) + c -> dot(a, b, c)
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) {
ir::value *lhs = add->get_operand(0);
@@ -131,10 +132,10 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
if(!x)
return false;
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
auto shapes = arg->get_type()->get_block_shapes();
if(shapes[x->get_axis()] == 1){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
x->replace_all_uses_with(new_red);
return true;
}

View File

@@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
@@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());

View File

@@ -40,7 +40,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
// handle retiling
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
auto shapes = op->get_type()->get_tile_shapes();
auto shapes = op->get_type()->get_block_shapes();
ir::value *old_arg = op->get_operand(0);
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
// retile(x + y) = retile(x) + retile(y)
@@ -54,19 +54,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
builder.set_insert_point(op);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
new_value = builder.create_add(new_lhs, new_rhs);
}
}
}
@@ -84,10 +84,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::value *rlhs = bin_lhs->get_operand(1);
// (cst + x) + y -> cst + (x + y)
if(is_cst(llhs))
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
@@ -95,10 +95,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::value *rrhs = bin_rhs->get_operand(1);
// x + (cst + y) -> cst + (x + y)
if(is_cst(lrhs))
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
// x + (y + cst) -> cst + (x + y)
if(is_cst(rrhs))
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
}
}
// extract constant and non-constant
@@ -166,7 +166,7 @@ void reassociate::run(ir::module &mod) {
ir::value* dyn = infos.at(op).dyn_ptr;
ir::value* cst = *sta->idx_begin();
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
auto shapes = rt->get_type()->get_tile_shapes();
auto shapes = rt->get_type()->get_block_shapes();
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
@@ -202,7 +202,7 @@ void reassociate::run(ir::module &mod) {
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
@@ -235,7 +235,8 @@ void reassociate::run(ir::module &mod) {
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
// re-add the offset
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
phi_sta->set_name( phi->get_name() + "_sta");
phi->replace_all_uses_with(phi_sta);
// remove offset from pz
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
@@ -245,8 +246,8 @@ void reassociate::run(ir::module &mod) {
builder.set_insert_point(*it);
}
ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_tile_ty())
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
if(off->get_type()->is_block_ty())
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));

View File

@@ -11,38 +11,38 @@ namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
ir::builder &builder = mod.get_builder();
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
// ir::builder &builder = mod.get_builder();
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
ir::value* _ptr = ld->get_pointer_operand();
ir::value* _msk = ld->get_mask_operand();
ir::value* _val = ld->get_false_value_operand();
auto ptr = std::find(block->begin(), block->end(), _ptr);
auto msk = std::find(block->begin(), block->end(), _msk);
auto val = std::find(block->begin(), block->end(), _val);
if(ptr == block->end() || msk == block->end() || val == block->end())
continue;
auto it = std::find(block->begin(), block->end(), i);
int dist_ptr = std::distance(ptr, it);
int dist_msk = std::distance(msk, it);
int dist_val = std::distance(val, it);
if(dist_ptr < dist_msk && dist_ptr < dist_val)
builder.set_insert_point(++ptr);
if(dist_msk < dist_ptr && dist_msk < dist_val)
builder.set_insert_point(++msk);
if(dist_val < dist_ptr && dist_val < dist_msk)
builder.set_insert_point(++val);
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
to_replace.push_back(std::make_pair(ld, new_ld));
}
}
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction* i: block->get_inst_list()){
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
// ir::value* _ptr = ld->get_pointer_operand();
// ir::value* _msk = ld->get_mask_operand();
// ir::value* _val = ld->get_false_value_operand();
// auto ptr = std::find(block->begin(), block->end(), _ptr);
// auto msk = std::find(block->begin(), block->end(), _msk);
// auto val = std::find(block->begin(), block->end(), _val);
// if(ptr == block->end() || msk == block->end() || val == block->end())
// continue;
// auto it = std::find(block->begin(), block->end(), i);
// int dist_ptr = std::distance(ptr, it);
// int dist_msk = std::distance(msk, it);
// int dist_val = std::distance(val, it);
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
// builder.set_insert_point(++ptr);
// if(dist_msk < dist_ptr && dist_msk < dist_val)
// builder.set_insert_point(++msk);
// if(dist_val < dist_ptr && dist_val < dist_msk)
// builder.set_insert_point(++val);
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
// to_replace.push_back(std::make_pair(ld, new_ld));
// }
// }
for(auto& x: to_replace)
x.first->replace_all_uses_with(x.second);
// for(auto& x: to_replace)
// x.first->replace_all_uses_with(x.second);
}

View File

@@ -212,7 +212,7 @@ static std::map<int, int> vptx = {
{11020, 72}
};
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 64;
@@ -316,6 +316,7 @@ void cu_module::init_from_ptx(const std::string& ptx) {
// log = match.suffix();
// }
// std::cout << log << std::endl;
// std::cout << ptx_ << std::endl;
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
@@ -351,8 +352,9 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
oss << *ll_module;
oss.flush();
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
if(cache_path.empty())
ptx_ = compile_llvm_module(std::move(ll_module), device);
if(cache_path.empty()){
ptx_ = compile_llvm_module(ll_module.get(), device);
}
else{
tools::mkdir(cache_path);
// update cache path to PTX file
@@ -370,7 +372,7 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
ptx_ = _ptx.str();
// compile and write-back if read empty
if(ptx_.empty()){
ptx_ = compile_llvm_module(std::move(ll_module), device);
ptx_ = compile_llvm_module(ll_module.get(), device);
std::ofstream ofs(cache_path);
ofs << ptx_;
}

View File

@@ -54,6 +54,18 @@ value *builder::get_int32(int32_t val)
value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_float16(float val)
{ return constant_fp::get(type::get_half_ty(ctx_), val); }
value *builder::get_float32(float val)
{ return constant_fp::get(type::get_float_ty(ctx_), val); }
value *builder::get_range(int32_t _lo, int32_t _hi) {
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
return insert(make_range::create(lo, hi));
}
type *builder::get_void_ty()
{ return type::get_void_ty(ctx_); }
@@ -105,8 +117,8 @@ value *builder::create_ret_void() {
// cast instructions
//===----------------------------------------------------------------------===//
#define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *src, type *dst_ty, std::string const &name){\
return create_cast(OPCODE, src, dst_ty, name);\
value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\
return create_cast(OPCODE, src, dst_ty);\
}
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
@@ -117,20 +129,20 @@ DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name){
return insert(cast_inst::create(op, v, dst_ty), name);
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){
return insert(cast_inst::create(op, v, dst_ty));
}
value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name){
return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed), name);
value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){
return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed));
}
//===----------------------------------------------------------------------===//
// phi instructions
//===----------------------------------------------------------------------===//
phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string &name){
return insert(phi_node::create(ty, num_reserved), name);
phi_node* builder::create_phi(type *ty, unsigned num_reserved){
return insert(phi_node::create(ty, num_reserved));
}
//===----------------------------------------------------------------------===//
@@ -138,8 +150,8 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
//===----------------------------------------------------------------------===//
#define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
return insert(binary_operator::create(OPCODE, lhs, rhs));\
}
// Binary
@@ -156,22 +168,22 @@ DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
value *rhs, const std::string &name,
value *rhs,
bool has_nuw, bool has_nsw) {
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
binary_operator* result = insert(binary_operator::create(op, lhs, rhs));
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw){\
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, has_nuw, has_nsw);\
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\
}\
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\
}
@@ -196,21 +208,21 @@ DEFINE_BINARY_INT(xor, binary_op_t::Xor)
// getelementptr instructions
//===----------------------------------------------------------------------===//
value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name){
return insert(getelementptr_inst::create(ptr, idx_list), name);
value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list){
return insert(getelementptr_inst::create(ptr, idx_list));
}
//===----------------------------------------------------------------------===//
// icmp instructions
//===----------------------------------------------------------------------===//
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
return insert(icmp_inst::create(pred, lhs, rhs), name);
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){
return insert(icmp_inst::create(pred, lhs, rhs));
}
#define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\
value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return create_icmp(OPCODE, lhs, rhs, name);\
value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\
return create_icmp(OPCODE, lhs, rhs);\
}
// Signed
@@ -232,13 +244,13 @@ DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
// fcmp instructions
//===----------------------------------------------------------------------===//
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name){
return insert(fcmp_inst::create(pred, lhs, rhs), name);
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){
return insert(fcmp_inst::create(pred, lhs, rhs));
}
#define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\
value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return create_fcmp(OPCODE, lhs, rhs, name);\
value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\
return create_fcmp(OPCODE, lhs, rhs);\
}
// Ordered
@@ -255,102 +267,92 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
// load/store instructions
//===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr, const std::string &name){
return insert(unmasked_load_inst::create(ptr, name));
// type *ty = ptr->get_type()->get_pointer_element_ty();
// value *mask = constant_int::get(get_int1_ty(), 1);
// value *undef = undef_value::get(ty);
// if(ptr->get_type()->is_tile_ty()){
// auto shapes = ptr->get_type()->get_tile_shapes();
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
// }
// return insert(masked_load_inst::create(ptr, mask, undef, name));
value *builder::create_load(value *ptr){
return insert(unmasked_load_inst::create(ptr));
}
value *builder::create_store(value *ptr, value *val, const std::string &name){
return insert(unmasked_store_inst::create(ptr, val, name));
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){
return insert(masked_load_inst::create(ptr, mask, false_value, name));
value *builder::create_masked_load(value *ptr, value *mask, value *false_value){
return insert(masked_load_inst::create(ptr, mask, false_value));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
return insert(masked_store_inst::create(ptr, val, mask, name));
value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
}
//===----------------------------------------------------------------------===//
// tile instructions
// block instructions
//===----------------------------------------------------------------------===//
value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(reshape_inst::create(arg, shapes, name));
value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
return insert(reshape_inst::create(arg, shapes));
}
value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(splat_inst::create(arg, shapes, name));
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
return insert(splat_inst::create(arg, shapes));
}
value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(broadcast_inst::create(arg, shapes, name));
value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) {
return insert(broadcast_inst::create(arg, shapes));
}
value *builder::create_downcast(value *arg, const std::string &name) {
return insert(downcast_inst::create(arg, name));
value *builder::create_downcast(value *arg) {
return insert(downcast_inst::create(arg));
}
//===----------------------------------------------------------------------===//
// built-in instructions
//===----------------------------------------------------------------------===//
value *builder::create_get_program_id(unsigned axis, const std::string &name) {
return insert(get_program_id_inst::create(ctx_, axis, name));
value *builder::create_get_program_id(unsigned axis) {
return insert(get_program_id_inst::create(ctx_, axis));
}
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
return insert(get_num_program_inst::create(ctx_, axis, name));
value *builder::create_get_num_programs(unsigned axis) {
return insert(get_num_programs_inst::create(ctx_, axis));
}
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
return insert(atomic_cas_inst::create(ptr, cmp, val));
}
value *builder::create_atomic_exch(value *ptr, value *val, const std::string &name){
return insert(atomic_exch_inst::create(ptr, val, name));
value *builder::create_atomic_exch(value *ptr, value *val){
return insert(atomic_exch_inst::create(ptr, val));
}
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, msk, name));
value *builder::create_atomic_add(value *ptr, value *val, value *msk){
return insert(atomic_add_inst::create(ptr, val, msk));
}
value *builder::create_exp(value *arg, const std::string &name){
return insert(exp_inst::create(arg, name));
value *builder::create_exp(value *arg){
return insert(exp_inst::create(arg));
}
value *builder::create_log(value *arg, const std::string &name){
return insert(log_inst::create(arg, name));
value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
return insert(dot_inst::create_nn(A, B, C, name));
value *builder::create_dot(value *A, value *B, value *C) {
return insert(dot_inst::create_nn(A, B, C));
}
value *builder::create_trans(value *A, const std::vector<int>& perm, const std::string &name) {
return insert(trans_inst::create(A, perm, name));
value *builder::create_trans(value *A, const std::vector<int>& perm) {
return insert(trans_inst::create(A, perm));
}
value *builder::create_sqrt(value *A, const std::string &name) {
return insert(sqrt_inst::create(A, name));
value *builder::create_sqrt(value *A) {
return insert(sqrt_inst::create(A));
}
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, op, axis, name));
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) {
return insert(reduce_inst::create(A, op, axis));
}
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
return insert(select_inst::create(pred, if_value, else_value, name));
value *builder::create_select(value *pred, value *if_value, value *else_value){
return insert(select_inst::create(pred, if_value, else_value));
}
//===----------------------------------------------------------------------===//
@@ -358,26 +360,28 @@ value *builder::create_select(value *pred, value *if_value, value *else_value, c
//===----------------------------------------------------------------------===//
value *builder::create_copy_to_shared(value *arg, const std::string &name) {
return insert(copy_to_shared_inst::create(arg, name));
value *builder::create_copy_to_shared(value *arg) {
return insert(copy_to_shared_inst::create(arg));
}
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
return insert(copy_from_shared_inst::create(arg, name));
value *builder::create_copy_from_shared(value *arg) {
return insert(copy_from_shared_inst::create(arg));
}
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) {
return insert(masked_load_async_inst::create(ptr, mask, false_value));
}
value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
return insert(barrier_inst::create(ctx_));
}
value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_, N));
}
}
}

View File

@@ -30,9 +30,9 @@ constant *constant::get_null_value(type *ty) {
constant *constant::get_all_ones_value(type *ty) {
if(ty->is_integer_ty())
return constant_int::get(ty, 0xFFFFFFFF);
return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF);
if(ty->is_floating_point_ty())
return constant_fp::get(ty, 0xFFFFFFFF);
return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF);
throw std::runtime_error("Cannot create all ones value for that type!");
}

613
lib/ir/dispatch.cc Normal file
View File

@@ -0,0 +1,613 @@
#include "triton/ir/dispatch.h"
#include <iostream>
namespace triton{
namespace ir{
ir::value* throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
return 0;
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
}
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
return a_rank > b_rank ? a_ty : b_ty;
}
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if(a_ty->is_double_ty() || b_ty->is_double_ty())
return type::get_double_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if(a_ty->is_float_ty() || b_ty->is_float_ty())
return type::get_float_ty(ctx);
// 3 ) if one operand is half, the other is implicitly
// converted to half
if(a_ty->is_half_ty() || b_ty->is_half_ty())
return type::get_half_ty(ctx);
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("augment_types");
// 4 ) both operands are integer and undergo
// integer promotion
return integer_promote(a_ty, b_ty);
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
}
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
if(type_a->is_pointer_ty()){
if(!allow_ptr_a)
throw_incompatible_types(type_a, type_b);
// T* + U* with T != U
if(type_b->is_pointer_ty() && (type_a != type_b))
throw_incompatible_types(type_a, type_b);
// T* + float
if(type_b->is_floating_point_ty())
throw_incompatible_types(type_a, type_b);
}
}
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true){
// implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty);
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
}
}
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, true);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// offset + ptr
// ptr + offset
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
std::swap(input, other);
if (input_scalar_ty->is_pointer_ty())
return builder->create_gep(input, {other});
// float + float
else if (input_scalar_ty->is_floating_point_ty())
return builder->create_fadd(input, other);
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
return throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, false);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// ptr - offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(input, {dispatch::minus(other, builder)});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(input, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
return throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(input, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
return throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
other = cast(other, input_scalar_ty, builder);
// int / float
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
input = cast(input, other_scalar_ty, builder);
// int / int (cast to float32)
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
input = cast(input, builder->get_float_ty(), builder);
other = cast(other, builder->get_float_ty(), builder);
}
// float / float (cast to highest exponent type)
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
other = cast(other, input_scalar_ty, builder);
else
input = cast(input, other_scalar_ty, builder);
}
// unreachable
else
return throw_unreachable("div");
return builder->create_fdiv(input, other);
}
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
return builder->create_sdiv(input, other);
}
return throw_unreachable("floordiv");
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(input, other);
return throw_unreachable("mod");
}
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
// for some reason pytorch assigns the result of binary op to have the type of the lhs...
if(force_lhs_type){
if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
}
else{
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
input = dispatch::cast(input, other_sca_ty, builder);
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
}
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_and(input, other);
}
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_or(input, other);
}
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_xor(input, other);
}
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false);
return builder->create_lshr(input, other);
}
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false);
return builder->create_shl(input, other);
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
return input;
}
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty())
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
return dispatch::sub(_0, input, builder);
}
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
return dispatch::xor_(input, _1, builder);
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(input, other);
return throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(input, other);
return throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(input, other);
return throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(input, other);
return throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
return throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpONE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
return throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
}
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
unsigned numel = 1;
for(unsigned s: dst_shape) numel *= s;
if(input->get_type()->get_tile_num_elements() != numel)
throw semantic_error("cannot reshape block of different shape");
return builder->create_reshape(input, dst_shape);
}
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
return builder->create_broadcast(input, shape);
}
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
return builder->create_fp_to_ui(input, dst_ty);
else
return builder->create_fp_to_si(input, dst_ty);
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if(src_sca_ty->is_bool_ty())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
}
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, builder->get_int64_ty(), builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
return throw_unreachable("cast");
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(other){
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
}
}
if (!mask && !other)
return builder->create_load(ptr);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::type *elt_ty = ptr->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty())
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type();
val = dispatch::cast(val, ptr_ty->get_scalar_ty()->get_pointer_element_ty(), builder);
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
throw semantic_error("Mask must have boolean scalar type");
return builder->create_masked_store(ptr, val, mask);
}
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
return builder->create_atomic_cas(ptr, cmp, val);
}
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder){
return builder->create_atomic_exch(ptr, val);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
if(condition->get_type()->is_block_ty()){
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty())
throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty());
return builder->create_select(condition, x, y);
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
return throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
return builder->create_exp(x);
}
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
return builder->create_log(x);
}
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("multiple_of");
i->set_metadata(ir::metadata::multiple_of, value);
return i;
}
ir::value *dispatch::debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
}
}

View File

@@ -1,4 +1,5 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/context.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
@@ -30,9 +31,9 @@ void instruction::erase_from_parent() {
}
bool instruction::has_tile_result_or_op() {
bool result = get_type()->is_tile_ty();
bool result = get_type()->is_block_ty();
for(unsigned i = 0; i < get_num_operands(); i++)
result |= get_operand(i)->get_type()->is_tile_ty();
result |= get_operand(i)->get_type()->is_block_ty();
return result;
}
@@ -209,8 +210,8 @@ cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *
type* cmp_inst::make_cmp_result_type(type *ty){
type* int1_ty = type::get_int1_ty(ty->get_context());
if (tile_type* tile_ty = dynamic_cast<tile_type*>(ty))
return tile_type::get_same_shapes(int1_ty, tile_ty);
if (block_type* tile_ty = dynamic_cast<block_type*>(ty))
return block_type::get_same_shapes(int1_ty, tile_ty);
return int1_ty;
}
@@ -279,7 +280,7 @@ std::string cast_inst::repr_impl() const {
}
// TODO
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty());
assert(arg->get_type()->is_block_ty() == ty->is_block_ty());
return true;
}
@@ -383,11 +384,11 @@ type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vec
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
// Tile GEP
if(ty->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, ty);
if(ty->is_block_ty())
return block_type::get_same_shapes(ptr_ty, ty);
for(value *idx : idx_list)
if (idx->get_type()->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, ty);
if (idx->get_type()->is_block_ty())
return block_type::get_same_shapes(ptr_ty, ty);
// Scalar GEP
return ptr_ty;
}
@@ -440,8 +441,8 @@ load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::str
type *load_inst::get_pointee_type(type *ty) {
type *scalar_ty = ty->get_scalar_ty();
type *pointee_ty = scalar_ty->get_pointer_element_ty();
if(ty->is_tile_ty())
return tile_type::get_same_shapes(pointee_ty, ty);
if(ty->is_block_ty())
return block_type::get_same_shapes(pointee_ty, ty);
return pointee_ty;
}
@@ -531,14 +532,14 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
// retile_inst classes
//===----------------------------------------------------------------------===//
retile_inst::retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes,
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
const std::string &name, instruction *next)
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
// reshape
instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
const std::string &name, instruction *next) {
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
}
@@ -546,14 +547,14 @@ instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
// splat
instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes,
const std::string &name, instruction *next) {
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
}
// broadcast
instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes,
const std::string &name, instruction *next) {
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
}
@@ -610,20 +611,20 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
// get argument shapes
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes();
// permutate argument shapes
perm = init_perm(ty, perm);
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
ir::block_type::block_shapes_t res_shapes = arg_shapes;
for(size_t i = 0; i < perm.size(); i++)
res_shapes[i] = arg_shapes[perm[i]];
// construct type
return tile_type::get(ty->get_scalar_ty(), res_shapes);
return block_type::get(ty->get_scalar_ty(), res_shapes);
}
std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
if(!perm.empty())
return perm;
auto size = ty->get_tile_shapes().size();
auto size = ty->get_block_shapes().size();
std::vector<int> result;
result.push_back(size - 1);
for(size_t i = 0; i < size - 1; i++)
@@ -682,13 +683,13 @@ std::string reduce_inst::to_str(op_t op) {
}
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes();
shapes.erase(shapes.begin() + axis);
type *scalar_ty = arg->get_type()->get_scalar_ty();
if(shapes.empty())
// shapes.push_back(1);
return scalar_ty;
return tile_type::get(scalar_ty, shapes);
return block_type::get(scalar_ty, shapes);
}
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
@@ -733,13 +734,13 @@ instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std:
}
// get_num_program
get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
}
instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_num_program_inst(type::get_int32_ty(ctx), axis, name, next);
instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
}
@@ -863,7 +864,7 @@ make_range *make_range::create(constant_int *first, constant_int *last) {
assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type());
assert(((constant_int*)first)->get_value() == 0);
type *ty = tile_type::get(first->get_type(), {(unsigned)last->get_value()});
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value()});
return new make_range(ty, first, last);
}

View File

@@ -10,8 +10,8 @@ namespace triton{
namespace ir{
/* Module */
module::module(const std::string &name)
: name_(name), builder_(context_) {
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
sealed_blocks_.insert(nullptr);
}
@@ -19,10 +19,6 @@ ir::builder& module::get_builder() {
return builder_;
}
ir::context& module::get_context() {
return context_;
}
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
@@ -98,7 +94,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = get_scope().types.at(name);
ir::type *ty = get_scope().get_type(name);
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];

View File

@@ -14,7 +14,7 @@ namespace ir{
// attributes
type *type::get_scalar_ty() const {
if(is_tile_ty())
if(is_block_ty())
return get_tile_element_ty();
return const_cast<type*>(this);
}
@@ -28,7 +28,7 @@ unsigned type::get_primitive_size_in_bits() const {
case FP128TyID: return 128;
case PPC_FP128TyID: return 128;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
case TileTyID: return ((tile_type*)(this))->get_bitwidth();
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0;
}
}
@@ -37,19 +37,19 @@ unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
unsigned type::get_tile_bitwidth() const
{ return ((tile_type*)(this))->get_bitwidth(); }
{ return ((block_type*)(this))->get_bitwidth(); }
unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == HalfTyID) return 11;
if (id == FloatTyID) return 24;
if (id == HalfTyID) return 10;
if (id == FloatTyID) return 23;
if (id == DoubleTyID) return 53;
throw std::runtime_error("unreachable");
}
type* type::get_tile_element_ty() const {
assert(is_tile_ty());
assert(is_block_ty());
return contained_tys_[0];
}
@@ -62,31 +62,31 @@ type * type::get_pointer_element_ty() const {
type *ptr_ty = get_scalar_ty();
assert(ptr_ty->is_pointer_ty());
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
if(is_tile_ty())
return tile_type::get_same_shapes(scalar_ty, (type*)this);
if(is_block_ty())
return block_type::get_same_shapes(scalar_ty, (type*)this);
return scalar_ty;
}
const type::tile_shapes_t &type::get_tile_shapes() const {
assert(is_tile_ty());
return ((tile_type*)this)->get_shapes();
type::block_shapes_t type::get_block_shapes() const {
assert(is_block_ty());
return ((block_type*)this)->get_shapes();
}
const size_t type::get_tile_rank() const {
return get_tile_shapes().size();
return get_block_shapes().size();
}
const size_t type::get_tile_ranks1() const {
int ret = 0;
for(int s: get_tile_shapes())
for(int s: get_block_shapes())
ret += s > 1;
return ret;
}
unsigned type::get_tile_num_elements() const {
const tile_shapes_t& shapes = get_tile_shapes();
const block_shapes_t& shapes = get_block_shapes();
unsigned result = 1;
for(auto shape: shapes)
result *= shape;
@@ -112,7 +112,7 @@ bool type::is_sized() const {
return true;
}
// tile types are sizes
if(is_tile_ty())
if(is_block_ty())
return get_scalar_ty()->is_sized();
return false;
}
@@ -160,12 +160,12 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
//===----------------------------------------------------------------------===//
type* composite_type::get_type_at_index(value *) const{
assert(is_tile_ty());
assert(is_block_ty());
return get_scalar_ty();
}
bool composite_type::index_valid(value *idx) const{
assert(is_tile_ty());
assert(is_block_ty());
return idx->get_type()->is_int_or_tileint_ty();
}
@@ -173,41 +173,41 @@ bool composite_type::index_valid(value *idx) const{
// tile_type class
//===----------------------------------------------------------------------===//
tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
block_type::block_type(type *ty, const block_shapes_t &shapes)
: composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
contained_tys_.push_back(ty);
}
bool tile_type::is_valid_elt_ty(type *ty) {
bool block_type::is_valid_elt_ty(type *ty) {
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
}
unsigned tile_type::get_num_elements() const {
unsigned block_type::get_num_elements() const {
unsigned res = 1;
for(auto shape: shapes_)
res *= shape;
return res;
}
unsigned tile_type::get_bitwidth() const {
unsigned block_type::get_bitwidth() const {
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
}
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
assert(elt_ty && "Can't get a tile of <null> type!");
assert(shapes.size() && "Can't create a tile with empty shapes!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
// look-up
context_impl *impl = elt_ty->get_context().p_impl.get();
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
if(!entry)
entry = new tile_type(elt_ty, shapes);
entry = new block_type(elt_ty, shapes);
return entry;
}
tile_type* tile_type::get_same_shapes(type *ty, type *ref){
assert(ref->is_tile_ty());
return get(ty, ref->get_tile_shapes());
block_type* block_type::get_same_shapes(type *ty, type *ref){
assert(ref->is_block_ty());
return get(ty, ref->get_block_shapes());
}
//===----------------------------------------------------------------------===//

File diff suppressed because it is too large Load Diff

View File

@@ -1,847 +0,0 @@
#include "triton/lang/code_gen.h"
#include "triton/lang/evaluator.h"
#include "triton/lang/parser.h"
#include "triton/lang/token.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
// Helpers
void Generator::set_ret(ir::value* value) {
ret_ = value;
}
inline bool is_terminator(ir::value* x) {
return x && dynamic_cast<ir::terminator_inst*>(x);
}
// Expression
void Generator::VisitBinaryOp(BinaryOp* binary) {
Visit(binary->rhs_);
ir::value* rhs = ret_;
if(binary->op_ == '=')
return set_ret(assign_->GenExpr(binary->lhs_, rhs));
Visit(binary->lhs_);
ir::value* lhs = ret_;
// op info
auto type = binary->lhs_->Type()->ScalarType();
auto flt = type->IsFloat();
auto sign = !type->IsUnsigned();
// return
switch(binary->op_){
case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs));
case Token::LOGICAL_OR: return set_ret(bld_->create_or(lhs, rhs));
case '|': return set_ret(bld_->create_or(lhs, rhs));
case '&': return set_ret(bld_->create_and(lhs, rhs));
case '^': return set_ret(bld_->create_xor(lhs, rhs));
case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs));
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
case '.': return error_not_implemented(". binary operator not implemented");
case ',': return error_not_implemented(", binary operator not implemented");
case '@' : {
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
ir::value* _0;
if(ret_scal_ty->is_float_ty())
_0 = ir::constant_fp::get(ret_scal_ty, 0);
else
_0 = ir::constant_int::get(ret_scal_ty, 0);
_0 = bld_->create_splat(_0, ret_ty->get_tile_shapes());
return set_ret(bld_->create_dot(lhs, rhs, _0));
}
case Token::MASKED_DEREF: {
// TODO: FIXME
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
auto it = bld_->get_insert_block();
if(ret_ty->is_tile_ty())
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
bld_->set_insert_point(it);
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
}
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
if(!clhs || !crhs)
error_not_implemented("ellipsis between variables not implemented");
return set_ret(bld_->insert(ir::make_range::create(clhs, crhs)));
}
case '+':
if(binary->lhs_->Type()->ScalarType()->ToPointer()){
return set_ret(bld_->create_gep(lhs, {rhs}));
}
else if(flt)
return set_ret(bld_->create_fadd(lhs, rhs));
else
return set_ret(bld_->create_add(lhs, rhs));
case '-':
if(binary->lhs_->Type()->ToPointer())
return set_ret(bld_->create_gep(lhs, {GenUnaryMinus(rhs)}));
else if(flt)
return set_ret(bld_->create_fsub(lhs, rhs));
else
return set_ret(bld_->create_sub(lhs, rhs));
case '*':
if(flt)
return set_ret(bld_->create_fmul(lhs, rhs));
else
return set_ret(bld_->create_mul(lhs, rhs));
case '/':
if(flt)
return set_ret(bld_->create_fdiv(lhs, rhs));
else if(sign)
return set_ret(bld_->create_sdiv(lhs, rhs));
else if(!sign)
return set_ret(bld_->create_udiv(lhs, rhs));
else
return should_not_happen("/ should not encounter type not in {float, int}");
case '%':
if(flt)
return set_ret(bld_->create_frem(lhs, rhs));
else if(sign)
return set_ret(bld_->create_srem(lhs, rhs));
else
return set_ret(bld_->create_urem(lhs, rhs));
case '<':
if(flt)
return set_ret(bld_->create_fcmpOLT(lhs, rhs));
else if(sign)
return set_ret(bld_->create_icmpSLT(lhs, rhs));
else if(!sign)
return set_ret(bld_->create_icmpULT(lhs, rhs));
else
return should_not_happen("< should not encounter type not in {float, int}");
case '>':
if(flt)
return set_ret(bld_->create_fcmpOGT(lhs, rhs));
else if(sign)
return set_ret(bld_->create_icmpSGT(lhs, rhs));
else if(!sign)
return set_ret(bld_->create_icmpUGT(lhs, rhs));
else
return should_not_happen("> should not encounter type not in {float, int}");
case Token::LE:
if(flt)
return set_ret(bld_->create_fcmpOLE(lhs, rhs));
else if(sign)
return set_ret(bld_->create_icmpSLE(lhs, rhs));
else if(!sign)
return set_ret(bld_->create_icmpULE(lhs, rhs));
else
return should_not_happen("<= should not encounter type not in {float, int}");
case Token::GE:
if(flt)
return set_ret(bld_->create_fcmpOGE(lhs, rhs));
else if(sign)
return set_ret(bld_->create_icmpSGE(lhs, rhs));
else if(!sign)
return set_ret(bld_->create_icmpUGE(lhs, rhs));
else
return should_not_happen(">= should not encounter type not in {float, int}");
case Token::EQ:
if(flt)
return set_ret(bld_->create_fcmpOEQ(lhs, rhs));
else
return set_ret(bld_->create_icmpEQ(lhs, rhs));
case Token::NE:
if(flt)
return set_ret(bld_->create_fcmpONE(lhs, rhs));
else
return set_ret(bld_->create_icmpNE(lhs, rhs));
default:
return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented");
}
should_not_happen("");
}
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
using ir::reduce_inst;
switch(tag){
case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD;
case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB;
case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX;
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
default: break;
}
error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented");
return reduce_inst::op_t();
}
ir::value* Generator::GenUnaryMinus(ir::value* arg) {
ir::type *ty = arg->get_type();
ir::type *sca_ty = ty->get_scalar_ty();
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
if(ty->is_tile_ty())
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
if(sca_ty->is_floating_point_ty())
return bld_->create_fsub(_0, arg);
else
return bld_->create_sub(_0, arg);
}
ir::value* Generator::GenUnaryInc(UnaryOp* expr, bool is_postfix,
bool is_inc) {
Visit(expr->operand_);
ir::value* arg = ret_;
ir::value *_1 = nullptr;
ir::value *instr = nullptr;
if (arg->get_type()->is_floating_point_ty()) {
_1 = ir::constant_fp::get(arg->get_type(), 1.0);
if (is_inc)
instr = bld_->create_fadd(arg, _1);
else
instr = bld_->create_fsub(arg, _1);
} else if (arg->get_type()->is_integer_ty()) {
_1 = ir::constant_int::get(arg->get_type(), 1);
if (is_inc)
instr = bld_->create_add(arg, _1);
else
instr = bld_->create_sub(arg, _1);
} else if (arg->get_type()->is_pointer_ty()) {
ir::type *ty = ir::type::get_int64_ty(*ctx_);
_1 = ir::constant_int::get(ty, 1);
if (is_inc)
instr = bld_->create_gep(arg, {_1});
else {
ir::value *neg_1 = ir::constant_int::get(ty, -1);
instr = bld_->create_gep(arg, {neg_1});
}
} else
error_not_implemented("data type not supported for unary inc");
mod_->set_value(arg->get_name(), instr);
if (is_postfix)
return arg;
else
return instr;
}
void Generator::VisitUnaryOp(UnaryOp* unary) {
// recursion
Visit(unary->operand_);
ir::value* arg = ret_;
ir::type *arg_ty = arg->get_type();
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
// return
switch (unary->op_) {
case Token::PREFIX_INC: return set_ret(GenUnaryInc(unary, false, true));
case Token::PREFIX_DEC: return set_ret(GenUnaryInc(unary, false, false));
case Token::POSTFIX_INC: return set_ret(GenUnaryInc(unary, true, true));
case Token::POSTFIX_DEC: return set_ret(GenUnaryInc(unary, true, false));
case Token::ADDR: return error_not_implemented("unary & not implemented");
case Token::DEREF: return set_ret(bld_->create_load(arg));
case Token::PLUS: return error_not_implemented("unary + not implemented");
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
case '~': return error_not_implemented("unary ~ not implemented");
case '!': return error_not_implemented("unary ! not implemented");
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
case Token::LOG: return set_ret(bld_->create_log(arg));
case Token::SQRTF: return set_ret(bld_->create_sqrt(arg));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
bool is_float = arg_scal_ty->is_floating_point_ty();
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
return set_ret(bld_->create_reduce(arg, op, ax));
}
default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented");
}
return should_not_happen("");
}
void Generator::VisitTransOp(TransOp *trans) {
Visit(trans->operand_);
ir::value* arg = ret_;
return set_ret(bld_->create_trans(arg, trans->getPerm()));
}
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
auto &instructions = bld_->get_insert_block()->get_inst_list();
VisitExpr(condOp->cond_);
ir::value* true_cond = ret_;
ir::instruction* start = instructions.back();
VisitExpr(condOp->exprTrue_);
ir::value* true_val = ret_;
VisitExpr(condOp->exprFalse_);
ir::value* false_val = ret_;
auto begin = std::find(instructions.begin(), instructions.end(), start);
bool is_in_true_cond = true;
for(auto it = begin; it != instructions.end(); it++){
ir::instruction* instr = *it;
// we mask load with `cond` when used to compute true_value
// we mask load with `!cond` when used to compute false_value
if(auto ld = dynamic_cast<ir::unmasked_load_inst*>(instr)){
bld_->set_insert_point(ld);
ir::type* ty = ld->get_type();
ir::value* cond = is_in_true_cond ? true_cond : true_cond;
ir::value* ptr = ld->get_pointer_operand();
ir::value* else_val = ir::undef_value::get(ty);
ir::value* masked_ld = bld_->create_masked_load(ptr, cond, else_val);
ld->replace_all_uses_with(masked_ld);
ld->erase_from_parent();
if(true_val == ld)
true_val = masked_ld;
if(false_val == ld)
false_val = masked_ld;
it = std::find(instructions.begin(), instructions.end(), masked_ld);
}
if(instr == true_val)
is_in_true_cond = false;
}
bld_->set_insert_point(bld_->get_insert_block());
return set_ret(bld_->create_select(true_cond, true_val, false_val));
// VisitExpr(condOp->cond_);
// ir::value* cond = ret_;
// VisitExpr(condOp->exprTrue_);
// ir::value* true_val = ret_;
// VisitExpr(condOp->exprFalse_);
// ir::value* false_val = ret_;
// if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
// if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
// false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
// ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
// ld->replace_all_uses_with(new_ld);
// ld->erase_from_parent();
// return set_ret(new_ld);
// }
// return set_ret(bld_->create_select(cond, true_val, false_val));
}
void Generator::VisitFuncCall(FuncCall* funcCall) {
std::string name = funcCall->Name();
if(name == "get_program_id"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_program_id(axis->get_value()));
else
return should_not_happen("get_program_id argument should be constant");
}
if(name == "get_num_programs"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_num_program(axis->get_value()));
else
return should_not_happen("get_num_programs argument should be constant");
}
if(name == "atomic_cas"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* cmp = ret_;
VisitExpr(funcCall->Args()->at(2));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_cas(ptr, cmp, val));
}
if(name == "atomic_xchg"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name.substr(0, 10) == "atomic_add"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
VisitExpr(funcCall->Args()->at(2));
ir::value* msk = ret_;
return set_ret(bld_->create_atomic_add(ptr, val, msk));
}
if(name == "calloc"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
ir::constant_int *size = dynamic_cast<ir::constant_int*>(ret);
assert(size);
ir::alloc_const* alloc = new ir::alloc_const(bld_->get_int8_ty(), size);
mod_->add_alloc(alloc);
return set_ret(alloc);
}
//TODO: integrate this into conditionalop
if(name == "select"){
VisitExpr(funcCall->Args()->at(0));
ir::value* cond = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* true_val = ret_;
VisitExpr(funcCall->Args()->at(2));
ir::value* false_val = ret_;
return set_ret(bld_->create_select(cond, true_val, false_val));
}
if(name == "__debug_barrier"){
bld_->create_barrier();
return;
}
return error_not_implemented("function calls not implemented");
}
void Generator::VisitObject(Object* obj) {
return set_ret(mod_->get_value(obj->Name()));
}
void Generator::VisitEnumerator(Enumerator* enumer) {
return error_not_implemented("enumeration not implemented");
}
void Generator::VisitIdentifier(Identifier* ident) {
return set_ret(mod_->get_value(ident->Name()));
}
void Generator::VisitConstant(Constant* cons) {
Type* ctype = cons->Type();
ir::type *type = GenIRType(cons->Type(), *ctx_);
if(ctype->IsInteger())
return set_ret(ir::constant_int::get(type, cons->IVal()));
if(ctype->IsFloat() && ctype->IsReal())
return set_ret(ir::constant_fp::get(type, cons->FVal()));
return error_not_implemented("constant of type not in {int, float} not implemented");
}
void Generator::VisitTempVar(TempVar* tempVar) {
return error_not_implemented("temporary variable not implemented");
}
// Statement
// TODO: int x = x; crashes
void Generator::VisitDeclaration(Declaration* decl) {
auto obj = decl->obj_;
// initialize to undef
ir::type* ty = GenIRType(obj->Type(), *ctx_);
ir::value* val = ir::undef_value::get(ty);
//obj->GetAttrList()
// compute initializers
std::vector<ir::value*> inits;
for (const Initializer& init: decl->Inits()) {
VisitExpr(init.expr_);
ir::value *val = ret_;
for(const auto& attr: obj->GetAttrList())
SetIRMetadata(attr, val);
inits.push_back(val);
}
// initialize declaration
ir::type::id_t id = ty->get_type_id();
if(id == ir::type::StructTyID)
error_not_implemented("struct not implemented");
if(inits.size() > 1)
error_not_implemented("initializer list > 1 element not implemented");
if(inits.size() > 0)
val = inits[0];
assert(val->get_type() == ty);
// update scope symbols table
const std::string &name = obj->Name();
if(!name.empty()){
mod_->set_value(name, val);
mod_->get_scope().types[name] = ty;
}
}
void Generator::VisitEmptyStmt(EmptyStmt*) {
return;
}
void Generator::VisitIfStmt(IfStmt* ifStmt) {
ir::function *fn = bld_->get_insert_block()->get_parent();
Stmt *then_ = ifStmt->then_;
Stmt *else_ = ifStmt->else_;
VisitExpr(ifStmt->cond_);
ir::value* cond = ret_;
ir::basic_block *then_bb = ir::basic_block::create(*ctx_, "then", fn);
ir::basic_block *else_bb = else_? ir::basic_block::create(*ctx_, "else", fn) : nullptr;
ir::basic_block *endif_bb = ir::basic_block::create(*ctx_, "endif", fn);
// seal blocks
mod_->seal_block(then_bb);
if(else_bb)
mod_->seal_block(else_bb);
// branches
if(else_)
bld_->create_cond_br(cond, then_bb, else_bb);
else
bld_->create_cond_br(cond, then_bb, endif_bb);
// then
bld_->set_insert_point(then_bb);
VisitStmt(then_);
if(!is_terminator(ret_))
bld_->create_br(endif_bb);
// else
if(else_){
bld_->set_insert_point(else_bb);
VisitStmt(else_);
if(!is_terminator(ret_))
bld_->create_br(endif_bb);
}
// endif
mod_->seal_block(endif_bb);
bld_->set_insert_point(endif_bb);
}
void Generator::VisitForStmt(ForStmt *forStmt) {
Stmt *init_ = forStmt->init_;
Expr *cond_ = forStmt->cond_;
Expr *step_ = forStmt->step_;
Stmt *body_ = forStmt->body_;
ir::basic_block *current_bb = bld_->get_insert_block();
ir::function *fn = current_bb->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(*ctx_, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(*ctx_, "postloop", fn);
mod_->set_continue_fn([&](){
if(step_)
VisitExpr(step_);
VisitExpr(cond_);
ir::value *cond = ret_;
return bld_->create_cond_br(cond, loop_bb, next_bb);
});
if(init_)
VisitStmt(init_);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
// bld_->create_br(loop_bb);
bld_->set_insert_point(loop_bb);
if(body_)
VisitStmt(body_);
if(!is_terminator(ret_))
mod_->get_continue_fn()();
ir::basic_block *stop_bb = bld_->get_insert_block();
mod_->seal_block(stop_bb);
mod_->seal_block(loop_bb);
mod_->seal_block(bld_->get_insert_block());
mod_->seal_block(next_bb);
bld_->set_insert_point(next_bb);
}
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
return error_not_implemented("jump not implemented");
}
void Generator::VisitReturnStmt(ReturnStmt* returnStmt) {
ir::value *ret;
if(returnStmt->expr_)
return error_not_implemented("non-void return not implemented");
else
ret = bld_->create_ret_void();
return set_ret(ret);
}
void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
return error_not_implemented("label not implemented");
}
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
if (compoundStmt->scope_)
pushScope();
for (auto stmt: compoundStmt->stmts_)
Visit(stmt);
if(compoundStmt->scope_)
popScope();
}
void Generator::VisitFuncDef(FuncDef* funcDef) {
Stmt *body = funcDef->body_;
const std::string& name = funcDef->Name();
FuncType* type = funcDef->FuncType();
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
if(!prototype)
should_not_happen("could not parse function prototype");
ir::function *fn = mod_->get_or_insert_function(name, prototype);
std::vector<ir::argument*> args = fn->args();
size_t i = 0;
for(Object* obj: type->Params()){
std::string name = obj->Name();
args[i]->set_name(name);
if(obj->Type()->ToPointer())
fn->add_attr(i + 1, ir::attribute(ir::aligned, 16));
for(ASTNode::Attr attr: obj->GetAttrList()){
fn->add_attr(i + 1, GenIRAttr(attr));
}
if(obj->IsRestrictQualified())
fn->add_attr(i, ir::attribute(ir::noalias));
mod_->set_value(name, nullptr, args[i]);
mod_->get_scope().types[name] = args[i]->get_type();
i++;
}
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
mod_->seal_block(entry);
mod_->get_builder().set_insert_point(entry);
VisitStmt(body);
if(!dynamic_cast<ir::return_inst*>(ret_))
mod_->get_builder().create_ret_void();
}
void Generator::VisitTranslationUnit(TranslationUnit* unit) {
pushScope();
for (auto extDecl: unit->ExtDecls())
Visit(extDecl);
popScope();
}
void Generator::Gen(ir::module *mod) {
mod_ = mod;
ctx_ = &mod_->get_context();
bld_ = &mod_->get_builder();
assign_ = new LValAssigner(this);
VisitTranslationUnit(parser_->Unit());
delete assign_;
assign_ = nullptr;
}
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
if(src->get_type() == dst_ty)
return src;
if(dst_ty->is_tile_ty()) {
ir::type *src_ty = src->get_type();
auto dst_shapes = dst_ty->get_tile_shapes();
if(!src_ty->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src_ty->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size()){
unsigned src_numel = 1;
for(unsigned s: src_shapes)
src_numel *= s;
unsigned dst_numel = 1;
for(unsigned s: dst_shapes)
dst_numel *= s;
if(src_numel == dst_numel)
return bld_->create_reshape(src, dst_shapes);
else {
auto padded_shapes = src_shapes;
while(padded_shapes.size() != dst_shapes.size())
padded_shapes.insert(padded_shapes.begin(), 1);
// check that broadcast is legal
for(size_t d = 0; d < padded_shapes.size(); d++){
if(dst_shapes[d] != padded_shapes[d] &&
padded_shapes[d] != 1)
should_not_happen("broadcast should not happen between these shapes");
}
// pad and broadcast
ir::value *padded = bld_->create_reshape(src, padded_shapes);
return bld_->create_broadcast(padded, dst_shapes);
}
}
else{
return bld_->create_broadcast(src, dst_shapes);
}
}
else if(src->get_type()->is_tile_ty() && src->get_type()->get_tile_num_elements() == 1){
return bld_->create_downcast(src);
}
return src;
}
ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
if(src->get_type()->is_tile_ty())
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
bool src_signed = false;
bool dst_signed = false;
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty())
return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())),
bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes()));
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_si_to_fp(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_ui_to_fp(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
return bld_->create_fp_to_si(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
return bld_->create_fp_to_ui(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_ext(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_trunc(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return bld_->create_int_cast(src, dst_ty, dst_signed);
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
return bld_->create_cast(ir::BitCast, src, dst_ty);
else{
error_not_implemented("cast type not implemented");
return nullptr;
}
}
ir::value* Generator::GenSemCastOp(ir::value* src, ir::type* dst_ty) {
return GenNumcastOp(GenBroadcastOp(src, dst_ty), dst_ty);
}
ir::value* Generator::GenBitCastOp(ir::value* src, ir::type* dst_ty) {
return bld_->create_cast(ir::BitCast, GenBroadcastOp(src, dst_ty), dst_ty);
}
// Triton-IR Attr
ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
VisitExpr(attr.vals[0]);
auto cst = dynamic_cast<ir::constant_int*>(ret_);
if(!cst) should_not_happen("multipleof only works on constants");
return ir::attribute(ir::multiple_of, cst->get_value());
}
if(attr.kind == ASTNode::Attr::ALIGNED) {
VisitExpr(attr.vals[0]);
auto cst = dynamic_cast<ir::constant_int*>(ret_);
return ir::attribute(ir::aligned, cst->get_value());
}
if(attr.kind == ASTNode::Attr::NOALIAS)
return ir::attribute(ir::noalias);
if(attr.kind == ASTNode::Attr::READONLY)
return ir::attribute(ir::readonly);
if(attr.kind == ASTNode::Attr::WRITEONLY)
return ir::attribute(ir::writeonly);
if(attr.kind == ASTNode::Attr::RETUNE)
return ir::attribute(ir::retune);
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
return ir::attribute(ir::not_implemented);
}
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(attr.kind == ASTNode::Attr::MULTIPLEOF)
i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value());
}
// Triton-IR Types
ir::type* Generator::GenIRType(::Type* type, ir::context& ctx) {
if(auto T = type->ToVoid())
return ir::type::get_void_ty(ctx);
if(auto T = type->ToArithm())
return GenIRArithmType(T, ctx);
if(auto T = type->ToArray())
return GenIRArrayType(T, ctx);
if(auto T = type->ToTile())
return GenIRTileType(T, ctx);
if(auto T = type->ToFunc())
return GenIRFuncType(T, ctx);
if(auto T = type->ToPointer())
return GenIRPointerType(T, ctx);
if(auto T = type->ToStruct())
return GenIRStructType(T, ctx);
assert(false);
return nullptr;
}
ir::type* Generator::GenIRArithmType(ArithmType* type, ir::context& ctx) {
int tag = type->Tag();
if(tag & T_BOOL)
return ir::type::get_int1_ty(ctx);
if(tag & T_CHAR)
return ir::type::get_int8_ty(ctx);
if(tag & T_SHORT)
return ir::type::get_int16_ty(ctx);
if(tag & T_INT)
return ir::type::get_int32_ty(ctx);
if(tag & T_LONG)
return ir::type::get_int64_ty(ctx);
if(tag & T_HALF)
return ir::type::get_half_ty(ctx);
if(tag & T_FLOAT)
return ir::type::get_float_ty(ctx);
if(tag & T_DOUBLE)
return ir::type::get_double_ty(ctx);
assert(false);
return nullptr;
}
ir::type* Generator::GenIRArrayType(ArrayType* type, ir::context& ctx) {
assert(false);
return nullptr;
}
ir::type* Generator::GenIRTileType(TileType* type, ir::context& ctx) {
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
auto _shape = type->Shape();
ir::tile_type::tile_shapes_t shape;
for(int s: _shape)
shape.push_back(static_cast<unsigned>(s));
return ir::tile_type::get(ele_ty, shape);
}
ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) {
ir::type* ret_ty = GenIRType(type->Derived().GetPtr(), ctx);
std::vector<ir::type*> param_tys;
for(Object* obj: type->Params())
param_tys.push_back(GenIRType(obj->Type(), ctx));
return ir::function_type::get(ret_ty, param_tys);
}
ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
unsigned addr_space = 1;
if(type->Derived().IsConstantQualified())
addr_space = 4;
return ir::pointer_type::get(ele_ty, addr_space);
}
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
error_not_implemented("struct not implemented");
return nullptr;
}
void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) {
return error_not_implemented("alloc not implemented");
}
// SSA
void Generator::pushScope() {
mod_->add_new_scope();
}
void Generator::popScope() {
mod_->pop_scope();
}
// LValue Generator
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
if(binary->op_ != Token::MASKED_DEREF)
error_not_implemented("lvalue for binary non masked-deref not implemented");
gen_->VisitExpr(binary->lhs_);
ir::value* mask = gen_->ret_;
gen_->VisitExpr(binary->rhs_);
ir::value* addr = gen_->ret_;
ret_ = gen_->bld_->create_masked_store(addr, rhs_, mask);
}
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
if(unary->op_ != Token::DEREF)
error_not_implemented("lvalue for unary non deref not implemented");
gen_->VisitExpr(unary->operand_);
ir::value* addr = gen_->ret_;
ret_ = gen_->bld_->create_store(addr, rhs_);
}
void LValAssigner::VisitObject(Object* obj) {
std::string name = obj->Name();
gen_->mod_->set_value(name, rhs_);
ret_ = rhs_;
}
void LValAssigner::VisitIdentifier(Identifier* ident) {
std::string name = ident->Name();
gen_->mod_->set_value(name, rhs_);
ret_ = rhs_;
}

View File

@@ -1,884 +0,0 @@
#include "triton/lang/cpp.h"
#include "triton/lang/evaluator.h"
#include "triton/lang/parser.h"
#include <ctime>
#include <fcntl.h>
#include <unistd.h>
#include <unordered_map>
using DirectiveMap = std::unordered_map<std::string, int>;
static const DirectiveMap directiveMap {
{"if", Token::PP_IF},
{"ifdef", Token::PP_IFDEF},
{"ifndef", Token::PP_IFNDEF},
{"elif", Token::PP_ELIF},
{"else", Token::PP_ELSE},
{"endif", Token::PP_ENDIF},
{"include", Token::PP_INCLUDE},
// Non-standard GNU extension
{"include_next", Token::PP_INCLUDE},
{"define", Token::PP_DEFINE},
{"undef", Token::PP_UNDEF},
{"line", Token::PP_LINE},
{"error", Token::PP_ERROR},
{"pragma", Token::PP_PRAGMA}
};
/*
* params:
* is: input token sequence
* os: output token sequence
*/
void Preprocessor::Expand(TokenSequence& os, TokenSequence is, bool inCond) {
Macro* macro = nullptr;
int direcitve;
while (!is.Empty()) {
UpdateFirstTokenLine(is);
auto tok = is.Peek();
const auto& name = tok->str_;
if ((direcitve = GetDirective(is)) != Token::INVALID) {
ParseDirective(os, is, direcitve);
} else if (!inCond && !NeedExpand()) {
// Discards the token
is.Next();
} else if (inCond && name == "defined") {
is.Next();
os.InsertBack(EvalDefOp(is));
} else if (tok->hs_ && tok->hs_->find(name) != tok->hs_->end()) {
os.InsertBack(is.Next());
} else if ((macro = FindMacro(name))) {
is.Next();
if (name == "__FILE__") {
HandleTheFileMacro(os, tok);
} else if (name == "__LINE__") {
HandleTheLineMacro(os, tok);
} else if (macro->ObjLike()) {
// Make a copy, as subst will change repSeq
auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_);
TokenList tokList;
TokenSequence repSeqSubsted(&tokList);
ParamMap paramMap;
// TODO(wgtdkp): hideset is not right
// Make a copy of hideset
// HS U {name}
auto hs = tok->hs_ ? *tok->hs_: HideSet();
hs.insert(name);
Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap);
is.InsertFront(repSeqSubsted);
} else if (is.Try('(')) {
ParamMap paramMap;
auto rpar = ParseActualParam(is, macro, paramMap);
auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_);
TokenList tokList;
TokenSequence repSeqSubsted(&tokList);
// (HS ^ HS') U {name}
// Use HS' U {name} directly
auto hs = rpar->hs_ ? *rpar->hs_: HideSet();
hs.insert(name);
Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap);
is.InsertFront(repSeqSubsted);
} else {
os.InsertBack(tok);
}
} else {
os.InsertBack(is.Next());
}
}
}
static bool FindActualParam(TokenSequence& ap,
ParamMap& params,
const std::string& fp) {
auto res = params.find(fp);
if (res == params.end()) {
return false;
}
ap.Copy(res->second);
return true;
}
void Preprocessor::Subst(TokenSequence& os,
TokenSequence is,
bool leadingWS,
const HideSet& hs,
ParamMap& params) {
TokenSequence ap;
while (!is.Empty()) {
if (is.Test('#') && FindActualParam(ap, params, is.Peek2()->str_)) {
is.Next(); is.Next();
auto tok = Stringize(ap);
os.InsertBack(tok);
} else if (is.Test(Token::DSHARP) &&
FindActualParam(ap, params, is.Peek2()->str_)) {
is.Next(); is.Next();
if (!ap.Empty())
Glue(os, ap);
} else if (is.Test(Token::DSHARP)) {
is.Next();
auto tok = is.Next();
Glue(os, tok);
} else if (is.Peek2()->tag_ == Token::DSHARP &&
FindActualParam(ap, params, is.Peek()->str_)) {
is.Next();
if (ap.Empty()) {
is.Next();
if (FindActualParam(ap, params, is.Peek()->str_)) {
is.Next();
os.InsertBack(ap);
}
} else {
os.InsertBack(ap);
}
} else if (FindActualParam(ap, params, is.Peek()->str_)) {
auto tok = is.Next();
const_cast<Token*>(ap.Peek())->ws_ = tok->ws_;
Expand(os, ap);
} else {
os.InsertBack(is.Peek());
is.Next();
}
}
os.FinalizeSubst(leadingWS, hs);
}
void Preprocessor::Glue(TokenSequence& os, const Token* tok) {
TokenList tokList {tok};
TokenSequence is(&tokList);
Glue(os, is);
}
void Preprocessor::Glue(TokenSequence& os, TokenSequence is) {
auto lhs = os.Back();
auto rhs = is.Peek();
auto str = new std::string(lhs->str_ + rhs->str_);
TokenSequence ts;
Scanner scanner(str, lhs->loc_);
scanner.Tokenize(ts);
is.Next();
if (ts.Empty()) {
// TODO(wgtdkp):
// No new Token generated
// How to handle it???
} else {
os.PopBack();
auto newTok = const_cast<Token*>(ts.Next());
newTok->ws_ = lhs->ws_;
newTok->hs_ = lhs->hs_;
os.InsertBack(newTok);
}
if (!ts.Empty()) {
Error(lhs, "macro expansion failed: cannot concatenate");
}
os.InsertBack(is);
}
/*
* This is For the '#' operator in func-like macro
*/
const Token* Preprocessor::Stringize(TokenSequence is) {
std::string str = "\"";
while (!is.Empty()) {
auto tok = is.Next();
// Have preceding white space
// and is not the first token of the sequence
str.append(tok->ws_ && str.size() > 1, ' ');
if (tok->tag_ == Token::LITERAL || tok->tag_ == Token::C_CONSTANT) {
for (auto c: tok->str_) {
if (c == '"' || c == '\\')
str.push_back('\\');
str.push_back(c);
}
} else {
str += tok->str_;
}
}
str.push_back('\"');
auto ret = Token::New(*is.Peek());
ret->tag_ = Token::LITERAL;
ret->str_ = str;
return ret;
}
void Preprocessor::Finalize(TokenSequence os) {
while (!os.Empty()) {
auto tok = os.Next();
if (tok->tag_ == Token::INVALID) {
Error(tok, "stray token in program");
} else if (tok->tag_ == Token::IDENTIFIER) {
auto tag = Token::KeyWordTag(tok->str_);
if (Token::IsKeyWord(tag)) {
const_cast<Token*>(tok)->tag_ = tag;
} else {
const_cast<Token*>(tok)->str_ = Scanner(tok).ScanIdentifier();
}
}
if (fName_ && !tok->loc_.filename_) {
assert(false);
}
}
}
// TODO(wgtdkp): add predefined macros
void Preprocessor::Process(TokenSequence& os) {
TokenSequence is;
// Add source file
if(fName_)
IncludeFile(is, fName_);
else
IncludeSrc(is, fSrc_, nullptr);
// Expand
Expand(os, is);
Finalize(os);
}
const Token* Preprocessor::ParseActualParam(TokenSequence& is,
Macro* macro,
ParamMap& paramMap) {
const Token* ret;
if (macro->Params().size() == 0 && !macro->Variadic()) {
ret = is.Next();
if (ret->tag_ != ')')
Error(ret, "too many arguments");
return ret;
}
auto fp = macro->Params().begin();
TokenSequence ap;
int cnt = 1;
while (cnt > 0) {
if (is.Empty())
Error(is.Peek(), "premature end of input");
else if (is.Test('('))
++cnt;
else if (is.Test(')'))
--cnt;
if ((is.Test(',') && cnt == 1) || cnt == 0) {
if (fp == macro->Params().end()) {
if (!macro->Variadic())
Error(is.Peek(), "too many arguments");
if (cnt == 0)
paramMap.insert(std::make_pair("__VA_ARGS__", ap));
else
ap.InsertBack(is.Peek());
} else {
paramMap.insert(std::make_pair(*fp, ap));
ap = TokenSequence();
++fp;
}
} else {
ap.InsertBack(is.Peek());
}
ret = is.Next();
}
if (fp != macro->Params().end())
Error(is.Peek(), "too few params");
return ret;
}
const Token* Preprocessor::EvalDefOp(TokenSequence& is) {
auto hasPar = is.Try('(');
auto macro = is.Expect(Token::IDENTIFIER);
auto cons = Token::New(*macro);
if (hasPar) is.Expect(')');
cons->tag_ = Token::I_CONSTANT;
cons->str_ = FindMacro(macro->str_) ? "1": "0";
return cons;
}
void Preprocessor::ReplaceIdent(TokenSequence& is) {
TokenSequence os;
while (!is.Empty()) {
auto tok = is.Next();
if (tok->tag_ == Token::IDENTIFIER) {
auto cons = Token::New(*tok);
cons->tag_ = Token::I_CONSTANT;
cons->str_ = "0";
os.InsertBack(cons);
} else {
os.InsertBack(tok);
}
}
is = os;
}
int Preprocessor::GetDirective(TokenSequence& is) {
if (!is.Test('#') || !is.IsBeginOfLine())
return Token::INVALID;
is.Next();
if (is.IsBeginOfLine())
return Token::PP_EMPTY;
auto tag = is.Peek()->tag_;
if (tag == Token::IDENTIFIER || Token::IsKeyWord(tag)) {
auto str = is.Peek()->str_;
auto res = directiveMap.find(str);
if (res == directiveMap.end())
return Token::PP_NONE;
return res->second;
}
return Token::PP_NONE;
}
void Preprocessor::ParseDirective(TokenSequence& os,
TokenSequence& is,
int directive) {
if (directive == Token::PP_EMPTY)
return;
auto ls = is.GetLine();
switch(directive) {
case Token::PP_IF:
ParseIf(ls); break;
case Token::PP_IFDEF:
ParseIfdef(ls); break;
case Token::PP_IFNDEF:
ParseIfndef(ls); break;
case Token::PP_ELIF:
ParseElif(ls); break;
case Token::PP_ELSE:
ParseElse(ls); break;
case Token::PP_ENDIF:
ParseEndif(ls); break;
case Token::PP_INCLUDE:
if (NeedExpand())
ParseInclude(is, ls);
break;
case Token::PP_DEFINE:
if (NeedExpand())
ParseDef(ls);
break;
case Token::PP_UNDEF:
if (NeedExpand())
ParseUndef(ls);
break;
case Token::PP_LINE:
if (NeedExpand())
ParseLine(ls);
break;
case Token::PP_ERROR:
if (NeedExpand())
ParseError(ls);
break;
case Token::PP_PRAGMA:
if (NeedExpand())
ParsePragma(ls);
break;
case Token::PP_NONE:
break;
default:
assert(false);
}
}
void Preprocessor::ParsePragma(TokenSequence ls) {
// TODO(wgtdkp):
ls.Next();
}
void Preprocessor::ParseError(TokenSequence ls) {
ls.Next();
const auto& literal = Stringize(ls);
std::string msg;
Scanner(literal).ScanLiteral(msg);
Error(ls.Peek(), "%s", msg.c_str());
}
void Preprocessor::ParseLine(TokenSequence ls) {
auto directive = ls.Next(); // Skip directive 'line'
TokenSequence ts;
Expand(ts, ls);
auto tok = ts.Expect(Token::I_CONSTANT);
int line = 0;
size_t end = 0;
try {
line = stoi(tok->str_, &end, 10);
} catch (const std::out_of_range& oor) {
Error(tok, "line number out of range");
}
if (line == 0 || end != tok->str_.size()) {
Error(tok, "illegal line number");
}
curLine_ = line;
lineLine_ = directive->loc_.line_;
if (ts.Empty())
return;
tok = ts.Expect(Token::LITERAL);
// Enusure "s-char-sequence"
if (tok->str_.front() != '"' || tok->str_.back() != '"') {
Error(tok, "expect s-char-sequence");
}
}
void Preprocessor::ParseIf(TokenSequence ls) {
if (!NeedExpand()) {
ppCondStack_.push({Token::PP_IF, false, false});
return;
}
auto tok = ls.Next(); // Skip the directive
if (ls.Empty()) {
Error(tok, "expect expression in 'if' directive");
}
TokenSequence ts;
Expand(ts, ls, true);
ReplaceIdent(ts);
Parser parser(ts);
auto expr = parser.ParseExpr();
if (!parser.ts().Empty()) {
Error(parser.ts().Peek(), "unexpected extra expression");
}
bool cond;
if (expr->Type()->IsFloat()) {
cond = static_cast<bool>(Evaluator<double>().Eval(expr));
} else {
cond = static_cast<bool>(Evaluator<long>().Eval(expr));
}
ppCondStack_.push({Token::PP_IF, NeedExpand(), cond});
}
void Preprocessor::ParseIfdef(TokenSequence ls) {
if (!NeedExpand()) {
ppCondStack_.push({Token::PP_IFDEF, false, false});
return;
}
ls.Next();
auto ident = ls.Expect(Token::IDENTIFIER);
if (!ls.Empty()) {
Error(ls.Peek(), "expect new line");
}
auto cond = FindMacro(ident->str_) != nullptr;
ppCondStack_.push({Token::PP_IFDEF, NeedExpand(), cond});
}
void Preprocessor::ParseIfndef(TokenSequence ls) {
ParseIfdef(ls);
auto top = ppCondStack_.top();
ppCondStack_.pop();
top.tag_ = Token::PP_IFNDEF;
top.cond_ = !top.cond_;
ppCondStack_.push(top);
}
void Preprocessor::ParseElif(TokenSequence ls) {
auto directive = ls.Next(); // Skip the directive
if (ppCondStack_.empty())
Error(directive, "unexpected 'elif' directive");
auto top = ppCondStack_.top();
if (top.tag_ == Token::PP_ELSE)
Error(directive, "unexpected 'elif' directive");
while (!ppCondStack_.empty()) {
top = ppCondStack_.top();
if (top.tag_ == Token::PP_IF ||
top.tag_ == Token::PP_IFDEF ||
top.tag_ == Token::PP_IFNDEF ||
top.cond_) {
break;
}
ppCondStack_.pop();
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'elif' directive");
auto enabled = top.enabled_;
if (!enabled) {
ppCondStack_.push({Token::PP_ELIF, false, false});
return;
}
if (ls.Empty()) {
Error(ls.Peek(), "expect expression in 'elif' directive");
}
TokenSequence ts;
Expand(ts, ls, true);
ReplaceIdent(ts);
Parser parser(ts);
auto expr = parser.ParseExpr();
if (!parser.ts().Empty()) {
Error(parser.ts().Peek(), "unexpected extra expression");
}
bool cond;
if (expr->Type()->IsFloat()) {
std::cout << Evaluator<double>().Eval(expr) << std::endl;
cond = static_cast<bool>(Evaluator<double>().Eval(expr));
} else {
cond = static_cast<bool>(Evaluator<long>().Eval(expr));
}
cond = cond && !top.cond_;
ppCondStack_.push({Token::PP_ELIF, true, cond});
}
void Preprocessor::ParseElse(TokenSequence ls) {
auto directive = ls.Next();
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
if (ppCondStack_.empty())
Error(directive, "unexpected 'else' directive");
auto top = ppCondStack_.top();
if (top.tag_ == Token::PP_ELSE)
Error(directive, "unexpected 'else' directive");
while (!ppCondStack_.empty()) {
top = ppCondStack_.top();
if (top.tag_ == Token::PP_IF ||
top.tag_ == Token::PP_IFDEF ||
top.tag_ == Token::PP_IFNDEF ||
top.cond_) {
break;
}
ppCondStack_.pop();
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'else' directive");
auto cond = !top.cond_;
auto enabled = top.enabled_;
ppCondStack_.push({Token::PP_ELSE, enabled, cond});
}
void Preprocessor::ParseEndif(TokenSequence ls) {
auto directive = ls.Next();
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
while ( !ppCondStack_.empty()) {
auto top = ppCondStack_.top();
ppCondStack_.pop();
if (top.tag_ == Token::PP_IF
|| top.tag_ == Token::PP_IFDEF
|| top.tag_ == Token::PP_IFNDEF) {
return;
}
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'endif' directive");
}
// Have Read the '#'
void Preprocessor::ParseInclude(TokenSequence& is, TokenSequence ls) {
bool next = ls.Next()->str_ == "include_next"; // Skip 'include'
if (!ls.Test(Token::LITERAL) && !ls.Test('<')) {
TokenSequence ts;
Expand(ts, ls, true);
ls = ts;
}
auto tok = ls.Next();
if (tok->tag_ == Token::LITERAL) {
if (!ls.Empty()) {
Error(ls.Peek(), "expect new line");
}
std::string filename;
Scanner(tok).ScanLiteral(filename);
auto fullPath = SearchFile(filename, false, next, *tok->loc_.filename_);
if (fullPath == nullptr)
Error(tok, "%s: No such file or directory", filename.c_str());
IncludeFile(is, fullPath);
} else if (tok->tag_ == '<') {
auto lhs = tok;
auto rhs = tok;
int cnt = 1;
while (!(rhs = ls.Next())->IsEOF()) {
if (rhs->tag_ == '<')
++cnt;
else if (rhs->tag_ == '>')
--cnt;
if (cnt == 0)
break;
}
if (cnt != 0)
Error(rhs, "expect '>'");
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
const auto& filename = Scanner::ScanHeadName(lhs, rhs);
auto fullPath = SearchFile(filename, true, next, *tok->loc_.filename_);
if (fullPath == nullptr) {
Error(tok, "%s: No such file or directory", filename.c_str());
}
IncludeFile(is, fullPath);
} else {
Error(tok, "expect filename(string or in '<>')");
}
}
void Preprocessor::ParseUndef(TokenSequence ls) {
ls.Next(); // Skip directive
auto ident = ls.Expect(Token::IDENTIFIER);
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
RemoveMacro(ident->str_);
}
void Preprocessor::ParseDef(TokenSequence ls) {
ls.Next();
auto ident = ls.Expect(Token::IDENTIFIER);
if (ident->str_ == "defined") {
Error(ident, "'defined' cannot be used as a macro name");
}
auto tok = ls.Peek();
if (tok->tag_ == '(' && !tok->ws_) {
// There is no white space between ident and '('
// Hence, we are defining function-like macro
// Parse Identifier list
ls.Next(); // Skip '('
ParamList params;
auto variadic = ParseIdentList(params, ls);
const auto& macro = Macro(variadic, params, ls);
AddMacro(ident->str_, macro);
} else {
AddMacro(ident->str_, Macro(ls));
}
}
bool Preprocessor::ParseIdentList(ParamList& params, TokenSequence& is) {
const Token* tok = is.Peek();
while (!is.Empty()) {
tok = is.Next();
if (tok->tag_ == ')') {
return false;
} else if (tok->tag_ == Token::ELLIPSIS) {
is.Expect(')');
return true;
} else if (tok->tag_ != Token::IDENTIFIER) {
Error(tok, "expect identifier");
}
for (const auto& param: params) {
if (param == tok->str_)
Error(tok, "duplicated param");
}
params.push_back(tok->str_);
if (!is.Try(',')) {
is.Expect(')');
return false;
}
}
Error(tok, "unexpected end of line");
}
void Preprocessor::IncludeSrc(TokenSequence& is,
const std::string* text,
const std::string* filename) {
TokenSequence ts {is.tokList_, is.begin_, is.begin_};
Scanner scanner(text, filename);
scanner.Tokenize(ts);
// We done including header file
is.begin_ = ts.begin_;
}
void Preprocessor::IncludeFile(TokenSequence& is,
const std::string* filename) {
IncludeSrc(is, ReadFile(*filename), filename);
}
static std::string GetDir(const std::string& path) {
auto pos = path.rfind('/');
if (pos == std::string::npos)
return "./";
return path.substr(0, pos + 1);
}
std::string* Preprocessor::SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath) {
if (libHeader && !next) {
searchPaths_.push_back(GetDir(curPath));
} else {
searchPaths_.push_front(GetDir(curPath));
}
auto iter = searchPaths_.begin();
for (; iter != searchPaths_.end(); ++iter) {
auto dd = open(iter->c_str(), O_RDONLY);
if (dd == -1) // TODO(wgtdkp): or ensure it before preprocessing
continue;
auto fd = openat(dd, name.c_str(), O_RDONLY);
close(dd);
if (fd != -1) {
// Intentional, so that recursive include
// will result in running out of file descriptor
//close(fd);
auto path = *iter + name;
if (next) {
if (path != curPath)
continue;
else
next = false;
} else {
if (path == curPath)
continue;
if (libHeader && !next)
searchPaths_.pop_back();
else
searchPaths_.pop_front();
return new std::string(path);
}
} else if (errno == EMFILE) {
Error("may recursive include");
}
}
return nullptr;
}
void Preprocessor::AddMacro(const std::string& name,
std::string* text,
bool preDef) {
TokenSequence ts;
Scanner scanner(text);
scanner.Tokenize(ts);
Macro macro(ts, preDef);
AddMacro(name, macro);
}
static std::string* Date() {
time_t t = time(NULL);
struct tm* tm = localtime(&t);
char buf[14];
strftime(buf, sizeof buf, "\"%a %M %Y\"", tm);
return new std::string(buf);
}
void Preprocessor::Init() {
// Preinclude search paths
AddSearchPath("/usr/local/include/");
AddSearchPath("/usr/include/x86_64-linux-gnu/");
AddSearchPath("/usr/include/linux/");
AddSearchPath("/usr/include/");
AddSearchPath("/usr/local/include/");
// The __FILE__ and __LINE__ macro is empty
// They are handled seperately
AddMacro("__FILE__", Macro(TokenSequence(), true));
AddMacro("__LINE__", Macro(TokenSequence(), true));
AddMacro("__DATE__", Date(), true);
AddMacro("__STDC__", new std::string("1"), true);
AddMacro("__STDC__HOSTED__", new std::string("0"), true);
AddMacro("__STDC_VERSION__", new std::string("201103L"), true);
}
void Preprocessor::HandleTheFileMacro(TokenSequence& os, const Token* macro) {
auto file = Token::New(*macro);
file->tag_ = Token::LITERAL;
file->str_ = "\"" + *macro->loc_.filename_ + "\"";
os.InsertBack(file);
}
void Preprocessor::HandleTheLineMacro(TokenSequence& os, const Token* macro) {
auto line = Token::New(*macro);
line->tag_ = Token::I_CONSTANT;
line->str_ = std::to_string(macro->loc_.line_);
os.InsertBack(line);
}
void Preprocessor::UpdateFirstTokenLine(TokenSequence ts) {
auto loc = ts.Peek()->loc_;
loc.line_ = curLine_ + loc.line_ - lineLine_ - 1;
ts.UpdateHeadLocation(loc);
}
TokenSequence Macro::RepSeq(const std::string* filename, unsigned line) {
// Update line
TokenList tl;
TokenSequence ret(&tl);
ret.Copy(repSeq_);
auto ts = ret;
while (!ts.Empty()) {
auto loc = ts.Peek()->loc_;
loc.filename_ = filename;
loc.line_ = line;
ts.UpdateHeadLocation(loc);
ts.Next();
}
return ret;
}
void Preprocessor::AddSearchPath(std::string path) {
if (path.back() != '/')
path += "/";
if (path[0] != '/')
path = "./" + path;
searchPaths_.push_front(path);
}

View File

@@ -1,42 +0,0 @@
#include "triton/lang/encoding.h"
#include <climits>
#include <codecvt>
#include <locale>
#include <iostream>
static void Append16LE(std::string& str, char16_t c) {
str.push_back(c & UCHAR_MAX);
str.push_back((c >> 8) & UCHAR_MAX);
}
static void Append32LE(std::string& str, char32_t c) {
Append16LE(str, c & USHRT_MAX);
Append16LE(str, (c >> 16) & USHRT_MAX);
}
void ConvertToUTF16(std::string& str) {
std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> utf8_ucs2_cvt;
auto str16 = utf8_ucs2_cvt.from_bytes(str);
str.resize(0);
for (auto c16: str16)
Append16LE(str, c16);
}
void ConvertToUTF32(std::string& str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> utf8_ucs4_cvt;
auto str32 = utf8_ucs4_cvt.from_bytes(str);
str.resize(0);
for (auto c32: str32)
Append32LE(str, c32);
}
void AppendUCN(std::string& str, int c) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> utf8_ucs4_cvt;
str += utf8_ucs4_cvt.to_bytes(static_cast<char32_t>(c));
}

View File

@@ -1,91 +0,0 @@
#include "triton/lang/error.h"
#include "triton/lang/ast.h"
#include "triton/lang/token.h"
#include <cstdarg>
#include <cstdio>
#include <cstring>
#include <string>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
void Error(const char* format, ...) {
fprintf(stderr,
ANSI_COLOR_RED "error: " ANSI_COLOR_RESET);
va_list args;
va_start(args, format);
vfprintf(stderr, format, args);
va_end(args);
fprintf(stderr, "\n");
exit(-1);
}
[[noreturn]]
static void VError(const SourceLocation& loc,
const char* format,
va_list args) {
const char* filename = nullptr;
if(loc.filename_)
filename = loc.filename_->c_str();
fprintf(stderr,
"%s:%d:%d: " ANSI_COLOR_RED "error: " ANSI_COLOR_RESET,
filename,
loc.line_,
loc.column_);
vfprintf(stderr, format, args);
fprintf(stderr, "\n ");
bool sawNoSpace = false;
int nspaces = 0;
for (auto p = loc.lineBegin_; *p != '\n' && *p != 0; p++) {
if (!sawNoSpace && (*p == ' ' || *p == '\t')) {
++nspaces;
} else {
sawNoSpace = true;
fputc(*p, stderr);
}
}
fprintf(stderr, "\n ");
for (unsigned i = 1; i + nspaces < loc.column_; ++i)
fputc(' ', stderr);
fprintf(stderr, "^\n");
exit(-1);
}
void Error(const SourceLocation& loc, const char* format, ...) {
va_list args;
va_start(args, format);
VError(loc, format, args);
va_end(args);
}
void Error(const Token* tok, const char* format, ...) {
va_list args;
va_start(args, format);
VError(tok->loc_, format, args);
va_end(args);
}
void Error(const Expr* expr, const char* format, ...) {
va_list args;
va_start(args, format);
VError(expr->Tok()->loc_, format, args);
va_end(args);
}

View File

@@ -1,206 +0,0 @@
#include "triton/lang/evaluator.h"
#include "triton/lang/ast.h"
#include "triton/lang/token.h"
template<typename T>
void Evaluator<T>::VisitBinaryOp(BinaryOp* binary) {
#define L Evaluator<T>().Eval(binary->lhs_)
#define R Evaluator<T>().Eval(binary->rhs_)
#define LL Evaluator<long>().Eval(binary->lhs_)
#define LR Evaluator<long>().Eval(binary->rhs_)
if (binary->Type()->ToPointer()) {
auto val = Evaluator<Addr>().Eval(binary);
if (val.label_.size()) {
Error(binary, "expect constant integer expression");
}
val_ = static_cast<T>(val.offset_);
return;
}
switch (binary->op_) {
case '+': val_ = L + R; break;
case '-': val_ = L - R; break;
case '*': val_ = L * R; break;
case '/': {
auto l = L, r = R;
if (r == 0)
Error(binary, "division by zero");
val_ = l / r;
} break;
case '%': {
auto l = LL, r = LR;
if (r == 0)
Error(binary, "division by zero");
val_ = l % r;
} break;
// Bitwise operators that do not accept float
case '|': val_ = LL | LR; break;
case '&': val_ = LL & LR; break;
case '^': val_ = LL ^ LR; break;
case Token::LEFT: val_ = LL << LR; break;
case Token::RIGHT: val_ = LL >> LR; break;
case '<': val_ = L < R; break;
case '>': val_ = L > R; break;
case Token::LOGICAL_AND: val_ = L && R; break;
case Token::LOGICAL_OR: val_ = L || R; break;
case Token::EQ: val_ = L == R; break;
case Token::NE: val_ = L != R; break;
case Token::LE: val_ = L <= R; break;
case Token::GE: val_ = L >= R; break;
case '=': case ',': val_ = R; break;
case '.': {
auto addr = Evaluator<Addr>().Eval(binary);
if (addr.label_.size())
Error(binary, "expect constant expression");
val_ = addr.offset_;
}
default: assert(false);
}
#undef L
#undef R
#undef LL
#undef LR
}
template<typename T>
void Evaluator<T>::VisitUnaryOp(UnaryOp* unary) {
#define VAL Evaluator<T>().Eval(unary->operand_)
#define LVAL Evaluator<long>().Eval(unary->operand_)
switch (unary->op_) {
case Token::PLUS: val_ = VAL; break;
case Token::MINUS: val_ = -VAL; break;
case '~': val_ = ~LVAL; break;
case '!': val_ = !VAL; break;
case Token::CAST:
if (unary->Type()->IsInteger())
val_ = static_cast<long>(VAL);
else
val_ = VAL;
break;
case Token::ADDR: {
auto addr = Evaluator<Addr>().Eval(unary->operand_);
if (addr.label_.size())
Error(unary, "expect constant expression");
val_ = addr.offset_;
} break;
default: Error(unary, "expect constant expression");
}
#undef LVAL
#undef VAL
}
template<typename T>
void Evaluator<T>::VisitConditionalOp(ConditionalOp* condOp) {
bool cond;
auto condType = condOp->cond_->Type();
if (condType->IsInteger()) {
auto val = Evaluator<long>().Eval(condOp->cond_);
cond = val != 0;
} else if (condType->IsFloat()) {
auto val = Evaluator<double>().Eval(condOp->cond_);
cond = val != 0.0;
} else if (condType->ToPointer()) {
auto val = Evaluator<Addr>().Eval(condOp->cond_);
cond = val.label_.size() || val.offset_;
} else {
assert(false);
}
if (cond) {
val_ = Evaluator<T>().Eval(condOp->exprTrue_);
} else {
val_ = Evaluator<T>().Eval(condOp->exprFalse_);
}
}
void Evaluator<Addr>::VisitBinaryOp(BinaryOp* binary) {
#define LR Evaluator<long>().Eval(binary->rhs_)
#define R Evaluator<Addr>().Eval(binary->rhs_)
auto l = Evaluator<Addr>().Eval(binary->lhs_);
int width = 1;
auto pointerType = binary->Type()->ToPointer();
if (pointerType)
width = pointerType->Derived()->Width();
switch (binary->op_) {
case '+':
assert(pointerType);
addr_.label_ = l.label_;
addr_.offset_ = l.offset_ + LR * width;
break;
case '-':
assert(pointerType);
addr_.label_ = l.label_;
addr_.offset_ = l.offset_ + LR * width;
break;
case '.': {
addr_.label_ = l.label_;
auto type = binary->lhs_->Type()->ToStruct();
auto offset = type->GetMember(binary->rhs_->tok_->str_)->Offset();
addr_.offset_ = l.offset_ + offset;
break;
}
default: assert(false);
}
#undef LR
#undef R
}
void Evaluator<Addr>::VisitUnaryOp(UnaryOp* unary) {
auto addr = Evaluator<Addr>().Eval(unary->operand_);
switch (unary->op_) {
case Token::CAST:
case Token::ADDR:
case Token::DEREF:
addr_ = addr; break;
default: assert(false);
}
}
void Evaluator<Addr>::VisitConditionalOp(ConditionalOp* condOp) {
bool cond;
auto condType = condOp->cond_->Type();
if (condType->IsInteger()) {
auto val = Evaluator<long>().Eval(condOp->cond_);
cond = val != 0;
} else if (condType->IsFloat()) {
auto val = Evaluator<double>().Eval(condOp->cond_);
cond = val != 0.0;
} else if (condType->ToPointer()) {
auto val = Evaluator<Addr>().Eval(condOp->cond_);
cond = val.label_.size() || val.offset_;
} else {
assert(false);
}
if (cond) {
addr_ = Evaluator<Addr>().Eval(condOp->exprTrue_);
} else {
addr_ = Evaluator<Addr>().Eval(condOp->exprFalse_);
}
}
void Evaluator<Addr>::VisitConstant(Constant* cons) {
if (cons->Type()->IsInteger()) {
addr_ = {"", static_cast<int>(cons->IVal())};
} else if (cons->Type()->ToArray()) {
assert(false);
} else {
assert(false);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,452 +0,0 @@
#include "triton/lang/scanner.h"
#include <cctype>
#include <climits>
void Scanner::Tokenize(TokenSequence& ts) {
while (true) {
auto tok = Scan();
if (tok->tag_ == Token::END) {
if (ts.Empty() || (ts.Back()->tag_ != Token::NEW_LINE)) {
auto t = Token::New(*tok);
t->tag_ = Token::NEW_LINE;
t->str_ = "\n";
ts.InsertBack(t);
}
break;
} else {
if (!ts.Empty() && ts.Back()->tag_ == Token::NEW_LINE)
tok->ws_ = true;
ts.InsertBack(tok);
}
}
}
std::string Scanner::ScanHeadName(const Token* lhs, const Token* rhs) {
std::string str;
const char* begin = lhs->loc_.Begin() + 1;
const char* end = rhs->loc_.Begin();
for (; begin != end; ++begin) {
if (*begin == '\n' && str.back() == '\\')
str.pop_back();
else
str.push_back(*begin);
}
return str;
}
Token* Scanner::Scan(bool ws) {
tok_.ws_ = ws;
SkipWhiteSpace();
Mark();
if (Test('\n')) {
auto ret = MakeNewLine();
Next();
return ret;
}
auto c = Next();
switch (c) {
case '#': return MakeToken(Try('#') ? Token::DSHARP: c);
case ':': return MakeToken(Try('>') ? ']': c);
case '(': case ')': case '[': case ']':
case '?': case ',': case '{': case '}':
case '~': case ';': case '@':
return MakeToken(c);
case '-':
if (Try('>')) return MakeToken(Token::PTR);
if (Try('-')) return MakeToken(Token::DEC);
if (Try('=')) return MakeToken(Token::SUB_ASSIGN);
return MakeToken(c);
case '+':
if (Try('+')) return MakeToken(Token::INC);
if (Try('=')) return MakeToken(Token::ADD_ASSIGN);
return MakeToken(c);
case '<':
if (Try('<')) return MakeToken(Try('=') ? Token::LEFT_ASSIGN: Token::LEFT);
if (Try('=')) return MakeToken(Token::LE);
if (Try(':')) return MakeToken('[');
if (Try('%')) return MakeToken('{');
return MakeToken(c);
case '%':
if (Try('=')) return MakeToken(Token::MOD_ASSIGN);
if (Try('>')) return MakeToken('}');
if (Try(':')) {
if (Try('%')) {
if (Try(':')) return MakeToken(Token::DSHARP);
PutBack();
}
return MakeToken('#');
}
return MakeToken(c);
case '>':
if (Try('>')) return MakeToken(Try('=') ? Token::RIGHT_ASSIGN: Token::RIGHT);
if (Try('=')) return MakeToken(Token::GE);
return MakeToken(c);
case '=': return MakeToken(Try('=') ? Token::EQ: c);
case '!': return MakeToken(Try('=') ? Token::NE: c);
case '&':
if (Try('&')) return MakeToken(Token::LOGICAL_AND);
if (Try('=')) return MakeToken(Token::AND_ASSIGN);
return MakeToken(c);
case '|':
if (Try('|')) return MakeToken(Token::LOGICAL_OR);
if (Try('=')) return MakeToken(Token::OR_ASSIGN);
return MakeToken(c);
case '*': return MakeToken(Try('=') ? Token::MUL_ASSIGN: c);
case '/':
if (Test('/') || Test('*')) {
SkipComment();
return Scan(true);
}
return MakeToken(Try('=') ? Token::DIV_ASSIGN: c);
case '^': return MakeToken(Try('=') ? Token::XOR_ASSIGN: c);
case '.':
if (isdigit(Peek())) return SkipNumber();
if (Try('.')) {
if (Try('.')) return MakeToken(Token::ELLIPSIS);
PutBack();
return MakeToken('.');
}
return MakeToken(c);
case '0' ... '9': return SkipNumber();
case 'u': case 'U': case 'L': {
/*auto enc = */ScanEncoding(c);
if (Try('\'')) return SkipCharacter();
if (Try('\"')) return SkipLiteral();
return SkipIdentifier();
}
case '\'': return SkipCharacter();
case '\"': return SkipLiteral();
case 'a' ... 't': case 'v' ... 'z': case 'A' ... 'K':
case 'M' ... 'T': case 'V' ... 'Z': case '_': case '$':
case 0x80 ... 0xfd:
return SkipIdentifier();
case '\\':
// Universal character name is allowed in identifier
if (Test('u') || Test('U'))
return SkipIdentifier();
return MakeToken(Token::INVALID);
case '\0': return MakeToken(Token::END);
default: return MakeToken(Token::INVALID);
}
}
void Scanner::SkipWhiteSpace() {
while (isspace(Peek()) && Peek() != '\n') {
tok_.ws_ = true;
Next();
}
}
void Scanner::SkipComment() {
if (Try('/')) {
// Line comment terminated an newline or eof
while (!Empty()) {
if (Peek() == '\n')
return;
Next();
}
return;
} else if (Try('*')) {
while (!Empty()) {
auto c = Next();
if (c == '*' && Peek() == '/') {
Next();
return;
}
}
Error(loc_, "unterminated block comment");
}
assert(false);
}
std::string Scanner::ScanIdentifier() {
std::string val;
while (!Empty()) {
auto c = Next();
if (IsUCN(c)) {
c = ScanEscaped(); // Call ScanUCN()
AppendUCN(val, c);
} else {
val.push_back(c);
}
}
return val;
}
Token* Scanner::SkipIdentifier() {
PutBack();
auto c = Next();
while (isalnum(c)
|| (0x80 <= c && c <= 0xfd)
|| c == '_'
|| c == '$'
|| IsUCN(c)) {
if (IsUCN(c))
c = ScanEscaped(); // Just read it
c = Next();
}
PutBack();
return MakeToken(Token::IDENTIFIER);
}
// Scan PP-Number
Token* Scanner::SkipNumber() {
PutBack();
bool sawHexPrefix = false;
int tag = Token::I_CONSTANT;
auto c = Next();
while (c == '.' || isdigit(c) || isalpha(c) || c == '_' || IsUCN(c)) {
if (c == 'e' || c =='E' || c == 'p' || c == 'P') {
if (!Try('-')) Try('+');
if (!((c == 'e' || c == 'E') && sawHexPrefix))
tag = Token::F_CONSTANT;
} else if (IsUCN(c)) {
ScanEscaped();
} else if (c == '.') {
tag = Token::F_CONSTANT;
} else if (c == 'x' || c == 'X') {
sawHexPrefix = true;
}
c = Next();
}
PutBack();
return MakeToken(tag);
}
Encoding Scanner::ScanLiteral(std::string& val) {
auto enc = Test('\"') ? Encoding::NONE: ScanEncoding(Next());
Next();
val.resize(0);
while (!Test('\"')) {
auto c = Next();
bool isucn = IsUCN(c);
if (c == '\\')
c = ScanEscaped();
if (isucn)
AppendUCN(val, c);
else
val.push_back(c);
}
return enc;
}
Token* Scanner::SkipLiteral() {
auto c = Next();
while (c != '\"' && c != '\n' && c != '\0') {
if (c == '\\') Next();
c = Next();
}
if (c != '\"')
Error(loc_, "unterminated string literal");
return MakeToken(Token::LITERAL);
}
Encoding Scanner::ScanCharacter(int& val) {
auto enc = Test('\'') ? Encoding::NONE: ScanEncoding(Next());
Next();
val = 0;
while (!Test('\'')) {
auto c = Next();
if (c == '\\')
c = ScanEscaped();
if (enc == Encoding::NONE)
val = (val << 8) + c;
else
val = c;
}
return enc;
}
Token* Scanner::SkipCharacter() {
auto c = Next();
while (c != '\'' && c != '\n' && c != '\0') {
if (c == '\\') Next();
c = Next();
}
if (c != '\'')
Error(loc_, "unterminated character constant");
return MakeToken(Token::C_CONSTANT);
}
int Scanner::ScanEscaped() {
auto c = Next();
switch (c) {
case '\\': case '\'': case '\"': case '\?':
return c;
case 'a': return '\a';
case 'b': return '\b';
case 'f': return '\f';
case 'n': return '\n';
case 'r': return '\r';
case 't': return '\t';
case 'v': return '\v';
// Non-standard GCC extention
case 'e': return '\033';
case 'x': return ScanHexEscaped();
case '0' ... '7': return ScanOctEscaped(c);
case 'u': return ScanUCN(4);
case 'U': return ScanUCN(8);
default: Error(loc_, "unrecognized escape character '%c'", c);
}
return c; // Make compiler happy
}
int Scanner::ScanHexEscaped() {
int val = 0, c = Peek();
if (!isxdigit(c))
Error(loc_, "expect xdigit, but got '%c'", c);
while (isxdigit(c)) {
val = (val << 4) + XDigit(c);
Next();
c = Peek();
}
return val;
}
int Scanner::ScanOctEscaped(int c) {
int val = XDigit(c);
c = Peek();
if (!IsOctal(c))
return val;
val = (val << 3) + XDigit(c);
Next();
c = Peek();
if (!IsOctal(c))
return val;
val = (val << 3) + XDigit(c);
Next();
return val;
}
int Scanner::ScanUCN(int len) {
assert(len == 4 || len == 8);
int val = 0;
for (auto i = 0; i < len; ++i) {
auto c = Next();
if (!isxdigit(c))
Error(loc_, "expect xdigit, but got '%c'", c);
val = (val << 4) + XDigit(c);
}
return val;
}
int Scanner::XDigit(int c) {
switch (c) {
case '0' ... '9': return c - '0';
case 'a' ... 'z': return c - 'a' + 10;
case 'A' ... 'Z': return c - 'A' + 10;
default: assert(false); return c;
}
}
Encoding Scanner::ScanEncoding(int c) {
switch (c) {
case 'u': return Try('8') ? Encoding::UTF8: Encoding::CHAR16;
case 'U': return Encoding::CHAR32;
case 'L': return Encoding::WCHAR;
default: assert(false); return Encoding::NONE;
}
}
std::string* ReadFile(const std::string& filename) {
FILE* f = fopen(filename.c_str(), "r");
if (!f) Error("%s: No such file or directory", filename.c_str());
auto text = new std::string;
int c;
while (EOF != (c = fgetc(f)))
text->push_back(c);
fclose(f);
return text;
}
int Scanner::Next() {
int c = Peek();
++p_;
if (c == '\n') {
++loc_.line_;
loc_.column_ = 1;
loc_.lineBegin_ = p_;
} else {
++loc_.column_;
}
return c;
}
int Scanner::Peek() {
int c = (uint8_t)(*p_);
if (c == '\\' && p_[1] == '\n') {
p_ += 2;
++loc_.line_;
loc_.column_ = 1;
loc_.lineBegin_ = p_;
return Peek();
}
return c;
}
// There couldn't be more than one PutBack() that
// cross two line, so just leave lineBegin, because
// we never care about the pos of newline token
void Scanner::PutBack() {
int c = *--p_;
if (c == '\n' && p_[-1] == '\\') {
--loc_.line_;
--p_;
return PutBack();
} else if (c == '\n') {
--loc_.line_;
} else {
--loc_.column_;
}
}
Token* Scanner::MakeToken(int tag) {
tok_.tag_ = tag;
auto& str = tok_.str_;
str.resize(0);
const char* p = tok_.loc_.lineBegin_ + tok_.loc_.column_ - 1;
for (; p < p_; ++p) {
if (p[0] == '\n' && p[-1] == '\\')
str.pop_back();
else
str.push_back(p[0]);
}
return Token::New(tok_);
}
/*
* New line is special, it is generated before reading the character '\n'
*/
Token* Scanner::MakeNewLine() {
tok_.tag_ = '\n';
tok_.str_ = std::string(p_, p_ + 1);
return Token::New(tok_);
}

View File

@@ -1,111 +0,0 @@
#include "triton/lang/scope.h"
#include "triton/lang/ast.h"
#include <cassert>
#include <iostream>
Identifier* Scope::Find(const Token* tok) {
auto ret = Find(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindInCurScope(const Token* tok) {
auto ret = FindInCurScope(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindTag(const Token* tok) {
auto ret = FindTag(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindTagInCurScope(const Token* tok) {
auto ret = FindTagInCurScope(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
void Scope::Insert(Identifier* ident) {
Insert(ident->Name(), ident);
}
void Scope::InsertTag(Identifier* ident) {
Insert(TagName(ident->Name()), ident);
}
Identifier* Scope::Find(const std::string& name) {
auto ident = identMap_.find(name);
if (ident != identMap_.end())
return ident->second;
if (type_ == S_FILE || parent_ == nullptr)
return nullptr;
return parent_->Find(name);
}
Identifier* Scope::FindInCurScope(const std::string& name) {
auto ident = identMap_.find(name);
if (ident == identMap_.end())
return nullptr;
return ident->second;
}
void Scope::Insert(const std::string& name, Identifier* ident) {
assert(FindInCurScope(name) == nullptr);
identMap_[name] = ident;
}
Identifier* Scope::FindTag(const std::string& name) {
auto tag = Find(TagName(name));
if (tag) assert(tag->ToTypeName());
return tag;
}
Identifier* Scope::FindTagInCurScope(const std::string& name) {
auto tag = FindInCurScope(TagName(name));
assert(tag == nullptr || tag->ToTypeName());
return tag;
}
Scope::TagList Scope::AllTagsInCurScope() const {
TagList tags;
for (auto& kv: identMap_) {
if (IsTagName(kv.first))
tags.push_back(kv.second);
}
return tags;
}
void Scope::Print() {
std::cout << "scope: " << this << std::endl;
auto iter = identMap_.begin();
for (; iter != identMap_.end(); ++iter) {
auto name = iter->first;
auto ident = iter->second;
if (ident->ToTypeName()) {
std::cout << name << "\t[type:\t"
<< ident->Type()->Str() << "]" << std::endl;
} else {
std::cout << name << "\t[object:\t"
<< ident->Type()->Str() << "]" << std::endl;
}
}
std::cout << std::endl;
}

View File

@@ -1,271 +0,0 @@
#include "triton/lang/token.h"
#include "triton/lang/mem_pool.h"
#include "triton/lang/parser.h"
static MemPoolImp<Token> tokenPool;
const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "__constant__", Token::CMEM },
{ "__global__", Token::GLOBAL },
{ "auto", Token::AUTO },
{ "break", Token::BREAK },
{ "case", Token::CASE },
{ "char", Token::CHAR },
{ "const", Token::CONST },
{ "continue", Token::CONTINUE },
{ "default", Token::DEFAULT },
{ "do", Token::DO },
{ "double", Token::DOUBLE },
{ "else", Token::ELSE },
{ "enum", Token::ENUM },
{ "extern", Token::EXTERN },
{ "float", Token::FLOAT },
{ "for", Token::FOR },
{ "goto", Token::GOTO },
{ "half", Token::HALF },
{ "if", Token::IF },
{ "inline", Token::INLINE },
{ "int", Token::INT },
{ "long", Token::LONG },
{ "newaxis", Token::NEWAXIS },
{ "signed", Token::SIGNED },
{ "unsigned", Token::UNSIGNED },
{ "restrict", Token::RESTRICT },
{ "return", Token::RETURN },
{ "short", Token::SHORT },
{ "sizeof", Token::SIZEOF },
{ "static", Token::STATIC },
{ "struct", Token::STRUCT },
{ "switch", Token::SWITCH },
{ "typedef", Token::TYPEDEF },
{ "union", Token::UNION },
{ "void", Token::VOID },
{ "volatile", Token::VOLATILE },
{ "while", Token::WHILE },
{ "bitcast", Token::BITCAST },
{ "exp", Token::EXP },
{ "log", Token::LOG },
{ "sqrtf", Token::SQRTF },
{ "_Alignas", Token::ALIGNAS },
{ "_Alignof", Token::ALIGNOF },
{ "_Atomic", Token::ATOMIC },
{ "__attribute__", Token::ATTRIBUTE },
{ "_Bool", Token::BOOL },
{ "_Complex", Token::COMPLEX },
{ "_Generic", Token::GENERIC },
{ "_Imaginary", Token::IMAGINARY },
{ "_Noreturn", Token::NORETURN },
{ "_Static_assert", Token::STATIC_ASSERT },
{ "_Thread_local", Token::THREAD },
{ "max", Token::MAX },
{ "min", Token::MIN },
};
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ '(', "(" },
{ ')', ")" },
{ '[', "[" },
{ ']', "]" },
{ ':', ":" },
{ ',', "," },
{ ';', ";" },
{ '+', "+" },
{ '-', "-" },
{ '*', "*" },
{ '/', "/" },
{ '|', "|" },
{ '&', "&" },
{ '<', "<" },
{ '>', ">" },
{ '=', "=" },
{ '.', "." },
{ '%', "%" },
{ '{', "{" },
{ '}', "}" },
{ '^', "^" },
{ '~', "~" },
{ '!', "!" },
{ '?', "?" },
{ '#', "#" },
{ '@', "@" },
{ Token::DSHARP, "##" },
{ Token::PTR, "->" },
{ Token::INC, "++" },
{ Token::DEC, "--" },
{ Token::LEFT, "<<" },
{ Token::RIGHT, ">>" },
{ Token::LE, "<=" },
{ Token::GE, ">=" },
{ Token::EQ, "==" },
{ Token::NE, "!=" },
{ Token::LOGICAL_AND, "&&" },
{ Token::LOGICAL_OR, "||" },
{ Token::MUL_ASSIGN, "*=" },
{ Token::DIV_ASSIGN, "/=" },
{ Token::MOD_ASSIGN, "%=" },
{ Token::ADD_ASSIGN, "+=" },
{ Token::SUB_ASSIGN, "-=" },
{ Token::LEFT_ASSIGN, "<<=" },
{ Token::RIGHT_ASSIGN, ">>=" },
{ Token::AND_ASSIGN, "&=" },
{ Token::XOR_ASSIGN, "^=" },
{ Token::OR_ASSIGN, "|=" },
{ Token::ELLIPSIS, "..." },
{ Token::AUTO, "auto" },
{ Token::BREAK, "break" },
{ Token::CASE, "case" },
{ Token::CHAR, "char" },
{ Token::CONST, "const" },
{ Token::CONTINUE, "continue" },
{ Token::DEFAULT, "default" },
{ Token::DO, "do" },
{ Token::DOUBLE, "double" },
{ Token::ELSE, "else" },
{ Token::ENUM, "enum" },
{ Token::EXTERN, "extern" },
{ Token::FLOAT, "float" },
{ Token::FOR, "for" },
{ Token::GLOBAL, "global" },
{ Token::GOTO, "goto" },
{ Token::IF, "if" },
{ Token::INLINE, "inline" },
{ Token::INT, "int" },
{ Token::LONG, "long" },
{ Token::NEWAXIS, "newaxis" },
{ Token::SIGNED, "signed" },
{ Token::UNSIGNED, "unsigned" },
{ Token::RESTRICT, "restrict" },
{ Token::RETURN, "return" },
{ Token::SHORT, "short" },
{ Token::SIZEOF, "sizeof" },
{ Token::STATIC, "static" },
{ Token::STRUCT, "struct" },
{ Token::SWITCH, "switch" },
{ Token::TYPEDEF, "typedef" },
{ Token::UNION, "union" },
{ Token::VOID, "void" },
{ Token::VOLATILE, "volatile" },
{ Token::WHILE, "while" },
{ Token::BITCAST, "bitcast" },
{ Token::EXP, "exp" },
{ Token::LOG, "log" },
{ Token::SQRTF, "sqrtf" },
{ Token::ALIGNAS, "_Alignas" },
{ Token::ALIGNOF, "_Alignof" },
{ Token::ATOMIC, "_Atomic" },
{ Token::ATTRIBUTE, "__attribute__" },
{ Token::BOOL, "_Bool" },
{ Token::COMPLEX, "_Complex" },
{ Token::GENERIC, "_Generic" },
{ Token::IMAGINARY, "_Imaginary" },
{ Token::NORETURN, "_Noreturn" },
{ Token::STATIC_ASSERT, "_Static_assert" },
{ Token::THREAD, "_Thread_local" },
{ Token::END, "(eof)" },
{ Token::IDENTIFIER, "(identifier)" },
{ Token::CONSTANT, "(constant)" },
{ Token::LITERAL, "(string literal)" },
};
Token* Token::New(int tag) {
return new (tokenPool.Alloc()) Token(tag);
}
Token* Token::New(const Token& other) {
return new (tokenPool.Alloc()) Token(other);
}
Token* Token::New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws) {
return new (tokenPool.Alloc()) Token(tag, loc, str, ws);
}
TokenSequence TokenSequence::GetLine() {
auto begin = begin_;
while (begin_ != end_ && (*begin_)->tag_ != Token::NEW_LINE)
++begin_;
auto end = begin_;
return {tokList_, begin, end};
}
/*
* If this seq starts from the begin of a line.
* Called only after we have saw '#' in the token sequence.
*/
bool TokenSequence::IsBeginOfLine() const {
if (begin_ == tokList_->begin())
return true;
auto pre = begin_;
--pre;
// We do not insert a newline at the end of a source file.
// Thus if two token have different filename, the second is
// the begin of a line.
return ((*pre)->tag_ == Token::NEW_LINE ||
(*pre)->loc_.filename_ != (*begin_)->loc_.filename_);
}
const Token* TokenSequence::Peek() const {
static auto eof = Token::New(Token::END);
if (begin_ != end_ && (*begin_)->tag_ == Token::NEW_LINE) {
++begin_;
return Peek();
} else if (begin_ == end_) {
if (end_ != tokList_->begin())
*eof = *Back();
eof->tag_ = Token::END;
return eof;
} else if (parser_ && (*begin_)->tag_ == Token::IDENTIFIER &&
(*begin_)->str_ == "__func__") {
auto filename = Token::New(*(*begin_));
filename->tag_ = Token::LITERAL;
filename->str_ = "\"" + parser_->CurFunc()->Name() + "\"";
*begin_ = filename;
}
return *begin_;
}
const Token* TokenSequence::Expect(int expect) {
auto tok = Peek();
if (!Try(expect)) {
Error(tok, "'%s' expected, but got '%s'",
Token::Lexeme(expect), tok->str_.c_str());
}
return tok;
}
void TokenSequence::Print(FILE* fp) const {
unsigned lastLine = 0;
auto ts = *this;
while (!ts.Empty()) {
auto tok = ts.Next();
if (lastLine != tok->loc_.line_) {
fputs("\n", fp);
for (unsigned i = 0; i < tok->loc_.column_; ++i)
fputc(' ', fp);
} else if (tok->ws_) {
fputc(' ', fp);
}
fputs(tok->str_.c_str(), fp);
fflush(fp);
lastLine = tok->loc_.line_;
}
fputs("\n", fp);
}
//void TokenSequence::Print(std::string *str) const {
//}

View File

@@ -1,508 +0,0 @@
#include "triton/lang/type.h"
#include "triton/lang/ast.h"
#include "triton/lang/scope.h"
#include "triton/lang/token.h"
#include <cassert>
#include <algorithm>
#include <iostream>
static MemPoolImp<VoidType> voidTypePool;
static MemPoolImp<ArrayType> arrayTypePool;
static MemPoolImp<TileType> tileTypePool;
static MemPoolImp<FuncType> funcTypePool;
static MemPoolImp<PointerType> pointerTypePool;
static MemPoolImp<StructType> structUnionTypePool;
static MemPoolImp<ArithmType> arithmTypePool;
QualType Type::MayCast(QualType type, bool inProtoScope) {
auto funcType = type->ToFunc();
auto arrayType = type->ToArray();
if (funcType) {
return PointerType::New(funcType);
} else if (arrayType) {
auto ret = PointerType::New(arrayType->Derived());
// C11 6.7.6.3 [7]: qualifiers are specified in '[]'
// As we do not support qualifiers in '[]', the qualifier whould be none
return QualType(ret, inProtoScope? 0: Qualifier::CONST);
}
return type;
}
const Type* Type::ScalarType() const {
if(IsScalar())
return this;
if(const TileType* p = ToTile())
return p->Derived().GetPtr();
return nullptr;
}
Type* Type::ScalarType() {
auto cthis = const_cast<const Type*>(this);
return const_cast<Type*>(cthis->ScalarType());
}
VoidType* VoidType::New() {
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
return ret;
}
ArithmType* ArithmType::New(int typeSpec) {
#define NEW_TYPE(tag) \
new (arithmTypePool.Alloc()) ArithmType(&arithmTypePool, tag);
static auto boolType = NEW_TYPE(T_BOOL);
static auto charType = NEW_TYPE(T_CHAR);
static auto ucharType = NEW_TYPE(T_UNSIGNED | T_CHAR);
static auto shortType = NEW_TYPE(T_SHORT);
static auto ushortType = NEW_TYPE(T_UNSIGNED | T_SHORT);
static auto intType = NEW_TYPE(T_INT);
static auto uintType = NEW_TYPE(T_UNSIGNED | T_INT);
static auto longType = NEW_TYPE(T_LONG);
static auto ulongType = NEW_TYPE(T_UNSIGNED | T_LONG);
static auto llongType = NEW_TYPE(T_LLONG)
static auto ullongType = NEW_TYPE(T_UNSIGNED | T_LLONG);
static auto halfType = NEW_TYPE(T_HALF);
static auto floatType = NEW_TYPE(T_FLOAT);
static auto doubleType = NEW_TYPE(T_DOUBLE);
static auto ldoubleType = NEW_TYPE(T_LONG | T_DOUBLE);
auto tag = ArithmType::Spec2Tag(typeSpec);
switch (tag) {
case T_BOOL: return boolType;
case T_CHAR: return charType;
case T_UNSIGNED | T_CHAR: return ucharType;
case T_SHORT: return shortType;
case T_UNSIGNED | T_SHORT:return ushortType;
case T_INT: return intType;
case T_UNSIGNED:
case T_UNSIGNED | T_INT: return uintType;
case T_LONG: return longType;
case T_UNSIGNED | T_LONG: return ulongType;
case T_LLONG: return llongType;
case T_UNSIGNED | T_LLONG:return ullongType;
case T_HALF: return halfType;
case T_FLOAT: return floatType;
case T_DOUBLE: return doubleType;
case T_LONG | T_DOUBLE: return ldoubleType;
default:
assert(tag & T_COMPLEX);
Error("complex not supported yet");
}
return nullptr; // Make compiler happy
#undef NEW_TYPE
}
ArrayType* ArrayType::New(int len, QualType eleType) {
return new (arrayTypePool.Alloc())
ArrayType(&arrayTypePool, len, eleType);
}
ArrayType* ArrayType::New(Expr* expr, QualType eleType) {
return new (arrayTypePool.Alloc())
ArrayType(&arrayTypePool, expr, eleType);
}
TileType* TileType::New(const ShapeInt &shape, QualType eleType) {
return new (tileTypePool.Alloc())
TileType(&tileTypePool, shape, eleType);
}
FuncType* FuncType::New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params) {
return new (funcTypePool.Alloc())
FuncType(&funcTypePool, derived, funcSpec, variadic, params);
}
PointerType* PointerType::New(QualType derived) {
return new (pointerTypePool.Alloc())
PointerType(&pointerTypePool, derived);
}
StructType* StructType::New(bool isStruct,
bool hasTag,
Scope* parent) {
return new (structUnionTypePool.Alloc())
StructType(&structUnionTypePool, isStruct, hasTag, parent);
}
int ArithmType::Width() const {
switch (tag_) {
case T_BOOL: case T_CHAR: case T_UNSIGNED | T_CHAR:
return 1;
case T_SHORT: case T_UNSIGNED | T_SHORT:
return intWidth_ >> 1;
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT:
return intWidth_;
case T_LONG: case T_UNSIGNED | T_LONG:
return intWidth_ << 1;
case T_LLONG: case T_UNSIGNED | T_LLONG:
return intWidth_ << 1;
case T_HALF:
return intWidth_ >> 1;
case T_FLOAT:
return intWidth_;
case T_DOUBLE:
return intWidth_ << 1;
case T_LONG | T_DOUBLE:
return intWidth_ << 1;
case T_HALF | T_COMPLEX:
return intWidth_;
case T_FLOAT | T_COMPLEX:
return intWidth_ << 1;
case T_DOUBLE | T_COMPLEX:
return intWidth_ << 2;
case T_LONG | T_DOUBLE | T_COMPLEX:
return intWidth_ << 2;
default:
assert(false);
}
return intWidth_; // Make compiler happy
}
int ArithmType::Rank() const {
switch (tag_) {
case T_BOOL: return 0;
case T_CHAR: case T_UNSIGNED | T_CHAR: return 1;
case T_SHORT: case T_UNSIGNED | T_SHORT: return 2;
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
case T_HALF: return 6;
case T_FLOAT: return 7;
case T_DOUBLE: return 8;
case T_LONG | T_DOUBLE: return 9;
default:
assert(tag_ & T_COMPLEX);
Error("complex not supported yet");
}
return 0;
}
ArithmType* ArithmType::MaxType(ArithmType* lhs,
ArithmType* rhs) {
if (lhs->IsInteger())
lhs = ArithmType::IntegerPromote(lhs);
if (rhs->IsInteger())
rhs = ArithmType::IntegerPromote(rhs);
auto ret = lhs->Rank() > rhs->Rank() ? lhs: rhs;
if (lhs->Width() == rhs->Width() && (lhs->IsUnsigned() || rhs->IsUnsigned()))
return ArithmType::New(T_UNSIGNED | ret->Tag());
return ret;
}
/*
* Converting from type specifier to type tag
*/
int ArithmType::Spec2Tag(int spec) {
if (spec == T_SIGNED) {
return T_INT;
}
spec &= ~T_SIGNED;
if ((spec & T_SHORT) || (spec & T_LONG)
|| (spec & T_LLONG)) {
spec &= ~T_INT;
}
return spec;
}
std::string ArithmType::Str() const {
std::string width = ":" + std::to_string(Width());
switch (tag_) {
case T_BOOL:
return "bool" + width;
case T_CHAR:
return "char" + width;
case T_UNSIGNED | T_CHAR:
return "unsigned char" + width;
case T_SHORT:
return "short" + width;
case T_UNSIGNED | T_SHORT:
return "unsigned short" + width;
case T_INT:
return "int" + width;
case T_UNSIGNED:
return "unsigned int" + width;
case T_LONG:
return "long" + width;
case T_UNSIGNED | T_LONG:
return "unsigned long" + width;
case T_LLONG:
return "long long" + width;
case T_UNSIGNED | T_LLONG:
return "unsigned long long" + width;
case T_FLOAT:
return "float" + width;
case T_DOUBLE:
return "double" + width;
case T_LONG | T_DOUBLE:
return "long double" + width;
case T_FLOAT | T_COMPLEX:
return "float complex" + width;
case T_DOUBLE | T_COMPLEX:
return "double complex" + width;
case T_LONG | T_DOUBLE | T_COMPLEX:
return "long double complex" + width;
default:
assert(false);
}
return "error"; // Make compiler happy
}
bool PointerType::Compatible(const Type& other) const {
// C11 6.7.6.1 [2]: pointer compatibility
auto otherPointer = other.ToPointer();
return otherPointer &&
derived_->Compatible(*otherPointer->derived_);
// FIXME(wgtdkp): cannot loose compatible constraints
//return other.IsInteger() ||
// (otherPointer && derived_->Compatible(*otherPointer->derived_));
}
bool ArrayType::Compatible(const Type& other) const {
// C11 6.7.6.2 [6]: For two array type to be compatible,
// the element types must be compatible, and have same length
// if both specified.
auto otherArray = other.ToArray();
if (!otherArray) return false;
if (!derived_->Compatible(*otherArray->derived_)) return false;
// The lengths should equal if both specified
if (complete_ && otherArray->complete_)
return len_ == otherArray->len_;
return true;
}
TileType::TileType(MemPool* pool, const ShapeInt& shape, QualType derived)
: DerivedType(pool, derived),
shape_(shape) {
bool isComplete = true;
for(int s: shape_)
isComplete = isComplete && (s>=0);
SetComplete(isComplete);
}
bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible,
// the element types must be compatible
// and they must have the same shapea
auto otherTile = other.ToTile();
if(!otherTile)
return false;
if (!derived_->Compatible(*otherTile->derived_))
return false;
// The shapes should be equal if both specified
if(complete_ && otherTile->complete_)
return shape_ == otherTile->shape_;
return true;
}
bool FuncType::Compatible(const Type& other) const {
auto otherFunc = other.ToFunc();
// The other type is not an function type
if (!otherFunc) return false;
// TODO(wgtdkp): do we need to check the type of return value when deciding
// compatibility of two function types ??
if (!derived_->Compatible(*otherFunc->derived_))
return false;
if (params_.size() != otherFunc->params_.size())
return false;
auto thisIter = params_.begin();
auto otherIter = otherFunc->params_.begin();
while (thisIter != params_.end()) {
if (!(*thisIter)->Type()->Compatible(*(*otherIter)->Type()))
return false;
++thisIter;
++otherIter;
}
return true;
}
std::string FuncType::Str() const {
auto str = derived_->Str() + "(";
auto iter = params_.begin();
for (; iter != params_.end(); ++iter) {
str += (*iter)->Type()->Str() + ", ";
}
if (variadic_)
str += "...";
else if (params_.size())
str.resize(str.size() - 2);
return str + ")";
}
StructType::StructType(MemPool* pool,
bool isStruct,
bool hasTag,
Scope* parent)
: Type(pool, false),
isStruct_(isStruct),
hasTag_(hasTag),
memberMap_(new Scope(parent, S_BLOCK)),
offset_(0),
width_(0),
// If a struct type has no member, it gets alignment of 1
align_(1),
bitFieldAlign_(1) {}
Object* StructType::GetMember(const std::string& member) {
auto ident = memberMap_->FindInCurScope(member);
if (ident == nullptr)
return nullptr;
return ident->ToObject();
}
void StructType::CalcWidth() {
width_ = 0;
auto iter = memberMap_->identMap_.begin();
for (; iter != memberMap_->identMap_.end(); ++iter) {
width_ += iter->second->Type()->Width();
}
}
bool StructType::Compatible(const Type& other) const {
return this == &other; // Pointer comparison
}
// TODO(wgtdkp): more detailed representation
std::string StructType::Str() const {
std::string str = isStruct_ ? "struct": "union";
return str + ":" + std::to_string(width_);
}
// Remove useless unnamed bitfield members as they are just for parsing
void StructType::Finalize() {
for (auto iter = members_.begin(); iter != members_.end();) {
if ((*iter)->BitFieldWidth() && (*iter)->Anonymous()) {
members_.erase(iter++);
} else {
++iter;
}
}
}
void StructType::AddMember(Object* member) {
auto offset = MakeAlign(offset_, member->Align());
member->SetOffset(offset);
members_.push_back(member);
memberMap_->Insert(member->Name(), member);
align_ = std::max(align_, member->Align());
bitFieldAlign_ = std::max(bitFieldAlign_, align_);
if (isStruct_) {
offset_ = offset + member->Type()->Width();
width_ = MakeAlign(offset_, align_);
} else {
assert(offset_ == 0);
width_ = std::max(width_, member->Type()->Width());
width_ = MakeAlign(width_, align_);
}
}
void StructType::AddBitField(Object* bitField, int offset) {
bitField->SetOffset(offset);
members_.push_back(bitField);
if (!bitField->Anonymous())
memberMap_->Insert(bitField->Name(), bitField);
auto bytes = MakeAlign(bitField->BitFieldEnd(), 8) / 8;
bitFieldAlign_ = std::max(bitFieldAlign_, bitField->Align());
// Does not aligned, default is 1
if (isStruct_) {
offset_ = offset + bytes;
width_ = MakeAlign(offset_, std::max(bitFieldAlign_, bitField->Align()));
} else {
assert(offset_ == 0);
width_ = std::max(width_, bitField->Type()->Width());
}
}
// Move members of Anonymous struct/union to external struct/union
void StructType::MergeAnony(Object* anony) {
auto anonyType = anony->Type()->ToStruct();
auto offset = MakeAlign(offset_, anony->Align());
// Members in map are never anonymous
for (auto& kv: *anonyType->memberMap_) {
auto& name = kv.first;
auto member = kv.second->ToObject();
if (member == nullptr) {
continue;
}
// Every member of anonymous struct/union
// are offseted by external struct/union
member->SetOffset(offset + member->Offset());
if (GetMember(name)) {
Error(member, "duplicated member '%s'", name.c_str());
}
// Simplify anony struct's member searching
memberMap_->Insert(name, member);
}
anony->SetOffset(offset);
members_.push_back(anony);
align_ = std::max(align_, anony->Align());
if (isStruct_) {
offset_ = offset + anonyType->Width();
width_ = MakeAlign(offset_, align_);
} else {
assert(offset_ == 0);
width_ = std::max(width_, anonyType->Width());
}
}

View File

View File

@@ -1,364 +0,0 @@
#include <string>
#include <mutex>
#include <regex>
#include <functional>
#include <algorithm>
#include <sstream>
#include <memory>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/runtime/function.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/lang/code_gen.h"
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/driver/error.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/print.h"
#include "triton/runtime/error.h"
#include "triton/tools/bench.hpp"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
#include "llvm/IR/Module.h"
#include <mutex>
#include <fstream>
namespace triton{
namespace runtime {
/* --------------------------------- */
/* --------------------------------- */
/* --------------------------------- */
std::shared_ptr<ir::module> kernel::src_to_ir(const std::string& _src, const options_t& opt) {
std::string src =
R"(
#define bool _Bool
#define true 1
#define false 0
#define __readonly __attribute__((readonly))
#define __writeonly __attribute__((writeonly))
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
#define __retune __attribute__((retune))
#define F32_INFINITY bitcast<float>(0x7F800000)
#define F16_INFINITY bitcast<half>((int16)0x7C00)
#define min(a,b) (((a)<(b))?(a):(b))
#define max(a,b) (((a)>(b))?(a):(b))
#define PASTER(a, b, _) a ## _ ## b
#define EVALUATOR(a, b, _) PASTER(a, b, _)
#define atomic_add(TYPE, TM, TN) EVALUATOR(atomic_add, EVALUATOR(TYPE, EVALUATOR(TM, TN, x), _), _)
#define DECLARATION(TYPE, TM, TN) extern void atomic_add(TYPE, TM, TN)(TYPE*[TM, TN], TYPE[TM, TN], bool[TM, TN])
DECLARATION(float, 64, 64);
DECLARATION(float, 64, 128);
DECLARATION(float, 128, 64);
DECLARATION(float, 128, 128);
extern void atomic_add_half_1x1(half*, half, bool);
DECLARATION(half , 64, 64);
DECLARATION(half , 64, 128);
DECLARATION(half , 128, 64);
DECLARATION(half , 128, 128);
extern void atomic_add_float_1x1(float*, float, bool);
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern int get_program_id(int);
extern void __debug_barrier();
extern int get_num_programs(int);
extern int select(bool, int, int);
extern char __constant__ * calloc(int);
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long uint64;
typedef char int8;
typedef short int16;
typedef int int32;
typedef long int64;
)";
src += _src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&src, true);
for(auto it: opt.defines)
cpp.AddMacro(it.first, &it.second);
cpp.Process(tokens);
// src -> ast
Parser parser(tokens);
parser.Parse();
// ast -> triton-ir
auto ret = std::make_shared<ir::module>("");
Generator gen(&parser);
gen.Gen(&*ret);
return ret;
}
std::tuple<std::shared_ptr<driver::module>,
std::shared_ptr<driver::kernel>,
size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(ir);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
// ir::print(ir, std::cout);
if(target->is_gpu())
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if(target->is_gpu()){
reassociate.run(ir);
cts.run(ir);
}
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
barriers.run(ir);
isel.visit(ir, *llvm);
std::shared_ptr<driver::module> mod(driver::module::create(dev, std::move(llvm)));
std::shared_ptr<driver::kernel> ker(driver::kernel::create(&*mod, name.c_str()));
size_t shared_mem = allocation.allocated_size();
return std::make_tuple(mod, ker, shared_mem);
}
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map<int, ir::attribute> &attrs):
opt(opt), dev_(dev) {
// compile to Triton IR
ir_ = src_to_ir(src, opt);
// add attributes
for(const auto&x: attrs)
ir_->get_function_list()[0]->add_attr(x.first, x.second);
// compile to binary
std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt);
}
void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector<size_t>& _grid) const{
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
std::array<size_t, 3> grid;
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
}
std::string kernel::get_asm(const std::string& mode) {
std::vector<std::string> modes = {"llir", "ptx"};
if(std::find(modes.begin(), modes.end(), mode) == modes.end()){
std::string err = "Unrecognized mode. Supported values are: ";
for(std::string m: modes){
if(m != modes[0])
err += ", ";
err += m;
}
throw std::runtime_error(err);
}
if(mode == "llir")
return ((driver::cu_module*)mod_.get())->llir();
if(mode == "ptx")
return ((driver::cu_module*)mod_.get())->ptx();
assert(false);
return "";
}
/* --------------------------------- */
/* --------------------------------- */
/* --------------------------------- */
function::function(const std::string& src, const options_t &opt, driver::device *device,
const std::vector<config> &tune_confs, const std::vector<std::string>& tune_key)
: src_(src), device_(device) {
// kernel options
size_t num_opts = std::max(tune_confs.size(), (size_t)1);
opts_ = std::vector<options_t>(num_opts, opt);
for(size_t i = 0; i < tune_confs.size(); i++){
opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end());
opts_[i].num_warps = tune_confs[i].num_warps;
}
std::shared_ptr<ir::module> ir = kernel::src_to_ir(src, opts_[0]);
std::vector<ir::argument*> args = ir->get_function_list()[0]->args();
// signature
auto convert = [](ir::type *ty) {
if(ty->is_integer_ty(1)) return INT1_T;
if(ty->is_integer_ty(8)) return INT8_T;
if(ty->is_integer_ty(16)) return INT16_T;
if(ty->is_integer_ty(32)) return INT32_T;
if(ty->is_integer_ty(64)) return INT64_T;
if(ty->is_half_ty()) return HALF_T;
if(ty->is_float_ty()) return FLOAT_T;
if(ty->is_double_ty()) return DOUBLE_T;
if(ty->is_pointer_ty()) return BUFFER_T;
throw std::runtime_error("unknown type");
};
for(ir::argument* arg: args)
sig_.push_back(convert(arg->get_type()));
// find indices of autotune keys
for(const std::string& name: tune_key){
auto pred = [&](ir::argument* arg) { return arg->get_name() == name; };
// std::cout << "----" << std::endl;
// for(ir::argument* arg: args)
// std::cout << arg->get_name() << std::endl;
auto it = std::find_if(args.begin(), args.end(), pred);
if(it == args.end())
throw std::runtime_error(name + " is not a valid argument name");
key_idxs_.push_back(std::distance(args.begin(), it));
}
// find indices of pointer
for(size_t i = 0; i < args.size(); i++)
if(args[i]->get_type()->is_pointer_ty() ||
args[i]->get_type()->is_integer_ty())
align_idxs_.push_back(i);
// argument size and offset
size_t curr = 0;
for(arg_type ty: sig_){
arg_size_.push_back(size_of(ty));
arg_off_.push_back(curr);
curr += arg_size_.back();
}
}
uint64_t pow2_divisor(uint64_t N){
if(N % 16 == 0) return 16;
if(N % 8 == 0) return 8;
if(N % 4 == 0) return 4;
if(N % 2 == 0) return 2;
return 1;
}
kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) {
// align key
std::vector<uint64_t> rt_key(align_idxs_.size(), 0);
for(size_t i = 0; i < align_idxs_.size(); i++){
int idx = align_idxs_[i];
uint64_t tmp = 0;
std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
rt_key[i] = pow2_divisor(tmp);
}
// auto-tuning key
std::vector<uint64_t> at_key(key_idxs_.size(), 0);
for(size_t i = 0; i < at_key.size(); i++){
int idx = key_idxs_[i];
std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
}
// cache key
std::vector<uint64_t> cache_key;
cache_key.reserve(rt_key.size() + at_key.size());
cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end());
cache_key.insert(cache_key.end(), at_key.begin(), at_key.end());
auto it = cache_.find(cache_key);
if(it != cache_.end())
return it->second;
// compile kernels
if(kernels_.find(rt_key) == kernels_.end()){
std::map<int, ir::attribute> attrs;
for(size_t i = 0; i < align_idxs_.size(); i++){
bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T;
attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])});
}
for(const options_t& opt: opts_)
kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs));
}
// run auto-tuner
double best_ts = INFINITY;
auto& kernels = kernels_.at(rt_key);
kernel* ret = nullptr;
if(kernels.size() == 1)
ret = &*kernels.back();
else{
for(auto &current : kernels_.at(rt_key)){
auto grid = grid_fn(current->opt);
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, stream, grid); },
stream, 5, 20);
ret = (ts < best_ts) ? &*current : ret;
best_ts = std::min(ts, best_ts);
}
stream->synchronize();
}
it = cache_.insert({cache_key, ret}).first;
return it->second;
}
void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
runtime::kernel* fn = autotune(args, grid_fn, stream);
(*fn)(args, stream, grid_fn(fn->opt));
}
}
}

View File

@@ -49,11 +49,11 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
# self.debug = True
self.debug = False
#self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm")
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), f"llvm-{build_suffix}")
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):

676
python/src/functions.h Normal file
View File

@@ -0,0 +1,676 @@
#include "triton/ir/builder.h"
#include <functional>
#include <iostream>
#include <pybind11/pybind11.h>
namespace ir = triton::ir;
namespace py = pybind11;
static const std::string _builder_doc = R"pbdoc(
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
:type builder: triton.ir.builder
)pbdoc";
#define VA_ARGS(...) , ##__VA_ARGS__
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
void throw_not_implemented(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
}
void throw_not_int_or_float(std::string key) {
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
}
enum type_code {
_bool,
int8,
int16,
int32,
int64,
float16,
float32,
float64
};
ir::type *make_ir(type_code ty, ir::builder *builder) {
switch (ty) {
case float16:
return builder->get_half_ty();
case float32:
return builder->get_float_ty();
default:
throw_not_implemented("make_ir");
}
}
type_code from_ir(ir::type *ty) {
if (ty->is_half_ty())
return float16;
if (ty->is_float_ty())
return float32;
throw_not_implemented("from_ir");
}
/*----------------------------------------------
definition of triton.cast / triton.ir.value.to
----------------------------------------------*/
std::string cast_docstr = R"pbdoc(
Tries to cast a block to a new data type.
:param input: The input block.
:type input: triton.ir.value
)pbdoc";
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
ir::type *src_ty = input->get_type();
ir::type *dst_ty = make_ir(_dtype, builder);
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
return builder->create_fp_to_si(input, dst_ty);
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
return builder->create_si_to_fp(input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, int64, builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_not_implemented("cast");
}
/*----------------------------------------------
definition of triton.broadcast_check
----------------------------------------------*/
std::string try_broadcast_docstr = R"pbdoc(
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
)pbdoc";
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
/*----------------------------------------------
definition of triton.broadcast_to
----------------------------------------------*/
std::string broadcast_to_docstr = R"pbdoc(
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
)pbdoc";
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
return builder->create_broadcast(input, shape);
}
/*----------------------------------------------
definition of triton.load
----------------------------------------------*/
std::string load_docstr = R"pbdoc(
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
)pbdoc";
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
if (!_mask.has_value() && !_other.has_value())
return builder->create_load(pointer);
if (!_mask.has_value())
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::value *mask = _mask.value();
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = pointer->get_type()->get_block_shapes();
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
other = cast(other, from_ir(elt_ty), builder);
other = broadcast_to(other, shape, builder);
mask = broadcast_to(mask, shape, builder);
return builder->create_masked_load(pointer, mask, other);
}
/*----------------------------------------------
definition of triton.store
----------------------------------------------*/
std::string store_docstr = R"pbdoc(
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
)pbdoc";
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
if (!_mask.has_value())
return builder->create_store(ptr, val);
ir::value *mask = _mask.value();
return builder->create_masked_store(ptr, val, mask);
}
/*----------------------------------------------
definition of triton.dot
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
)pbdoc";
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
/*----------------------------------------------
definition of triton.where
----------------------------------------------*/
std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
)pbdoc";
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
return builder->create_select(condition, x, y);
};
/*----------------------------------------------
definition of triton.arange
----------------------------------------------*/
std::string arange_docstr = R"pbdoc(
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
)pbdoc";
ir::value *arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
};
/*----------------------------------------------
definition of triton.program_id
----------------------------------------------*/
std::string program_id_docstr = R"pbdoc(
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
};
/*----------------------------------------------
definition of triton.num_programs
----------------------------------------------*/
std::string num_programs_docstr = R"pbdoc(
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
};
/*----------------------------------------------
definition of triton.zeros
----------------------------------------------*/
std::string zeros_docstr = R"pbdoc(
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16
:type dtype: triton.ir.dtype
)pbdoc";
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
ir::type *dtype = make_ir(_dtype, builder);
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
};
/*----------------------------------------------
definition of triton.exp
----------------------------------------------*/
std::string _exp_docstr = R"pbdoc(
Returns the element-wise exponential of `input`.
)pbdoc";
ir::value *_exp(ir::value *input, ir::builder *builder) {
return builder->create_exp(input);
};
/*----------------------------------------------
definition of triton.log
----------------------------------------------*/
std::string _log_docstr = R"pbdoc(
Returns the element-wise natural logarithm of `input`.
)pbdoc";
ir::value *_log(ir::value *input, ir::builder *builder) {
return builder->create_log(input);
};
/*----------------------------------------------
definition of triton.sqrt
----------------------------------------------*/
std::string sqrt_docstr = R"pbdoc(
Returns the element-wise square root of `input`.
)pbdoc";
ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
else
throw_not_int_or_float(name);
}
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value of `input`.
)pbdoc";
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/
std::string sum_docstr = R"pbdoc(
Returns the sum of `input`.
)pbdoc";
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
};
/*----------------------------------------------
definition of triton.atomic_cas
----------------------------------------------*/
std::string atomic_cas_docstr = R"pbdoc(
Atomic compare-and-swap.
)pbdoc";
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
return builder->create_atomic_cas(ptr, cmp, val);
};
/*----------------------------------------------
definition of triton.atomic_xchg
----------------------------------------------*/
std::string atomic_xchg_docstr = R"pbdoc(
Atomic exchange.
)pbdoc";
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
return builder->create_atomic_exch(ptr, val);
};
/*----------------------------------------------
debug barrier
----------------------------------------------*/
std::string debug_barrier_docstr = R"pbdoc(
Temporary hacky fixup for when the compiler forgets to insert sync barriers
)pbdoc";
ir::value *debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
template <class FN>
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
binary_op(const FN &fn) {
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
//std::tie(self, other) = try_broadcast(self, other, builder);
return fn(self, other, builder);
};
return ret;
}
/*----------------------------------------------
definition of self + other
----------------------------------------------*/
std::string add_docstr = R"pbdoc(
Returns self + other, element-wise.
)pbdoc";
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
else if (scalar_ty->is_floating_point_ty())
return builder->create_fadd(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_add(self, other);
throw_not_implemented("add");
}
/*----------------------------------------------
definition of self - other
----------------------------------------------*/
std::string sub_docstr = R"pbdoc(
Returns self - other, element-wise.
)pbdoc";
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(self, other);
throw_not_implemented("sub");
}
/*----------------------------------------------
definition of self * other
----------------------------------------------*/
std::string mul_docstr = R"pbdoc(
Returns self * other, element-wise.
)pbdoc";
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(self, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(self, other);
throw_not_implemented("mul");
}
/*----------------------------------------------
definition of self > other
----------------------------------------------*/
std::string greater_than_docstr = R"pbdoc(
Returns self > other, element-wise.
)pbdoc";
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(self, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(self, other);
throw_not_implemented("greater_than");
}
/*----------------------------------------------
definition of self >= other
----------------------------------------------*/
std::string greater_equal_docstr = R"pbdoc(
Returns self >= other, element-wise.
)pbdoc";
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(self, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(self, other);
throw_not_implemented("greater_equal");
}
/*----------------------------------------------
definition of self < other
----------------------------------------------*/
std::string less_than_docstr = R"pbdoc(
Returns self < other, element-wise.
)pbdoc";
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(self, other);
throw_not_implemented("less_than");
}
/*----------------------------------------------
definition of self <= other
----------------------------------------------*/
std::string less_equal_docstr = R"pbdoc(
Returns self <= other, element-wise.
)pbdoc";
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(self, other);
throw_not_implemented("less_equal");
}
/*----------------------------------------------
definition of self == other
----------------------------------------------*/
std::string equal_docstr = R"pbdoc(
Returns self == other, element-wise.
)pbdoc";
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(self, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(self, other);
throw_not_implemented("equal");
}
/*----------------------------------------------
definition of self / other
----------------------------------------------*/
std::string _div_docstr = R"pbdoc(
Returns self / other, element-wise.
)pbdoc";
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float / float
if (scalar_ty->is_floating_point_ty())
return builder->create_fdiv(self, other);
// int / int
else if (scalar_ty->is_integer_ty())
return builder->create_sdiv(self, other);
throw_not_implemented("div");
}
/*----------------------------------------------
definition of self % other
----------------------------------------------*/
std::string mod_docstr = R"pbdoc(
Returns self % other, element-wise.
)pbdoc";
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(self, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(self, other);
throw_not_implemented("mod");
}
/*----------------------------------------------
definition of self & other
----------------------------------------------*/
std::string _and_docstr = R"pbdoc(
Returns self & other, element-wise.
)pbdoc";
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
return builder->create_and(self, other);
}
/*----------------------------------------------
definition of minimum(self, other)
----------------------------------------------*/
std::string minimum_docstr = R"pbdoc(
Returns element-wise minimum of self and other
)pbdoc";
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
return where(less_than(self, other, builder), self, other, builder);
}
/*----------------------------------------------
definition of self[slices]
----------------------------------------------*/
enum slice_mode_t {
NEWAXIS,
ALL
};
std::string subscript_docstr = R"pbdoc(
returns self[slices].
:param slices: The slices to subscript with.
:type slices: List of `None` or `:` slices.
)pbdoc";
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
std::vector<slice_mode_t> modes;
for (py::object slice : slices) {
py::object none = py::none();
py::object all = py::make_tuple(none, none, none);
if (slice.is(none))
modes.push_back(NEWAXIS);
else if (all.attr("__eq__")(slice))
modes.push_back(ALL);
else
throw std::runtime_error("slice must be None or (None, None, None)");
}
ir::type::block_shapes_t shape;
size_t curr = 0;
for (slice_mode_t mode : modes) {
if (mode == NEWAXIS)
shape.push_back(1);
else {
assert(mode == ALL);
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
}
}
return builder->create_reshape(self, shape);
}

View File

@@ -1,5 +1,13 @@
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/codegen/pass.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/driver/stream.h"
#include "triton/ir/builder.h"
#include "triton/ir/dispatch.h"
#include "triton/ir/enums.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
@@ -8,78 +16,9 @@
#include <string>
namespace py = pybind11;
using namespace triton;
namespace rt = triton::runtime;
namespace ir = triton::ir;
namespace drv = triton::driver;
/*****************************************************************************/
/* Python bindings for triton::tools */
/*****************************************************************************/
/*!
@brief Function for extracting kernels out of a given source-string
This can be important to enable pre-processor macros (or tunable parameters) that should only
be defined within the scope of a single kernel function
*/
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
if (names.empty())
return str;
// search for all regex matches of kernel_regex in str
std::smatch matches;
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
std::sregex_iterator it(str.begin(), str.end(), regex);
std::sregex_iterator end;
std::vector<std::tuple<std::string, int, int>> kernels;
for (; it != end; ++it) {
int pos = it->position();
int len = it->length();
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
// check that all the kernels provided actually exist
for (const std::string &name : names) {
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if (!found)
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// simple parsing logic to extract the declaration and body of each specified kernel
std::string ret;
for (const auto &k : kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if (std::find(names.begin(), names.end(), name) == names.end())
continue;
std::string def = str.substr(pos, str.size() - pos);
// skip over declaration
// by finding matching ')' for first '('
int count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
// by finding matching '{' for first '}'
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
return ret;
}
void init_triton_tools(py::module &&m) {
m.def("extract_kernels", &extract_kernels);
}
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
@@ -88,14 +27,14 @@ void init_triton_driver(py::module &&m) {
// base device
py::class_<drv::device>(m, "device");
// cuda device
py::class_<drv::cu_device, driver::device>(m, "cu_device")
py::class_<drv::cu_device, drv::device>(m, "cu_device")
.def(py::init([](int dev_id, bool take_ownership) {
CUdevice handle;
drv::dispatch::cuDeviceGet(&handle, dev_id);
return new drv::cu_device(handle, take_ownership);
}));
// host device
py::class_<drv::host_device, driver::device>(m, "host_device")
py::class_<drv::host_device, drv::device>(m, "host_device")
.def(py::init<>());
// base stream
@@ -108,54 +47,236 @@ void init_triton_driver(py::module &&m) {
// py doesn't support opaque pointer (e.g., CUstream) so
// we assume it has been converted to uint64_t
.def(py::init([](uint64_t handle, bool take_ownership) {
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
}));
return std::unique_ptr<drv::cu_stream>(new drv::cu_stream((CUstream)handle, take_ownership));
}))
.def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel,
size_t grid_0, size_t grid_1, size_t grid_2,
size_t block_0, size_t block_1, size_t block_2,
const std::string &args,
size_t shared_mem) {
return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2},
(void *)args.data(), args.size(), shared_mem);
});
py::class_<drv::module>(m, "module");
//py::class_<drv::cu_module, drv::module>(m, "cu_module");
py::class_<drv::kernel>(m, "kernel");
}
/*****************************************************************************/
/* Python bindings for triton::runtime */
/* Python bindings for triton::codegen */
/*****************************************************************************/
void init_triton_runtime(py::module &&m) {
// argument type
py::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
// compilation options
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
.def(py::init<>())
.def_readwrite("defines", &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
return opt->D<int>(name);
});
// kernel
py::class_<rt::kernel>(m, "kernel")
.def("__call__", &rt::kernel::operator())
.def_readonly("opt", &rt::kernel::opt)
.def("asm", &rt::kernel::get_asm);
// tune conf
py::class_<rt::config>(m, "config")
.def(py::init<std::map<std::string, std::string>, int>(),
py::arg("defines") = std::map<std::string, std::string>(),
py::arg("num_warps"));
// function
py::class_<rt::function>(m, "function")
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
.def("signature", &rt::function::get_signature);
void init_triton_codegen(py::module &&m) {
m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
drv::module *mod;
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
return std::make_tuple(mod, ker, shared_mem);
},
py::return_value_policy::take_ownership);
}
/*****************************************************************************/
/* User-facing language features */
/*****************************************************************************/
void init_triton_frontend(py::module &&m) {
using ret = py::return_value_policy;
// programming model
m.def("program_id", &ir::dispatch::program_id, ret::reference);
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
// binary
m.def("add", &ir::dispatch::add, ret::reference);
m.def("sub", &ir::dispatch::sub, ret::reference);
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
m.def("xor_", &ir::dispatch::xor_, ret::reference);
m.def("lshr", &ir::dispatch::lshr, ret::reference);
m.def("shl", &ir::dispatch::shl, ret::reference);
// unary
m.def("plus", &ir::dispatch::plus, ret::reference);
m.def("minus", &ir::dispatch::minus, ret::reference);
m.def("invert", &ir::dispatch::invert, ret::reference);
// comparison
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
m.def("less_than", &ir::dispatch::less_than, ret::reference);
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
m.def("equal", &ir::dispatch::equal, ret::reference);
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
// block creation
m.def("arange", &ir::dispatch::arange, ret::reference);
m.def("zeros", &ir::dispatch::zeros, ret::reference);
// type manipuatation
m.def("reshape", &ir::dispatch::reshape, ret::reference);
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing
m.def("where", &ir::dispatch::where, ret::reference);
// reduction
m.def("min", &ir::dispatch::min, ret::reference);
m.def("max", &ir::dispatch::max, ret::reference);
m.def("sum", &ir::dispatch::sum, ret::reference);
// math
m.def("exp", &ir::dispatch::exp, ret::reference);
m.def("log", &ir::dispatch::log, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
}
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;
py::class_<ir::context>(m, "context")
.def(py::init<>());
auto value = py::class_<ir::value>(m, "value");
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
value.def_property_readonly("type", &ir::value::get_type);
py::class_<ir::user, ir::value>(m, "user");
py::class_<ir::constant, ir::user>(m, "constant");
py::class_<ir::undef_value, ir::constant>(m, "undef")
.def("get", &ir::undef_value::get, ret::reference);
py::class_<ir::constant_int, ir::constant>(m, "constant_int")
.def_property_readonly("value", &ir::constant_int::get_value)
.def("__int__", [](ir::constant_int *self) { return self->get_value(); });
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value);
py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference)
.def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference)
.def("get_void", &ir::type::get_void_ty, ret::reference)
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
.def("get_int1", &ir::type::get_int1_ty, ret::reference)
.def("get_int8", &ir::type::get_int8_ty, ret::reference)
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp16", &ir::type::is_half_ty)
.def("is_fp32", &ir::type::is_float_ty)
.def("is_fp64", &ir::type::is_double_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference);
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::integer_type, ir::type>(m, "integer_type");
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::scope>(m, "scope")
.def(py::init<>())
.def_property_readonly("values", &ir::scope::get_values)
.def("set_type", &ir::scope::set_type);
py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>())
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("add_new_scope", &ir::module::add_new_scope, ret::reference)
.def("seal_block", &ir::module::seal_block)
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("pop_scope", &ir::module::pop_scope)
.def_property_readonly("scope", &ir::module::get_scope, ret::reference)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;
py::enum_<eattr>(m, "attribute_kind")
.value("readonly", eattr::readonly)
.value("writeonly", eattr::writeonly)
.value("noalias", eattr::noalias)
.value("aligned", eattr::aligned)
.value("multiple_of", eattr::multiple_of)
.value("retune", eattr::retune)
.value("not_implemented", eattr::not_implemented);
py::class_<ir::attribute>(m, "attribute")
.def(py::init<eattr, int>());
py::class_<ir::function>(m, "function")
.def_property_readonly("args", &ir::function::args)
.def_property_readonly("attrs", &ir::function::attrs)
.def("add_attr", &ir::function::add_attr);
py::class_<ir::argument, ir::value>(m, "argument");
py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
// control flow
.def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
// constants
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference);
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_driver(std::move(subm.def_submodule("driver")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_tools(std::move(subm.def_submodule("tools")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_triton_frontend(std::move(subm.def_submodule("frontend")));
}

View File

@@ -0,0 +1,209 @@
import torch
import triton
import copy
import pytest
import ast
torch.manual_seed(0)
# convert from string to torch.dtype
# Necessary because doesn't print torch.dtype properly
cvt = {
'bool': torch.bool,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
# generic test functions
def _test_unary(dtype_x, expr, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
# reference result
z_ref = eval(expr)
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
y = triton.load(Y + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
# reference result
z_ref = eval(expr)
# triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['+', '-', '*', '/', '%'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['&', '|', '^'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
if 'float' in dtype_x + dtype_y:
with pytest.raises(RuntimeError):
_test_binary(dtype_x, dtype_y, expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test compare ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['==', '!=', '>', '<', '>=', '<='] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, f' -x') for dtype_x in float_dtypes
] + [\
(dtype_x, f' ~x') for dtype_x in int_dtypes
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# ----------------
# test indexing
# ----------------
def make_ptr_str(name, shape):
rank = len(shape)
offsets = []
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
['None, :', ':, None',\
'None, :, :', ':, :, None']\
])
def test_index1d(expr, device='cuda'):
dtype = torch.int32
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
# Triton kernel
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = triton.arange(0, SIZE)
n = triton.arange(0, SIZE)
x = triton.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
kernel = patch_kernel(kernel, to_replace)
# torch result
x = triton.testing.random(shape_x, dtype=dtype, device=device)
y = torch.zeros(shape_z, dtype=dtype, device=device)
z_ref = eval(expr) + y
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
# compare
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test load
# ---------------
# ---------------
# test store
# ---------------
# ---------------
# test if
# ---------------
# ---------------
# test for
# ---------------
# ---------------
# test while
# ---------------

View File

@@ -1,17 +0,0 @@
import torch
import triton
def test_op():
torch.manual_seed(0)
DTYPE = torch.float16
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad, stride, = (1, 1), (1, 1)
dilation = (1, 1)
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
tt_c = triton.ops.conv(a, b, pad, stride)
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)

View File

@@ -3,9 +3,11 @@ import itertools
import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
@@ -17,14 +19,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# # 2 warp
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# # 4 warp
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -32,37 +34,43 @@ import torch
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# # split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
# # variable input
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
]),
]
),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0)
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
if M is None:
M = TM
if N is None:
N = TN
if K is None:
K = TK * SPLITK
# nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
configs = [triton.Config(meta=META, num_warps=NWARP)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c)

View File

@@ -2,9 +2,10 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from . import testing
from .kernel import *
from . import ops
from .code_gen import jit, autotune, heuristics, Config, Autotuner
from .core import *
from . import testing
from . import ops
# version
__version__ = '1.0.0'

648
python/triton/code_gen.py Normal file
View File

@@ -0,0 +1,648 @@
import inspect
import struct
import enum
import types
import torch
import ast
import builtins
import triton._C.libtriton.triton as _triton
import triton
import sys
import textwrap
from abc import ABC, abstractmethod
class CodeGenerator(ast.NodeVisitor):
def get_value(self, name):
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block):
handle = self.module.get_value(name)
return triton.block(handle)
return ret
def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.block(value)
if isinstance(value, triton.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
return isinstance(value, triton.block)
def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
self.module.add_new_scope()
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
if add_scope:
self.module.pop_scope()
return self.last_ret
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder)
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
def visit_Module(self, node):
self.module.add_new_scope()
ast.NodeVisitor.generic_visit(self, node)
self.module.pop_scope()
def visit_List(self, node):
ctx = self.visit(node.ctx)
assert ctx is None
elts = [self.visit(elt) for elt in node.elts]
return elts
# By design, only non-kernel functions can return
def visit_Return(self, node):
return self.visit(node.value)
def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
pass
else:
fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = []
for i, arg_name in enumerate(arg_names):
if i in self.constants:
arg_values.append(self.constants[i])
else:
if i in self.attributes:
is_ptr = fn.args[i].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name
arg_values.append(fn.args[i])
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
return self.visit_compound_statement(node.body, add_scope=True)
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
self.builder.set_insert_block(entry)
# visit function body
self.visit_compound_statement(node.body, add_scope=True)
# finalize function
self.builder.ret_void()
def visit_arguments(self, node):
arg_names = []
for arg in node.args:
arg_names += [self.visit(arg)]
kwarg_names = self.visit(node.kwarg)
return arg_names, kwarg_names
def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_Assign(self, node):
names = []
for target in node.targets:
names += [self.visit(target)]
assert len(names) == 1
name = names[0]
value = self.visit(node.value)
self.set_value(names[0], value)
def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.get_value(name)
def visit_Name(self, node):
if type(node.ctx) == ast.Store:
return node.id
return self.get_value(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Load(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
return tuple(args)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
kws = dict()
if self.is_triton_object(lhs):
kws['builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['builder'] = self.builder
fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws)
return ret
def visit_If(self, node):
cond = self.visit(node.test)
if self.is_triton_object(cond):
current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self.module.seal_block(then_bb)
if else_bb:
self.module.seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb)
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
self.visit_compound_statement(node.body, add_scope=True)
# TODO: last statement is a terminator?
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
self.visit_compound_statement(node.orelse, add_scope=True)
#TODO: last statement is a terminator?
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:
if cond:
self.visit_compound_statement(node.body)
else:
self.visit_compound_statement(node.orelse)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if cond:
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_Pass(self, node):
pass
def visit_Compare(self, node):
assert len(node.comparators) == 1
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
return getattr(lhs, fn)(rhs)
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Invert: '__invert__',
}[type(node.op)]
if self.is_triton_object(op):
return getattr(op, fn)(builder=self.builder)
return getattr(op, fn)()
def visit_While(self, node):
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn():
cond = self.visit(node.test)
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
continue_fn()
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Str(self, node):
return ast.literal_eval(node)
def visit_Subscript(self, node):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if self.is_triton_object(lhs):
return lhs.__getitem__(slices, builder=self.builder)
return lhs[slices]
def visit_ExtSlice(self, node):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
iterator = self.visit(node.iter.func)
assert iterator == self.builtins['range']
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
# code generation
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn():
self.visit(step_node)
cond = build_cond()
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.visit(init_node)
cond = build_cond()
self.builder.cond_br(cond.handle, loop_bb, next_bb)
self.builder.set_insert_block(loop_bb)
self.visit_compound_statement(node.body, add_scope=True)
# TODO: handle case where body breaks control flow
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
step = self.visit(node.step)
return slice(lower, upper, step)
def visit_Index(self, node):
return self.visit(node.value)
def visit_NameConstant(self, node):
return node.value
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_Call(self, node):
fn = self.visit(node.func)
kws = dict()
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core:
return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws)
def visit_Num(self, node):
return node.n
def visit_Attribute(self, node):
lhs = self.visit(node.value)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_NoneType(self, node):
return None
def visit(self, node):
if node is not None:
self.last_node = node
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
class Binary:
def __init__(self, module, kernel, num_warps, shared_mem):
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
class CompilationError(Exception):
def __init__(self, src, node, err):
self.message = '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
class Kernel:
type_names = {
int: 'I',
float: 'f',
bool: 'B',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
@staticmethod
def _to_triton_ir(context, obj):
type_map = {
'I': _triton.ir.type.get_int32,
'f': _triton.ir.type.get_fp32,
'B': _triton.ir.type.get_int1,
'f16': _triton.ir.type.get_fp16,
'f32': _triton.ir.type.get_fp32,
'f64': _triton.ir.type.get_fp64,
'i1': _triton.ir.type.get_int1,
'i8': _triton.ir.type.get_int8,
'i16': _triton.ir.type.get_int16,
'i32': _triton.ir.type.get_int32,
'i64': _triton.ir.type.get_int64,
}
# convert torch.Tensor to Triton IR pointers
if isinstance(obj, torch.Tensor):
name = Kernel.type_names[obj.dtype]
elt_ty = type_map[name](context)
return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly
name = Kernel.type_names[obj.__class__]
return type_map[name](context)
@staticmethod
def _types_key(*wargs, tensor_idxs):
# type inference
types_key = [None] * len(wargs)
for i, arg in enumerate(wargs):
prefix = 'P' if i in tensor_idxs else ''
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
types_key[i] = prefix + suffix
return tuple(types_key)
@staticmethod
def pow2_divisor(N):
if N % 16 == 0: return 16
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
try:
generator.visit(self.fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem)
def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
device = wargs[tensor_idxs[0]].device
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# determine if we need to re-compile
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items())
const_key = frozenset(constants.items())
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
cache = self.fn.cache
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
# enqueue cached function into stream
binary = cache[key]
cu_stream = torch.cuda.current_stream(device.index).cuda_stream
stream = _triton.driver.cu_stream(cu_stream, False)
grid = grid(meta) if hasattr(grid, '__call__') else grid
binary(stream, params, *grid)
class Launcher:
def __init__(self, kernel, grid):
self.kernel = kernel
self.grid = grid
def __call__(self, *wargs, **kwargs):
self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner:
def __init__(self, kernel, arg_names, configs, key):
if not configs:
self.configs = [Config(dict(), num_warps=4)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
self.kernel = kernel
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.meta.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
timings = {config: self._bench(*args, config=config, **meta) \
for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get)
config = self.cache[key]
else:
config = self.configs[0]
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
class JITFunction:
def __init__(self, fn):
self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args
self.cache = dict()
self.kernel_decorators = []
self.src = textwrap.dedent(inspect.getsource(fn))
self.kernel = None
# we do not parse in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, generator: CodeGenerator, **meta):
try:
return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node, e)
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
for decorator in reversed(self.kernel_decorators):
self.kernel = decorator(self.kernel)
return self.kernel
def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid)
class Config:
def __init__(self, meta, num_warps=4):
self.meta = meta
self.num_warps = num_warps
def autotune(configs, key):
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key)
fn.kernel_decorators.append(wrapper)
return fn
return decorator
def heuristics(values):
def decorator(fn):
def wrapper(kernel):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur(*args, **meta)
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)
return fn
return decorator
def jit(fn):
return JITFunction(fn)

499
python/triton/core.py Normal file
View File

@@ -0,0 +1,499 @@
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
import triton
from functools import wraps
def _patch(fn):
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x
def _from_ir(x):
if isinstance(x, ir.value):
if x.type.is_void():
return None
return block(x)
return x
def wrapper(*args, **kwargs):
builder = args[-1]
assert isinstance(builder, ir.builder)
args = [_to_ir(x, builder) for x in args]
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
ret = fn(*args, **kwargs)
if isinstance(ret, tuple):
return map(_from_ir, ret)
return _from_ir(ret)
return wrapper
for name in dir(frontend):
fn = getattr(frontend, name)
if callable(fn):
setattr(frontend, name, _patch(fn))
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if 'builder' not in kwargs or \
kwargs['builder'] is None:
raise ValueError("Builder argument must be provided outside of JIT functions")
return fn(*args, **kwargs)
if wrapper.__doc__:
wrapper.__doc__ += """\
:param builder: IR builder to generate code into, optional from within @triton.jit functions
:type builder: triton.ir.builder
"""
return wrapper
class dtype:
def __init__(self, init):
self.init = init
def handle(self, builder):
ctx = builder.context
return self.init(ctx)
class pointer_dtype:
def __init__(self, element_ty):
self.element_ty = element_ty
def handle(self, builder):
return ir.type.make_ptr(self.element_ty, 1)
int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
float16 = dtype(ir.type.get_fp16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
class block:
@staticmethod
def _init_dtype(ir_type):
# primitive type
if ir_type.is_int1(): return int1
if ir_type.is_int8(): return int8
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp16(): return float16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64
# pointer type
if ir_type.is_ptr():
element_ty = block._init_dtype(ir_type.element)
return pointer_dtype(element_ty)
raise ValueError(f"Unsupported type {ir_type}")
def __init__(self, handle):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if self.handle.type.is_block():
self.shape = self.handle.type.shape
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
@builtin
def __add__(self, other, builder=None):
return frontend.add(self, other, builder)
def __radd__(self, other, builder=None):
return self.__add__(other, builder=builder)
@builtin
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)
@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
def __rmul__(self, other, builder=None):
return self.__mul__(other, builder=builder)
@builtin
def __truediv__(self, other, builder=None):
return frontend.truediv(self, other, builder)
def __rtruediv__(self, other, builder=None):
return frontend.truediv(other, self, builder)
@builtin
def __floordiv__(self, other, builder=None):
return frontend.floordiv(self, other, builder)
@builtin
def __mod__(self, other, builder=None):
return frontend.mod(self, other, builder)
# unary operators
@builtin
def __neg__(self, builder=None):
return frontend.minus(self, builder)
@builtin
def __invert__(self, builder=None):
return frontend.invert(self, builder)
# bitwise operators
@builtin
def __and__(self, other, builder=None):
return frontend.and_(self, other, builder)
@builtin
def __or__(self, other, builder=None):
return frontend.or_(self, other, builder)
@builtin
def __xor__(self, other, builder=None):
return frontend.xor_(self, other, builder)
@builtin
def __lshift__(self, other, builder=None):
return frontend.shl(self, other, builder)
@builtin
def __rshift__(self, other, builder=None):
return frontend.lshr(self, other, builder)
# comparison operators
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
@builtin
def __ne__(self, other, builder=None):
return frontend.not_equal(self, other, builder)
@builtin
def __getitem__(self, slices, builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
if sl == None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])
curr += 1
ret = frontend.reshape(self, dst_shape, builder)
return ret
@builtin
def to(self, dtype, builder=None):
return frontend.cast(self, dtype.handle(builder), builder)
# -----------------------
# SPMD Programming Model
# -----------------------
@builtin
def program_id(axis, builder=None):
"""
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.program_id(axis, builder)
@builtin
def num_programs(axis, builder=None):
"""
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
"""
return frontend.num_programs(axis, builder)
# -----------------------
# Block Initialization
# -----------------------
@builtin
def arange(start, end, builder=None):
"""
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
"""
return frontend.arange(start, end, builder)
@builtin
def zeros(shape, dtype, builder=None):
"""
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16
:type dtype: triton.ir.dtype
"""
return frontend.zeros(shape, dtype, builder)
# -----------------------
# Shape Manipulation
# -----------------------
@builtin
def broadcast(input, other, builder=None):
"""
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
"""
return frontend.broadcast(input, other, builder)
@builtin
def broadcast_to(input, shape, builder=None):
"""
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
"""
return frontend.broadcast_to(input, shape, builder)
@builtin
def reshape(input, shape, builder=None):
"""
Reshapes a block to a new shape.
"""
return frontend.reshape(input, shape, builder)
# -----------------------
# Linear Algebra
# -----------------------
@builtin
def dot(input, other, builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
"""
return frontend.dot(input, other, builder)
# -----------------------
# Memory Operations
# -----------------------
@builtin
def load(pointer, mask=None, other=None, builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
"""
return frontend.load(pointer, mask, other, builder)
@builtin
def store(pointer, value, mask=None, builder=None):
"""
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
"""
return frontend.store(pointer, value, mask, builder)
@builtin
def atomic_cas(ptr, cmp, val, builder=None):
return frontend.atomic_cas(ptr, cmp, val, builder)
@builtin
def atomic_xchg(ptr, val, builder=None):
return frontend.atomic_xchg(ptr, val, builder)
# -----------------------
# Conditioning
# -----------------------
@builtin
def where(condition, x, y, builder=None):
"""
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
The shape of `x` and `y` are both broadcast to the shape of `condition`.
`x` and `y` must have the data type.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
"""
return frontend.where(condition, x, y, builder)
# -----------------------
# Math
# -----------------------
@builtin
def exp(x, builder=None):
return frontend.exp(x, builder)
@builtin
def log(x, builder=None):
return frontend.log(x, builder)
# -----------------------
# Reductions
# -----------------------
@builtin
def max(input, axis, builder=None):
return frontend.max(input, axis, builder)
@builtin
def min(input, axis, builder=None):
return frontend.min(input, axis, builder)
@builtin
def sum(input, axis, builder=None):
return frontend.sum(input, axis, builder)
# -----------------------
# Internal for debugging
# -----------------------
@builtin
def debug_barrier(builder=None):
return frontend.debug_barrier(builder)
@builtin
def multiple_of(x, value, builder=None):
return frontend.multiple_of(x, value, builder)
# -----------------------
# Standard library
# -----------------------
@triton.jit
def minimum(x, y):
return triton.where(x < y, x, y)
@triton.jit
def maximum(x, y):
return triton.where(x > y, x, y)
@triton.jit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@triton.jit
def ravel(x):
return triton.reshape(x, [x.type.numel])
@triton.jit
def softmax(x):
z = x - triton.max(x, 0)
num = triton.exp(z)
den = triton.sum(num, 0)
return num / den
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -1,119 +0,0 @@
import os
import struct
from typing import Optional, Dict, List, Callable
import torch
import triton._C.libtriton.triton as _triton
codes = {
_triton.runtime.arg_type.int1: 'B',
_triton.runtime.arg_type.int8: 'B',
_triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q',
_triton.runtime.arg_type.half: 'H',
_triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd',
_triton.runtime.arg_type.buffer: 'P',
}
def th_to_triton(obj):
""" Convert a `torch.dtype` to a Triton-C type string. """
tys = {
torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double',
}
if isinstance(obj, torch.dtype):
return tys[obj]
return str(obj)
def cdiv(a: int, b: int) -> int:
""" Ceil division (a + b - 1) // b"""
return (a + b - 1) // b
def read(path: str, kernel_names: Optional[List] = None) -> str:
""" Extracts the source code for `kernel_names` from the given `path` file."""
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f:
source = f.read()
source = _triton.tools.extract_kernels(source, kernel_names)
return source
config = _triton.runtime.config
class kernel:
"""
A class used to represent a Triton kernel.
"""
def __init__(
self,
src: str,
device: torch.device,
defines: Optional[Dict] = None,
num_warps: int = 4,
autotune_configs: Optional[List] = None,
autotune_key: Optional[List] = None
):
"""
:param src: The source code of the kernel.
:param device: The device to compile the kernel for.
:param defines: A dictionary of preprocessor #define for the compiler.
:param num_warps: Optimization flag for the compiler's internal auto-parallelization engine.
:param autotune_configs: A list of triton.config objects for the autotuner to try.
:param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run.
"""
if defines is None:
defines = {}
if autotune_configs is None:
autotune_configs = []
if autotune_key is None:
autotune_key = []
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device_id = torch.cuda.current_device() if device.index is None else device.index
self.device = _triton.driver.cu_device(self.device_id, False)
cu_stream = torch.cuda.current_stream(self.device_id).cuda_stream
self.stream = _triton.driver.cu_stream(cu_stream, False)
if device.type == 'cpu':
self.device_id = -1
self.device = _triton.driver.host_device()
self.device = _triton.driver.host_stream()
torch.cuda.set_device(self.device_id)
# function
self.opt = _triton.runtime.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# autotune_configs = [({}, 4)]
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key)
self.tys = ''.join([codes[x] for x in self.fn.signature()])
def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]):
"""
Runs the kernel on the given arguments and launch grid.
:param args: The arguments to the kernel in the orders that they appear in the Triton-C source.
:param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers.
:return: None
"""
# make sure that the executing thread is on the right device
torch.cuda.set_device(self.device_id)
# pack parameters into a byte buffer
params = struct.pack(self.tys, *args)
kernel = self.fn.autotune(params, grid, self.stream)
# run kernel
grid = grid(kernel.opt)
kernel(params, self.stream, grid)

View File

@@ -1,4 +1,4 @@
from .conv import _conv, conv
#from .conv import _conv, conv
from .matmul import _matmul, matmul
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse

View File

@@ -1,199 +0,0 @@
__global__ void NAME(TYPE *A __readonly __noalias,
TYPE *B __readonly __noalias,
TYPE *C __noalias,
int lda,
int ldb,
int ldc,
long stride_za,
long stride_zb,
long stride_zc,
long stride_ha,
long stride_hb,
long stride_hc,
int DS0, int DS1,
int SDD_K,
int SDD_off_width,
int *lut, int *locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
int offlutn[TN] = blockidn * 4;
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for (int k = AS1; k > 0; k -= step) {
acc += a @b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *? (checka)pa;
b = *? (checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
int rr_offlutn[TN] = rr_blockidn * 4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
// write-back directly
if (lockid == 0) {
*? (checkc)pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -4,7 +4,183 @@ import torch
import os
import math
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
@triton.jit
def _kernel(
A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc,
stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta
):
TM = meta['TM']
TN = meta['TN']
TK = meta['TK']
TZ = meta['TZ']
BLOCK = meta['BLOCK']
#------------#
#- Prologue -#
#------------#
pid0 = triton.program_id(0)
pid1 = triton.program_id(1)
pidz = triton.program_id(2)
if meta['SDD']:
pid1 = pid1 + SDD_off_width
blockidm = triton.arange(0, TM) // BLOCK
blockidn = triton.arange(0, TN) // BLOCK
offlutm = blockidm * (TN // BLOCK) * 4
offlutn = blockidn * 4
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
z = triton.load(header + 0)
i = triton.load(header + 1 + offlutm)
j = triton.load(header + 2 + offlutn)
AS1 = SDD_K // TZ
lockid = triton.where(TZ > 1, 1, 0)
offka = pid0 * AS1
offkb = pid0 * AS1
offmc = 0
offnc = 0
offpa = 0
offpb = 0
maxid = TZ
offhc = 0
offha = z
offhb = z
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK)
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK)
else:
header = lut + pid0 * 6
offset = triton.load(header + 0)
AS1 = triton.load(header + 1)
column = triton.load(header + 2)
depth = triton.load(header + 3)
lockid = triton.load(header + 4)
maxid = triton.load(header + 5)
pinc = lut + offset
offhc = depth
if meta['DSD']:
# output offset
offnc = pid1 * TN
offmc = column * TM
offpc = 0
# dense input offset
offnb = pid1 * TN
offkb = triton.load(pinc)
offkb = triton.multiple_of(offkb, 8) # compiler hint
offpb = 0
# sparse input offset
offma = 0
offka = 0
offpa = triton.load(pinc + 1)
offpa = triton.multiple_of(offpa, 8) # compiler hint
offpa = offpa * BLOCK * BLOCK
offha = 0
offhb = depth
else:
# output offset
offmc = pid1 * TM
offnc = column * TN
offpc = 0
# dense input offset
offma = pid1 * TM
offka = triton.load(pinc)
offka = triton.multiple_of(offka, 8) # compiler hint
offpa = 0
# sparse input offset
offnb = 0
offkb = 0
offpb = triton.load(pinc + 1)
offpb = triton.multiple_of(offpb, 8) # compiler hint
offpb = offpb * BLOCK * BLOCK
offha = depth
offhb = 0
ram = offma + triton.arange(0, TM)
rbn = offnb + triton.arange(0, TN)
# initialize a, b pointers
rka = offka + triton.arange(0, TK)
rkb = offkb + triton.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']:
checkam = ram[:, None] < DS0
else:
checkam = AS1 > 0
if meta['DSD']:
checkbn = rbn[None, :] < DS0
else:
checkbn = AS1 > 0
a = triton.load(pa, mask=checkam, other=0.)
b = triton.load(pb, mask=checkbn, other=0.)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = triton.zeros((TM, TN), dtype=triton.float32)
for k in range(AS1, 0, -TK):
acc += triton.dot(a, b)
if meta['SDD']:
inc_a = TK * stride_ka
inc_b = TK * stride_kb
else:
pinc += 2
if meta['DSD']:
inc_b = triton.load(pinc)
inc_a = triton.load(pinc + 1)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = triton.load(pinc)
inc_b = triton.load(pinc + 1)
inc_a = triton.multiple_of(inc_a, 8)
inc_b = triton.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
pa += inc_a
pb += inc_b
# pre-fetch
checkak = k > TK
checkbk = k > TK
checka = checkam & checkak
checkb = checkbn & checkbk
a = triton.load(pa, mask=checka)
b = triton.load(pb, mask=checkb)
c = acc.to(C.dtype.element_ty)
if meta['SDD']:
checkc = True
rr_blockidm = triton.arange(0, TM) // BLOCK
rr_blockidn = triton.arange(0, TN) // BLOCK
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
rr_offlutn = rr_blockidn * 4
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
bkid = triton.load(header + off_bkid)
offpc = bkid * BLOCK * BLOCK
rcm = triton.arange(0, TM) % BLOCK
rcn = triton.arange(0, TN) % BLOCK
else:
rcm = offmc + triton.arange(0, TM)
rcn = offnc + triton.arange(0, TN)
if meta['DSD']:
checkc = rcn[None, :] < DS0
if meta['DDS']:
checkc = rcm[:, None] < DS0
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
# write-back directly
if lockid == 0:
triton.store(pc, c, mask=checkc)
# accumulate partial results using spin-locks
else:
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks
while triton.atomic_cas(plock, 0, 1) == 1:
pass
count = triton.load(pcount)
if count == 0:
triton.store(pc, c, mask=checkc)
else:
d = triton.load(pc, mask=checkc)
triton.store(pc, d + c, mask=checkc)
triton.atomic_xchg(pcount, (count + 1) % maxid)
triton.atomic_xchg(plock, 0)
##############
@@ -118,31 +294,11 @@ class _matmul(torch.autograd.Function):
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {
'TM': block * pack,
'TN': block * pack,
'TMN': block * block * pack * pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc',
'STRIDE_CN': '1',
'SDD': True,
'TZ': 1,
'NAME': 'sdd_kernel'
}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
'SDD': True, 'DSD': False, 'DDS': False}
# create output
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
# maximum grid size is 65535
@@ -150,27 +306,32 @@ class _matmul(torch.autograd.Function):
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
b.stride(2),
block,
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(0),
c.stride(2),
c.stride(3),
AS2,
AS2,
AS3,
off_width,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]
num_warps=4,
**meta
)
# save for backward pass
return c
@@ -282,25 +443,8 @@ class _matmul(torch.autograd.Function):
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {
'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
'SDD': False, 'DSD': False, 'DDS': True}
# output
CS0 = AS0
CS1 = AS1
@@ -308,27 +452,32 @@ class _matmul(torch.autograd.Function):
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
block,
c.stride(2),
grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(3 if trans_c else 2),
c.stride(2 if trans_c else 3),
AS2,
BS2,
0,
0,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_locks,
grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]
num_warps=4,
**meta
)
return c
@@ -344,25 +493,8 @@ class _matmul(torch.autograd.Function):
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {
'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
'SDD': False, 'DSD': True, 'DDS': False}
# output
CS0 = BS0
CS1 = BS1
@@ -370,27 +502,32 @@ class _matmul(torch.autograd.Function):
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
block,
b.stride(2),
c.stride(2),
grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
BS3,
AS1,
0,
0,
lut.data_ptr(),
locks.data_ptr(),
lut,
locks,
num_locks,
grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]
num_warps=4,
**meta
)
return c

View File

@@ -1,135 +0,0 @@
__global__ void forward(TYPE *X __readonly __noalias,
float scale,
int *LUT __readonly __noalias,
TYPE *RPE __readonly __noalias,
TYPE *KP_M __readonly __noalias,
TYPE *ATTN_M __readonly __noalias,
int sizemax,
long stride_zx,
long stride_zrpe,
int stride_hrpe,
int stride_srpe,
int stride_zkpm,
int stride_zattnm) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int *header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// block id and column id
long blockid[TN] = *(LUT + offset + rbmn * 4 + 0);
long columnid[TN] = *(LUT + offset + rbmn * 4 + 1);
long rowid[TN] = *(LUT + offset + rbmn * 4 + 2);
long headid[TN] = *(LUT + offset + rbmn * 4 + 3);
// pointers to X
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
#ifdef APPLY_RPE
// pointers to relative position embedding
TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn;
#endif
#ifdef APPLY_KP_MASK
// pointers to key padding mask
TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn;
#endif
#ifdef APPLY_ATTN_MASK
// pointers to attention mask
TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn;
#endif
// load input
TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE
// load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0;
#endif
#ifdef APPLY_KP_MASK
// load key-padding mask
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif
#ifdef APPLY_ATTN_MASK
// load attention mask
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif
// compute softmax in float
#ifdef APPLY_RPE
float Frpe[TN] = rpe;
#endif
#ifdef APPLY_KP_MASK
float Fkp_m[TN] = kp_m;
#endif
#ifdef APPLY_ATTN_MASK
float Fattn_m[TN] = attn_m;
#endif
#ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0;
#endif
#ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0;
#endif
float Fx[TN] = x;
#ifdef APPLY_SCALE
Fx = Fx * scale; // apply scale
#endif
#ifdef APPLY_RPE
Fx = Fx + Frpe; // apply relative position embedding
#endif
#ifdef APPLY_KP_MASK
Fx = Fx + Fkp_m; // apply key padding mask
#endif
#ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask
#endif
float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+];
// write-back in half/float
TYPE y[TN] = Fy;
TYPE ysum = Fysum;
*? (check)px = y / ysum;
}
__global__ void backward(TYPE *X __readonly __noalias,
float scale,
TYPE *DX __readonly __noalias,
int *LUT,
int sizemax,
long stride_zx,
long stride_zdx) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int *header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
// bounds checking on lut
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// initialize pointers to block-sparse input
long blockid[TN] = *(LUT + offset + rbmn * 4);
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
// compute fused softmax backward
TYPE x[TN] = check ? *px : 0;
TYPE dx[TN] = check ? *pdx : 0;
float Fdx[TN] = dx;
float Fx[TN] = x;
float Fxdx[TN] = Fdx * Fx;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
}

View File

@@ -2,15 +2,8 @@ import triton
import torch
import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict()
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
bwd_kernels = dict()
class _softmax(torch.autograd.Function):
@staticmethod
def next_power_of_2(n):
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
@@ -20,6 +13,107 @@ class _softmax(torch.autograd.Function):
n += 1
return n
def num_warps(n):
if n < 512:
return 4
if n < 2048:
return 8
return 16
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
**meta
):
TN = meta['TN']
BLOCK = meta['BLOCK']
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
# extract information from LUT
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
# block id and column id
blockid = triton.load(LUT + offset + rbmn * 4 + 0)
columnid = triton.load(LUT + offset + rbmn * 4 + 1)
rowid = triton.load(LUT + offset + rbmn * 4 + 2)
headid = triton.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = triton.load(px, mask=check, other=-float('inf'))
x = x.to(triton.float32)
# apply scale
if meta['APPLY_SCALE']:
x = x * scale
# apply RPE
if meta['APPLY_RPE']:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = triton.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if meta['APPLY_KP_MASK']:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = triton.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']:
kp_m = triton.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if meta['APPLY_ATTN_MASK']:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = triton.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']:
attn_m = triton.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# computation
x = triton.softmax(x)
triton.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = triton.program_id(0)
pidz = triton.program_id(1)
TN = meta['TN']
BLOCK = meta['BLOCK']
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = triton.arange(0, TN) % BLOCK
rbn = triton.arange(0, TN) // BLOCK
# extract information from look-up table
header = LUT + rbm * 2
size = triton.load(header + 0)
offset = triton.load(header + 1)
# bounds checking on lut
check = rbn < size
rbmn = triton.where(check, rbn, size - 1)
# initialize pointers to block-sparse input
blockid = triton.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward
x = triton.load(X, mask=check, other=0)
dx = triton.load(DX, mask=check, other=0)
x = x.to(triton.float32)
dx = dx.to(triton.float32)
y = x * (dx - triton.sum(x * dx, 0)) * scale
triton.store(DX, y, mask=check)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -43,40 +137,9 @@ class _softmax(torch.autograd.Function):
return lut, int(sizes.max())
@staticmethod
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode):
if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
TN = _softmax.next_power_of_2(max_k)
# just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode)
if key not in cache:
defines = {
'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
{torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
}
if apply_scale:
defines['APPLY_SCALE'] = True
if apply_rpe:
defines['APPLY_RPE'] = True
if apply_kp_mask:
defines['APPLY_KP_MASK'] = True
if kp_mask_mode == 'mul':
defines['KP_MASK_MUL'] = True
if apply_attn_mask:
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel
return cache[key]
@staticmethod
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
maxlut, bench, time):
def forward(
ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time
):
apply_scale = False if scale == 1.0 else True
# handle None rpe
@@ -107,26 +170,20 @@ class _softmax(torch.autograd.Function):
stride_zattnm = attn_mask.stride(0)
# run kernel
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
M = x.shape[0]
meta = {
'BLOCK': block,
'APPLY_SCALE': apply_scale,
'APPLY_RPE': apply_rpe,
'APPLY_KP_MASK': apply_kp_mask,
'APPLY_ATTN_MASK': apply_attn_mask,
'KP_MASK_MUL': kp_mask_mode == 'mul',
'ATTN_MASK_MUL': attn_mask_mode == 'mul',
}
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
# run kernel
kernel(x.data_ptr(),
scale,
lut.data_ptr(),
rpe.data_ptr(),
key_padding_mask.data_ptr(),
attn_mask.data_ptr(),
maxlut,
x.stride(0),
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
grid=grid)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
@@ -147,14 +204,12 @@ class _softmax(torch.autograd.Function):
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class softmax:
apply_softmax = _softmax.apply
@@ -172,14 +227,9 @@ class softmax:
self.bench = bench
self.lut_cache = dict()
def __call__(self,
x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
def __call__(
self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add'
):
time_y = [None]
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -188,6 +238,8 @@ class softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
self.spdims, self.block, lut, maxlut, self.bench, time_y)
x = softmax.apply_softmax(
x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut,
maxlut, self.bench, time_y
)
return x

View File

@@ -1,123 +0,0 @@
__global__ void conv(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
// equivalent matmul
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z, int lda_ci, int lda_h, int lda_w,
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
int ldc_z, int ldc_co, int ldc_p, int ldc_q)
{
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs[TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr[TK] = rcir % RR;
int rci[TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci +
rh * lda_h + rw * 1;
TYPE *pa[TM, TK] = A + offa;
int *padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r +
rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK)
{
acc += a @b;
// increment A
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta [newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr [newaxis, :];
rw = rw_0[:, newaxis] + rs [newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = *? (checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co +
rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1;
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
#if (TZ == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -1,81 +0,0 @@
import torch
import triton
import os
class _conv(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride):
# create kernel if necessary
dtype = a.dtype
device = a.device
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2 * pad[0] - R) // stride[0] + 1
Q = (W + 2 * pad[1] - S) // stride[1] + 1
# compile kernel
if (dtype, device) not in _conv.kernel:
TK = 16
defines = {
'TYPE': dtype,
'TM': 64,
'TN': 64,
'TK': TK,
'TZ': 1,
'HH': H,
'WW': W,
'PP': P,
'QQ': Q,
'SS': S,
'RR': R,
}
idx = torch.arange(CI * R * S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
1.,
Z * P * Q,
CO,
CI * R * S,
pad[0],
pad[1],
stride[0],
stride[1],
delta.data_ptr(),
a.stride(0),
a.stride(1),
a.stride(2),
a.stride(3),
b.stride(0),
b.stride(1),
b.stride(2),
b.stride(3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply

View File

@@ -1,35 +0,0 @@
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = logit + offset;
TYPE *pmodified[TILE] = modified_logit + offset;
long local_ind = *(indices + row);
TYPE F16[TILE] = check ? *px : -INFINITY;
float shifted_logit[TILE] = F16 - F16[max];
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
*? (check)pmodified = neg_logprob;
__debug_barrier();
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
}
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
int row = get_program_id(0);
// pointer arithmetic
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = neg_logprobs + offset;
long local_ind = *(indices + row);
TYPE local_dn = *(dneg_logprobs + row);
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
// and we have -log(p[k]) stored, so this is easy
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0;
// selected_logit_idx is selected logit index for our token
bool find_one[TILE] = ((0 ... TILE) == local_ind);
intermediate = intermediate - ((TYPE[TILE])find_one);
// multiply by dneg_logprobs
*? (check)px = intermediate * local_dn;
}

View File

@@ -2,6 +2,7 @@ import os
import triton
import torch
def next_power_of_2(n):
n -= 1
n |= n >> 1
@@ -12,34 +13,61 @@ def next_power_of_2(n):
n += 1
return n
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols)
div = largest_pow2_divisor(n_cols)
key = (dtype, rounded, div)
if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name])
infinities = {
torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY",
}
defines = {"TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div}
cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4)
return cache[key]
def num_warps(N):
if N < 2048:
return 4
elif N < 8192:
return 8
return 16
# forward kernel
fwd_kernels = dict()
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
# backward kernel
bwd_kernels = dict()
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
# pointers to logit and probs
LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(triton.float32)
logits = logits - triton.max(logits, 0)
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits
triton.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon.
triton.debug_barrier()
# write-back loss
probs = triton.load(READ_PROBS)
triton.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta):
BLOCK = meta['BLOCK']
row = triton.program_id(0)
cols = triton.arange(0, BLOCK)
idx = triton.load(IDX + row)
# pointers to probs
PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy
probs = -triton.load(PROBS, mask=cols < N, other=float('inf'))
probs = triton.exp(probs.to(triton.float32))
delta = cols == idx
# write result in-place in PROBS
dout = triton.load(DPROBS + row)
din = (probs - delta) * dout
triton.store(PROBS, din.to(triton.float16), mask=cols < N)
class _cross_entropy(torch.autograd.Function):
@classmethod
@@ -49,16 +77,11 @@ class _cross_entropy(torch.autograd.Function):
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
kernel = make_fwd_kernel(device, dtype, n_cols)
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
kernel(logits.data_ptr(),
neg_logprobs.data_ptr(),
indices.data_ptr(),
result.data_ptr(),
n_cols,
grid=lambda opt: (logits.numel() // n_cols, ))
grid = lambda opt: (logits.numel() // n_cols, )
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@@ -75,14 +98,11 @@ class _cross_entropy(torch.autograd.Function):
# make kernel
device, dtype = neg_logprobs.device, neg_logprobs.dtype
n_cols = neg_logprobs.shape[-1]
kernel = make_bwd_kernel(device, dtype, n_cols)
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
kernel(neg_logprobs.data_ptr(),
indices.data_ptr(),
dneg_logprobs.data_ptr(),
n_cols,
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
return neg_logprobs, None
cross_entropy = _cross_entropy.apply

View File

@@ -1,94 +0,0 @@
#define STM 1
#define STN 1
__global__ void matmul(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
int M, int N, int K,
int lda, int ldb, int ldc,
int *locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM * gridn;
int stm = pid / width;
int RSTM = min(gridm - stm * STM, STM);
int stn = (pid % width) / (RSTM * STN);
int RSTN = min(gridn - stn * STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm * STM + lanem;
int pidn = stn * STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / SPLITK;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
TYPE *pa[TM, TK] = A + offa;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk [newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK) {
#if (IS_TK_DIV_K == 1)
bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk [newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @b;
#if (IS_TK_DIV_K == 1)
a = *? (checka)pa;
b = *? (checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
#endif
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (SPLITK == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -1,108 +1,117 @@
import torch
import triton
import os
@triton.heuristics({
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),
],
key=['M', 'N', 'K']
)
@triton.jit
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K']
# matrix multiplication
pid = triton.program_id(0)
pid_z = triton.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
# pointers
K = K // SPLIT_K
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
for k in range(K, 0, -BLOCK_K):
if META['EVEN_K']:
a = triton.load(A)
b = triton.load(B)
else:
a = triton.load(A, mask=rk[None, :] < k, other=0.)
b = triton.load(B, mask=rk[:, None] < k, other=0.)
acc += triton.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(triton.float16)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
triton.store(C, acc, mask=mask)
else:
LOCKS = LOCKS + triton.program_id(0)
COUNT = LOCKS + triton.num_programs(0)
while triton.atomic_cas(LOCKS, 0, 1) == 1:
pass
count = triton.load(COUNT)
if count == 0:
triton.store(C, acc, mask=mask)
else:
curr = triton.load(C, mask=mask, other=0.)
triton.store(C, acc + curr, mask=mask)
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
triton.atomic_xchg(LOCKS, 0)
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
kernel = _kernel
_locks = dict()
_kernels = dict()
@staticmethod
def _call(a, b):
dtype = a.dtype
device = a.device
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
lda = a.stride(0) if is_a_row else a.stride(1)
ldb = b.stride(0) if is_b_row else b.stride(1)
ldc = c.stride(0)
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0
key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels:
defines = {
"TYPE": dtype,
"STRIDE_AM": "lda" if is_a_row else "1",
"STRIDE_AK": "1" if is_a_row else "lda",
"STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
}
_matmul._kernels[key] = triton.kernel(
_matmul.src,
device,
defines=defines,
autotune_configs=_matmul._CONFIGS,
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocate locks for split-k
if a.device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.0
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK]
kernel(*args, grid=grid)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks)
# done
return c
@staticmethod
def forward(ctx, a, b):
c = _matmul._call(a, b)
return c
return _matmul._call(a, b)
matmul = _matmul.apply

View File

@@ -47,13 +47,37 @@ def mask_tensor(x, mask, block, value=0):
def allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
tol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
return err <= tol
def assert_allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
assert allclose(x, y, tol)
def random(shape, dtype, device):
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device)
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.randn(shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):

View File

@@ -1,139 +1,71 @@
import torch
import triton
"""
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic syntax of the Triton programming language
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
- The basic programming model used by Triton
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
- The best practices for validating and benchmarking custom ops against native reference implementations
"""
# %%
# Compute Kernel
# --------------------------
#
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
# programming model for more details).
#
# .. code-block:: C
#
# __global__ void add(float* z, float* x, float* y, int N){
# // The `get_program_id(i)` returns the i-th coordinate
# // of the program in the overaching SPMD context
# // (a.k.a launch grid). This is what allows us to process
# // different chunks of data in parallel.
# // For those similar with CUDA, `get_program_id({0,1,2})`
# // is similar to blockIdx.{x,y,z}
# int pid = get_program_id(0);
# // In Triton, arrays are first-class citizen. In other words,
# // they are primitives data-types and are -- contrary to C and
# // CUDA -- not implemented as pointers to contiguous chunks of
# // memory.
# // In the few lines below, we create an array of `BLOCK` pointers
# // whose memory values are, e.g.:
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
# float* pz [BLOCK] = z + offset;
# float* px [BLOCK] = x + offset;
# float* py [BLOCK] = y + offset;
# // Simple element-wise control-flow for load/store operations can
# // be achieved using the the ternary operator `cond ? val_true : val_false`
# // or the conditional dereferencing operator `*?(cond)ptr
# // Here, we make sure that we do not access memory out-of-bounds when we
# // write-back `z`
# bool check[BLOCK] = offset < N;
# *?(check)pz = *?(check)px + *?(check)py;
# }
#
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
@triton.jit
def _add(
X, # *Pointer* to first input vector
Y, # *Pointer* to second input vector
Z, # *Pointer* to output vector
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
# %%
# Torch Bindings
# --------------------------
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
#
# - :code:`source: string`: the source-code of the kernel you want to create
# - :code:`device: torch.device`: the device you want to compile this code for
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
import torch
import triton
# source-code for Triton compute kernel
# here we just copy-paste the above code without the extensive comments.
# you may prefer to store it in a .c file and load it from there instead.
_src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float* pz[BLOCK] = z + offset;
float* px[BLOCK] = x + offset;
float* py[BLOCK] = y + offset;
// bounds checking
bool check[BLOCK] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# We can also declara a helper function that handles allocating the output vector
# and enqueueing the kernel.
# This function returns a callable `triton.kernel` object created from the above source code.
# For portability, we maintain a cache of kernels for different `torch.device`
# We compile the kernel with -DBLOCK=1024
def make_add_kernel(device):
cache = make_add_kernel.cache
if device not in cache:
defines = {'BLOCK': 1024}
cache[device] = triton.kernel(_src, device=device, defines=defines)
return cache[device]
make_add_kernel.cache = dict()
# This is a standard torch custom autograd Function;
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
class _add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
# constraints of the op
assert x.dtype == torch.float32
# *allocate output*
def add(x, y):
z = torch.empty_like(x)
# *create launch grid*:
# this is a function which takes compilation parameters `opt`
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
# triton.cdiv is a shortcut for ceil division:
# triton.cdiv(a, b) = (a + b - 1) // b
N = z.shape[0]
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
# *launch kernel*:
# pointer to the data of torch tensors can be retrieved with
# the `.data_ptr()` method
kernel = make_add_kernel(z.device)
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
# NOTE:
# - torch.tensor objects are implicitly converted to pointers to their first element.
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously.
return z
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
add = _add.apply
# %%
# We can now use the above function to compute the sum of two `torch.tensor` objects:
# %%
# Unit Test
# -----------
#
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
torch.manual_seed(0)
x = torch.rand(98432, device='cuda')
y = torch.rand(98432, device='cuda')
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
za = x + y
zb = add(x, y)
print(za)

View File

@@ -4,8 +4,7 @@ Fused Softmax
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- The syntax and usage of reduction operators in Triton.
- The automatic vectorization capabilities of the Triton compiler.
- The reduction operators in Triton.
"""
# %%
@@ -36,79 +35,45 @@ def naive_softmax(x):
# %%
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
# %%
# Compute Kernel
# ----------------
# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
#
# .. code-block:: C
#
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
# // row index
# int m = get_program_id(0);
# // column indices
# int n [BLOCK] = 0 ... BLOCK;
# // the memory address of all the elements
# // that we want to load can be computed as follows
# float* px [BLOCK] = X + m*stride_xm + n;
# // because BLOCK has to be a power of two
# // (per Triton-C specs), it is important
# // to guard each memory operation with predicates
# // or we will read out of bounds
# bool check[BLOCK] = n < N;
# float x [BLOCK] = check ? *px : -F32_INFINITY;
# // syntax for reduction in Triton is:
# // x[:, :, OPERATOR, :, :]
# // ^
# // index
# // where operator is in {min, max, +}
# // for 1D vectors, this is just x[OPERATOR].
# float z [BLOCK] = x - x[max];
# // Note that exponentials in Triton are fast
# // but approximate (i.e., think __expf in CUDA)
# float num [BLOCK] = exp(z);
# float denom = num[+];
# // The result of the reduction is now stored in y
# float y [BLOCK] = num / denom;
# // We write it back
# float* py [BLOCK] = Y + m*stride_ym + n;
# *?(check)py = y;
# }
# %%
# Torch Bindings
# ---------------
# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
# This means that different values of BLOCK will result in different kernels
import torch
import triton
# Source code for the Triton kernel
_src = """
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
int m = get_program_id(0);
int n [BLOCK] = 0 ... BLOCK;
float* px [BLOCK] = X + m*stride_xm + n;
bool check[BLOCK] = n < N;
float x [BLOCK] = check ? *px : -F32_INFINITY;
float z [BLOCK] = x - x[max];
float num [BLOCK] = exp(z);
float denom = num[+];
float y [BLOCK] = num / denom;
float* py [BLOCK] = Y + m*stride_ym + n;
*?(check)py = y;
}
"""
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
# helper function to get the smaller power-of-two larger than a given number
def next_power_of_2(n):
n -= 1
n |= n >> 1
@@ -120,11 +85,9 @@ def next_power_of_2(n):
return n
# kernel caching mechanism
def make_kernel(N, device):
cache = make_kernel.cache
# Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix
def softmax(x):
M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
@@ -134,37 +97,13 @@ def make_kernel(N, device):
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache:
defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key]
make_kernel.cache = dict()
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# constraints of the op
assert x.dtype == torch.float32
# Allocate output
y = torch.empty_like(x)
# The launch grid is simple: we have one kernel instance per row of the input matrix
M, N = y.shape
grid = lambda opt: (M, )
# Launch kernel
kernel = make_kernel(N, y.device)
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
return y
softmax = _softmax.apply
# %%
# We can use the above softmax function to compute the row-wise softmax of a given matrix.
# %%
# Unit Test
# ----------

View File

@@ -1,10 +1,10 @@
"""
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
You will specifically learn about:
- The block-level matrix multiplication operator `@`
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
@@ -15,7 +15,7 @@ You will specifically learn about:
# -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
@@ -23,322 +23,212 @@ You will specifically learn about:
# .. code-block:: python
#
# # do in parallel
# for m in range(0, M, MB):
# for m in range(0, M, BLOCK_M):
# # do in parallel
# for n in range(0, N, NB):
# acc = zeros((MB, NB), dtype=float32)
# for k in range(0, K, KB):
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
# C[m : m+MB, n : n+NB] = acc;
# for n in range(0, N, BLOCK_N):
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
# for k in range(0, K, BLOCK_K):
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
# acc += dot(a, b)
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
#
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
# %%
# Compute Kernel
# ----------------
#
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
# The above algorithm is actually fairly straightforward to implement in Triton.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
#
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
#
# .. code-block:: python
#
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
#
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
#
# .. code-block:: C
# .. code-block:: python
# :force:
#
# int rm[MB] = program_id_m * MB + 0 ... MB;
# int rn[NB] = program_id_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# pid_m = triton.program_id(0)
# pid_n = triton.program_id(1)
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
# rk = triton.arange(0, BLOCK_K)
# // pointer for A operand
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
# // pointer for B operand
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
#
# These pointers can then be updated in the inner loop as:
#
# .. code-block:: C
# .. code-block:: python
#
# pa += KB * 1;
# pb += KB * ldb;
# pa += BLOCK_K * stride_a_1;
# pb += BLOCK_K * stride_b_0;
#
#
# L2 Cache Optimizations
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
# This means that a naive row-major ordering:
#
# .. code-block:: C
# .. code-block:: Python
#
# int program_id = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# int program_id_m = program_id / grid_n;
# int program_id_n = program_id % grid_n;
# pid = triton.program_id(0);
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
# pid_m = pid / grid_n;
# pid_n = pid % grid_n;
#
# is unlikely to result in optimal performance.
#
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
#
# .. code-block:: C
#
# int program_id = get_program_id(0);
# int width = GROUP_SIZE * grid_n;
# int group_id = pid / width;
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
# pid = triton.program_id(0);
# width = GROUP_M * grid_n;
# group_id = pid / width;
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
# pid_m = group_id * GROUP_M + (pid % group_size);
# pid_n = (pid % width) / (group_size);
#
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
# Final Result
# ~~~~~~~~~~~~~~
#
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
#
# .. code-block:: C
# :force:
#
# #define MAX_GROUP_SIZE 8
#
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
# int M, int N, int K,
# int stride_a_0, int stride_b_0, int stride_c_0) {
# // prologue
# int pid = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# // re-order program ID for better L2 performance
# int width = MAX_GROUP_SIZE * grid_n;
# int group_id = pid / width;
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
# // pointers to operands
# // note the parentheses here; they force the offset
# // computation to happen in typeof(stride_a_0) = int32 rather than
# // typeof(A) = int64
# int rm[MB] = pid_m * MB + 0 ... MB;
# int rn[NB] = pid_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# // reduction loop
# float acc[MB, NB] = 0;
# for (int k = K; k > 0; k -= KB) {
# acc += (*pa) @ (*pb);
# pa += KB * 1;
# pb += KB * stride_b_0;
# }
# // pointers to output
# // here we rematerialize `rm` and `rn` so that they are not live through
# // the above reduction loop. In the future, the compiler should be able to
# // do this automatically.
# rm = pid_m * MB + 0 ... MB;
# rn = pid_n * NB + 0 ... NB;
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
# }
#
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
#
# %%
# Torch Bindings
# ----------------
# Final Result
# -------------
#
# Auto-Tuning
# ~~~~~~~~~~~~~~
#
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
import torch
import triton
autotune_configs = [
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
]
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + triton.exp(-x))
ret_false = triton.exp(x) / (1 + triton.exp(x))
return triton.where(x >= 0, ret_true, ret_false)
@triton.jit
def swish(x):
return x * sigmoid(x)
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
pid = triton.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
for k in range(K, 0, -BLOCK_K):
a = triton.load(A)
b = triton.load(B)
acc += triton.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
# via metaparameters!
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
triton.store(C, acc, mask=mask)
# %%
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
autotune_key = ["M", "N", "K"]
# %%
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
src = """
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int lda, int ldb, int ldc) {
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * ldb;
}
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
"""
# We can also create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
def make_kernel(device, dtype):
key = (device, dtype)
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(
src,
device=device,
defines=defines,
autotune_configs=autotune_configs,
autotune_key=autotune_key,
)
return cache[key]
make_kernel.cache = dict()
# %%
# Autograd Function
# ~~~~~~~~~~~~~~~~~~
#
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
class _dot(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
M, Ka = a.shape
Kb, N = b.shape
assert Ka == Kb, "incompatible dimensions"
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
kernel = make_kernel(a.device, a.dtype)
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka, \
a.stride(0), b.stride(0), c.stride(0), \
grid=grid)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
_matmul[grid](
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
# return output
return c
dot = _dot.apply
# %%
# Unit Test
# -----------
#
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
c_0 = dot(a, b)
c_1 = torch.matmul(a, b)
#torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=swish)
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
print(c_0)
print(c_1)
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
print(triton.testing.allclose(c_0, c_1))
# %%
# Benchmark
# --------------
#
# Installing The CUTLASS Bindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
# To install CUTLASS, you need a recent version of cmake:
#
# .. code-block:: bash
#
# cd /path/to/cutlass/
# git clone https://github.com/NVIDIA/cutlass.git
# cd cutlass
# mkdir build
# cd build
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
# tar xzvf *.tar.gz
#
# You can then install CUTLASS as follows for V100
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
# make -j8 install
#
# Or as follows for A100:
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
# make -j8 install
#
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
# To re-install Triton with the updated CUTLASS bindings, run the following command:
#
# .. code-block:: bash
#
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
# pip uninstall -y triton
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
#
# Which we can test as follows:
import triton
c_2 = triton.testing.cutlass_matmul(a, b)
print(c_2)
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
# %%
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
@@ -347,29 +237,25 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
x_vals=[8192], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
y_vals=['cublas', 'triton'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton"], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
silu = torch.nn.SiLU()
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
if provider == 'cutlass':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True)
# %%
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
benchmark.run(print_data=True)