Init commit
This commit is contained in:
1
include/CMakeLists.txt
Normal file
1
include/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(ir)
|
8
include/triton/ir/CMakeLists.txt
Normal file
8
include/triton/ir/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonTableGen)
|
18
include/triton/ir/Dialect.h
Normal file
18
include/triton/ir/Dialect.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRITON_IR_DIALECT_H_
|
||||
#define TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
|
||||
#include "triton/Dialect.h.inc"
|
||||
|
||||
#include "triton/OpsEnums.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
39
include/triton/ir/TritonDialect.td
Normal file
39
include/triton/ir/TritonDialect.td
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_DIALECT
|
||||
#define TRITON_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Triton_Dialect : Dialect {
|
||||
let name = "triton";
|
||||
|
||||
let cppNamespace = "::mlir::triton";
|
||||
|
||||
let summary = "The Triton IR in MLIR";
|
||||
|
||||
let description = [{
|
||||
Triton Dialect.
|
||||
|
||||
Dependent Dialects:
|
||||
* Arithmetic:
|
||||
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
||||
* Tensor:
|
||||
* reshape (?)
|
||||
* ControlFlow:
|
||||
* bf, cond_bf
|
||||
* Func:
|
||||
* call, return
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"tensor::TensorDialect",
|
||||
"cf::ControlFlowDialect",
|
||||
"func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TRITON_DIALECT
|
235
include/triton/ir/TritonOps.td
Normal file
235
include/triton/ir/TritonOps.td
Normal file
@@ -0,0 +1,235 @@
|
||||
#ifndef Triton_OPS
|
||||
#define Triton_OPS
|
||||
|
||||
include "TritonDialect.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
// FloatType
|
||||
def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
||||
/*descr*/"8bit float",
|
||||
/*cppClassName*/"::mlir::triton::Float8Type">;
|
||||
|
||||
def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
||||
/*descr*/"8bit bfloat",
|
||||
/*cppClassName*/"::mlir::triton::BFloat8Type">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
|
||||
// IntegerType
|
||||
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
|
||||
// PointerType
|
||||
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
||||
def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
||||
def TT_PtrTensor : TensorOf<[TT_AnyPtr]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||
TT_AnyPtr, TT_PtrTensor]>;
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, traits>;
|
||||
|
||||
//
|
||||
// CastOps
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins I64Tensor:$from);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$from);
|
||||
|
||||
let results = (outs I64Tensor:$result);
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
Floating point casting for custom types (F8, BF8).
|
||||
|
||||
F8 <-> BF8, FP16, FP32
|
||||
BF8 <-> F8, FP16, FP32
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatTensor:$from);
|
||||
|
||||
let results = (outs TT_FloatTensor:$result);
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask)>
|
||||
];
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
// def TT_CatOp : TT_Op<"cat", []>;
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
//
|
||||
// builtin Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "dot";
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"SUM", 1, /*str*/"sum">,
|
||||
I32EnumAttrCase<"MAX", 2, "max">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"XOR_SUM", 4, "xor_sum">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis);
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
[
|
||||
I32EnumAttrCase<"AND", 1, "and">,
|
||||
I32EnumAttrCase<"OR", 2, "or">,
|
||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||
I32EnumAttrCase<"ADD", 4, "add">,
|
||||
I32EnumAttrCase<"MAX", 5, "max">,
|
||||
I32EnumAttrCase<"MIN", 6, "min">,
|
||||
I32EnumAttrCase<"UMAX", 7, "umax">,
|
||||
I32EnumAttrCase<"UMIN", 8, "umin">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
|
||||
|
||||
return old value at $ptr
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
||||
TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
compare $cmp with data $old at location $ptr,
|
||||
|
||||
if $old == $cmp, store $val to $ptr,
|
||||
|
||||
else store $old to $ptr,
|
||||
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
//
|
||||
// Intrinsics
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let summary = "make range";
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntegerTensor:$result);
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
46
include/triton/ir/Types.h
Normal file
46
include/triton/ir/Types.h
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef TRITON_IR_TYPES_H_
|
||||
#define TRITON_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
namespace detail {
|
||||
struct PointerTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
// TODO: Should be base class be FloatType?
|
||||
class Float8Type : public Type::TypeBase<Float8Type, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static Float8Type get(MLIRContext *context);
|
||||
};
|
||||
|
||||
class BFloat8Type : public Type::TypeBase<BFloat8Type, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static BFloat8Type get(MLIRContext *context);
|
||||
};
|
||||
|
||||
class PointerType : public Type::TypeBase<PointerType, Type,
|
||||
detail::PointerTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static PointerType get(Type pointeeType);
|
||||
|
||||
static PointerType get(Type pointeeType, unsigned addressSpace);
|
||||
|
||||
Type getPointeeType() const;
|
||||
|
||||
unsigned getAddressSpace() const;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
@@ -1,88 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
||||
#define _TRITON_IR_BASIC_BLOCK_H_
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include "value.h"
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class function;
|
||||
class instruction;
|
||||
|
||||
/* Basic Block */
|
||||
class basic_block: public value{
|
||||
public:
|
||||
// instruction iterator types
|
||||
typedef std::list<instruction*> inst_list_t;
|
||||
typedef inst_list_t::iterator iterator;
|
||||
typedef inst_list_t::const_iterator const_iterator;
|
||||
typedef inst_list_t::reverse_iterator reverse_iterator;
|
||||
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
basic_block(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
function* get_parent() { return parent_; }
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// get iterator to first instruction that is not a phi
|
||||
iterator get_first_non_phi();
|
||||
|
||||
// get instruction list
|
||||
inst_list_t &get_inst_list() { return inst_list_; }
|
||||
const inst_list_t &get_inst_list() const { return inst_list_; }
|
||||
void erase(instruction *i) { inst_list_.remove(i); }
|
||||
|
||||
// instruction iterator functions
|
||||
inline iterator begin() { return inst_list_.begin(); }
|
||||
inline const_iterator begin() const { return inst_list_.begin(); }
|
||||
inline iterator end () { return inst_list_.end(); }
|
||||
inline const_iterator end () const { return inst_list_.end(); }
|
||||
|
||||
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
|
||||
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
|
||||
inline reverse_iterator rend () { return inst_list_.rend(); }
|
||||
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
|
||||
|
||||
inline size_t size() const { return inst_list_.size(); }
|
||||
inline bool empty() const { return inst_list_.empty(); }
|
||||
inline const instruction &front() const { return *inst_list_.front(); }
|
||||
inline instruction &front() { return *inst_list_.front(); }
|
||||
inline const instruction &back() const { return *inst_list_.back(); }
|
||||
inline instruction &back() { return *inst_list_.back(); }
|
||||
|
||||
// predecessors
|
||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
||||
void add_predecessor(basic_block* pred);
|
||||
|
||||
// factory functions
|
||||
static basic_block* create(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_basic_block(this); }
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
std::string name_;
|
||||
function *parent_;
|
||||
std::vector<basic_block*> preds_;
|
||||
std::vector<basic_block*> succs_;
|
||||
inst_list_t inst_list_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,191 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BUILDER_H_
|
||||
#define _TRITON_IR_BUILDER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "instructions.h"
|
||||
#include "basic_block.h"
|
||||
#include "type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class value;
|
||||
class type;
|
||||
class constant_int;
|
||||
class instruction;
|
||||
class context;
|
||||
class phi_node;
|
||||
|
||||
/* Builder */
|
||||
class builder{
|
||||
typedef basic_block::iterator iterator;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
void set_insert_point_after(instruction* i);
|
||||
void set_insert_point(basic_block* block);
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
type *get_int8_ty();
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
template<typename InstTy>
|
||||
InstTy* insert(InstTy *inst){
|
||||
assert(block_);
|
||||
block_->get_inst_list().insert(insert_point_, inst);
|
||||
inst->set_parent(block_);
|
||||
// for(ir::value* op: inst->ops())
|
||||
// op->add_use(inst);
|
||||
return inst;
|
||||
}
|
||||
// terminator instructions
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
value* create_fp_to_si(value *src, type *dst_ty);
|
||||
value* create_fp_to_ui(value *src, type *dst_ty);
|
||||
value* create_fp_ext(value *src, type *dst_ty);
|
||||
value* create_fp_trunc(value *src, type *dst_ty);
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||
value *create_downcast(value *arg);
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||
// Binary instructions
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs);
|
||||
value *create_fdiv(value *lhs, value *rhs);
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
// GEP
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
|
||||
// Comparison (int)
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_icmpSLE(value *lhs, value *rhs);
|
||||
value *create_icmpSLT(value *lhs, value *rhs);
|
||||
value *create_icmpSGE(value *lhs, value *rhs);
|
||||
value *create_icmpSGT(value *lhs, value *rhs);
|
||||
value *create_icmpULE(value *lhs, value *rhs);
|
||||
value *create_icmpULT(value *lhs, value *rhs);
|
||||
value *create_icmpUGE(value *lhs, value *rhs);
|
||||
value *create_icmpUGT(value *lhs, value *rhs);
|
||||
value *create_icmpEQ(value *lhs, value *rhs);
|
||||
value *create_icmpNE(value *lhs, value *rhs);
|
||||
// Comparison (float)
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_fcmpOLT(value *lhs, value *rhs);
|
||||
value *create_fcmpOGT(value *lhs, value *rhs);
|
||||
value *create_fcmpOLE(value *lhs, value *rhs);
|
||||
value *create_fcmpOGE(value *lhs, value *rhs);
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpONE(value *lhs, value *rhs);
|
||||
value *create_fcmpULT(value *lhs, value *rhs);
|
||||
value *create_fcmpUGT(value *lhs, value *rhs);
|
||||
value *create_fcmpULE(value *lhs, value *rhs);
|
||||
value *create_fcmpUGE(value *lhs, value *rhs);
|
||||
value *create_fcmpUEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpUNE(value *lhs, value *rhs);
|
||||
// Logical
|
||||
value *create_and(value *lhs, value *rhs);
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
value *create_prefetch_s(value *arg, int inc);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
basic_block *block_;
|
||||
iterator insert_point_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONSTANT_H_
|
||||
#define _TRITON_IR_CONSTANT_H_
|
||||
|
||||
#include "enums.h"
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context;
|
||||
|
||||
/* Constant */
|
||||
class constant: public user{
|
||||
protected:
|
||||
using user::user;
|
||||
|
||||
public:
|
||||
static constant* get_all_ones_value(type *ty);
|
||||
static constant* get_null_value(type *ty);
|
||||
virtual std::string repr() const = 0;
|
||||
};
|
||||
|
||||
/* Undef value */
|
||||
class undef_value: public constant{
|
||||
private:
|
||||
undef_value(type *ty);
|
||||
|
||||
public:
|
||||
static undef_value* get(type* ty);
|
||||
std::string repr() const { return "undef"; }
|
||||
void accept(visitor* vst) { vst->visit_undef_value(this); }
|
||||
};
|
||||
|
||||
|
||||
/* Constant int */
|
||||
class constant_int: public constant{
|
||||
protected:
|
||||
constant_int(type *ty, uint64_t value);
|
||||
|
||||
public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_int(this); }
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Constant fp */
|
||||
class constant_fp: public constant{
|
||||
constant_fp(type *ty, double value);
|
||||
|
||||
public:
|
||||
double get_value() { return value_; }
|
||||
static constant* get_negative_zero(type *ty);
|
||||
static constant* get_zero_value_for_negation(type *ty);
|
||||
static constant* get(context &ctx, double v);
|
||||
static constant* get(type *ty, double v);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_fp(this); }
|
||||
|
||||
private:
|
||||
double value_;
|
||||
};
|
||||
|
||||
|
||||
/* Global Value */
|
||||
class global_value: public constant {
|
||||
public:
|
||||
enum linkage_types_t {
|
||||
external
|
||||
};
|
||||
|
||||
public:
|
||||
global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
private:
|
||||
linkage_types_t linkage_;
|
||||
};
|
||||
|
||||
/* global object */
|
||||
class global_object: public global_value {
|
||||
public:
|
||||
global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space = 0);
|
||||
std::string repr() const { return get_name(); }
|
||||
};
|
||||
|
||||
/* global variable */
|
||||
class alloc_const: public global_object {
|
||||
public:
|
||||
alloc_const(type *ty, constant_int *size,
|
||||
const std::string &name = "");
|
||||
std::string repr() const { return get_name(); }
|
||||
void accept(visitor* vst) { vst->visit_alloc_const(this); }
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_H_
|
||||
#define _TRITON_IR_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
/* Context */
|
||||
class context {
|
||||
public:
|
||||
context();
|
||||
context(const context&) = delete;
|
||||
context& operator=(const context&) = delete;
|
||||
|
||||
public:
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,46 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
|
||||
#define _TRITON_IR_CONTEXT_IMPL_H_
|
||||
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
public:
|
||||
// constructors
|
||||
context_impl(context &ctx);
|
||||
|
||||
public:
|
||||
// non-numeric types
|
||||
type void_ty, label_ty;
|
||||
// floating point types
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
||||
|
||||
// Int constants
|
||||
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
||||
// Float constants
|
||||
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,175 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_ENUMS_H_
|
||||
#define _TRITON_IR_ENUMS_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
enum binary_op_t: unsigned int{
|
||||
Add,
|
||||
FAdd,
|
||||
Sub,
|
||||
FSub,
|
||||
Mul,
|
||||
FMul,
|
||||
UDiv,
|
||||
SDiv,
|
||||
FDiv,
|
||||
URem,
|
||||
SRem,
|
||||
FRem,
|
||||
Shl,
|
||||
LShr,
|
||||
AShr,
|
||||
And,
|
||||
Or,
|
||||
Xor
|
||||
};
|
||||
|
||||
enum class atomic_rmw_op_t: unsigned int{
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Add,
|
||||
Max,
|
||||
Min,
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
Xchg,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
SExt,
|
||||
FPTrunc,
|
||||
FPExt,
|
||||
UIToFP,
|
||||
SIToFP,
|
||||
FPToUI,
|
||||
FPToSI,
|
||||
PtrToInt,
|
||||
IntToPtr,
|
||||
BitCast,
|
||||
AddrSpaceCast
|
||||
};
|
||||
|
||||
enum cmp_pred_t: unsigned int {
|
||||
FIRST_FCMP_PREDICATE,
|
||||
FCMP_FALSE,
|
||||
FCMP_OEQ,
|
||||
FCMP_OGT,
|
||||
FCMP_OGE,
|
||||
FCMP_OLT,
|
||||
FCMP_OLE,
|
||||
FCMP_ONE,
|
||||
FCMP_ORD,
|
||||
FCMP_UNO,
|
||||
FCMP_UEQ,
|
||||
FCMP_UGT,
|
||||
FCMP_UGE,
|
||||
FCMP_ULT,
|
||||
FCMP_ULE,
|
||||
FCMP_UNE,
|
||||
FCMP_TRUE,
|
||||
LAST_FCMP_PREDICATE,
|
||||
FIRST_ICMP_PREDICATE,
|
||||
ICMP_EQ,
|
||||
ICMP_NE,
|
||||
ICMP_UGT,
|
||||
ICMP_UGE,
|
||||
ICMP_ULT,
|
||||
ICMP_ULE,
|
||||
ICMP_SGT,
|
||||
ICMP_SGE,
|
||||
ICMP_SLT,
|
||||
ICMP_SLE,
|
||||
LAST_ICMP_PREDICATE
|
||||
};
|
||||
|
||||
enum value_id_t: unsigned {
|
||||
/* ------------ *
|
||||
INSTRUCTIONS
|
||||
* ------------ */
|
||||
INST_BEGIN,
|
||||
// phi
|
||||
INST_PHI,
|
||||
// arithmetic
|
||||
INST_BINOP,
|
||||
INST_GETELEMENTPTR,
|
||||
INST_SELECT,
|
||||
INST_SQRT,
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
INST_CAST_SEXT,
|
||||
INST_CAST_FP_TRUNC,
|
||||
INST_CAST_FP_EXT,
|
||||
INST_CAST_UI_TO_FP,
|
||||
INST_CAST_SI_TO_FP,
|
||||
INST_CAST_FP_TO_UI,
|
||||
INST_CAST_FP_TO_SI,
|
||||
INST_CAST_PTR_TO_INT,
|
||||
INST_CAST_INT_TO_PTR,
|
||||
INST_CAST_BIT_CAST,
|
||||
INST_CAST_ADDR_SPACE_CAST,
|
||||
// terminators
|
||||
INST_RETURN,
|
||||
INST_COND_BRANCH,
|
||||
INST_UNCOND_BRANCH,
|
||||
// io
|
||||
INST_UNMASKED_LOAD,
|
||||
INST_MASKED_LOAD,
|
||||
INST_MASKED_LOAD_ASYNC,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_CAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
INST_GET_PROGRAM_ID,
|
||||
INST_GET_NUM_PROGRAMS,
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_UMULHI,
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
INST_DOT,
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_CVT_LAYOUT,
|
||||
INST_CVT_SCANLINE,
|
||||
INST_DECOALESCE,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_ASYNC_WAIT,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,142 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_FUNCTION_H_
|
||||
#define _TRITON_IR_FUNCTION_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "value.h"
|
||||
#include "constant.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class function;
|
||||
class function_type;
|
||||
class module;
|
||||
class basic_block;
|
||||
|
||||
/* Argument */
|
||||
class argument: public value{
|
||||
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
|
||||
|
||||
public:
|
||||
static argument* create(type *ty, const std::string &name,
|
||||
function *parent = nullptr, unsigned arg_no = 0);
|
||||
function* get_parent() const;
|
||||
unsigned get_arg_no() const;
|
||||
|
||||
void accept(visitor *v);
|
||||
|
||||
private:
|
||||
function *parent_;
|
||||
unsigned arg_no_;
|
||||
};
|
||||
|
||||
/* Attribute */
|
||||
enum attribute_kind_t {
|
||||
readonly = 0,
|
||||
writeonly,
|
||||
noalias,
|
||||
aligned,
|
||||
multiple_of,
|
||||
retune,
|
||||
not_implemented
|
||||
};
|
||||
|
||||
class attribute {
|
||||
public:
|
||||
attribute(attribute_kind_t kind, unsigned value = 0):
|
||||
kind_(kind), value_(value){}
|
||||
|
||||
bool operator<(const attribute& other) const {
|
||||
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
|
||||
}
|
||||
|
||||
attribute_kind_t get_kind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
unsigned get_value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
bool is_llvm_attr() const {
|
||||
return kind_ != multiple_of;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(kind_){
|
||||
case readonly: return ".readonly";
|
||||
case writeonly: return ".writeonly";
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
|
||||
case retune: return ".retunr";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
private:
|
||||
attribute_kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
/* Function */
|
||||
class function: public global_object{
|
||||
typedef std::vector<argument*> args_t;
|
||||
typedef args_t::iterator arg_iterator;
|
||||
typedef args_t::const_iterator const_arg_iterator;
|
||||
|
||||
typedef std::vector<basic_block*> blocks_t;
|
||||
typedef blocks_t::iterator block_iterator;
|
||||
typedef blocks_t::const_iterator const_block_iterator;
|
||||
|
||||
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
|
||||
|
||||
private:
|
||||
function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name = "", module *parent = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const args_t &args() const { return args_; }
|
||||
function_type* get_fn_type() { return fn_ty_; }
|
||||
const function_type* get_fn_type() const { return fn_ty_; }
|
||||
module *get_parent() { return parent_; }
|
||||
const module *get_parent() const { return parent_; }
|
||||
|
||||
// factory methods
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_function(this); }
|
||||
|
||||
private:
|
||||
module *parent_;
|
||||
bool init_;
|
||||
function_type *fn_ty_;
|
||||
args_t args_;
|
||||
blocks_t blocks_;
|
||||
attr_map_t attrs_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,978 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_INSTRUCTIONS_H_
|
||||
#define _TRITON_IR_INSTRUCTIONS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
|
||||
#define _TRITON_DEFINE_CLONE(name) \
|
||||
ir::instruction* clone_impl() const { return new name(*this); }
|
||||
|
||||
#define _TRITON_DEFINE_ACCEPT(name) \
|
||||
void accept(visitor* v) { v->visit_ ## name (this); }
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class constant_int;
|
||||
class constant;
|
||||
class make_range;
|
||||
class basic_block;
|
||||
class context;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class result_reference;
|
||||
|
||||
|
||||
class instruction: public user{
|
||||
public:
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
private:
|
||||
virtual ir::instruction* clone_impl() const = 0;
|
||||
|
||||
protected:
|
||||
// constructors
|
||||
instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// parent
|
||||
void set_parent(basic_block *block) { parent_ = block; }
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
std::string repr() const { return repr_impl(); }
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
// for(auto it = op_begin(); it != op_end(); it++)
|
||||
// (*it)->add_use(res);
|
||||
res->parent_ = nullptr;
|
||||
res->users_.clear();
|
||||
return res;
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
value_id_t id_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class phi_node: public instruction {
|
||||
private:
|
||||
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "phi"; }
|
||||
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
void set_incoming_block(unsigned i, basic_block *block);
|
||||
value *get_value_for_block(basic_block *block);
|
||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||
unsigned get_num_incoming() { return get_num_operands(); }
|
||||
void add_incoming(value *v, basic_block *block);
|
||||
|
||||
// Type
|
||||
void set_type(type *ty) { ty_ = ty; }
|
||||
|
||||
// Factory methods
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(phi_node)
|
||||
_TRITON_DEFINE_ACCEPT(phi_node)
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
std::vector<basic_block*> blocks_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
// Constructors
|
||||
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// Get operand
|
||||
binary_op_t get_op() const { return op_; }
|
||||
|
||||
// Bool
|
||||
bool is_terminator() const;
|
||||
bool is_binary_op() const;
|
||||
bool is_int_div_rem() const;
|
||||
bool is_shift() const;
|
||||
bool is_cast() const;
|
||||
bool is_int_mult() const;
|
||||
bool is_int_add_sub() const;
|
||||
bool is_int_div() const;
|
||||
bool is_int_rem() const;
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Approx
|
||||
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
|
||||
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
|
||||
// Factory methods
|
||||
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
_TRITON_DEFINE_ACCEPT(binary_operator)
|
||||
|
||||
public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
|
||||
bool fdiv_ieee_rnd_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cmp_inst: public instruction{
|
||||
public:
|
||||
typedef cmp_pred_t pred_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(cmp_pred_t pred);
|
||||
static bool is_int_predicate(cmp_pred_t pred);
|
||||
static type* make_cmp_result_type(type *ty);
|
||||
|
||||
public:
|
||||
cmp_pred_t get_pred() const { return pred_; }
|
||||
|
||||
private:
|
||||
cmp_pred_t pred_;
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst {
|
||||
icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(icmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(icmp_inst)
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst {
|
||||
fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(fcmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(fcmp_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class unary_inst: public instruction {
|
||||
protected:
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cast_inst: public unary_inst{
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, id, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(cast_op_t op, value *arg, type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
cast_op_t get_op() const { return op_; }
|
||||
|
||||
// factory methods
|
||||
static cast_inst *create(cast_op_t op, value *arg, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_ACCEPT(cast_inst)
|
||||
|
||||
private:
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name) \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, id, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class terminator_inst: public instruction{
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
// return instruction
|
||||
class return_inst: public terminator_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "ret"; }
|
||||
return_inst(context &ctx, value *ret_val, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_return_value()
|
||||
{ return get_num_operands() ? get_operand(0) : nullptr; }
|
||||
|
||||
unsigned get_num_successors() const { return 0; }
|
||||
|
||||
// factory methods
|
||||
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(return_inst)
|
||||
_TRITON_DEFINE_ACCEPT(return_inst)
|
||||
};
|
||||
|
||||
// base branch instruction
|
||||
class branch_inst: public terminator_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "br"; }
|
||||
|
||||
protected:
|
||||
using terminator_inst::terminator_inst;
|
||||
|
||||
public:
|
||||
static branch_inst* create(basic_block *dest,
|
||||
instruction *next = nullptr);
|
||||
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// conditional branch
|
||||
class cond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
|
||||
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
|
||||
value *get_cond() { return get_operand(2); }
|
||||
_TRITON_DEFINE_CLONE(cond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
|
||||
};
|
||||
|
||||
// unconditional branch
|
||||
class uncond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
uncond_branch_inst(basic_block *dst, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
_TRITON_DEFINE_CLONE(uncond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class getelementptr_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "getelementptr"; }
|
||||
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
|
||||
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
|
||||
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
type *get_source_elt_ty() { return source_elt_ty; }
|
||||
op_iterator idx_begin() { return op_begin() + 1; }
|
||||
op_iterator idx_end() { return op_end(); }
|
||||
value *get_pointer_operand() { return *op_begin(); }
|
||||
|
||||
// factory methods
|
||||
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(getelementptr_inst)
|
||||
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
|
||||
|
||||
private:
|
||||
type *source_elt_ty;
|
||||
type *res_elt_ty;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
};
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
public:
|
||||
enum CACHE_MODIFIER : uint32_t {
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
return is_volatile_ ? ".volatile" : "";
|
||||
}
|
||||
bool is_volatile_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
|
||||
};
|
||||
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||
};
|
||||
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||
};
|
||||
|
||||
|
||||
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
};
|
||||
|
||||
// unmasked_store
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
class cat_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cat"; }
|
||||
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||
};
|
||||
|
||||
// retile
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
// reshape
|
||||
|
||||
class reshape_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "reshape"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reshape_inst)
|
||||
};
|
||||
|
||||
// splat
|
||||
|
||||
class splat_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "splat"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(splat_inst)
|
||||
};
|
||||
|
||||
// broadcast
|
||||
|
||||
class broadcast_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "broadcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(broadcast_inst)
|
||||
};
|
||||
|
||||
|
||||
// downcast
|
||||
|
||||
class downcast_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "downcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(downcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(downcast_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class builtin_inst: public instruction{
|
||||
protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_program_id_inst: public builtin_inst {
|
||||
private:
|
||||
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_program_id_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_num_programs_inst: public builtin_inst {
|
||||
private:
|
||||
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_programs_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_rmw"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_rmw_op_t get_op() { return op_; }
|
||||
|
||||
private:
|
||||
atomic_rmw_op_t op_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_cas_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class umulhi_inst: public builtin_inst {
|
||||
private:
|
||||
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "umulhi"; }
|
||||
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "exp"; }
|
||||
_TRITON_DEFINE_CLONE(exp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(exp_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class cos_inst: public builtin_inst {
|
||||
private:
|
||||
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "cos"; }
|
||||
_TRITON_DEFINE_CLONE(cos_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cos_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sin_inst: public builtin_inst {
|
||||
private:
|
||||
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "sin"; }
|
||||
_TRITON_DEFINE_CLONE(sin_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sin_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "log"; }
|
||||
_TRITON_DEFINE_CLONE(log_inst)
|
||||
_TRITON_DEFINE_ACCEPT(log_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
enum DataType {
|
||||
FP8, FP16, BF16, TF32, FP32,
|
||||
INT1, INT4, INT8, INT32,
|
||||
UNKNOWN,
|
||||
};
|
||||
|
||||
private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
|
||||
public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||
|
||||
private:
|
||||
bool is_prefetched_ = false;
|
||||
bool allow_tf32_ = false;
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
//private:
|
||||
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
//public:
|
||||
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
|
||||
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<int> get_perm() const;
|
||||
_TRITON_DEFINE_CLONE(trans_inst)
|
||||
_TRITON_DEFINE_ACCEPT(trans_inst)
|
||||
|
||||
private:
|
||||
std::vector<int> perm_;
|
||||
};
|
||||
|
||||
class sqrt_inst: public builtin_inst {
|
||||
private:
|
||||
sqrt_inst(value *arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "sqrt"; }
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(sqrt_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sqrt_inst)
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
private:
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
static std::string to_str(op_t op);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
_TRITON_DEFINE_CLONE(reduce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reduce_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
_TRITON_DEFINE_CLONE(select_inst)
|
||||
_TRITON_DEFINE_ACCEPT(select_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
value* get_pred_op() { return get_operand(0); }
|
||||
value* get_if_value_op() { return get_operand(1); }
|
||||
value* get_else_value_op() { return get_operand(2); }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsics classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
class copy_to_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_to_shared"; }
|
||||
|
||||
public:
|
||||
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class copy_from_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_from_shared"; }
|
||||
|
||||
public:
|
||||
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class cvt_layout_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "cvt_layout_inst"; }
|
||||
|
||||
public:
|
||||
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cvt_layout_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "barrier"; }
|
||||
_TRITON_DEFINE_CLONE(barrier_inst)
|
||||
_TRITON_DEFINE_ACCEPT(barrier_inst)
|
||||
|
||||
public:
|
||||
static barrier_inst* create(context &ctx, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class async_wait_inst: public instruction{
|
||||
private:
|
||||
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||
|
||||
public:
|
||||
static async_wait_inst* create(context &ctx, int N,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
int get_N() { return N_; }
|
||||
void set_N(int n) { N_ = n; }
|
||||
|
||||
private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
class prefetch_s_inst : public instruction {
|
||||
std::string repr_impl() const { return "prefetch_s"; }
|
||||
_TRITON_DEFINE_CLONE(prefetch_s_inst)
|
||||
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
|
||||
|
||||
/// inc_: 0->first, 1->latch
|
||||
int inc_ = 0;
|
||||
public:
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
int get_inc() const { return inc_; }
|
||||
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
|
||||
_TRITON_DEFINE_CLONE(make_range)
|
||||
_TRITON_DEFINE_ACCEPT(make_range)
|
||||
|
||||
public:
|
||||
static make_range *create(constant_int *first, constant_int *last);
|
||||
const constant_int* get_first() const;
|
||||
const constant_int* get_last() const;
|
||||
|
||||
private:
|
||||
constant_int* first_;
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,32 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_METADATA_H_
|
||||
#define _TRITON_IR_METADATA_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Metadata */
|
||||
class metadata{
|
||||
public:
|
||||
enum kind_t{
|
||||
multiple_of,
|
||||
max_contiguous
|
||||
};
|
||||
|
||||
private:
|
||||
metadata(kind_t kind, unsigned value);
|
||||
|
||||
public:
|
||||
static metadata* get(kind_t kind, unsigned value);
|
||||
|
||||
private:
|
||||
kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,92 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_MODULE_H_
|
||||
#define _TRITON_IR_MODULE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/context.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace lang{
|
||||
|
||||
class iteration_statement;
|
||||
class compound_statement;
|
||||
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class phi_node;
|
||||
class value;
|
||||
class context;
|
||||
class function;
|
||||
class attribute;
|
||||
class function_type;
|
||||
class constant;
|
||||
class global_value;
|
||||
class alloc_const;
|
||||
|
||||
/* Module */
|
||||
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
struct current_iteration_info_t{
|
||||
lang::iteration_statement *statement;
|
||||
basic_block *block;
|
||||
};
|
||||
|
||||
private:
|
||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
// Const allocation
|
||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||
// Register global
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder &builder_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_IR_PRINT_H_
|
||||
#define _TRITON_IR_PRINT_H_
|
||||
|
||||
#include "builder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
|
||||
void print(module &mod, std::ostream& os);
|
||||
void print(function &func, std::ostream& os);
|
||||
void print(basic_block &bb, std::ostream& os);
|
||||
void print(instruction &instr, std::ostream& os);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,239 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_TYPE_H_
|
||||
#define _TRITON_IR_TYPE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
typedef std::vector<unsigned> block_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
|
||||
|
||||
public:
|
||||
enum id_t {
|
||||
// primitive types
|
||||
VoidTyID = 0, ///< type with no size
|
||||
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
|
||||
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
|
||||
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
|
||||
FP32TyID, ///< 32-bit floating point type
|
||||
FP64TyID, ///< 64-bit floating point type
|
||||
LabelTyID, ///< Labels
|
||||
MetadataTyID, ///< Metadata
|
||||
TokenTyID, ///< Token
|
||||
// derived types
|
||||
IntegerTyID, ///< Arbitrary bit width integers
|
||||
FunctionTyID, ///< Functions
|
||||
PointerTyID, ///< Pointers
|
||||
StructTyID, ///< Struct
|
||||
BlockTyID, ///< Block
|
||||
};
|
||||
|
||||
public:
|
||||
//constructors
|
||||
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
|
||||
|
||||
//destructor
|
||||
virtual ~type(){}
|
||||
|
||||
// accessors
|
||||
context &get_context() const { return ctx_; }
|
||||
id_t get_type_id() const { return id_; }
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
block_shapes_t get_block_shapes() const;
|
||||
const size_t get_tile_rank() const;
|
||||
const size_t get_tile_ranks1() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
type *get_pointer_element_ty() const;
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
bool is_integer_ty(unsigned width) const;
|
||||
bool is_floating_point_ty() const;
|
||||
bool is_sized() const ;
|
||||
|
||||
// Factory methods
|
||||
// primitive types
|
||||
static type *get_void_ty(context &ctx);
|
||||
static type *get_label_ty(context &ctx);
|
||||
// half
|
||||
static type *get_fp8_ty(context &ctx);
|
||||
static type *get_fp16_ty(context &ctx);
|
||||
static type *get_bf16_ty(context &ctx);
|
||||
static type *get_fp32_ty(context &ctx);
|
||||
static type *get_fp64_ty(context &ctx);
|
||||
// integer types
|
||||
static integer_type *get_int1_ty(context &ctx);
|
||||
static integer_type *get_int8_ty(context &ctx);
|
||||
static integer_type *get_int16_ty(context &ctx);
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
std::string res = get_tile_element_ty()->repr();
|
||||
auto shapes = get_block_shapes();
|
||||
res += "<";
|
||||
for(size_t i = 0; i < shapes.size(); i++){
|
||||
if(i > 0)
|
||||
res += ", ";
|
||||
res += std::to_string(shapes[i]);
|
||||
}
|
||||
res+= ">";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case FP8TyID: return "fp8";
|
||||
case FP16TyID: return "f16";
|
||||
case FP32TyID: return "f32";
|
||||
case FP64TyID: return "f64";
|
||||
case BF16TyID: return "bf16";
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case BlockTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
|
||||
};
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
|
||||
protected:
|
||||
contained_tys_vec_t contained_tys_;
|
||||
};
|
||||
|
||||
class integer_type: public type {
|
||||
friend class context_impl;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
protected:
|
||||
using type::type;
|
||||
|
||||
public:
|
||||
bool index_valid(value *idx) const;
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class block_type: public composite_type {
|
||||
private:
|
||||
block_type(type *ty, const block_shapes_t &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const block_shapes_t& get_shapes() const { return shapes_; }
|
||||
unsigned get_num_elements() const;
|
||||
unsigned get_bitwidth() const;
|
||||
|
||||
// factory methods
|
||||
static block_type* get(type *ty, const block_shapes_t &shapes);
|
||||
static block_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
private:
|
||||
block_shapes_t shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
private:
|
||||
pointer_type(type *ty, unsigned address_space);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_address_space() const { return address_space_; }
|
||||
type *get_element_ty() const { return contained_tys_[0]; }
|
||||
// factory methods
|
||||
static pointer_type* get(type *ty, unsigned address_space);
|
||||
|
||||
private:
|
||||
unsigned address_space_;
|
||||
};
|
||||
|
||||
class function_type: public type {
|
||||
private:
|
||||
function_type(type *ret_ty, const std::vector<type *> ¶m_tys);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_num_params() const { return contained_tys_.size() - 1; }
|
||||
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
|
||||
const_ty_iterator params_end() const { return contained_tys_.end(); }
|
||||
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
|
||||
ty_iterator params_end() { return contained_tys_.end(); }
|
||||
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
||||
type* get_return_ty() const { return contained_tys_.at(0); }
|
||||
// factory methods
|
||||
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CFG_H_
|
||||
#define _TRITON_IR_CFG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class value;
|
||||
|
||||
class cfg {
|
||||
public:
|
||||
static std::vector<basic_block *> post_order(function* fn);
|
||||
static std::vector<basic_block *> reverse_post_order(function* fn);
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,95 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VALUE_H_
|
||||
#define _TRITON_IR_VALUE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class use;
|
||||
class user;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class value {
|
||||
public:
|
||||
typedef std::set<user*> users_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
value(type *ty, const std::string &name = "");
|
||||
virtual ~value(){ }
|
||||
// uses
|
||||
void add_use(user* arg);
|
||||
users_t::iterator erase_use(user* arg);
|
||||
const std::set<user*> &get_users() { return users_; }
|
||||
void replace_all_uses_with(value *target);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
const std::string &get_name() const { return name_; }
|
||||
bool has_name() const { return !name_.empty(); }
|
||||
type* get_type() const { return ty_; }
|
||||
// visitor
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
||||
protected:
|
||||
type *ty_;
|
||||
users_t users_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// user class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class user: public value{
|
||||
public:
|
||||
typedef std::vector<value*> ops_t;
|
||||
typedef ops_t::iterator op_iterator;
|
||||
typedef ops_t::const_iterator const_op_iterator;
|
||||
|
||||
protected:
|
||||
void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; }
|
||||
void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; }
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
user(type *ty, unsigned num_ops, const std::string &name = "")
|
||||
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
|
||||
}
|
||||
virtual ~user() { }
|
||||
|
||||
// Operands
|
||||
const ops_t& ops() { return ops_; }
|
||||
const ops_t& ops() const { return ops_; }
|
||||
op_iterator op_begin() { return ops_.begin(); }
|
||||
op_iterator op_end() { return ops_.end(); }
|
||||
void set_operand(unsigned i, value *x);
|
||||
value *get_operand(unsigned i) const;
|
||||
unsigned get_num_operands() const ;
|
||||
unsigned get_num_hidden() const;
|
||||
|
||||
// Utils
|
||||
value::users_t::iterator replace_uses_of_with(value *before, value *after);
|
||||
|
||||
|
||||
private:
|
||||
ops_t ops_;
|
||||
unsigned num_ops_;
|
||||
unsigned num_hidden_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,170 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VISITOR_H_
|
||||
#define _TRITON_IR_VISITOR_H_
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class value;
|
||||
|
||||
class instruction;
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
class s_ext_inst;
|
||||
class fp_trunc_inst;
|
||||
class fp_ext_inst;
|
||||
class ui_to_fp_inst;
|
||||
class si_to_fp_inst;
|
||||
class fp_to_ui_inst;
|
||||
class fp_to_si_inst;
|
||||
class ptr_to_int_inst;
|
||||
class int_to_ptr_inst;
|
||||
class bit_cast_inst;
|
||||
class addr_space_cast_inst;
|
||||
|
||||
class return_inst;
|
||||
class cond_branch_inst;
|
||||
class uncond_branch_inst;
|
||||
|
||||
|
||||
class unmasked_load_inst;
|
||||
class masked_load_inst;
|
||||
class unmasked_store_inst;
|
||||
class masked_store_inst;
|
||||
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class cat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class umulhi_inst;
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class cvt_layout_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class masked_load_async_inst;
|
||||
class barrier_inst;
|
||||
class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class function;
|
||||
|
||||
class basic_block;
|
||||
|
||||
class argument;
|
||||
|
||||
class visitor {
|
||||
public:
|
||||
virtual ~visitor() {}
|
||||
|
||||
virtual void visit_value(ir::value*);
|
||||
|
||||
virtual void visit_basic_block(basic_block*) = 0;
|
||||
virtual void visit_argument(argument*) = 0;
|
||||
virtual void visit_phi_node(phi_node*) = 0;
|
||||
virtual void visit_binary_operator(binary_operator*) = 0;
|
||||
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
|
||||
virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0;
|
||||
virtual void visit_masked_load_inst(masked_load_inst*) = 0;
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user