From a2c31ff4349a1674cae925452347385d4a18da0c Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 17 Mar 2022 20:40:55 +0800 Subject: [PATCH] Init commit --- CMakeLists.txt | 23 + include/CMakeLists.txt | 1 + include/triton/ir/CMakeLists.txt | 8 + include/triton/ir/Dialect.h | 18 + include/triton/ir/TritonDialect.td | 39 ++ include/triton/ir/TritonOps.td | 235 +++++++ include/triton/ir/Types.h | 46 ++ include/triton/ir/basic_block.h | 88 --- include/triton/ir/builder.h | 191 ------ include/triton/ir/constant.h | 113 ---- include/triton/ir/context.h | 29 - include/triton/ir/context_impl.h | 46 -- include/triton/ir/enums.h | 175 ------ include/triton/ir/function.h | 142 ----- include/triton/ir/instructions.h | 978 ----------------------------- include/triton/ir/metadata.h | 32 - include/triton/ir/module.h | 92 --- include/triton/ir/print.h | 22 - include/triton/ir/type.h | 239 ------- include/triton/ir/utils.h | 30 - include/triton/ir/value.h | 95 --- include/triton/ir/visitor.h | 170 ----- lib/ir/CMakeLists.txt | 12 + lib/ir/Dialect.cpp | 71 +++ lib/ir/Ops.cpp | 63 ++ lib/ir/Types.cpp | 55 ++ lib/ir/basic_block.cc | 41 -- lib/ir/builder.cc | 434 ------------- lib/ir/constant.cc | 118 ---- lib/ir/context.cc | 40 -- lib/ir/function.cc | 66 -- lib/ir/instructions.cc | 928 --------------------------- lib/ir/metadata.cc | 14 - lib/ir/module.cc | 22 - lib/ir/print.cc | 450 ------------- lib/ir/type.cc | 233 ------- lib/ir/utils.cc | 68 -- lib/ir/value.cc | 81 --- python/src/triton.cc | 7 +- python/triton/code_gen.py | 43 ++ 40 files changed, 615 insertions(+), 4943 deletions(-) create mode 100644 include/CMakeLists.txt create mode 100644 include/triton/ir/CMakeLists.txt create mode 100644 include/triton/ir/Dialect.h create mode 100644 include/triton/ir/TritonDialect.td create mode 100644 include/triton/ir/TritonOps.td create mode 100644 include/triton/ir/Types.h delete mode 100644 include/triton/ir/basic_block.h delete mode 100644 include/triton/ir/builder.h delete mode 100644 include/triton/ir/constant.h delete mode 100644 include/triton/ir/context.h delete mode 100644 include/triton/ir/context_impl.h delete mode 100644 include/triton/ir/enums.h delete mode 100644 include/triton/ir/function.h delete mode 100644 include/triton/ir/instructions.h delete mode 100644 include/triton/ir/metadata.h delete mode 100644 include/triton/ir/module.h delete mode 100644 include/triton/ir/print.h delete mode 100644 include/triton/ir/type.h delete mode 100644 include/triton/ir/utils.h delete mode 100644 include/triton/ir/value.h delete mode 100644 include/triton/ir/visitor.h create mode 100644 lib/ir/CMakeLists.txt create mode 100644 lib/ir/Dialect.cpp create mode 100644 lib/ir/Ops.cpp create mode 100644 lib/ir/Types.cpp delete mode 100644 lib/ir/basic_block.cc delete mode 100644 lib/ir/builder.cc delete mode 100644 lib/ir/constant.cc delete mode 100644 lib/ir/context.cc delete mode 100644 lib/ir/function.cc delete mode 100644 lib/ir/instructions.cc delete mode 100644 lib/ir/metadata.cc delete mode 100644 lib/ir/module.cc delete mode 100644 lib/ir/print.cc delete mode 100644 lib/ir/type.cc delete mode 100644 lib/ir/utils.cc delete mode 100644 lib/ir/value.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index c3aadf9c7..96d3ab446 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,6 +152,29 @@ else() add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) endif() +# MLIR +find_package(MLIR 14 REQUIRED CONFIG) +include(TableGen) # required by AddMLIR +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) # human-friendly error message + +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files + +add_subdirectory(lib) +# lib +add_library(triton) +# add_subdirectory(ir) +target_link_libraries(triton + PUBLIC + TRITONIR +# # optimizations +# MLIRPass +# MLIRTransforms +) + target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) if(WIN32) diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt new file mode 100644 index 000000000..dd1b708ce --- /dev/null +++ b/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ir) diff --git a/include/triton/ir/CMakeLists.txt b/include/triton/ir/CMakeLists.txt new file mode 100644 index 000000000..2036df67a --- /dev/null +++ b/include/triton/ir/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/ir/Dialect.h b/include/triton/ir/Dialect.h new file mode 100644 index 000000000..8d2d0a3d2 --- /dev/null +++ b/include/triton/ir/Dialect.h @@ -0,0 +1,18 @@ +#ifndef TRITON_IR_DIALECT_H_ +#define TRITON_IR_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" + +#include "triton/Dialect.h.inc" + +#include "triton/OpsEnums.h.inc" + +#define GET_OP_CLASSES +#include "triton/Ops.h.inc" + +#endif // TRITON_IR_DIALECT_H_ diff --git a/include/triton/ir/TritonDialect.td b/include/triton/ir/TritonDialect.td new file mode 100644 index 000000000..1d96b0967 --- /dev/null +++ b/include/triton/ir/TritonDialect.td @@ -0,0 +1,39 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "triton"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arithmetic: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Tensor: + * reshape (?) + * ControlFlow: + * bf, cond_bf + * Func: + * call, return + }]; + + let dependentDialects = [ + "arith::ArithmeticDialect", + "tensor::TensorDialect", + "cf::ControlFlowDialect", + "func::FuncDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; +} + +#endif // TRITON_DIALECT diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td new file mode 100644 index 000000000..dca22ff4f --- /dev/null +++ b/include/triton/ir/TritonOps.td @@ -0,0 +1,235 @@ +#ifndef Triton_OPS +#define Triton_OPS + +include "TritonDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType + +// +// Types +// +// FloatType +def F8 : Type">, + /*descr*/"8bit float", + /*cppClassName*/"::mlir::triton::Float8Type">; + +def BF8 : Type()">, + /*descr*/"8bit bfloat", + /*cppClassName*/"::mlir::triton::BFloat8Type">; + +def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : TensorOf<[TT_Float]>; + +// IntegerType +def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">; +def TT_IntegerTensor : TensorOf<[TT_Int]>; + +// PointerType +def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; +def TT_AnyPtr : DialectType; +def TT_PtrTensor : TensorOf<[TT_AnyPtr]>; + +def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; +def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; + +def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, + TT_AnyPtr, TT_PtrTensor]>; + +// +// Op Base +// +class TT_Op traits = []> : + Op; + +// +// CastOps +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins I64Tensor:$from); + + let results = (outs TT_PtrTensor:$result); +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrTensor:$from); + + let results = (outs I64Tensor:$result); +} + +def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8, BF8). + + F8 <-> BF8, FP16, FP32 + BF8 <-> F8, FP16, FP32 + }]; + + let arguments = (ins TT_FloatTensor:$from); + + let results = (outs TT_FloatTensor:$result); + + // TODO: We need a verifier here. +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { + let summary = "load"; + + let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other); + + let results = (outs TT_Type:$result); + + let builders = [ + // for args with default values + OpBuilder<(ins "Value":$ptr)>, + OpBuilder<(ins "Value":$ptr, "Value":$mask)> + ]; +} + +def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> { + let summary = "store"; + + let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask); + + let builders = [ + // for args with default values + OpBuilder<(ins "Value":$ptr, "Value":$value)>, + ]; +} + +def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> { + let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset); + + let results = (outs TT_Type:$result); +} + + +// +// Shape Manipulation Ops +// +// def TT_CatOp : TT_Op<"cat", []>; +def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> { + let summary = "broadcast"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); +} + +// +// builtin Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id"> { + let arguments = (ins I32Attr:$axis); + + let results = (outs I32:$result); +} + +def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "dot"; + + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c); + + let results = (outs TT_FpIntTensor:$d); +} + +// reduction +def TT_RedOpAttr : I32EnumAttr< + /*name*/"RedOp", /*summary*/"", + /*case*/ + [ + I32EnumAttrCase, + I32EnumAttrCase<"MAX", 2, "max">, + I32EnumAttrCase<"MIN", 3, "min">, + I32EnumAttrCase<"XOR_SUM", 4, "xor_sum"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_ReduceOp : TT_Op<"reduce"> { + let summary = "reduce"; + + let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis); +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"MAX", 5, "max">, + I32EnumAttrCase<"MIN", 6, "min">, + I32EnumAttrCase<"UMAX", 7, "umax">, + I32EnumAttrCase<"UMIN", 8, "umin"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr, + TT_Type:$val); + + let results = (outs TT_Type:$result); +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas"> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val); + + let results = (outs TT_Type:$result); +} + +// +// Intrinsics +// +// TODO: should have ConstantLike as Trait +def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { + let summary = "make range"; + + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntegerTensor:$result); +} + +#endif // Triton_OPS diff --git a/include/triton/ir/Types.h b/include/triton/ir/Types.h new file mode 100644 index 000000000..d2e59eb8d --- /dev/null +++ b/include/triton/ir/Types.h @@ -0,0 +1,46 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace triton { + +namespace detail { +struct PointerTypeStorage; +} // namespace detail + +// TODO: Should be base class be FloatType? +class Float8Type : public Type::TypeBase { +public: + using Base::Base; + + static Float8Type get(MLIRContext *context); +}; + +class BFloat8Type : public Type::TypeBase { +public: + using Base::Base; + + static BFloat8Type get(MLIRContext *context); +}; + +class PointerType : public Type::TypeBase { +public: + using Base::Base; + + static PointerType get(Type pointeeType); + + static PointerType get(Type pointeeType, unsigned addressSpace); + + Type getPointeeType() const; + + unsigned getAddressSpace() const; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h deleted file mode 100644 index 840145246..000000000 --- a/include/triton/ir/basic_block.h +++ /dev/null @@ -1,88 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_BASIC_BLOCK_H_ -#define _TRITON_IR_BASIC_BLOCK_H_ - -#include -#include -#include "value.h" -#include "visitor.h" - -namespace triton{ -namespace ir{ - -class context; -class function; -class instruction; - -/* Basic Block */ -class basic_block: public value{ -public: - // instruction iterator types - typedef std::list inst_list_t; - typedef inst_list_t::iterator iterator; - typedef inst_list_t::const_iterator const_iterator; - typedef inst_list_t::reverse_iterator reverse_iterator; - typedef inst_list_t::const_reverse_iterator const_reverse_iterator; - -private: - // constructors - basic_block(context &ctx, const std::string &name, function *parent); - -public: - // accessors - function* get_parent() { return parent_; } - context& get_context() { return ctx_; } - - // get iterator to first instruction that is not a phi - iterator get_first_non_phi(); - - // get instruction list - inst_list_t &get_inst_list() { return inst_list_; } - const inst_list_t &get_inst_list() const { return inst_list_; } - void erase(instruction *i) { inst_list_.remove(i); } - - // instruction iterator functions - inline iterator begin() { return inst_list_.begin(); } - inline const_iterator begin() const { return inst_list_.begin(); } - inline iterator end () { return inst_list_.end(); } - inline const_iterator end () const { return inst_list_.end(); } - - inline reverse_iterator rbegin() { return inst_list_.rbegin(); } - inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); } - inline reverse_iterator rend () { return inst_list_.rend(); } - inline const_reverse_iterator rend () const { return inst_list_.rend(); } - - inline size_t size() const { return inst_list_.size(); } - inline bool empty() const { return inst_list_.empty(); } - inline const instruction &front() const { return *inst_list_.front(); } - inline instruction &front() { return *inst_list_.front(); } - inline const instruction &back() const { return *inst_list_.back(); } - inline instruction &back() { return *inst_list_.back(); } - - // predecessors - const std::vector& get_predecessors() const { return preds_; } - const std::vector& get_successors() const { return succs_; } - void add_predecessor(basic_block* pred); - - // factory functions - static basic_block* create(context &ctx, const std::string &name, function *parent); - - void print(std::ostream &os); - - // visitor - void accept(visitor *v) { v->visit_basic_block(this); } - -private: - context &ctx_; - std::string name_; - function *parent_; - std::vector preds_; - std::vector succs_; - inst_list_t inst_list_; -}; - -} -} - -#endif diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h deleted file mode 100644 index fe85be947..000000000 --- a/include/triton/ir/builder.h +++ /dev/null @@ -1,191 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_BUILDER_H_ -#define _TRITON_IR_BUILDER_H_ - -#include -#include -#include "instructions.h" -#include "basic_block.h" -#include "type.h" - -namespace triton{ -namespace ir{ - -class basic_block; -class value; -class type; -class constant_int; -class instruction; -class context; -class phi_node; - -/* Builder */ -class builder{ - typedef basic_block::iterator iterator; - -public: - // Constructor - builder(context &ctx); - // Getters - const context& get_context() { return ctx_; } - // Setters - void set_insert_point(iterator instr); - void set_insert_point(instruction* i); - void set_insert_point_after(instruction* i); - void set_insert_point(basic_block* block); - basic_block* get_insert_block() { return block_; } - iterator get_insert_point() { return insert_point_;} - // Constants - value *get_int1(bool val); - value *get_int32(uint32_t val); - value *get_int64(uint64_t val); - value *get_float16(float val); - value *get_float32(float val); - value *get_range(int32_t lo, int32_t hi); - // Types - type *get_void_ty(); - type *get_int1_ty(); - type *get_int8_ty(); - type *get_int16_ty(); - type *get_int32_ty(); - type *get_int64_ty(); - type *get_fp8_ty(); - type *get_half_ty(); - type *get_bf16_ty(); - type *get_float_ty(); - type *get_double_ty(); - // Insert - template - InstTy* insert(InstTy *inst){ - assert(block_); - block_->get_inst_list().insert(insert_point_, inst); - inst->set_parent(block_); -// for(ir::value* op: inst->ops()) -// op->add_use(inst); - return inst; - } - // terminator instructions - value* create_br(basic_block *dest); - value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); - value* create_ret_void(); - // Cast instructions - value* create_bitcast(value *src, type *dest_ty); - value *create_cast(cast_op_t op, value *v, type *dst_ty); - value* create_int_to_ptr(value *src, type *dst_ty); - value* create_ptr_to_int(value *src, type *dst_ty); - value* create_si_to_fp(value *src, type *dst_ty); - value* create_ui_to_fp(value *src, type *dst_ty); - value* create_fp_to_si(value *src, type *dst_ty); - value* create_fp_to_ui(value *src, type *dst_ty); - value* create_fp_ext(value *src, type *dst_ty); - value* create_fp_trunc(value *src, type *dst_ty); - value* create_int_cast(value *src, type *dst_ty, bool is_signed); - value *create_downcast(value *arg); - // Phi instruction - phi_node* create_phi(type *ty, unsigned num_reserved); - // Binary instructions - value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw); - value *create_fmul(value *lhs, value *rhs); - value *create_fdiv(value *lhs, value *rhs); - value *create_frem(value *lhs, value *rhs); - value *create_fadd(value *lhs, value *rhs); - value *create_fsub(value *lhs, value *rhs); - value *create_sdiv(value *lhs, value *rhs); - value *create_udiv(value *lhs, value *rhs); - value *create_srem(value *lhs, value *rhs); - value *create_urem(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); - // GEP - value *create_gep(value *ptr, const std::vector& idx_list); - // Comparison (int) - value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs); - value *create_icmpSLE(value *lhs, value *rhs); - value *create_icmpSLT(value *lhs, value *rhs); - value *create_icmpSGE(value *lhs, value *rhs); - value *create_icmpSGT(value *lhs, value *rhs); - value *create_icmpULE(value *lhs, value *rhs); - value *create_icmpULT(value *lhs, value *rhs); - value *create_icmpUGE(value *lhs, value *rhs); - value *create_icmpUGT(value *lhs, value *rhs); - value *create_icmpEQ(value *lhs, value *rhs); - value *create_icmpNE(value *lhs, value *rhs); - // Comparison (float) - value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs); - value *create_fcmpOLT(value *lhs, value *rhs); - value *create_fcmpOGT(value *lhs, value *rhs); - value *create_fcmpOLE(value *lhs, value *rhs); - value *create_fcmpOGE(value *lhs, value *rhs); - value *create_fcmpOEQ(value *lhs, value *rhs); - value *create_fcmpONE(value *lhs, value *rhs); - value *create_fcmpULT(value *lhs, value *rhs); - value *create_fcmpUGT(value *lhs, value *rhs); - value *create_fcmpULE(value *lhs, value *rhs); - value *create_fcmpUGE(value *lhs, value *rhs); - value *create_fcmpUEQ(value *lhs, value *rhs); - value *create_fcmpUNE(value *lhs, value *rhs); - // Logical - value *create_and(value *lhs, value *rhs); - value *create_xor(value *lhs, value *rhs); - value *create_or(value *lhs, value *rhs); - // Input/Output - value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_masked_store(value *ptr, value *val, value *mask); - // Block instruction - value *create_splat(value *arg, const type::block_shapes_t &shapes); - value *create_reshape(value *arg, const type::block_shapes_t &shapes); - value *create_cat(value *lhs, value *rhs); - value *create_broadcast(value *arg, const type::block_shapes_t &shapes); - // Atomic instruction - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); - value *create_atomic_max(value *ptr, value *val, value *msk); - value *create_atomic_umax(value *ptr, value *val, value *msk); - value *create_atomic_min(value *ptr, value *val, value *msk); - value *create_atomic_umin(value *ptr, value *val, value *msk); - value *create_atomic_fadd(value *ptr, value *val, value *msk); - value *create_atomic_add(value *ptr, value *val, value *msk); - value *create_atomic_and(value *ptr, value *val, value *msk); - value *create_atomic_or(value *ptr, value *val, value *msk); - value *create_atomic_xor(value *ptr, value *val, value *msk); - value *create_atomic_xchg(value *ptr, value *val, value *msk); - // Built-in instruction - value *create_get_program_id(unsigned axis); - value *create_get_num_programs(unsigned axis); - value *create_exp(value* arg); - value *create_cos(value* arg); - value *create_sin(value* arg); - value *create_log(value* arg); - value *create_dot(value *A, value *B, value *C, bool allow_tf32); - value *create_trans(value *A, const std::vector &perm = {}); - value *create_sqrt(value *A); - value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); - value *create_select(value *pred, value *if_value, value *else_value); - // Intrinsics - // These have no place in the IR, and hopefully they can be removed at some point - value *create_umulhi(value* lhs, value* rhs); - value *create_copy_to_shared(value *arg); - value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY); - value *create_copy_from_shared(value *arg); - value *create_barrier(const std::string &name = ""); - value *create_async_wait(int N); - value *create_prefetch_s(value *arg, int inc); - -private: - context &ctx_; - basic_block *block_; - iterator insert_point_; -}; - - -} -} - -#endif diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h deleted file mode 100644 index 671d5e5f0..000000000 --- a/include/triton/ir/constant.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_CONSTANT_H_ -#define _TRITON_IR_CONSTANT_H_ - -#include "enums.h" -#include "value.h" -#include -#include "visitor.h" - -namespace triton{ -namespace ir{ - -class type; -class context; - -/* Constant */ -class constant: public user{ -protected: - using user::user; - -public: - static constant* get_all_ones_value(type *ty); - static constant* get_null_value(type *ty); - virtual std::string repr() const = 0; -}; - -/* Undef value */ -class undef_value: public constant{ -private: - undef_value(type *ty); - -public: - static undef_value* get(type* ty); - std::string repr() const { return "undef"; } - void accept(visitor* vst) { vst->visit_undef_value(this); } -}; - - -/* Constant int */ -class constant_int: public constant{ -protected: - constant_int(type *ty, uint64_t value); - -public: - virtual uint64_t get_value() const { return value_; } - static constant_int *get(type *ty, uint64_t value); - std::string repr() const { return std::to_string(value_); } - void accept(visitor* vst) { vst->visit_constant_int(this); } - -protected: - uint64_t value_; -}; - -/* Constant fp */ -class constant_fp: public constant{ - constant_fp(type *ty, double value); - -public: - double get_value() { return value_; } - static constant* get_negative_zero(type *ty); - static constant* get_zero_value_for_negation(type *ty); - static constant* get(context &ctx, double v); - static constant* get(type *ty, double v); - std::string repr() const { return std::to_string(value_); } - void accept(visitor* vst) { vst->visit_constant_fp(this); } - -private: - double value_; -}; - - -/* Global Value */ -class global_value: public constant { -public: - enum linkage_types_t { - external - }; - -public: - global_value(type *ty, unsigned num_ops, - linkage_types_t linkage, const std::string &name, - unsigned addr_space); - std::string repr() const { return get_name(); } - -private: - linkage_types_t linkage_; -}; - -/* global object */ -class global_object: public global_value { -public: - global_object(type *ty, unsigned num_ops, - linkage_types_t linkage, const std::string &name, - unsigned addr_space = 0); - std::string repr() const { return get_name(); } -}; - -/* global variable */ -class alloc_const: public global_object { -public: - alloc_const(type *ty, constant_int *size, - const std::string &name = ""); - std::string repr() const { return get_name(); } - void accept(visitor* vst) { vst->visit_alloc_const(this); } - - -}; - -} -} - -#endif diff --git a/include/triton/ir/context.h b/include/triton/ir/context.h deleted file mode 100644 index d824c98b6..000000000 --- a/include/triton/ir/context.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_CONTEXT_H_ -#define _TRITON_IR_CONTEXT_H_ - -#include -#include "triton/ir/type.h" - -namespace triton{ -namespace ir{ - -class type; -class context_impl; - -/* Context */ -class context { -public: - context(); - context(const context&) = delete; - context& operator=(const context&) = delete; - -public: - std::shared_ptr p_impl; -}; - -} -} - -#endif diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h deleted file mode 100644 index ef20af6b7..000000000 --- a/include/triton/ir/context_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_CONTEXT_IMPL_H_ -#define _TRITON_IR_CONTEXT_IMPL_H_ - -#include "triton/ir/type.h" -#include "triton/ir/constant.h" -#include -#include - -namespace triton{ -namespace ir{ - -class context; - -/* Context impl */ -class context_impl { -public: - // constructors - context_impl(context &ctx); - -public: - // non-numeric types - type void_ty, label_ty; - // floating point types - type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; - // integer types - integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; - // Pointer types - std::map, std::unique_ptr> ptr_tys; - // Block types - std::map, std::unique_ptr> block_tys; - - // Int constants - std::map, std::unique_ptr> int_constants_; - // Float constants - std::map, std::unique_ptr> fp_constants_; - // undef values - std::map> uv_constants_; - -}; - -} -} - -#endif diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h deleted file mode 100644 index 8cb7835f0..000000000 --- a/include/triton/ir/enums.h +++ /dev/null @@ -1,175 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_ENUMS_H_ -#define _TRITON_IR_ENUMS_H_ - -namespace triton{ -namespace ir{ - - -enum binary_op_t: unsigned int{ - Add, - FAdd, - Sub, - FSub, - Mul, - FMul, - UDiv, - SDiv, - FDiv, - URem, - SRem, - FRem, - Shl, - LShr, - AShr, - And, - Or, - Xor -}; - -enum class atomic_rmw_op_t: unsigned int{ - And, - Or, - Xor, - Add, - Max, - Min, - UMax, - UMin, - FAdd, - Xchg, -}; - -enum cast_op_t: unsigned int { - Trunc, - ZExt, - SExt, - FPTrunc, - FPExt, - UIToFP, - SIToFP, - FPToUI, - FPToSI, - PtrToInt, - IntToPtr, - BitCast, - AddrSpaceCast -}; - -enum cmp_pred_t: unsigned int { - FIRST_FCMP_PREDICATE, - FCMP_FALSE, - FCMP_OEQ, - FCMP_OGT, - FCMP_OGE, - FCMP_OLT, - FCMP_OLE, - FCMP_ONE, - FCMP_ORD, - FCMP_UNO, - FCMP_UEQ, - FCMP_UGT, - FCMP_UGE, - FCMP_ULT, - FCMP_ULE, - FCMP_UNE, - FCMP_TRUE, - LAST_FCMP_PREDICATE, - FIRST_ICMP_PREDICATE, - ICMP_EQ, - ICMP_NE, - ICMP_UGT, - ICMP_UGE, - ICMP_ULT, - ICMP_ULE, - ICMP_SGT, - ICMP_SGE, - ICMP_SLT, - ICMP_SLE, - LAST_ICMP_PREDICATE -}; - -enum value_id_t: unsigned { - /* ------------ * - INSTRUCTIONS - * ------------ */ - INST_BEGIN, - // phi - INST_PHI, - // arithmetic - INST_BINOP, - INST_GETELEMENTPTR, - INST_SELECT, - INST_SQRT, - // cmp - INST_ICMP, - INST_FCMP, - // cast - INST_CAST_TRUNC, - INST_CAST_ZEXT, - INST_CAST_SEXT, - INST_CAST_FP_TRUNC, - INST_CAST_FP_EXT, - INST_CAST_UI_TO_FP, - INST_CAST_SI_TO_FP, - INST_CAST_FP_TO_UI, - INST_CAST_FP_TO_SI, - INST_CAST_PTR_TO_INT, - INST_CAST_INT_TO_PTR, - INST_CAST_BIT_CAST, - INST_CAST_ADDR_SPACE_CAST, - // terminators - INST_RETURN, - INST_COND_BRANCH, - INST_UNCOND_BRANCH, - // io - INST_UNMASKED_LOAD, - INST_MASKED_LOAD, - INST_MASKED_LOAD_ASYNC, - INST_UNMASKED_STORE, - INST_MASKED_STORE, - // retile - INST_RESHAPE, - INST_SPLAT, - INST_CAT, - INST_BROADCAST, - INST_DOWNCAST, - // builtin - INST_GET_PROGRAM_ID, - INST_GET_NUM_PROGRAMS, - // atomics - INST_ATOMIC_CAS, - INST_ATOMIC_EXCH, - INST_ATOMIC_RMW, - // math - INST_UMULHI, - INST_EXP, - INST_COS, - INST_SIN, - INST_LOG, - // array arithmetic - INST_TRANS, - INST_REDUCE, - INST_DOT, - // intrinsics - INST_COPY_TO_SHARED, - INST_COPY_FROM_SHARED, - INST_CVT_LAYOUT, - INST_CVT_SCANLINE, - INST_DECOALESCE, - INST_RECOALESCE, - INST_BARRIER, - INST_ASYNC_WAIT, - INST_MAKE_RANGE_DYN, - INST_MAKE_RANGE_STA, - INST_MAKE_RANGE, - INST_PREFETCH_S, -}; - - - -} -} - -#endif diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h deleted file mode 100644 index 9e1bc981a..000000000 --- a/include/triton/ir/function.h +++ /dev/null @@ -1,142 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_FUNCTION_H_ -#define _TRITON_IR_FUNCTION_H_ - -#include -#include -#include "value.h" -#include "constant.h" - -namespace triton{ -namespace ir{ - -class function; -class function_type; -class module; -class basic_block; - -/* Argument */ -class argument: public value{ - argument(type *ty, const std::string &name, function *parent, unsigned arg_no); - -public: - static argument* create(type *ty, const std::string &name, - function *parent = nullptr, unsigned arg_no = 0); - function* get_parent() const; - unsigned get_arg_no() const; - - void accept(visitor *v); - -private: - function *parent_; - unsigned arg_no_; -}; - -/* Attribute */ -enum attribute_kind_t { - readonly = 0, - writeonly, - noalias, - aligned, - multiple_of, - retune, - not_implemented -}; - -class attribute { -public: - attribute(attribute_kind_t kind, unsigned value = 0): - kind_(kind), value_(value){} - - bool operator<(const attribute& other) const { - return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_); - } - - attribute_kind_t get_kind() const { - return kind_; - } - - unsigned get_value() const { - return value_; - } - - bool is_llvm_attr() const { - return kind_ != multiple_of; - } - - std::string repr() const { - switch(kind_){ - case readonly: return ".readonly"; - case writeonly: return ".writeonly"; - case noalias: return ".noalias"; - case aligned: return ".aligned(" + std::to_string(value_) + ")"; - case multiple_of: return ".multipleof(" + std::to_string(value_) + ")"; - case retune: return ".retunr"; - default: break; - } - assert(false); - return ""; - } - -private: - attribute_kind_t kind_; - unsigned value_; -}; - -/* Function */ -class function: public global_object{ - typedef std::vector args_t; - typedef args_t::iterator arg_iterator; - typedef args_t::const_iterator const_arg_iterator; - - typedef std::vector blocks_t; - typedef blocks_t::iterator block_iterator; - typedef blocks_t::const_iterator const_block_iterator; - - typedef std::map> attr_map_t; - -private: - function(function_type *ty, linkage_types_t linkage, - const std::string &name = "", module *parent = nullptr); - -public: - // accessors - const args_t &args() const { return args_; } - function_type* get_fn_type() { return fn_ty_; } - const function_type* get_fn_type() const { return fn_ty_; } - module *get_parent() { return parent_; } - const module *get_parent() const { return parent_; } - - // factory methods - static function *create(function_type *ty, linkage_types_t linkage, - const std::string &name, module *mod); - // blocks - const blocks_t &blocks() { return blocks_; } - const blocks_t &blocks() const { return blocks_; } - void insert_block(basic_block* block, basic_block *next = nullptr); - - // attributes - void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); } - const attr_map_t &attrs() { return attrs_; } - bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } - std::set get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; } - - void print(std::ostream &os); - - // visitor - void accept(visitor *v) { v->visit_function(this); } - -private: - module *parent_; - bool init_; - function_type *fn_ty_; - args_t args_; - blocks_t blocks_; - attr_map_t attrs_; -}; - -} -} - -#endif diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h deleted file mode 100644 index 0fb85db02..000000000 --- a/include/triton/ir/instructions.h +++ /dev/null @@ -1,978 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_INSTRUCTIONS_H_ -#define _TRITON_IR_INSTRUCTIONS_H_ - -#include -#include -#include "triton/ir/enums.h" -#include "triton/ir/constant.h" -#include "triton/ir/value.h" -#include "triton/ir/type.h" -#include "triton/ir/metadata.h" -#include "triton/ir/visitor.h" - -#define _TRITON_DEFINE_CLONE(name) \ - ir::instruction* clone_impl() const { return new name(*this); } - -#define _TRITON_DEFINE_ACCEPT(name) \ - void accept(visitor* v) { v->visit_ ## name (this); } - -namespace triton{ -namespace ir{ - -class constant_int; -class constant; -class make_range; -class basic_block; -class context; -class visitor; - -//===----------------------------------------------------------------------===// -// instruction classes -//===----------------------------------------------------------------------===// - -class result_reference; - - -class instruction: public user{ -public: - virtual std::string repr_impl() const = 0; - -private: - virtual ir::instruction* clone_impl() const = 0; - -protected: - // constructors - instruction(type *ty, value_id_t ity, unsigned num_ops, - const std::string &name = "", instruction *next = nullptr); - -public: - // parent - void set_parent(basic_block *block) { parent_ = block; } - const basic_block *get_parent() const { return parent_; } - basic_block *get_parent() { return parent_; } - void erase_from_parent(); - // helpers - bool has_tile_result_or_op(); - // repr - std::string repr() const { return repr_impl(); } - // metadata - void set_metadata(ir::metadata::kind_t kind, - unsigned value) { metadatas_[kind] = value;} - unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} - // cloning - ir::instruction* clone() { - ir::instruction* res = clone_impl(); -// for(auto it = op_begin(); it != op_end(); it++) -// (*it)->add_use(res); - res->parent_ = nullptr; - res->users_.clear(); - return res; - } - // instruction id - value_id_t get_id() const { return id_; } - - void print(std::ostream &os); - -private: - basic_block *parent_; - std::map metadatas_; - value_id_t id_; -}; - - -//===----------------------------------------------------------------------===// -// phi_node classes -//===----------------------------------------------------------------------===// - -class phi_node: public instruction { -private: - phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next); - std::string repr_impl() const { return "phi"; } - -public: - void set_incoming_value(unsigned i, value *v); - void set_incoming_block(unsigned i, basic_block *block); - value *get_value_for_block(basic_block *block); - value *get_incoming_value(unsigned i) { return get_operand(i); } - basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } - unsigned get_num_incoming() { return get_num_operands(); } - void add_incoming(value *v, basic_block *block); - - // Type - void set_type(type *ty) { ty_ = ty; } - - // Factory methods - static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr); - - _TRITON_DEFINE_CLONE(phi_node) - _TRITON_DEFINE_ACCEPT(phi_node) - -private: - unsigned num_reserved_; - std::vector blocks_; -}; - -//===----------------------------------------------------------------------===// -// binary_operator classes -//===----------------------------------------------------------------------===// - -class binary_operator: public instruction { -public: - typedef binary_op_t op_t; - -private: - std::string repr_impl() const; - -protected: - // Constructors - binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next); - -public: - // Get operand - binary_op_t get_op() const { return op_; } - - // Bool - bool is_terminator() const; - bool is_binary_op() const; - bool is_int_div_rem() const; - bool is_shift() const; - bool is_cast() const; - bool is_int_mult() const; - bool is_int_add_sub() const; - bool is_int_div() const; - bool is_int_rem() const; - bool is_shl() const; - bool is_shr() const; - - // Approx - void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; } - bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; } - - // Wraps - void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } - void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } - - // Factory methods - static binary_operator *create(binary_op_t op, value *lhs, value *rhs, - const std::string &name = "", instruction *next = nullptr); -// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr); -// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr); -// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr); - - _TRITON_DEFINE_CLONE(binary_operator) - _TRITON_DEFINE_ACCEPT(binary_operator) - -public: - binary_op_t op_; - bool has_no_unsigned_wrap_; - bool has_no_signed_wrap_; - - bool fdiv_ieee_rnd_; -}; - - -//===----------------------------------------------------------------------===// -// cmp_inst classes -//===----------------------------------------------------------------------===// - -class cmp_inst: public instruction{ -public: - typedef cmp_pred_t pred_t; - -private: - std::string repr_impl() const; - -protected: - cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, - value *lhs, value *rhs, const std::string &name, instruction *next); - static bool is_fp_predicate(cmp_pred_t pred); - static bool is_int_predicate(cmp_pred_t pred); - static type* make_cmp_result_type(type *ty); - -public: - cmp_pred_t get_pred() const { return pred_; } - -private: - cmp_pred_t pred_; -}; - -class icmp_inst: public cmp_inst { - icmp_inst(type *ty, cmp_pred_t pred, - value *lhs, value *rhs, const std::string &name, instruction *next); - -public: - static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(icmp_inst) - _TRITON_DEFINE_ACCEPT(icmp_inst) -}; - -class fcmp_inst: public cmp_inst { - fcmp_inst(type *ty, cmp_pred_t pred, - value *lhs, value *rhs, const std::string &name, instruction *next); - -public: - static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(fcmp_inst) - _TRITON_DEFINE_ACCEPT(fcmp_inst) -}; - -//===----------------------------------------------------------------------===// -// unary_inst classes -//===----------------------------------------------------------------------===// - -class unary_inst: public instruction { -protected: - unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next); -}; - - -//===----------------------------------------------------------------------===// -// cast_inst classes -//===----------------------------------------------------------------------===// - -class cast_inst: public unary_inst{ -private: - std::string repr_impl() const; - -protected: - cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op) - : unary_inst(ty, id, v, name, next), op_(op) { } - -private: - static bool is_valid(cast_op_t op, value *arg, type *ty); - -public: - // accessors - cast_op_t get_op() const { return op_; } - - // factory methods - static cast_inst *create(cast_op_t op, value *arg, type *ty, - const std::string &name = "", instruction *next = nullptr); - static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed, - const std::string &name = "", instruction *next = nullptr); - - _TRITON_DEFINE_ACCEPT(cast_inst) - -private: - cast_op_t op_; -}; - -#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \ -class name : public cast_inst { \ - _TRITON_DEFINE_CLONE(name) \ - friend class cast_inst; \ - name(type *ty, value *v, const std::string &name, instruction *next) \ - : cast_inst(ty, id, v, name, next, op){ } \ -}; - -TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc) -TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP) -TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI) -TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr) -TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast) -TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast) - -//===----------------------------------------------------------------------===// -// terminator_inst classes -//===----------------------------------------------------------------------===// - -class terminator_inst: public instruction{ - using instruction::instruction; -}; - -// return instruction -class return_inst: public terminator_inst { -private: - std::string repr_impl() const { return "ret"; } - return_inst(context &ctx, value *ret_val, instruction *next); - -public: - // accessors - value *get_return_value() - { return get_num_operands() ? get_operand(0) : nullptr; } - - unsigned get_num_successors() const { return 0; } - - // factory methods - static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr); - - _TRITON_DEFINE_CLONE(return_inst) - _TRITON_DEFINE_ACCEPT(return_inst) -}; - -// base branch instruction -class branch_inst: public terminator_inst{ -private: - std::string repr_impl() const { return "br"; } - -protected: - using terminator_inst::terminator_inst; - -public: - static branch_inst* create(basic_block *dest, - instruction *next = nullptr); - static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest, - instruction *next = nullptr); -}; - -// conditional branch -class cond_branch_inst: public branch_inst { -private: - friend class branch_inst; - cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next); - -public: - basic_block *get_true_dest() { return (basic_block*)get_operand(0); } - basic_block *get_false_dest() { return (basic_block*)get_operand(1); } - value *get_cond() { return get_operand(2); } - _TRITON_DEFINE_CLONE(cond_branch_inst) - _TRITON_DEFINE_ACCEPT(cond_branch_inst) -}; - -// unconditional branch -class uncond_branch_inst: public branch_inst { -private: - friend class branch_inst; - uncond_branch_inst(basic_block *dst, instruction *next); - -public: - basic_block *get_dest() { return (basic_block*)get_operand(0); } - _TRITON_DEFINE_CLONE(uncond_branch_inst) - _TRITON_DEFINE_ACCEPT(uncond_branch_inst) -}; - - -//===----------------------------------------------------------------------===// -// getelementptr_inst classes -//===----------------------------------------------------------------------===// - -class getelementptr_inst: public instruction { -private: - std::string repr_impl() const { return "getelementptr"; } - getelementptr_inst(type *pointee_ty, value *ptr, const std::vector &idx, const std::string &name, instruction *next); - -private: - static type *get_return_type(type *ty, value *ptr, const std::vector &idx); - static type *get_indexed_type_impl(type *ty, const std::vector &idx); - static type *get_indexed_type(type *ty, const std::vector &idx); - -public: - // accessors - type *get_source_elt_ty() { return source_elt_ty; } - op_iterator idx_begin() { return op_begin() + 1; } - op_iterator idx_end() { return op_end(); } - value *get_pointer_operand() { return *op_begin(); } - - // factory methods - static getelementptr_inst* create(value *ptr, const std::vector &idx, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(getelementptr_inst) - _TRITON_DEFINE_ACCEPT(getelementptr_inst) - -private: - type *source_elt_ty; - type *res_elt_ty; -}; - -//===----------------------------------------------------------------------===// -// load_inst/store_inst classes -//===----------------------------------------------------------------------===// - -class io_inst: public instruction { -protected: - io_inst(type *ty, value_id_t id, unsigned num_ops, - const std::string &name = "", instruction *next = nullptr); - -public: - // accessors - value *get_pointer_operand() { return get_operand(0); } -}; - -// load -class load_inst: public io_inst { -public: - enum CACHE_MODIFIER : uint32_t { - NONE=0, - CA, - CG, - }; - - enum EVICTION_POLICY : uint32_t { - NORMAL=0, - EVICT_FIRST, - EVICT_LAST, - }; - - CACHE_MODIFIER get_cache_modifier() const { return cache_; } - EVICTION_POLICY get_eviction_policy() const { return eviction_; } - bool get_is_volatile() const { return is_volatile_; } - -protected: - load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction, - bool is_volatile, - const std::string &name = "", instruction *next = nullptr); - std::string get_cache_modifier_repr() const { - if (cache_ == CA) return ".ca"; - if (cache_ == CG) return ".cg"; - return ""; - } - std::string get_eviction_policy_repr() const { - if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; - if (eviction_ == EVICT_LAST) return ".L2::evict_last"; - } - EVICTION_POLICY eviction_; - CACHE_MODIFIER cache_; - - std::string get_volatile_repr() { - return is_volatile_ ? ".volatile" : ""; - } - bool is_volatile_; - -private: - static type *get_pointee_type(type *ty); -}; - -// unmasked load -class unmasked_load_inst: public load_inst { -private: - std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } - unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); - -public: - static unmasked_load_inst* create(value *ptr, - CACHE_MODIFIER cache, EVICTION_POLICY eviction, - bool is_volatile, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(unmasked_load_inst) - _TRITON_DEFINE_ACCEPT(unmasked_load_inst) -}; - -// masked load -class masked_load_inst: public load_inst { -private: - std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } - masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, - const std::string &name, instruction *next); - -public: - // accessors - value *get_mask_operand() { return get_operand(1); } - value *get_false_value_operand() { return get_operand(2); } - // factory method - static masked_load_inst* create(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, EVICTION_POLICY eviction, - bool is_volatile, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(masked_load_inst) - _TRITON_DEFINE_ACCEPT(masked_load_inst) -}; - -// masked load async -class masked_load_async_inst: public load_inst { -private: - std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } - masked_load_async_inst(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, EVICTION_POLICY eviction, - const std::string &name, instruction *next); - -public: - // accessors - value *get_mask_operand() { return get_operand(1); } - value *get_false_value_operand() { return get_operand(2); } - // factory method - static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, - EVICTION_POLICY eviction, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(masked_load_async_inst) - _TRITON_DEFINE_ACCEPT(masked_load_async_inst) -}; - - - -// store -class store_inst: public io_inst { -protected: - store_inst(value *ptr, value_id_t id, unsigned num_ops, - const std::string &name = "", instruction *next = nullptr); - -public: - value *get_value_operand() { return get_operand(1); } -}; - -// unmasked_store -class unmasked_store_inst: public store_inst{ -private: - std::string repr_impl() const { return "unmasked_store"; } - unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next); - -public: - // factory method - static unmasked_store_inst* create(value* ptr, value *v, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(unmasked_store_inst) - _TRITON_DEFINE_ACCEPT(unmasked_store_inst) -}; - -class masked_store_inst: public store_inst{ -private: - std::string repr_impl() const { return "masked_store"; } - masked_store_inst(value *ptr, value *v, value *mask, - const std::string &name, instruction *next); - -public: - // accessors - value *get_mask_operand() { return get_operand(2); } - // factory method - static masked_store_inst* create(value *ptr, value *v, value *mask, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(masked_store_inst) - _TRITON_DEFINE_ACCEPT(masked_store_inst) -}; - -//===----------------------------------------------------------------------===// -// retile_inst classes -//===----------------------------------------------------------------------===// - -// cat - -class cat_inst: public instruction { -private: - std::string repr_impl() const { return "cat"; } - cat_inst(value *x, value *y, const std::string &name, instruction *next); - -public: - static instruction* create(value *lhs, value *rhs, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(cat_inst) - _TRITON_DEFINE_ACCEPT(cat_inst) -}; - -// retile - -class retile_inst: public unary_inst { -protected: - retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next); -}; - -// reshape - -class reshape_inst: public retile_inst { -private: - using retile_inst::retile_inst; - std::string repr_impl() const { return "reshape"; } - -public: - static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(reshape_inst) - _TRITON_DEFINE_ACCEPT(reshape_inst) -}; - -// splat - -class splat_inst: public retile_inst { -private: - using retile_inst::retile_inst; - std::string repr_impl() const { return "splat"; } - -public: - static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(splat_inst) - _TRITON_DEFINE_ACCEPT(splat_inst) -}; - -// broadcast - -class broadcast_inst: public retile_inst { -private: - using retile_inst::retile_inst; - std::string repr_impl() const { return "broadcast"; } - -public: - static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, - const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(broadcast_inst) - _TRITON_DEFINE_ACCEPT(broadcast_inst) -}; - - -// downcast - -class downcast_inst: public unary_inst { -private: - using unary_inst::unary_inst; - std::string repr_impl() const { return "downcast"; } - -public: - static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(downcast_inst) - _TRITON_DEFINE_ACCEPT(downcast_inst) -}; - -//===----------------------------------------------------------------------===// -// builtin_inst classes -//===----------------------------------------------------------------------===// - -class builtin_inst: public instruction{ -protected: - using instruction::instruction; -}; - -class get_program_id_inst: public builtin_inst { -private: - get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); - std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; } - -public: - static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); - unsigned get_axis() const { return axis_; } - _TRITON_DEFINE_CLONE(get_program_id_inst) - _TRITON_DEFINE_ACCEPT(get_program_id_inst) - -private: - unsigned axis_; -}; - -class get_num_programs_inst: public builtin_inst { -private: - get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next); - std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; } - -public: - static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); - unsigned get_axis() const { return axis_; } - _TRITON_DEFINE_CLONE(get_num_programs_inst) - _TRITON_DEFINE_ACCEPT(get_num_programs_inst) - -private: - unsigned axis_; -}; - - -class atomic_inst: public io_inst { -public: - using io_inst::io_inst; -}; - -class atomic_rmw_inst: public atomic_inst { -private: - atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "atomic_rmw"; } - _TRITON_DEFINE_CLONE(atomic_rmw_inst) - _TRITON_DEFINE_ACCEPT(atomic_rmw_inst) - -public: - static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); - atomic_rmw_op_t get_op() { return op_; } - -private: - atomic_rmw_op_t op_; -}; - -class atomic_cas_inst: public atomic_inst { -private: - atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); - std::string repr_impl() const { return "atomic_cas"; } - _TRITON_DEFINE_CLONE(atomic_cas_inst) - _TRITON_DEFINE_ACCEPT(atomic_cas_inst) - -public: - static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); -}; - -class umulhi_inst: public builtin_inst { -private: - umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "umulhi"; } - _TRITON_DEFINE_CLONE(umulhi_inst) - _TRITON_DEFINE_ACCEPT(umulhi_inst) - -public: - static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr); -}; - -class exp_inst: public builtin_inst { -private: - exp_inst(value *val, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "exp"; } - _TRITON_DEFINE_CLONE(exp_inst) - _TRITON_DEFINE_ACCEPT(exp_inst) - -public: - static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); -}; - -class cos_inst: public builtin_inst { -private: - cos_inst(value *val, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "cos"; } - _TRITON_DEFINE_CLONE(cos_inst) - _TRITON_DEFINE_ACCEPT(cos_inst) - -public: - static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); -}; - -class sin_inst: public builtin_inst { -private: - sin_inst(value *val, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "sin"; } - _TRITON_DEFINE_CLONE(sin_inst) - _TRITON_DEFINE_ACCEPT(sin_inst) - -public: - static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); -}; - -class log_inst: public builtin_inst { -private: - log_inst(value *val, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "log"; } - _TRITON_DEFINE_CLONE(log_inst) - _TRITON_DEFINE_ACCEPT(log_inst) - -public: - static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); -}; - - -class dot_inst: public builtin_inst { -public: - enum TransT { NoTrans, Trans }; - enum DataType { - FP8, FP16, BF16, TF32, FP32, - INT1, INT4, INT8, INT32, - UNKNOWN, - }; - -private: - dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next); - std::string repr_impl() const { return "dot"; } - -public: - bool is_prefetched() const { return is_prefetched_; } - void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } - bool allow_tf32() const { return allow_tf32_; } - -public: - static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(dot_inst) - _TRITON_DEFINE_ACCEPT(dot_inst) - -private: - bool is_prefetched_ = false; - bool allow_tf32_ = false; - DataType C_type_ = DataType::FP32; - DataType A_type_ = DataType::FP16; - DataType B_type_ = DataType::FP16; -}; - -//class outer_inst: public builtin_inst { -//private: -// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next); -//public: -// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); -//}; - -class trans_inst: public builtin_inst { -public: - ir::type* get_res_ty(ir::type* in, std::vector perm); - std::vector init_perm(ir::type* ty, const std::vector& perm); - -private: - trans_inst(value *arg, const std::vector& perm, const std::string& name, instruction* next); - std::string repr_impl() const { return "trans"; } - -public: - static instruction* create(value *arg, const std::vector &perm = {}, const std::string &name = "", instruction *next = nullptr); - const std::vector get_perm() const; - _TRITON_DEFINE_CLONE(trans_inst) - _TRITON_DEFINE_ACCEPT(trans_inst) - -private: - std::vector perm_; -}; - -class sqrt_inst: public builtin_inst { -private: - sqrt_inst(value *arg, const std::string& name, instruction* next); - std::string repr_impl() const { return "sqrt"; } -public: - static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(sqrt_inst) - _TRITON_DEFINE_ACCEPT(sqrt_inst) -}; - -class reduce_inst: public builtin_inst { -public: - enum op_t{ - ADD, SUB, MAX, MIN, - FADD, FSUB, FMAX, FMIN, - XOR - }; - -private: - static type* get_res_type(value *arg, unsigned axis); - static std::string to_str(op_t op); - -private: - reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next); - std::string repr_impl() const { return "reduce"; } - _TRITON_DEFINE_CLONE(reduce_inst) - _TRITON_DEFINE_ACCEPT(reduce_inst) - -public: - static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); - unsigned get_axis() const { return axis_; } - op_t get_op() const { return op_; } - -private: - unsigned axis_; - op_t op_; -}; - -class select_inst: public builtin_inst { -private: - select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); - std::string repr_impl() const { return "select"; } - _TRITON_DEFINE_CLONE(select_inst) - _TRITON_DEFINE_ACCEPT(select_inst) - -public: - static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr); - value* get_pred_op() { return get_operand(0); } - value* get_if_value_op() { return get_operand(1); } - value* get_else_value_op() { return get_operand(2); } -}; - -//===----------------------------------------------------------------------===// -// intrinsics classes -//===----------------------------------------------------------------------===// - - -class copy_to_shared_inst: public unary_inst{ -private: - using unary_inst::unary_inst; - std::string repr_impl() const { return "copy_to_shared"; } - -public: - static copy_to_shared_inst* create(value *arg, const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(copy_to_shared_inst) - _TRITON_DEFINE_ACCEPT(copy_to_shared_inst) -}; - -class copy_from_shared_inst: public unary_inst{ -private: - using unary_inst::unary_inst; - std::string repr_impl() const { return "copy_from_shared"; } - -public: - static copy_from_shared_inst* create(value *arg, const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(copy_from_shared_inst) - _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) -}; - -class cvt_layout_inst: public unary_inst { -private: - using unary_inst::unary_inst; - std::string repr_impl() const { return "cvt_layout_inst"; } - -public: - static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(cvt_layout_inst) - _TRITON_DEFINE_ACCEPT(cvt_layout_inst) -}; - -class barrier_inst: public instruction{ -private: - barrier_inst(context &ctx, const std::string &name, instruction *next); - std::string repr_impl() const { return "barrier"; } - _TRITON_DEFINE_CLONE(barrier_inst) - _TRITON_DEFINE_ACCEPT(barrier_inst) - -public: - static barrier_inst* create(context &ctx, const std::string &name = "", - instruction *next = nullptr); -}; - -class async_wait_inst: public instruction{ -private: - async_wait_inst(context &ctx, int N, const std::string &name, instruction *next); - std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; } - _TRITON_DEFINE_CLONE(async_wait_inst) - _TRITON_DEFINE_ACCEPT(async_wait_inst) - -public: - static async_wait_inst* create(context &ctx, int N, - const std::string &name = "", instruction *next = nullptr); - int get_N() { return N_; } - void set_N(int n) { N_ = n; } - -private: - int N_; -}; - -class prefetch_s_inst : public instruction { - std::string repr_impl() const { return "prefetch_s"; } - _TRITON_DEFINE_CLONE(prefetch_s_inst) - _TRITON_DEFINE_ACCEPT(prefetch_s_inst) - - /// inc_: 0->first, 1->latch - int inc_ = 0; -public: - prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next) - : instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) { - set_operand(0, arg); - } - int get_inc() const { return inc_; } - static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "", - instruction *next=nullptr); -}; - -/* constant range */ -class make_range: public instruction{ - make_range(type *ty, constant_int* first, constant_int* last); - std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; } - _TRITON_DEFINE_CLONE(make_range) - _TRITON_DEFINE_ACCEPT(make_range) - -public: - static make_range *create(constant_int *first, constant_int *last); - const constant_int* get_first() const; - const constant_int* get_last() const; - -private: - constant_int* first_; - constant_int* last_; -}; - - -} -} - -#endif diff --git a/include/triton/ir/metadata.h b/include/triton/ir/metadata.h deleted file mode 100644 index 9d4fb1137..000000000 --- a/include/triton/ir/metadata.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_METADATA_H_ -#define _TRITON_IR_METADATA_H_ - -namespace triton{ -namespace ir{ - - -/* Metadata */ -class metadata{ -public: - enum kind_t{ - multiple_of, - max_contiguous - }; - -private: - metadata(kind_t kind, unsigned value); - -public: - static metadata* get(kind_t kind, unsigned value); - -private: - kind_t kind_; - unsigned value_; -}; - -} -} - -#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h deleted file mode 100644 index ea64dfc6e..000000000 --- a/include/triton/ir/module.h +++ /dev/null @@ -1,92 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_MODULE_H_ -#define _TRITON_IR_MODULE_H_ - -#include -#include -#include -#include -#include -#include "triton/ir/builder.h" -#include "triton/ir/metadata.h" -#include "triton/ir/context.h" - -namespace triton{ - -namespace lang{ - -class iteration_statement; -class compound_statement; - -} - -namespace ir{ - -class basic_block; -class phi_node; -class value; -class context; -class function; -class attribute; -class function_type; -class constant; -class global_value; -class alloc_const; - -/* Module */ - -class module { - typedef std::pair val_key_t; - friend class function; - typedef std::pair md_pair_t; - -public: - typedef std::map symbols_map_t; - typedef std::vector functions_list_t; - struct current_iteration_info_t{ - lang::iteration_statement *statement; - basic_block *block; - }; - -private: - phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); - value *try_remove_trivial_phis(ir::phi_node *&phi); - value *add_phi_operands(const std::string& name, phi_node *&phi); - value *get_value_recursive(const std::string& name, basic_block *block); - void push_function(function *fn) { functions_.push_back(fn); } - -public: - module(const std::string &name, builder &builder): name_(name), builder_(builder) {} - builder &get_builder() { return builder_; }; - const std::string& get_name() { return name_; }; - - // Functions - const functions_list_t &get_function_list() const { return functions_; } - functions_list_t &get_function_list() { return functions_; } - function *get_or_insert_function(const std::string &name, function_type *ty); - // Const allocation - void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } - const std::vector& allocs() { return allocs_; } - // Register global - void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } - const std::map& globals() const { return globals_; } - // Metadata - void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - const std::map &get_metadatas() const { return metadatas_; } - void print(std::ostream &os); - -private: - std::string name_; - builder &builder_; - functions_list_t functions_; - symbols_map_t symbols_; - std::vector allocs_; - std::map globals_; - std::map metadatas_; -}; - -} -} - -#endif diff --git a/include/triton/ir/print.h b/include/triton/ir/print.h deleted file mode 100644 index 6dbf2fe02..000000000 --- a/include/triton/ir/print.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _TRITON_IR_PRINT_H_ -#define _TRITON_IR_PRINT_H_ - -#include "builder.h" - -namespace triton{ -namespace ir{ - -class module; -class function; -class basic_block; -class instruction; - -void print(module &mod, std::ostream& os); -void print(function &func, std::ostream& os); -void print(basic_block &bb, std::ostream& os); -void print(instruction &instr, std::ostream& os); - -} -} - -#endif diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h deleted file mode 100644 index b1ef1ad22..000000000 --- a/include/triton/ir/type.h +++ /dev/null @@ -1,239 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_TYPE_H_ -#define _TRITON_IR_TYPE_H_ - -#include -#include -#include -#include - -namespace triton{ -namespace ir{ - -class context; -class value; -class integer_type; -class constant_int; - -/* Type */ -class type { -public: - typedef std::vector block_shapes_t; - -protected: - typedef std::vector contained_tys_vec_t; - typedef contained_tys_vec_t::iterator ty_iterator; - typedef contained_tys_vec_t::const_iterator const_ty_iterator; - -public: - enum id_t { - // primitive types - VoidTyID = 0, ///< type with no size - FP8TyID, ///< 8-bit floating point type (3 bits mantissa) - FP16TyID, ///< 16-bit floating point type (10 bits mantissa) - BF16TyID, ///< 16-bit floating point type (7 bits mantissa) - FP32TyID, ///< 32-bit floating point type - FP64TyID, ///< 64-bit floating point type - LabelTyID, ///< Labels - MetadataTyID, ///< Metadata - TokenTyID, ///< Token - // derived types - IntegerTyID, ///< Arbitrary bit width integers - FunctionTyID, ///< Functions - PointerTyID, ///< Pointers - StructTyID, ///< Struct - BlockTyID, ///< Block - }; - -public: - //constructors - type(context &ctx, id_t id) : ctx_(ctx), id_(id) { } - - //destructor - virtual ~type(){} - - // accessors - context &get_context() const { return ctx_; } - id_t get_type_id() const { return id_; } - // type attributes - unsigned get_fp_mantissa_width() const; - unsigned get_integer_bitwidth() const; - unsigned get_tile_bitwidth() const; - unsigned get_primitive_size_in_bits() const; - type *get_scalar_ty() const; - block_shapes_t get_block_shapes() const; - const size_t get_tile_rank() const; - const size_t get_tile_ranks1() const; - unsigned get_tile_num_elements() const; - type *get_tile_element_ty() const; - unsigned get_pointer_address_space() const; - type *get_pointer_element_ty() const; - - // primitive predicates - bool is_void_ty() const { return id_ == VoidTyID; } - bool is_fp8_ty() const { return id_ == FP8TyID; } - bool is_fp16_ty() const { return id_ == FP16TyID; } - bool is_bf16_ty() const { return id_ == BF16TyID; } - bool is_fp32_ty() const { return id_ == FP32TyID; } - bool is_fp64_ty() const { return id_ == FP64TyID; } - bool is_label_ty() const { return id_ == LabelTyID;} - bool is_metadata_ty() const { return id_ == MetadataTyID; } - bool is_token_ty() const { return id_ == TokenTyID; } - bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_bool_ty() const { return is_integer_ty(1); } - bool is_pointer_ty() const { return id_ == PointerTyID; } - bool is_block_ty() const { return id_ == BlockTyID; } - - // Composite predicates - bool is_int_or_tileint_ty(); - bool is_integer_ty(unsigned width) const; - bool is_floating_point_ty() const; - bool is_sized() const ; - - // Factory methods - // primitive types - static type *get_void_ty(context &ctx); - static type *get_label_ty(context &ctx); - // half - static type *get_fp8_ty(context &ctx); - static type *get_fp16_ty(context &ctx); - static type *get_bf16_ty(context &ctx); - static type *get_fp32_ty(context &ctx); - static type *get_fp64_ty(context &ctx); - // integer types - static integer_type *get_int1_ty(context &ctx); - static integer_type *get_int8_ty(context &ctx); - static integer_type *get_int16_ty(context &ctx); - static integer_type *get_int32_ty(context &ctx); - static integer_type *get_int64_ty(context &ctx); - static integer_type *get_int128_ty(context &ctx); - - // repr - std::string tile_repr() const { - std::string res = get_tile_element_ty()->repr(); - auto shapes = get_block_shapes(); - res += "<"; - for(size_t i = 0; i < shapes.size(); i++){ - if(i > 0) - res += ", "; - res += std::to_string(shapes[i]); - } - res+= ">"; - return res; - } - - std::string repr() const { - switch(id_) { - case VoidTyID: return "void"; - case FP8TyID: return "fp8"; - case FP16TyID: return "f16"; - case FP32TyID: return "f32"; - case FP64TyID: return "f64"; - case BF16TyID: return "bf16"; - case LabelTyID: return "label"; - case MetadataTyID: return "md"; - case TokenTyID: return "tok"; - case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); - case FunctionTyID: return "fn"; - case PointerTyID: return get_pointer_element_ty()->repr() + "*"; - case StructTyID: return "struct"; - case BlockTyID: return tile_repr(); - default: break; - } - throw std::logic_error("unknown type id '" + std::to_string(id_) + "'"); - }; - -private: - context &ctx_; - id_t id_; - -protected: - contained_tys_vec_t contained_tys_; -}; - -class integer_type: public type { - friend class context_impl; - -private: - // constructors - integer_type(context &ctx, unsigned bitwidth) - : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} - -public: - // accessors - unsigned get_bitwidth() const { return bitwidth_; } - - // factory methods - static integer_type* get(context &ctx, unsigned width); - -private: - unsigned bitwidth_; -}; - -class composite_type: public type{ -protected: - using type::type; - -public: - bool index_valid(value *idx) const; - type* get_type_at_index(value *idx) const; -}; - -class block_type: public composite_type { -private: - block_type(type *ty, const block_shapes_t &shapes); - static bool is_valid_elt_ty(type *ty); - -public: - // accessors - const block_shapes_t& get_shapes() const { return shapes_; } - unsigned get_num_elements() const; - unsigned get_bitwidth() const; - - // factory methods - static block_type* get(type *ty, const block_shapes_t &shapes); - static block_type* get_same_shapes(type *ty, type *ref); - -private: - block_shapes_t shapes_; -}; - -class pointer_type: public type { -private: - pointer_type(type *ty, unsigned address_space); - static bool is_valid_elt_ty(type *ty); - -public: - // accessors - unsigned get_address_space() const { return address_space_; } - type *get_element_ty() const { return contained_tys_[0]; } - // factory methods - static pointer_type* get(type *ty, unsigned address_space); - -private: - unsigned address_space_; -}; - -class function_type: public type { -private: - function_type(type *ret_ty, const std::vector ¶m_tys); - -public: - // accessors - unsigned get_num_params() const { return contained_tys_.size() - 1; } - const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; } - const_ty_iterator params_end() const { return contained_tys_.end(); } - ty_iterator params_begin() { return contained_tys_.begin() + 1; } - ty_iterator params_end() { return contained_tys_.end(); } - type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); } - type* get_return_ty() const { return contained_tys_.at(0); } - // factory methods - static function_type* get(type *ret_ty, const std::vector& param_tys); -}; - - -} -} - -#endif diff --git a/include/triton/ir/utils.h b/include/triton/ir/utils.h deleted file mode 100644 index 893edd122..000000000 --- a/include/triton/ir/utils.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_CFG_H_ -#define _TRITON_IR_CFG_H_ - -#include -#include - -namespace triton{ -namespace ir{ - -class module; -class function; -class basic_block; -class instruction; -class value; - -class cfg { -public: - static std::vector post_order(function* fn); - static std::vector reverse_post_order(function* fn); -}; - -void for_each_instruction(ir::module& mod, const std::function &fn); -void for_each_value(ir::module& mod, const std::function &fn); - -} -} - -#endif diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h deleted file mode 100644 index 7a132d5e2..000000000 --- a/include/triton/ir/value.h +++ /dev/null @@ -1,95 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_VALUE_H_ -#define _TRITON_IR_VALUE_H_ - -#include -#include -#include - -namespace triton{ -namespace ir{ - -class type; -class use; -class user; -class visitor; - -//===----------------------------------------------------------------------===// -// value class -//===----------------------------------------------------------------------===// - -class value { -public: - typedef std::set users_t; - -public: - // constructor - value(type *ty, const std::string &name = ""); - virtual ~value(){ } - // uses - void add_use(user* arg); - users_t::iterator erase_use(user* arg); - const std::set &get_users() { return users_; } - void replace_all_uses_with(value *target); - // name - void set_name(const std::string &name); - const std::string &get_name() const { return name_; } - bool has_name() const { return !name_.empty(); } - type* get_type() const { return ty_; } - // visitor - virtual void accept(visitor *v) = 0; - -private: - std::string name_; - -protected: - type *ty_; - users_t users_; -}; - -//===----------------------------------------------------------------------===// -// user class -//===----------------------------------------------------------------------===// - -class user: public value{ -public: - typedef std::vector ops_t; - typedef ops_t::iterator op_iterator; - typedef ops_t::const_iterator const_op_iterator; - -protected: - void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; } - void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; } - -public: - // Constructor - user(type *ty, unsigned num_ops, const std::string &name = "") - : value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){ - } - virtual ~user() { } - - // Operands - const ops_t& ops() { return ops_; } - const ops_t& ops() const { return ops_; } - op_iterator op_begin() { return ops_.begin(); } - op_iterator op_end() { return ops_.end(); } - void set_operand(unsigned i, value *x); - value *get_operand(unsigned i) const; - unsigned get_num_operands() const ; - unsigned get_num_hidden() const; - - // Utils - value::users_t::iterator replace_uses_of_with(value *before, value *after); - - -private: - ops_t ops_; - unsigned num_ops_; - unsigned num_hidden_; -}; - -} -} - -#endif diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h deleted file mode 100644 index 4979b0b52..000000000 --- a/include/triton/ir/visitor.h +++ /dev/null @@ -1,170 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_VISITOR_H_ -#define _TRITON_IR_VISITOR_H_ - - -namespace triton{ -namespace ir{ - -class value; - -class instruction; - -class phi_node; -class binary_operator; -class getelementptr_inst; - -class icmp_inst; -class fcmp_inst; -class cast_inst; -class trunc_inst; -class z_ext_inst; -class s_ext_inst; -class fp_trunc_inst; -class fp_ext_inst; -class ui_to_fp_inst; -class si_to_fp_inst; -class fp_to_ui_inst; -class fp_to_si_inst; -class ptr_to_int_inst; -class int_to_ptr_inst; -class bit_cast_inst; -class addr_space_cast_inst; - -class return_inst; -class cond_branch_inst; -class uncond_branch_inst; - - -class unmasked_load_inst; -class masked_load_inst; -class unmasked_store_inst; -class masked_store_inst; - -class retile_inst; -class reshape_inst; -class splat_inst; -class cat_inst; -class broadcast_inst; -class downcast_inst; - -class umulhi_inst; -class exp_inst; -class cos_inst; -class sin_inst; -class log_inst; - -class get_program_id_inst; -class get_num_programs_inst; -class atomic_inst; -class atomic_cas_inst; -class atomic_rmw_inst; -class dot_inst; -class trans_inst; -class sqrt_inst; -class reduce_inst; -class select_inst; - -class cvt_layout_inst; -class copy_to_shared_inst; -class copy_from_shared_inst; -class masked_load_async_inst; -class barrier_inst; -class async_wait_inst; -class make_range_dyn; -class make_range; -class prefetch_s_inst; - -class make_range_sta; -class undef_value; -class constant_int; -class constant_fp; -class global_value; -class global_object; -class alloc_const; - -class constant_fp; -class undef_value; -class constant_int; -class constant_fp; -class global_value; -class global_object; -class alloc_const; - -class function; - -class basic_block; - -class argument; - -class visitor { -public: - virtual ~visitor() {} - - virtual void visit_value(ir::value*); - - virtual void visit_basic_block(basic_block*) = 0; - virtual void visit_argument(argument*) = 0; - virtual void visit_phi_node(phi_node*) = 0; - virtual void visit_binary_operator(binary_operator*) = 0; - virtual void visit_getelementptr_inst(getelementptr_inst*) = 0; - - virtual void visit_icmp_inst(icmp_inst*) = 0; - virtual void visit_fcmp_inst(fcmp_inst*) = 0; - virtual void visit_cast_inst(cast_inst*) = 0; - - virtual void visit_return_inst(return_inst*) = 0; - virtual void visit_cond_branch_inst(cond_branch_inst*) = 0; - virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0; - - - virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0; - virtual void visit_masked_load_inst(masked_load_inst*) = 0; - virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0; - virtual void visit_masked_store_inst(masked_store_inst*) = 0; - - virtual void visit_umulhi_inst(umulhi_inst*) = 0; - virtual void visit_exp_inst(exp_inst*) = 0; - virtual void visit_cos_inst(cos_inst*) = 0; - virtual void visit_sin_inst(sin_inst*) = 0; - virtual void visit_log_inst(log_inst*) = 0; - - virtual void visit_reshape_inst(reshape_inst*) = 0; - virtual void visit_splat_inst(splat_inst*) = 0; - virtual void visit_cat_inst(cat_inst*) = 0; - virtual void visit_broadcast_inst(broadcast_inst*) = 0; - virtual void visit_downcast_inst(downcast_inst*) = 0; - - virtual void visit_get_program_id_inst(get_program_id_inst*) = 0; - virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0; - virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0; - virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0; - virtual void visit_dot_inst(dot_inst*) = 0; - virtual void visit_trans_inst(trans_inst*) = 0; - virtual void visit_sqrt_inst(sqrt_inst*) = 0; - virtual void visit_reduce_inst(reduce_inst*) = 0; - virtual void visit_select_inst(select_inst*) = 0; - - virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0; - virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; - virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; - - - virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; - virtual void visit_barrier_inst(barrier_inst*) = 0; - virtual void visit_async_wait_inst(async_wait_inst*) = 0; - virtual void visit_make_range(make_range*) = 0; - virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; - virtual void visit_function(function*) = 0; - - virtual void visit_undef_value(undef_value*) = 0; - virtual void visit_constant_int(constant_int*) = 0; - virtual void visit_constant_fp(constant_fp*) = 0; - virtual void visit_alloc_const(alloc_const*) = 0; -}; - -} -} - -#endif diff --git a/lib/ir/CMakeLists.txt b/lib/ir/CMakeLists.txt new file mode 100644 index 000000000..71d7d6df1 --- /dev/null +++ b/lib/ir/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(TRITONIR + Dialect.cpp + Ops.cpp + Types.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithmetic + MLIRControlFlow + MLIRFunc + MLIRTensor +) diff --git a/lib/ir/Dialect.cpp b/lib/ir/Dialect.cpp new file mode 100644 index 000000000..7df9494cb --- /dev/null +++ b/lib/ir/Dialect.cpp @@ -0,0 +1,71 @@ +#include "triton/Dialect.h" +#include "triton/Types.h" + +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/IR/DialectImplementation.h" + + +#include "triton/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Ops.cpp.inc" + >(); + + // We can also add interface here. +} + +//===----------------------------------------------------------------------===// +// Type Parsing +//===----------------------------------------------------------------------===// +// pointer-type ::= `!triton.ptr<` element-type ` >` +static Type parsePointerType(TritonDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType); +} + +// trtion-type ::= pointer-type +Type TritonDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + if (keyword == "ptr") + return parsePointerType(*this, parser); + + parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword; + return Type(); +} + +//===----------------------------------------------------------------------===// +// Type Printing +//===----------------------------------------------------------------------===// +static void print(PointerType type, DialectAsmPrinter &os) { + os << "ptr<" << type.getPointeeType() << ">"; +} + +void TritonDialect::printType(Type type, DialectAsmPrinter &os) const { + TypeSwitch(type) + .Case( [&](auto type) { print(type, os); }) + .Default([](Type) { llvm_unreachable("unhandled Triton type"); }); +} diff --git a/lib/ir/Ops.cpp b/lib/ir/Ops.cpp new file mode 100644 index 000000000..de5c7b3eb --- /dev/null +++ b/lib/ir/Ops.cpp @@ -0,0 +1,63 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "triton/Dialect.h" +#include "triton/Types.h" + +#define GET_OP_CLASSES +#include "triton/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- StoreOp -- +// Default mask +void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) { + TensorType ptrType = ptr.getType().dyn_cast(); + auto shape = ptrType.getShape(); + ::mlir::Value mask = builder.create( + ptr.getLoc(), + RankedTensorType::get(shape, builder.getI1Type()), + DenseIntElementsAttr::get( + RankedTensorType::get(shape, builder.getI1Type()), true + ) + ); + state.addOperands(ptr); + state.addOperands(value); + state.addOperands(mask); +} + +//-- LoadOp -- +void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) { + TensorType ptrType = ptr.getType().dyn_cast(); + Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); + auto shape = ptrType.getShape(); + // mask + ::mlir::Value mask = builder.create( + ptr.getLoc(), + RankedTensorType::get(shape, builder.getI1Type()), + DenseIntElementsAttr::get( + RankedTensorType::get(shape, builder.getI1Type()), true + ) + ); + // other + Type resultType = RankedTensorType::get(shape, elementType); + ::mlir::Value other = builder.create( + ptr.getLoc(), + resultType, + DenseElementsAttr::get( + resultType, builder.getZeroAttr(elementType) + ) + ); + state.addOperands(ptr); + state.addOperands(mask); + state.addOperands(other); + state.addTypes({resultType}); +} + +} // namespace triton +} // namespace mlir diff --git a/lib/ir/Types.cpp b/lib/ir/Types.cpp new file mode 100644 index 000000000..89c0e6af8 --- /dev/null +++ b/lib/ir/Types.cpp @@ -0,0 +1,55 @@ +#include "triton/Dialect.h" +#include "triton/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +// F8 & BF8 +Float8Type Float8Type::get(MLIRContext *context) { + return Base::get(context); +} + +BFloat8Type BFloat8Type::get(MLIRContext *context) { + return Base::get(context); +} + +//===----------------------------------------------------------------------===// +// PointerType +//===----------------------------------------------------------------------===// +struct triton::detail::PointerTypeStorage : public TypeStorage { + using KeyTy = std::pair; + + static PointerTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) PointerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(pointeeType, addressSpace); + } + + PointerTypeStorage(const KeyTy &key) + : pointeeType(key.first), addressSpace(key.second) {} + + Type pointeeType; + unsigned addressSpace; +}; + +PointerType PointerType::get(Type pointeeType) { + return Base::get(pointeeType.getContext(), pointeeType, 0); +} + +PointerType PointerType::get(Type pointeeType, unsigned addressSpace) { + return Base::get(pointeeType.getContext(), pointeeType, addressSpace); +} + +Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } + +unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; } + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes(); +} diff --git a/lib/ir/basic_block.cc b/lib/ir/basic_block.cc deleted file mode 100644 index 0654156a3..000000000 --- a/lib/ir/basic_block.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include "triton/ir/function.h" - -namespace triton { -namespace ir { - -class phi_node; - - -basic_block::basic_block(context &ctx, const std::string &name, function *parent): - value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) { - if(parent_) - parent_->insert_block(this); -} - -basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){ - return new basic_block(ctx, name, parent); -} - -void basic_block::add_predecessor(basic_block *pred) { - preds_.push_back(pred); - if(pred) - pred->succs_.push_back(this); -} - - - -basic_block::iterator basic_block::get_first_non_phi(){ - auto it = begin(); - for(; it != end(); it++) - if(!dynamic_cast(*it)){ - return it; - } - return it; -} - -} - -} diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc deleted file mode 100644 index 9b8a2a45e..000000000 --- a/lib/ir/builder.cc +++ /dev/null @@ -1,434 +0,0 @@ -#include -#include -#include -#include "triton/ir/basic_block.h" -#include "triton/ir/builder.h" -#include "triton/ir/constant.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" - -namespace triton{ -namespace ir{ - -builder::builder(context &ctx): - ctx_(ctx), block_(nullptr) {} - -//===----------------------------------------------------------------------===// -// utilities -//===----------------------------------------------------------------------===// -void builder::set_insert_point(basic_block::iterator it){ - block_ = (*it)->get_parent(); - insert_point_ = it; -} - -void builder::set_insert_point(instruction* i){ - block_ = i->get_parent(); - auto it = std::find(block_->begin(), block_->end(), i); - set_insert_point(it); -} - - -void builder::set_insert_point_after(instruction* i){ - block_ = i->get_parent(); - auto it = std::find(block_->begin(), block_->end(), i); - set_insert_point(++it); -} - - -void builder::set_insert_point(basic_block *block){ - block_ = block; - insert_point_ = block->end(); -} - - -//===----------------------------------------------------------------------===// -// convenience functions -//===----------------------------------------------------------------------===// - -value *builder::get_int1(bool val) -{ return constant_int::get(type::get_int1_ty(ctx_), val); } - -value *builder::get_int32(uint32_t val) -{ return constant_int::get(type::get_int32_ty(ctx_), val);} - -value *builder::get_int64(uint64_t val) -{ return constant_int::get(type::get_int64_ty(ctx_), val);} - -value *builder::get_float16(float val) -{ return constant_fp::get(type::get_fp16_ty(ctx_), val); } - -value *builder::get_float32(float val) -{ return constant_fp::get(type::get_fp32_ty(ctx_), val); } - -value *builder::get_range(int32_t _lo, int32_t _hi) { - constant_int* lo = static_cast(get_int32(_lo)); - constant_int* hi = static_cast(get_int32(_hi)); - return insert(make_range::create(lo, hi)); -} - -type *builder::get_void_ty() -{ return type::get_void_ty(ctx_); } - -type *builder::get_int1_ty() -{ return type::get_int1_ty(ctx_); } - -type *builder::get_int8_ty() -{ return type::get_int8_ty(ctx_); } - -type *builder::get_int16_ty() -{ return type::get_int16_ty(ctx_); } - -type *builder::get_int32_ty() -{ return type::get_int32_ty(ctx_); } - -type *builder::get_int64_ty() -{ return type::get_int64_ty(ctx_); } - -type *builder::get_fp8_ty() -{ return type::get_fp8_ty(ctx_); } - -type *builder::get_half_ty() -{ return type::get_fp16_ty(ctx_); } - -type *builder::get_bf16_ty() -{ return type::get_bf16_ty(ctx_); } - -type *builder::get_float_ty() -{ return type::get_fp32_ty(ctx_); } - -type *builder::get_double_ty() -{ return type::get_fp64_ty(ctx_); } - - -//===----------------------------------------------------------------------===// -// terminator instructions -//===----------------------------------------------------------------------===// - -value* builder::create_br(basic_block *dest){ - dest->add_predecessor(block_); - return insert(branch_inst::create(dest)); -} - -value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){ - if_dest->add_predecessor(block_); - else_dest->add_predecessor(block_); - return insert(branch_inst::create(cond, if_dest, else_dest)); -} - -value *builder::create_ret_void() { - return insert(return_inst::create(ctx_)); -} - -//===----------------------------------------------------------------------===// -// cast instructions -//===----------------------------------------------------------------------===// -#define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\ - return create_cast(OPCODE, src, dst_ty);\ - } - -DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) -DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) -DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) -DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) -DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) -DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI) -DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI) -DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt) -DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc) - -value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){ - return insert(cast_inst::create(op, v, dst_ty)); -} - -value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){ - return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed)); -} - -//===----------------------------------------------------------------------===// -// phi instructions -//===----------------------------------------------------------------------===// - -phi_node* builder::create_phi(type *ty, unsigned num_reserved){ - return insert(phi_node::create(ty, num_reserved)); -} - -//===----------------------------------------------------------------------===// -// binary float instructions -//===----------------------------------------------------------------------===// - -#define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ - return insert(binary_operator::create(OPCODE, lhs, rhs));\ - } - -// Binary -DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul) -DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv) -DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem) -DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd) -DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub) - - -//===----------------------------------------------------------------------===// -// binary int instructions -//===----------------------------------------------------------------------===// - - -value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs, - value *rhs, - bool has_nuw, bool has_nsw) { - binary_operator* result = insert(binary_operator::create(op, lhs, rhs)); - if (has_nuw) result->set_has_no_unsigned_wrap(); - if (has_nsw) result->set_has_no_signed_wrap(); - return result; -} - -#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\ - value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\ - return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\ - }\ - -#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ - return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\ - } - - - -// Binary -DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul) -DEFINE_NOWRAP_BINARY(add, binary_op_t::Add) -DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub) -DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl) -DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr) -DEFINE_NOWRAP_BINARY(lshr, binary_op_t::LShr) -DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv) -DEFINE_BINARY_INT(udiv, binary_op_t::UDiv) -DEFINE_BINARY_INT(srem, binary_op_t::SRem) -DEFINE_BINARY_INT(urem, binary_op_t::URem) -DEFINE_BINARY_INT(and, binary_op_t::And) -DEFINE_BINARY_INT(or, binary_op_t::Or) -DEFINE_BINARY_INT(xor, binary_op_t::Xor) - - -//===----------------------------------------------------------------------===// -// getelementptr instructions -//===----------------------------------------------------------------------===// - -value* builder::create_gep(value *ptr, const std::vector& idx_list){ - return insert(getelementptr_inst::create(ptr, idx_list)); -} - -//===----------------------------------------------------------------------===// -// icmp instructions -//===----------------------------------------------------------------------===// - -value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){ - return insert(icmp_inst::create(pred, lhs, rhs)); -} - -#define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\ - value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\ - return create_icmp(OPCODE, lhs, rhs);\ - } - -// Signed -DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE) -DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT) -DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE) -DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT) -// Unsigned -DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE) -DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT) -DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE) -DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT) -// General -DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ) -DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE) - - -//===----------------------------------------------------------------------===// -// fcmp instructions -//===----------------------------------------------------------------------===// - -value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){ - return insert(fcmp_inst::create(pred, lhs, rhs)); -} - -#define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\ - value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\ - return create_fcmp(OPCODE, lhs, rhs);\ - } - -// Ordered -DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE) -DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT) -DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE) -DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT) -DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ) -DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) - -DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE) -DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT) -DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE) -DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT) -DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ) -DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) - - -//===----------------------------------------------------------------------===// -// load/store instructions -//===----------------------------------------------------------------------===// - -value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ - return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); -} - -value *builder::create_store(value *ptr, value *val){ - return insert(unmasked_store_inst::create(ptr, val)); -} - -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ - return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); -} - -value *builder::create_masked_store(value *ptr, value *val, value *mask){ - return insert(masked_store_inst::create(ptr, val, mask)); -} - -//===----------------------------------------------------------------------===// -// block instructions -//===----------------------------------------------------------------------===// - -value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) { - return insert(reshape_inst::create(arg, shapes)); -} - -value *builder::create_cat(value *lhs, value *rhs) { - return insert(cat_inst::create(lhs, rhs)); -} - -value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) { - return insert(splat_inst::create(arg, shapes)); -} - -value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) { - return insert(broadcast_inst::create(arg, shapes)); -} - -value *builder::create_downcast(value *arg) { - return insert(downcast_inst::create(arg)); -} - -// - -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} - -#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ - return create_atomic_rmw(OPCODE, ptr, val, mask);\ - } - -DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) -DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) -DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) -DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) -DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) -DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) -DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) -DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) -DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) -DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) - -//===----------------------------------------------------------------------===// -// built-in instructions -//===----------------------------------------------------------------------===// - -value *builder::create_get_program_id(unsigned axis) { - return insert(get_program_id_inst::create(ctx_, axis)); -} - -value *builder::create_get_num_programs(unsigned axis) { - return insert(get_num_programs_inst::create(ctx_, axis)); -} - -value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ - return insert(atomic_cas_inst::create(ptr, cmp, val)); -} - - -value *builder::create_exp(value *arg){ - return insert(exp_inst::create(arg)); -} - -value *builder::create_cos(value *arg){ - return insert(cos_inst::create(arg)); -} - -value *builder::create_sin(value *arg){ - return insert(sin_inst::create(arg)); -} - -value *builder::create_log(value *arg){ - return insert(log_inst::create(arg)); -} - -value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) { - return insert(dot_inst::create_nn(A, B, C, allow_tf32)); -} - -value *builder::create_trans(value *A, const std::vector& perm) { - return insert(trans_inst::create(A, perm)); -} - -value *builder::create_sqrt(value *A) { - return insert(sqrt_inst::create(A)); -} - -value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) { - return insert(reduce_inst::create(A, op, axis)); -} - -value *builder::create_select(value *pred, value *if_value, value *else_value){ - return insert(select_inst::create(pred, if_value, else_value)); -} - -//===----------------------------------------------------------------------===// -// intrinsic instructions -//===----------------------------------------------------------------------===// - -value *builder::create_umulhi(value *lhs, value *rhs) { - return insert(umulhi_inst::create(lhs, rhs)); -} - -value *builder::create_copy_to_shared(value *arg) { - return insert(copy_to_shared_inst::create(arg)); -} - - -value *builder::create_copy_from_shared(value *arg) { - return insert(copy_from_shared_inst::create(arg)); -} - -value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) { - return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction)); -} - -value *builder::create_barrier(const std::string &name) { - return insert(barrier_inst::create(ctx_)); -} - -value *builder::create_async_wait(int N) { - return insert(async_wait_inst::create(ctx_, N)); -} - -value *builder::create_prefetch_s(value *arg, int inc) { - return insert(prefetch_s_inst::create(ctx_, arg, inc)); -} - - -} -} diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc deleted file mode 100644 index ab1f6f497..000000000 --- a/lib/ir/constant.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include -#include -#include "triton/ir/constant.h" -#include "triton/ir/type.h" -#include "triton/ir/context.h" -#include "triton/ir/context_impl.h" - -namespace triton{ -namespace ir{ - - -// constant - -constant *constant::get_null_value(type *ty) { - context &ctx = ty->get_context(); - switch (ty->get_scalar_ty()->get_type_id()) { - case type::IntegerTyID: - return constant_int::get(ty, 0); - case type::FP16TyID: - return constant_fp::get(type::get_fp16_ty(ctx), 0); - case type::FP32TyID: - return constant_fp::get(type::get_fp32_ty(ctx), 0); - case type::FP64TyID: - return constant_fp::get(type::get_fp64_ty(ctx), 0); - default: - throw std::runtime_error("Cannot create a null constant of that type!"); - } -} - -// FIXME - -constant *constant::get_all_ones_value(type *ty) { - if(ty->is_integer_ty()) - return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF); - if(ty->is_floating_point_ty()) - return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF); - throw std::runtime_error("Cannot create all ones value for that type!"); -} - -// constant_int -// FIXME use something like APInt - -constant_int::constant_int(type *ty, uint64_t value) - : constant(ty, 0), value_(value){ } - -constant_int *constant_int::get(type *ty, uint64_t value) { - if (!ty->is_integer_ty()) - throw std::runtime_error("Cannot create constant_int with non integer ty"); - context_impl *impl = ty->get_context().p_impl.get(); - std::unique_ptr &cst = impl->int_constants_[std::make_pair(ty, value)]; - if(!cst) - cst.reset(new constant_int(ty, value)); - return cst.get(); -} - - -// constant_fp -// FIXME use something like APFloat - -constant_fp::constant_fp(type *ty, double value) - : constant(ty, 0), value_(value){ } - -constant *constant_fp::get_negative_zero(type *ty){ - double neg_zero = 0; - return get(ty, neg_zero); -} - -constant *constant_fp::get_zero_value_for_negation(type *ty) { - if(ty->get_scalar_ty()->is_floating_point_ty()) - return constant_fp::get(ty, 0); - return constant::get_null_value(ty); -} - -constant *constant_fp::get(type *ty, double v){ - context_impl *impl = ty->get_context().p_impl.get(); - std::unique_ptr &result = impl->fp_constants_[std::make_pair(ty, v)]; - if(!result) - result.reset(new constant_fp(ty, v)); - return result.get(); -} - - -// undef value -undef_value::undef_value(type *ty) - : constant(ty, 0) { } - -undef_value *undef_value::get(type *ty) { - context_impl *impl = ty->get_context().p_impl.get(); - std::unique_ptr &result = impl->uv_constants_[ty]; - if(!result) - result.reset(new undef_value(ty)); - return result.get(); -} - -/* global value */ -global_value::global_value(type *ty, unsigned num_ops, - linkage_types_t linkage, - const std::string &name, unsigned addr_space) - : constant(pointer_type::get(ty, addr_space), num_ops, name), - linkage_(linkage) { } - - -/* global object */ -global_object::global_object(type *ty, unsigned num_ops, - linkage_types_t linkage, - const std::string &name, unsigned addr_space) - : global_value(ty, num_ops, linkage, name, addr_space) { } - - -/* alloc const */ -alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name) - : global_object(ty, 1, global_value::external, name, 4) { - set_operand(0, size); -} - - -} -} diff --git a/lib/ir/context.cc b/lib/ir/context.cc deleted file mode 100644 index 0fc65ddc2..000000000 --- a/lib/ir/context.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "triton/ir/context_impl.h" -#include "triton/ir/context.h" -#include "triton/ir/type.h" - -namespace triton{ -namespace ir{ - -//===----------------------------------------------------------------------===// -// context implementation -//===----------------------------------------------------------------------===// - -context_impl::context_impl(context &ctx) - : void_ty(ctx, type::VoidTyID), - label_ty(ctx, type::LabelTyID), - // floating point - fp8_ty(ctx, type::FP8TyID), - fp16_ty(ctx, type::FP16TyID), - bf16_ty(ctx, type::BF16TyID), - fp32_ty(ctx, type::FP32TyID), - fp64_ty(ctx, type::FP64TyID), - // integers - int1_ty(ctx, 1), - int8_ty(ctx, 8), - int16_ty(ctx, 16), - int32_ty(ctx, 32), - int64_ty(ctx, 64), - int128_ty(ctx, 128) {} - -//===----------------------------------------------------------------------===// -// context -//===----------------------------------------------------------------------===// - -context::context(): - p_impl(std::make_shared(*this)) { - -} - - -} -} diff --git a/lib/ir/function.cc b/lib/ir/function.cc deleted file mode 100644 index 84d52df72..000000000 --- a/lib/ir/function.cc +++ /dev/null @@ -1,66 +0,0 @@ -#include -#include "triton/ir/function.h" -#include "triton/ir/type.h" -#include "triton/ir/module.h" - -namespace triton{ -namespace ir{ - - -/* Argument */ - -argument::argument(type *ty, const std::string &name, function *parent, unsigned arg_no) - : value(ty, name), parent_(parent), arg_no_(arg_no) { } - -argument *argument::create(type *ty, const std::string &name, - function *parent, unsigned arg_no) { - return new argument(ty, name, parent, arg_no); -} - -function* argument::get_parent() const { - return parent_; -} - -unsigned argument::get_arg_no() const { - return arg_no_; -} - -void argument::accept(visitor *v) { - v->visit_argument(this); -} - - -/* function */ -function::function(function_type *ty, linkage_types_t linkage, - const std::string &name, module *parent) - : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) { - unsigned num_params = fn_ty_->get_num_params(); - // skip if no parameter - if(num_params == 0) - return; - // create arguments - args_.resize(num_params); - for(unsigned i = 0; i < num_params; i++){ - type *param_ty = fn_ty_->get_param_ty(i); - args_[i] = argument::create(param_ty, "", this, i); - } - if(parent) - parent->push_function(this); -} - -/* basic block */ -void function::insert_block(basic_block *block, basic_block *next) { - auto it = std::find(blocks_.begin(), blocks_.end(), next); - blocks_.insert(it, block); -} - - -function *function::create(function_type *ty, linkage_types_t linkage, - const std::string &name, module *mod) { - return new function(ty, linkage, name, mod); -} - - -} -} - diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc deleted file mode 100644 index 39bd945bc..000000000 --- a/lib/ir/instructions.cc +++ /dev/null @@ -1,928 +0,0 @@ -#include -#include -#include "triton/ir/context.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/constant.h" -#include "triton/ir/type.h" - -namespace triton{ -namespace ir{ - -//===----------------------------------------------------------------------===// -// instruction classes -//===----------------------------------------------------------------------===// - -instruction::instruction(type *ty, value_id_t ity, unsigned num_ops, - const std::string &name, instruction *next) - : user(ty, num_ops, name), id_(ity) { - if(next){ - basic_block *block = next->get_parent(); - assert(block && "Next instruction is not in a basic block!"); - auto it = std::find(block->begin(), block->end(), next); - block->get_inst_list().insert(it, next); - } -} - -void instruction::erase_from_parent() { - parent_->erase(this); - for(ir::value* op: ops()) - op->erase_use(this); -} - -bool instruction::has_tile_result_or_op() { - bool result = get_type()->is_block_ty(); - for(unsigned i = 0; i < get_num_operands(); i++) - result |= get_operand(i)->get_type()->is_block_ty(); - return result; -} - -//===----------------------------------------------------------------------===// -// phi_node classes -//===----------------------------------------------------------------------===// - -phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next) - : instruction(ty, INST_PHI, 0, name, next) { - blocks_.reserve(num_reserved); -} - -value* phi_node::get_value_for_block(basic_block * block) { - auto it = std::find(blocks_.begin(), blocks_.end(), block); - size_t n = std::distance(blocks_.begin(), it); - return get_incoming_value(n); -} - -// Set incoming value -void phi_node::set_incoming_value(unsigned i, value *v){ - assert(v && "PHI node got a null value!"); - assert(get_type() == v->get_type() && - "All operands to PHI node must be the same type as the PHI node!"); - set_operand(i, v); -} - -// Set incoming block -void phi_node::set_incoming_block(unsigned i, basic_block *block){ - assert(block && "PHI node got a null basic block!"); - blocks_[i] = block; -} - -// Add incoming -void phi_node::add_incoming(value *v, basic_block *block){ - resize_ops(get_num_operands() + 1); - blocks_.resize(get_num_operands() + 1); - set_incoming_value(get_num_operands() - 1, v); - set_incoming_block(get_num_operands() - 1, block); -} - -// Factory methods -phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){ - return new phi_node(ty, num_reserved, name, next); -} - - -//===----------------------------------------------------------------------===// -// binary_operator classes -//===----------------------------------------------------------------------===// - -std::string binary_operator::repr_impl() const { - switch(op_) { - case Add : return "add"; - case FAdd : return "fadd"; - case Sub : return "sub"; - case FSub : return "fsub"; - case Mul : return "mul"; - case FMul : return "fmul"; - case UDiv : return "udiv"; - case SDiv : return "sdiv"; - case FDiv : return "fdiv"; - case URem : return "urem"; - case SRem : return "srem"; - case FRem : return "frem"; - case Shl : return "shl"; - case LShr : return "lshr"; - case AShr : return "ashr"; - case And : return "and"; - case Or : return "or"; - case Xor : return "xor"; - default: throw std::runtime_error("unknown binary operator"); - } -} - -bool binary_operator::is_int_div() const { - return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv; -} - -bool binary_operator::is_int_rem() const { - return op_ == binary_op_t::URem || op_ == binary_op_t::SRem; -} - -bool binary_operator::is_shl() const { - return op_ == binary_op_t::Shl; -} - -bool binary_operator::is_shr() const { - return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr; -} - -bool binary_operator::is_int_mult() const { - return op_ == binary_op_t::Mul; -} - -bool binary_operator::is_int_add_sub() const { - return op_ == binary_op_t::Add || op_ == binary_op_t::Sub; -} - - -binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) - : instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){ - set_operand(0, lhs); - set_operand(1, rhs); -} - -binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){ - assert(lhs->get_type() == rhs->get_type() && - "Cannot create binary operator with two operands of differing type!"); - return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next); -} - -//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){ -// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty()); -// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()); -// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next); -//} - -//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){ -// assert(arg->get_type()->get_scalar_ty()->is_integer_ty()); -// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty()); -// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next); -//} - -//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){ -// assert(arg->get_type()->is_integer_ty()); -// constant *mask = constant::get_all_ones_value(arg->get_type()); -// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next); -//} - -//===----------------------------------------------------------------------===// -// cmp_inst classes -//===----------------------------------------------------------------------===// - - - -// cmp_inst -std::string cmp_inst::repr_impl() const { - switch (pred_) { - case FCMP_FALSE : return "false"; - case FCMP_OEQ : return "fcmp_oeq"; - case FCMP_OGT : return "fcmp_ogt"; - case FCMP_OGE : return "fcmp_oge"; - case FCMP_OLT : return "fcmp_olt"; - case FCMP_OLE : return "fcmp_ole"; - case FCMP_ONE : return "fcmp_one"; - case FCMP_ORD : return "fcmp_ord"; - case FCMP_UNO : return "fcmp_uno"; - case FCMP_UEQ : return "fcmp_ueq"; - case FCMP_UGT : return "fcmp_ugt"; - case FCMP_UGE : return "fcmp_uge"; - case FCMP_ULT : return "fcmp_ult"; - case FCMP_ULE : return "fcmp_ule"; - case FCMP_UNE : return "fcmp_une"; - case FCMP_TRUE : return "true"; - case ICMP_EQ : return "icmp_eq"; - case ICMP_NE : return "icmp_ne"; - case ICMP_UGT : return "icmp_ugt"; - case ICMP_UGE : return "icmp_uge"; - case ICMP_ULT : return "icmp_ult"; - case ICMP_ULE : return "icmp_ule"; - case ICMP_SGT : return "icmp_sgt"; - case ICMP_SGE : return "icmp_sge"; - case ICMP_SLT : return "icmp_slt"; - case ICMP_SLE : return "icmp_sle"; - default: throw std::runtime_error("unreachable"); - } -} - -cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next) - : instruction(ty, id, 2, name, next), pred_(pred) { - set_operand(0, lhs); - set_operand(1, rhs); -} - -type* cmp_inst::make_cmp_result_type(type *ty){ - type* int1_ty = type::get_int1_ty(ty->get_context()); - if (block_type* tile_ty = dynamic_cast(ty)) - return block_type::get_same_shapes(int1_ty, tile_ty); - return int1_ty; -} - - -bool cmp_inst::is_fp_predicate(cmp_pred_t pred) { - return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE; -} - -bool cmp_inst::is_int_predicate(cmp_pred_t pred) { - return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE; -} - - -// icmp_inst -icmp_inst::icmp_inst(type *ty, cmp_pred_t pred, - value *lhs, value *rhs, const std::string &name, instruction *next) - : cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ } - -icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ - assert(is_int_predicate(pred)); - assert(lhs->get_type() == rhs->get_type()); - type *res_ty = make_cmp_result_type(lhs->get_type()); - return new icmp_inst(res_ty, pred, lhs, rhs, name, next); -} - -// fcmp_inst -fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred, - value *lhs, value *rhs, const std::string &name, instruction *next) - : cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ } - -fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ - assert(is_fp_predicate(pred)); - type *res_ty = make_cmp_result_type(lhs->get_type()); - return new fcmp_inst(res_ty, pred, lhs, rhs, name, next); -} - -//===----------------------------------------------------------------------===// -// unary_inst classes -//===----------------------------------------------------------------------===// - -unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next) - : instruction(ty, id, 1, name, next) { - set_operand(0, v); -} - -//===----------------------------------------------------------------------===// -// cast_inst classes -//===----------------------------------------------------------------------===// - -std::string cast_inst::repr_impl() const { - switch (op_){ - case cast_op_t::Trunc: return "trunc"; - case cast_op_t::ZExt: return "zext"; - case cast_op_t::SExt: return "sext"; - case cast_op_t::FPTrunc: return "fp_trunc"; - case cast_op_t::FPExt: return "fp_ext"; - case cast_op_t::UIToFP: return "ui_to_fp"; - case cast_op_t::SIToFP: return "si_to_fp"; - case cast_op_t::FPToUI: return "fp_to_ui"; - case cast_op_t::FPToSI: return "fp_to_si"; - case cast_op_t::PtrToInt: return "ptr_to_int"; - case cast_op_t::IntToPtr: return "int_to_ptr"; - case cast_op_t::BitCast: return "bitcast"; - case cast_op_t::AddrSpaceCast: return "addr_space_cast"; - default: throw std::runtime_error("unreachable"); - } -} -// TODO -bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) { - assert(arg->get_type()->is_block_ty() == ty->is_block_ty()); - return true; -} - -cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){ - assert(is_valid(op, arg, ty) && "Invalid cast!"); - // Construct and return the appropriate CastInst subclass - switch (op) { - case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next); - case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next); - case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next); - case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next); - case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next); - case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next); - case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next); - case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next); - case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next); - case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next); - case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next); - case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next); - case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next); - default: throw std::runtime_error("unreachable"); - } -} - -cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){ - type *arg_ty = arg->get_type(); - assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!"); - unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); - unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); - cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); - return create(op, arg, ty, name, next); -} - -//===----------------------------------------------------------------------===// -// terminator_inst classes -//===----------------------------------------------------------------------===// - - -// return_inst -return_inst::return_inst(context &ctx, value *ret_val, instruction *next) - : terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ - if(ret_val) - set_operand(0, ret_val); -} - -return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){ - return new return_inst(ctx, ret_val, next); -} - - -// branch_inst -branch_inst* branch_inst::create(basic_block *dst, instruction *next) { - assert(dst && "Branch destination may not be null!"); - return new uncond_branch_inst(dst, next); -} - -branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) { - assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); - return new cond_branch_inst(if_dst, else_dst, cond, next); -} - -// uncond_branch_inst -uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next) - : branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){ - set_operand(0, dst); -} - -// cond_branch_inst -cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next) - : branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){ - assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); - set_operand(0, if_dst); - set_operand(1, else_dst); - set_operand(2, cond); -} - - -//===----------------------------------------------------------------------===// -// getelementptr_inst classes -//===----------------------------------------------------------------------===// - -getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector &idx, const std::string &name, instruction *next) - : instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next), - source_elt_ty(pointee_ty), - res_elt_ty(get_indexed_type(pointee_ty, idx)){ - // sanity check - type *expected_ty = get_type()->get_scalar_ty(); - expected_ty = ((pointer_type*)expected_ty)->get_element_ty(); - assert(res_elt_ty == expected_ty); - // set operands - set_operand(0, ptr); - for(size_t i = 0; i < idx.size(); i++) - set_operand(1 + i, idx[i]); -} - -type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector &idx_list) { - // result pointer type - type *ty = x->get_type(); - unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space(); - type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space); - // Tile GEP - if(ty->is_block_ty()) - return block_type::get_same_shapes(ptr_ty, ty); - for(value *idx : idx_list) - if (idx->get_type()->is_block_ty()) - return block_type::get_same_shapes(ptr_ty, ty); - // Scalar GEP - return ptr_ty; -} - -type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector &idx_list) { - if(idx_list.empty()) - return ty; - if(!ty->is_sized()) - return nullptr; - unsigned cur_idx = 1; - for(; cur_idx != idx_list.size(); cur_idx++){ - composite_type *cty = dynamic_cast(ty); - if(!cty || cty->is_pointer_ty()) - break; - value *idx = idx_list[cur_idx]; - if(!cty->index_valid(idx)) - break; - ty = cty->get_type_at_index(idx); - } - return (cur_idx == idx_list.size())? ty : nullptr; -} - -type *getelementptr_inst::get_indexed_type(type *ty, const std::vector &idx_list) { - type *result = get_indexed_type_impl(ty, idx_list); - assert(result && "invalid GEP type!"); - return result; -} - -getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector &idx, const std::string &name, instruction *next) { - type *pointee_ty = ((pointer_type*)(ptr->get_type()->get_scalar_ty()))->get_element_ty(); - return new getelementptr_inst(pointee_ty, ptr, idx, name, next); -} - - -//===----------------------------------------------------------------------===// -// load_inst/store_inst classes -//===----------------------------------------------------------------------===// - -// io_inst -io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) - : instruction(ty, id, num_ops, name, next) -{ } - -// load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile) -{ } - -// load -type *load_inst::get_pointee_type(type *ty) { - type *scalar_ty = ty->get_scalar_ty(); - type *pointee_ty = scalar_ty->get_pointer_element_ty(); - if(ty->is_block_ty()) - return block_type::get_same_shapes(pointee_ty, ty); - return pointee_ty; -} - -// unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) { - set_operand(0, ptr); -} - -unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next); -} - -// masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, - bool is_volatile, - const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) { - set_operand(0, ptr); - set_operand(1, mask); - set_operand(2, false_value); -} - -masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, - bool is_volatile, - const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next); -} - -// masked load async -masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, - const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) { - set_operand(0, ptr); - set_operand(1, mask); - set_operand(2, false_value); -} - -masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, - const std::string &name, instruction *next) { - return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next); -} - -// store - -store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) - : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next) -{ } - -// unmasked_store -unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, - const std::string &name, instruction *next) - : store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) { - set_operand(0, ptr); - set_operand(1, val); -} - -unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, - const std::string &name, instruction *next) { - return new unmasked_store_inst(ptr, val, name, next); -} - -// masked store -masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, - const std::string &name, instruction *next) - : store_inst(ptr, INST_MASKED_STORE, 3, name, next) { - set_operand(0, ptr); - set_operand(1, val); - set_operand(2, mask); -} - -masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { - return new masked_store_inst(ptr, val, mask, name, next); -} -//===----------------------------------------------------------------------===// -// retile_inst classes -//===----------------------------------------------------------------------===// - -// cat - -cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next) - : instruction(block_type::get(x->get_type()->get_scalar_ty(), - {x->get_type()->get_block_shapes()[0] + - y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) { - set_operand(0, x); - set_operand(1, y); -} - -instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) { - return new cat_inst(lhs, rhs, name, next); -} - -// retile - -retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, - const std::string &name, instruction *next) - : unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { } - - - -// reshape - -instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes, - const std::string &name, instruction *next) { - return new reshape_inst(arg, INST_RESHAPE, shapes, name, next); -} - - -// splat - -instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes, - const std::string &name, instruction *next) { - return new splat_inst(arg, INST_SPLAT, shapes, name, next); -} - -// broadcast - -instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes, - const std::string &name, instruction *next) { - return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next); -} - -// downcast - -instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) { - return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next); -} - -//===----------------------------------------------------------------------===// -// matmul_inst classes -//===----------------------------------------------------------------------===// - -dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, - const std::string &name, instruction *next) - : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { - set_operand(0, A); - set_operand(1, B); - set_operand(2, C); - allow_tf32_ = allow_tf32; -} - -instruction *dot_inst::create(value *A, value *B, value *C, - bool AT, bool BT, bool allow_tf32, - const std::string &name, instruction *next) { - TransT OPA = AT ? Trans : NoTrans; - TransT OPB = BT ? Trans : NoTrans; - return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next); -} - -instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32, - const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next); -} - -instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32, - const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next); -} - -instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32, - const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next); -} - -instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32, - const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next); -} - -//===----------------------------------------------------------------------===// -// trans instructions -//===----------------------------------------------------------------------===// - -ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) { - // get argument shapes - ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes(); - // permutate argument shapes - perm = init_perm(ty, perm); - ir::block_type::block_shapes_t res_shapes = arg_shapes; - for(size_t i = 0; i < perm.size(); i++) - res_shapes[i] = arg_shapes[perm[i]]; - // construct type - return block_type::get(ty->get_scalar_ty(), res_shapes); -} - -std::vector trans_inst::init_perm(ir::type* ty, const std::vector& perm) { - if(!perm.empty()) - return perm; - auto size = ty->get_block_shapes().size(); - std::vector result; - result.push_back(size - 1); - for(size_t i = 0; i < size - 1; i++) - result.push_back(i); - return result; -} - -trans_inst::trans_inst(value *arg, const std::vector &perm, const std::string &name, instruction *next) - : builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) { - // sanity check - perm_ = init_perm(arg->get_type(), perm); - //auto size = arg->get_type()->get_tile_shapes().size(); - //assert(perm_.size() == size); - set_operand(0, arg); -} - -instruction* trans_inst::create(value *arg, const std::vector &perm, const std::string &name, instruction *next) { - return new trans_inst(arg, perm, name, next); -} - -const std::vector trans_inst::get_perm() const { - return perm_; -} - -//===----------------------------------------------------------------------===// -// sqrt instructions -//===----------------------------------------------------------------------===// - -sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next) - : builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){ - set_operand(0, arg); -} - -instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) { - return new sqrt_inst(arg, name, next); -} - -//===----------------------------------------------------------------------===// -// reduce instructions -//===----------------------------------------------------------------------===// - -std::string reduce_inst::to_str(op_t op) { - switch (op) { - case ADD: return "+"; - case SUB: return "-"; - case MAX: return "imax"; - case MIN: return "imin"; - case FADD: return "+"; - case FSUB: return "-"; - case FMAX: return "fmax"; - case FMIN: return "fmin"; - default: break; - } - assert(false); - return ""; -} - -type* reduce_inst::get_res_type(value *arg, unsigned axis) { - ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes(); - shapes.erase(shapes.begin() + axis); - type *scalar_ty = arg->get_type()->get_scalar_ty(); - if(shapes.empty()) -// shapes.push_back(1); - return scalar_ty; - return block_type::get(scalar_ty, shapes); -} - -reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next), - op_(op), - axis_(axis){ - set_operand(0, arg); -} - -instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) { - return new reduce_inst(arg, op, axis, name, next); -} - - -//===----------------------------------------------------------------------===// -// select instructions -//===----------------------------------------------------------------------===// - -select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) - : builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){ - set_operand(0, pred); - set_operand(1, if_value); - set_operand(2, else_value); -} - -instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) { - return new select_inst(pred, if_value, else_value, name, next); -} -//===----------------------------------------------------------------------===// -// builtin instructions -//===----------------------------------------------------------------------===// - - -// get_program_id -get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){ - -} - -instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { - return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next); -} - -// get_num_program -get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){ - -} - -instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { - return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next); -} - -// atomic_rmw - -atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) - : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) { - set_operand(0, ptr); - set_operand(1, val); - set_operand(2, msk); -} - -instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) { - return new atomic_rmw_inst(op, ptr, val, msk, name, next); -} - - -// atomic cas - -atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) - : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) { - set_operand(0, ptr); - set_operand(1, cmp); - set_operand(2, val); -} - -instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) { - return new atomic_cas_inst(ptr, cmp, val, name, next); -} - - -// umulhi - -umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next) - : builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) { - set_operand(0, lhs); - set_operand(1, rhs); -} - -instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) { - return new umulhi_inst(lhs, rhs, name, next); -} - - -// exp - -exp_inst::exp_inst(value *val, const std::string &name, instruction *next) - : builtin_inst(val->get_type(), INST_EXP, 1, name, next) { - set_operand(0, val); -} - -instruction* exp_inst::create(value *val, const std::string& name, instruction *next) { - return new exp_inst(val, name, next); -} - -// cos -cos_inst::cos_inst(value *val, const std::string &name, instruction *next) - : builtin_inst(val->get_type(), INST_COS, 1, name, next) { - set_operand(0, val); -} - -instruction* cos_inst::create(value *val, const std::string& name, instruction *next) { - return new cos_inst(val, name, next); -} - -// sin -sin_inst::sin_inst(value *val, const std::string &name, instruction *next) - : builtin_inst(val->get_type(), INST_SIN, 1, name, next) { - set_operand(0, val); -} - -instruction* sin_inst::create(value *val, const std::string& name, instruction *next) { - return new sin_inst(val, name, next); -} - - -// log - -log_inst::log_inst(value *val, const std::string &name, instruction *next) - : builtin_inst(val->get_type(), INST_LOG, 1, name, next) { - set_operand(0, val); -} - -instruction* log_inst::create(value *val, const std::string& name, instruction *next) { - return new log_inst(val, name, next); -} - - -//===----------------------------------------------------------------------===// -// intrinsic instructions -//===----------------------------------------------------------------------===// - -// cvt_scanline -cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) { - return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next); -} - -// copy to shared -copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, - instruction *next) { - return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next); -} - -// copy from shared -copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name, - instruction *next) { - return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); -} - -// barrier -barrier_inst::barrier_inst(context &ctx, const std::string &name, - instruction *next) - : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { } - -barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { - return new barrier_inst(ctx, name, next); -} - -async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next) - : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { } - -async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) { - return new async_wait_inst(ctx, N, name, next); -} - -// prefetch_s -prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) { - return new prefetch_s_inst(ctx, arg, inc, name, next); -} - -//// nv_dynamic_program_idx -//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) -// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } - -//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { -// return new make_range_dyn(ty, name, next); -//} - -//// nv_static_program_idx -//make_range_sta::make_range_sta(make_range *range) -// : constant(range->get_type(), 0), range_(range) { } - -//make_range* make_range_sta::get_range() const -//{ return range_; } - -//make_range_sta* make_range_sta::get(make_range* range) { -// static std::map cache; -// if(cache.find(range) == cache.end()) -// cache.insert({range, new make_range_sta(range)}); -// return cache.at(range); -//} - - -// make_range -make_range::make_range(type *ty, constant_int *first, constant_int *last) - : instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ } - -make_range *make_range::create(constant_int *first, constant_int *last) { - assert(first->get_type()->is_integer_ty()); - assert(first->get_type() == last->get_type()); -// assert(((constant_int*)first)->get_value() == 0); - type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()}); - return new make_range(ty, first, last); -} - -const constant_int* make_range::get_first() const { - return first_; -} - -const constant_int* make_range::get_last() const { - return last_; -} - -} -} diff --git a/lib/ir/metadata.cc b/lib/ir/metadata.cc deleted file mode 100644 index 16bc059c5..000000000 --- a/lib/ir/metadata.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "triton/ir/metadata.h" - -namespace triton{ -namespace ir{ - -metadata::metadata(kind_t kind, unsigned value) - : kind_(kind), value_(value) { } - -metadata* metadata::get(kind_t kind, unsigned value) { - return new metadata(kind, value); -} - -} -} diff --git a/lib/ir/module.cc b/lib/ir/module.cc deleted file mode 100644 index a37d3048f..000000000 --- a/lib/ir/module.cc +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include -#include "triton/ir/basic_block.h" -#include "triton/ir/module.h" -#include "triton/ir/type.h" -#include "triton/ir/constant.h" -#include "triton/ir/function.h" - -namespace triton{ -namespace ir{ - -/* functions */ -function *module::get_or_insert_function(const std::string &name, function_type *ty) { - function *&fn = (function*&)symbols_[name]; - if(fn == nullptr) - return fn = function::create(ty, global_value::external, name, this); - return fn; -} - - -} -} diff --git a/lib/ir/print.cc b/lib/ir/print.cc deleted file mode 100644 index db73ec7d9..000000000 --- a/lib/ir/print.cc +++ /dev/null @@ -1,450 +0,0 @@ -#include -#include "triton/ir/basic_block.h" -#include "triton/ir/module.h" -#include "triton/ir/type.h" -#include "triton/ir/value.h" -#include "triton/ir/constant.h" -#include "triton/ir/function.h" -#include "triton/ir/instructions.h" -#include "triton/ir/print.h" - -#include -#include - -namespace triton{ -namespace ir{ - -namespace { -class SlotTracker { - // A mapping of values to slot numbers. - using value_map = std::map; - - // The module for which we are holding slot numbers. - const module *mod_; - bool module_processed = false; - - // The function for which we are holding slot numbers. - const function *func_ = nullptr; - bool function_processed = false; - - // m_map - The slot map for the module level data. - value_map m_map; - unsigned m_next = 0; - - // f_map - The slot map for the function level data. - value_map f_map; - unsigned f_next = 0; - -public: - // Construct from a module - explicit SlotTracker(const module *mod) : mod_(mod) {} - - // Construct from a function - explicit SlotTracker(const function *f) - : mod_(f? f->get_parent() : nullptr), func_(f) {} - - // Return the slot number of the specified value. If something is not in - // the SlotTracker, return -1 - int get_local_slot(const value *v); - - void initialize_if_needed(); - - // If you'd like to deal with a function instead of just a module, use - // this method to get its data into the SlotTracker - void incorporate_function(const function *f) { - func_ = f; - function_processed = false; - } - -private: - // Add all of the module level global variables (and their initializers) - // and function declarations, but not contents of those functions. - void process_module(); - - // Add all of the functions arguments, basic blocks, and instructions. - void process_function(); - - // Insert specified value* into the slot table - void create_function_slot(const value *v); -}; - -class AssemblyWriter { - std::ostream &os; - SlotTracker &slot_tracker; - -public: - AssemblyWriter(std::ostream &os, SlotTracker &slot_tracker) - : os(os), slot_tracker(slot_tracker) {} - - void print_module(const module *mod); - void print_function(const function *f); - void print_argument(const argument *arg); - void print_basic_block(const basic_block *bb); - void print_instruction(const instruction *instr); - void print_value(const value *v); - - void write_operand(const value *op, bool print_type = false); -}; -} // anonymous namespace - -//------------------------- -// SlotTracker -//------------------------- -void SlotTracker::process_module() { - // Nothing to do at the moment. - // Create slots for global variable & unamed functions & ... - module_processed = true; -} - -void SlotTracker::process_function() { - f_next = 0; - - // Add all the function arguments with no names. - for (const argument *arg : func_->args()) - if (!arg->has_name()) - create_function_slot(arg); - - // Add all of the basic blocks and instructions with no names. - for (const basic_block *bb : func_->blocks()) { - if (!bb->has_name()) - create_function_slot(bb); - - for (const instruction *instr : bb->get_inst_list()) { - if (!instr->get_type()->is_void_ty() && !instr->has_name()) - create_function_slot(instr); - } - } - - function_processed = true; -} - -void SlotTracker::create_function_slot(const value *v) { - assert(!v->get_type()->is_void_ty() && !v->has_name() && "Doesn't need a slot"); - - unsigned dst_slot = f_next++; - f_map[v] = dst_slot; -} - -int SlotTracker::get_local_slot(const value *v) { - assert(dynamic_cast(v) == nullptr && "Can't get a constant slot"); - - // Check for uninitialized state and do lazy initialization. - initialize_if_needed(); - - value_map::iterator f_iter = f_map.find(v); - return f_iter == f_map.end() ? -1 : (int)f_iter->second; -} - -void SlotTracker::initialize_if_needed() { - if (mod_ && !module_processed) - process_module(); - - if (func_ && !function_processed) - process_function(); -} - - -//------------------------------- -// AssemblyWriter -//------------------------------- -void AssemblyWriter::write_operand(const value *operand, bool print_type) { - if (!operand) { - os << ""; - return; - } - - if (auto *c = dynamic_cast(operand)) { - os << c->repr(); - return; - } - - if (operand->has_name()) { - os << operand->get_name(); - return; - } - - // Print the normal way - int slot_num = slot_tracker.get_local_slot(operand); - - if (slot_num != -1) - os << "%" << slot_num; - else - os << ""; -} - -void AssemblyWriter::print_module(const module *mod) { - slot_tracker.initialize_if_needed(); - // ;ModuleID = ... - // source_filename = ... - - // Print all of the functions. - for (function *f : mod->get_function_list()) { - os << "\n"; - print_function(f); - } -} - -void AssemblyWriter::print_function(const function *f) { - // Annotation & Attributes - - slot_tracker.incorporate_function(f); - - os << "def "; - ir::type *rt_type = f->get_fn_type()->get_return_ty(); - // Functions must have names. - os << rt_type->repr() << " " << f->get_name() << "("; - // Print arguments - for (ir::argument *arg : f->args()) { - if (arg->get_arg_no() > 0) - os << ", "; - print_argument(arg); - } - os << ")"; - - // Print function body - os << "{"; - for (const basic_block *bb : f->blocks()) - print_basic_block(bb); - os << "}\n"; -} - -void AssemblyWriter::print_argument(const argument *arg) { - // Print type - os << arg->get_type()->repr(); - - // Print name, if available. - if (arg->has_name()) - os << " " << arg->get_name(); - else { - int slot_num = slot_tracker.get_local_slot(arg); - assert(slot_num != -1 && "expect argument in function here"); - os << " %" << slot_num; - } - - // Print attributes - std::set attrs = arg->get_parent()->get_attributes(arg); - for (attribute attr : attrs) - os << " " << attr.repr(); -} - -void AssemblyWriter::print_basic_block(const basic_block *bb) { - // bb label - if (bb->has_name()) { - os << "\n"; - os << bb->get_name() << ":"; - } else { - os << "\n"; - int slot_num = slot_tracker.get_local_slot(bb); - if (slot_num != -1) - os << slot_num << ":"; - else - os << ":"; - } - - // Print predecessors for the block - auto const &predecessors = bb->get_predecessors(); - if (!predecessors.empty()) { - os << std::setw(50) << std::setfill(' ') - << "; preds = "; - for (size_t i=0; iget_inst_list()) - print_instruction(instr); -} - -void AssemblyWriter::print_instruction(const instruction *instr) { - // Print out indentation for an instruction. - os << " "; - - ir::type *type = instr->get_type(); - if (instr->has_name()) { - os << instr->get_name(); - os << " = "; - } else if (!type->is_void_ty()) { - // Print out the def slot taken. - int slot_num = slot_tracker.get_local_slot(instr); - if (slot_num == -1) - os << " = "; - else - os << "%" << slot_num << " = "; - } - - // Print out opcode - os << instr->repr() << " " << type->repr(); - - size_t num_ops = instr->get_num_operands(); - if (num_ops > 0) - os << " "; - ir::instruction::ops_t ops = instr->ops(); - for (unsigned i = 0; i < num_ops; ++i) { - if (i) - os << ", "; - write_operand(ops[i]); - } - - os << ";\n"; -} - -void AssemblyWriter::print_value(const value *v) { - // Not implemented -} - - -//------------------------------- -// External interface -//------------------------------- -void module::print(std::ostream &os) { - SlotTracker slot_tracker(this); - AssemblyWriter writer(os, slot_tracker); - writer.print_module(this); -} - -void function::print(std::ostream &os) { - SlotTracker slot_tracker(this); - AssemblyWriter writer(os, slot_tracker); - writer.print_function(this); -} - -void basic_block::print(std::ostream &os) { - SlotTracker slot_tracker(this->get_parent()); - AssemblyWriter writer(os, slot_tracker); - writer.print_basic_block(this); -} - -void instruction::print(std::ostream &os) { - SlotTracker slot_tracker(this->get_parent()->get_parent()); - AssemblyWriter writer(os, slot_tracker); - writer.print_instruction(this); -} - -//------------------------------- -// legacy print interface -//------------------------------- -std::string get_name(ir::value *v, unsigned i) { - if(v->get_name().empty()){ - std::string name = "%" + std::to_string(i); - v->set_name(name); - } - return v->get_name(); -} - - -void print(module &mod, std::ostream& os) { - unsigned cnt = 0; - for(ir::function *fn: mod.get_function_list()){ - os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ; - for(ir::argument* arg: fn->args()) { - if(arg->get_arg_no() > 0) - os << ", "; - os << arg->get_type()->repr() << " " << arg->get_name(); - auto attrs = fn->get_attributes(arg); - if(attrs.size() > 0) - os << " "; - for(ir::attribute attr: attrs) - os << attr.repr() << " "; - } - os << ")" << std::endl; - os << "{" << std::endl; - for(ir::basic_block *block: fn->blocks()){ - auto const &predecessors = block->get_predecessors(); - os << block->get_name() << ":"; - if(!predecessors.empty()){ - os << " "; - os << "; preds = "; - auto const &predecessors = block->get_predecessors(); - for(ir::basic_block *pred: predecessors) - os << pred->get_name() << (pred!=predecessors.back()?", ":""); - } - os << std::endl; - for(ir::instruction *inst: block->get_inst_list()){ - os << " "; - if(!inst->get_type()->is_void_ty()){ - os << get_name(inst, cnt++); - os << " = "; - } - ir::type* type = inst->get_type(); - os << inst->repr() << " " << type->repr(); - ir::instruction::ops_t ops = inst->ops(); - size_t num_ops = inst->get_num_operands(); - if(num_ops > 0) - os << " ";; - for(unsigned i = 0; i < num_ops; i++){ - if(auto *x = dynamic_cast(ops[i])) - os << x->repr(); - else - os << get_name(ops[i], cnt++); - os << (i < num_ops - 1?", ":""); - } - os << ";"; -// os << " ("; -// for(ir::user* usr: inst->get_users()) -// os << get_name(usr, cnt++) << ", " ; -// os << " )"; - os << std::endl; - } - } - os << "}" << std::endl; - } -} - -void print(function &fn, std::ostream &os) { - // -} - -void print(basic_block &bb, std::ostream &os) { - auto const &predecessors = bb.get_predecessors(); - os << bb.get_name() << ":"; - if(!predecessors.empty()){ - os << " "; - os << "; preds = "; - auto const &predecessors = bb.get_predecessors(); - for(ir::basic_block *pred: predecessors) - os << pred->get_name() << (pred!=predecessors.back()?", ":""); - } - os << std::endl; - for(ir::instruction *inst: bb.get_inst_list()){ - print(*inst, os); - } -} - -void print(instruction &instr, std::ostream &os) { - instruction *inst = &instr; - os << " "; - if(!inst->get_type()->is_void_ty()){ - os << instr.get_name(); - os << " = "; - } - ir::type* type = inst->get_type(); - os << inst->repr() << " " << type->repr(); - ir::instruction::ops_t ops = inst->ops(); - size_t num_ops = inst->get_num_operands(); - if(num_ops > 0) - os << " ";; - for(unsigned i = 0; i < num_ops; i++){ - if(auto *x = dynamic_cast(ops[i])) - os << x->repr(); - else - os << ops[i]->get_name(); - os << (i < num_ops - 1?", ":""); - } - os << ";"; -// os << " ("; -// for(ir::user* usr: inst->get_users()) -// os << get_name(usr, cnt++) << ", " ; -// os << " )"; - os << std::endl; -} - - -} -} diff --git a/lib/ir/type.cc b/lib/ir/type.cc deleted file mode 100644 index 056ae99e6..000000000 --- a/lib/ir/type.cc +++ /dev/null @@ -1,233 +0,0 @@ -#include -#include -#include "triton/ir/type.h" -#include "triton/ir/context.h" -#include "triton/ir/context_impl.h" -#include "triton/ir/value.h" -#include "triton/ir/constant.h" - -namespace triton{ -namespace ir{ - -//===----------------------------------------------------------------------===// -// type class -//===----------------------------------------------------------------------===// - -// attributes -type *type::get_scalar_ty() const { - if(is_block_ty()) - return get_tile_element_ty(); - return const_cast(this); -} - -unsigned type::get_primitive_size_in_bits() const { - switch (id_) { - case FP8TyID: return 8; - case FP16TyID: return 16; - case BF16TyID: return 16; - case FP32TyID: return 32; - case FP64TyID: return 64; - case IntegerTyID: return ((integer_type*)(this))->get_bitwidth(); - case BlockTyID: return ((block_type*)(this))->get_bitwidth(); - default: return 0; - } -} - -unsigned type::get_integer_bitwidth() const -{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } - -unsigned type::get_tile_bitwidth() const -{ return ((block_type*)(this))->get_bitwidth(); } - -unsigned type::get_fp_mantissa_width() const { - id_t id = get_scalar_ty()->id_; - assert(is_floating_point_ty() && "Not a floating point type!"); - if (id == FP8TyID) return 3; - if (id == FP16TyID) return 10; - if (id == BF16TyID) return 7; - if (id == FP32TyID) return 23; - if (id == FP64TyID) return 53; - throw std::runtime_error("unreachable"); -} - -type* type::get_tile_element_ty() const { - assert(is_block_ty()); - return contained_tys_[0]; -} - -unsigned type::get_pointer_address_space() const { - assert(is_pointer_ty()); - return ((pointer_type*)this)->get_address_space(); -} - -type * type::get_pointer_element_ty() const { - type *ptr_ty = get_scalar_ty(); - assert(ptr_ty->is_pointer_ty()); - type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty(); - if(is_block_ty()) - return block_type::get_same_shapes(scalar_ty, (type*)this); - return scalar_ty; -} - - -type::block_shapes_t type::get_block_shapes() const { - assert(is_block_ty()); - return ((block_type*)this)->get_shapes(); -} - -const size_t type::get_tile_rank() const { - return get_block_shapes().size(); -} - -const size_t type::get_tile_ranks1() const { - int ret = 0; - for(int s: get_block_shapes()) - ret += s > 1; - return ret; -} - - -unsigned type::get_tile_num_elements() const { - const block_shapes_t& shapes = get_block_shapes(); - unsigned result = 1; - for(auto shape: shapes) - result *= shape; - return result; -} - - -// composite predicates -bool type::is_int_or_tileint_ty() -{ return get_scalar_ty()->is_integer_ty(); } - -bool type::is_integer_ty(unsigned width) const -{ return is_integer_ty() && get_integer_bitwidth()== width; } - - -bool type::is_floating_point_ty() const -{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); } - -bool type::is_sized() const { - // primitive types are sized - if(is_integer_ty() || is_floating_point_ty() || - is_pointer_ty()){ - return true; - } - // tile types are sizes - if(is_block_ty()) - return get_scalar_ty()->is_sized(); - return false; -} - -// primitive types -type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; } -type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; } -// floating point -type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; } -type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; } -type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; } -type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; } -type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; } -// integer types -integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; } -integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; } -integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } -integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } -integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } -integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } - - - -pointer_type::pointer_type(type *ty, unsigned address_space) - : type(ty->get_context(), PointerTyID), address_space_(address_space){ - contained_tys_.push_back(ty); -} - -bool pointer_type::is_valid_elt_ty(type *ty){ - return !ty->is_void_ty() && !ty->is_label_ty() && - !ty->is_metadata_ty() && !ty->is_token_ty(); -} - -pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ - assert(elt_ty && "Can't get a pointer to type!"); - assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); - // look-up - context_impl *impl = elt_ty->get_context().p_impl.get(); - std::unique_ptr &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; - if(!entry) - entry.reset(new pointer_type(elt_ty, address_space)); - return entry.get(); -} - -//===----------------------------------------------------------------------===// -// composite_type class -//===----------------------------------------------------------------------===// - -type* composite_type::get_type_at_index(value *) const{ - assert(is_block_ty()); - return get_scalar_ty(); -} - -bool composite_type::index_valid(value *idx) const{ - assert(is_block_ty()); - return idx->get_type()->is_int_or_tileint_ty(); -} - -//===----------------------------------------------------------------------===// -// tile_type class -//===----------------------------------------------------------------------===// - -block_type::block_type(type *ty, const block_shapes_t &shapes) - : composite_type(ty->get_context(), BlockTyID), shapes_(shapes) { - contained_tys_.push_back(ty); -} - -bool block_type::is_valid_elt_ty(type *ty) { - return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty(); -} - -unsigned block_type::get_num_elements() const { - unsigned res = 1; - for(auto shape: shapes_) - res *= shape; - return res; -} - -unsigned block_type::get_bitwidth() const { - return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits(); -} - -block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) { - assert(elt_ty && "Can't get a tile of type!"); - assert(shapes.size() && "Can't create a tile with empty shapes!"); - assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); - // look-up - context_impl *impl = elt_ty->get_context().p_impl.get(); - std::unique_ptr &entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; - if(!entry) - entry.reset(new block_type(elt_ty, shapes)); - return entry.get(); -} - -block_type* block_type::get_same_shapes(type *ty, type *ref){ - assert(ref->is_block_ty()); - return get(ty, ref->get_block_shapes()); -} - -//===----------------------------------------------------------------------===// -// function_type class -//===----------------------------------------------------------------------===// - -function_type::function_type(type *ret_ty, const std::vector ¶m_tys): - type(ret_ty->get_context(), FunctionTyID) { - contained_tys_.push_back(ret_ty); - for(type *ty: param_tys) - contained_tys_.push_back(ty); -} - -function_type* function_type::get(type *ret_ty, const std::vector ¶m_tys) { - return new function_type(ret_ty, param_tys); -} - -} -} diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc deleted file mode 100644 index cbfb4baf9..000000000 --- a/lib/ir/utils.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include -#include -#include "triton/ir/utils.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" - -namespace triton{ -namespace ir{ - -std::vector cfg::post_order(function* fn) { - std::stack stack; - std::set visited; - std::vector result; - // initialize stack - for(ir::basic_block* block: fn->blocks()) - if(block->get_predecessors().empty()){ - stack.push(block); - visited.insert(block); - } - // DFS - while(!stack.empty()) { - basic_block* current = stack.top(); - bool tail = true; - for(basic_block* succ: current->get_successors()) - if(visited.find(succ) == visited.end()){ - stack.push(succ); - visited.insert(succ); - tail = false; - break; - } - if(tail){ - stack.pop(); - result.push_back(current); - } - } - return result; -} - -std::vector cfg::reverse_post_order(function* fn) { - auto result = post_order(fn); - std::reverse(result.begin(), result.end()); - return result; -} - -void for_each_instruction(module &mod, const std::function &do_work) { - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: cfg::reverse_post_order(fn)) - for(ir::instruction *i: block->get_inst_list()) - do_work(i); -} - -void for_each_value(module &mod, const std::function &do_work) { - std::set seen; - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: cfg::reverse_post_order(fn)) - for(ir::instruction *i: block->get_inst_list()){ - for(ir::value *op: i->ops()){ - if(seen.insert(op).second) - do_work(op); - } - if(seen.insert(i).second) - do_work(i); - } -} - -} -} diff --git a/lib/ir/value.cc b/lib/ir/value.cc deleted file mode 100644 index b970e07d7..000000000 --- a/lib/ir/value.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include -#include -#include "triton/ir/value.h" -#include "triton/ir/instructions.h" - -namespace triton{ -namespace ir{ - -class type; - -//===----------------------------------------------------------------------===// -// value class -//===----------------------------------------------------------------------===// - -value::value(type *ty, const std::string &name): ty_(ty){ - set_name(name); -} - -void value::add_use(user *arg) { - users_.insert(arg); -} - -value::users_t::iterator value::erase_use(user *arg){ - auto it = users_.find(arg); - if(it == users_.end()) - return it; - return users_.erase(it); -} - -// TODO: automatic naming scheme + update symbol table -void value::set_name(const std::string &name){ - name_ = name; -} - -void value::replace_all_uses_with(value *target){ - for (auto it = users_.begin(); it != users_.end(); ) { - it = (*it)->replace_uses_of_with(this, target); - } -} - - -void visitor::visit_value(ir::value* v) { - v->accept(this); -} - - -//===----------------------------------------------------------------------===// -// user class -//===----------------------------------------------------------------------===// -void user::set_operand(unsigned i, value *x) { - assert(i < ops_.size() && "set_operand() out of range!"); - ops_[i] = x; - x->add_use(this); -} - -value* user::get_operand(unsigned i) const { - assert(i < ops_.size() && "get_operand() out of range!"); - return ops_[i]; -} - -unsigned user::get_num_operands() const { - return num_ops_; -} - -unsigned user::get_num_hidden() const { - return num_hidden_; -} - -value::users_t::iterator user::replace_uses_of_with(value *before, value *after) { - for(size_t i = 0; i < ops_.size(); i++) - if(ops_[i] == before){ - ops_[i] = after; - after->add_use(this); - } - return before->erase_use(this); -} - - - -} -} diff --git a/python/src/triton.cc b/python/src/triton.cc index b66761ec3..1a32c73e5 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -906,12 +906,7 @@ void init_triton_ir(py::module &&m) { // Intrinsics // These have no place in the IR, and hopefully they can be removed at some point .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) - .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) - .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) - .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) - .def("create_barrier", &ir::builder::create_barrier, ret::reference) - .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) - .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); + .def("create_barrier", &ir::builder::create_barrier, ret::reference); } void init_triton(py::module &m) { diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a253e2c4c..170b71a09 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -819,6 +819,49 @@ class Kernel: return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + # Compile to ttir, for the propose of testing MLIR rewriting + def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + # TODO: share code with _compile & __call__ + + # preparing args + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + attributes = dict() + for i, arg in enumerate(wargs): + if i in self.fn.do_not_specialize: + continue + if isinstance(arg, int): + attributes[i] = Kernel.pow2_divisor(arg) + elif i in tensor_idxs: + addr = arg.data_ptr() + range_size = _triton.runtime.get_pointer_range_size(addr) + attributes[i] = min(Kernel.pow2_divisor(addr), + Kernel.pow2_divisor(range_size)) + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) + arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] + + # create IR module + context = _triton.ir.context() + # get just-in-time proto-type of kernel + arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] + ret_type = triton.language.void + prototype = triton.language.function_type(ret_type, arg_types) + # generate Triton-IR + # export symbols visible from self into code-generator object + gscope = self.__globals__ + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + try: + generator.visit(self.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.src, node) from e + return generator.module + class Launcher: def __init__(self, kernel, grid):