Init commit
This commit is contained in:
@@ -152,6 +152,29 @@ else()
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
endif()
|
||||
|
||||
# MLIR
|
||||
find_package(MLIR 14 REQUIRED CONFIG)
|
||||
include(TableGen) # required by AddMLIR
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
include(HandleLLVMOptions) # human-friendly error message
|
||||
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
|
||||
add_subdirectory(lib)
|
||||
# lib
|
||||
add_library(triton)
|
||||
# add_subdirectory(ir)
|
||||
target_link_libraries(triton
|
||||
PUBLIC
|
||||
TRITONIR
|
||||
# # optimizations
|
||||
# MLIRPass
|
||||
# MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
if(WIN32)
|
||||
|
1
include/CMakeLists.txt
Normal file
1
include/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(ir)
|
8
include/triton/ir/CMakeLists.txt
Normal file
8
include/triton/ir/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(TritonTableGen)
|
18
include/triton/ir/Dialect.h
Normal file
18
include/triton/ir/Dialect.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRITON_IR_DIALECT_H_
|
||||
#define TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
|
||||
#include "triton/Dialect.h.inc"
|
||||
|
||||
#include "triton/OpsEnums.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
39
include/triton/ir/TritonDialect.td
Normal file
39
include/triton/ir/TritonDialect.td
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_DIALECT
|
||||
#define TRITON_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Triton_Dialect : Dialect {
|
||||
let name = "triton";
|
||||
|
||||
let cppNamespace = "::mlir::triton";
|
||||
|
||||
let summary = "The Triton IR in MLIR";
|
||||
|
||||
let description = [{
|
||||
Triton Dialect.
|
||||
|
||||
Dependent Dialects:
|
||||
* Arithmetic:
|
||||
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
||||
* Tensor:
|
||||
* reshape (?)
|
||||
* ControlFlow:
|
||||
* bf, cond_bf
|
||||
* Func:
|
||||
* call, return
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"tensor::TensorDialect",
|
||||
"cf::ControlFlowDialect",
|
||||
"func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TRITON_DIALECT
|
235
include/triton/ir/TritonOps.td
Normal file
235
include/triton/ir/TritonOps.td
Normal file
@@ -0,0 +1,235 @@
|
||||
#ifndef Triton_OPS
|
||||
#define Triton_OPS
|
||||
|
||||
include "TritonDialect.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
// FloatType
|
||||
def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
||||
/*descr*/"8bit float",
|
||||
/*cppClassName*/"::mlir::triton::Float8Type">;
|
||||
|
||||
def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
||||
/*descr*/"8bit bfloat",
|
||||
/*cppClassName*/"::mlir::triton::BFloat8Type">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
|
||||
// IntegerType
|
||||
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
|
||||
// PointerType
|
||||
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
||||
def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
||||
def TT_PtrTensor : TensorOf<[TT_AnyPtr]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||
TT_AnyPtr, TT_PtrTensor]>;
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, traits>;
|
||||
|
||||
//
|
||||
// CastOps
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins I64Tensor:$from);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$from);
|
||||
|
||||
let results = (outs I64Tensor:$result);
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
Floating point casting for custom types (F8, BF8).
|
||||
|
||||
F8 <-> BF8, FP16, FP32
|
||||
BF8 <-> F8, FP16, FP32
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatTensor:$from);
|
||||
|
||||
let results = (outs TT_FloatTensor:$result);
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask)>
|
||||
];
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
// def TT_CatOp : TT_Op<"cat", []>;
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
//
|
||||
// builtin Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "dot";
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"SUM", 1, /*str*/"sum">,
|
||||
I32EnumAttrCase<"MAX", 2, "max">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"XOR_SUM", 4, "xor_sum">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis);
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
[
|
||||
I32EnumAttrCase<"AND", 1, "and">,
|
||||
I32EnumAttrCase<"OR", 2, "or">,
|
||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||
I32EnumAttrCase<"ADD", 4, "add">,
|
||||
I32EnumAttrCase<"MAX", 5, "max">,
|
||||
I32EnumAttrCase<"MIN", 6, "min">,
|
||||
I32EnumAttrCase<"UMAX", 7, "umax">,
|
||||
I32EnumAttrCase<"UMIN", 8, "umin">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
|
||||
|
||||
return old value at $ptr
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
||||
TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
compare $cmp with data $old at location $ptr,
|
||||
|
||||
if $old == $cmp, store $val to $ptr,
|
||||
|
||||
else store $old to $ptr,
|
||||
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
//
|
||||
// Intrinsics
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let summary = "make range";
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntegerTensor:$result);
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
46
include/triton/ir/Types.h
Normal file
46
include/triton/ir/Types.h
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef TRITON_IR_TYPES_H_
|
||||
#define TRITON_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
namespace detail {
|
||||
struct PointerTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
// TODO: Should be base class be FloatType?
|
||||
class Float8Type : public Type::TypeBase<Float8Type, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static Float8Type get(MLIRContext *context);
|
||||
};
|
||||
|
||||
class BFloat8Type : public Type::TypeBase<BFloat8Type, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static BFloat8Type get(MLIRContext *context);
|
||||
};
|
||||
|
||||
class PointerType : public Type::TypeBase<PointerType, Type,
|
||||
detail::PointerTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static PointerType get(Type pointeeType);
|
||||
|
||||
static PointerType get(Type pointeeType, unsigned addressSpace);
|
||||
|
||||
Type getPointeeType() const;
|
||||
|
||||
unsigned getAddressSpace() const;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
@@ -1,88 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
||||
#define _TRITON_IR_BASIC_BLOCK_H_
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include "value.h"
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class function;
|
||||
class instruction;
|
||||
|
||||
/* Basic Block */
|
||||
class basic_block: public value{
|
||||
public:
|
||||
// instruction iterator types
|
||||
typedef std::list<instruction*> inst_list_t;
|
||||
typedef inst_list_t::iterator iterator;
|
||||
typedef inst_list_t::const_iterator const_iterator;
|
||||
typedef inst_list_t::reverse_iterator reverse_iterator;
|
||||
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
basic_block(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
function* get_parent() { return parent_; }
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// get iterator to first instruction that is not a phi
|
||||
iterator get_first_non_phi();
|
||||
|
||||
// get instruction list
|
||||
inst_list_t &get_inst_list() { return inst_list_; }
|
||||
const inst_list_t &get_inst_list() const { return inst_list_; }
|
||||
void erase(instruction *i) { inst_list_.remove(i); }
|
||||
|
||||
// instruction iterator functions
|
||||
inline iterator begin() { return inst_list_.begin(); }
|
||||
inline const_iterator begin() const { return inst_list_.begin(); }
|
||||
inline iterator end () { return inst_list_.end(); }
|
||||
inline const_iterator end () const { return inst_list_.end(); }
|
||||
|
||||
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
|
||||
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
|
||||
inline reverse_iterator rend () { return inst_list_.rend(); }
|
||||
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
|
||||
|
||||
inline size_t size() const { return inst_list_.size(); }
|
||||
inline bool empty() const { return inst_list_.empty(); }
|
||||
inline const instruction &front() const { return *inst_list_.front(); }
|
||||
inline instruction &front() { return *inst_list_.front(); }
|
||||
inline const instruction &back() const { return *inst_list_.back(); }
|
||||
inline instruction &back() { return *inst_list_.back(); }
|
||||
|
||||
// predecessors
|
||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
||||
void add_predecessor(basic_block* pred);
|
||||
|
||||
// factory functions
|
||||
static basic_block* create(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_basic_block(this); }
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
std::string name_;
|
||||
function *parent_;
|
||||
std::vector<basic_block*> preds_;
|
||||
std::vector<basic_block*> succs_;
|
||||
inst_list_t inst_list_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,191 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BUILDER_H_
|
||||
#define _TRITON_IR_BUILDER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "instructions.h"
|
||||
#include "basic_block.h"
|
||||
#include "type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class value;
|
||||
class type;
|
||||
class constant_int;
|
||||
class instruction;
|
||||
class context;
|
||||
class phi_node;
|
||||
|
||||
/* Builder */
|
||||
class builder{
|
||||
typedef basic_block::iterator iterator;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
void set_insert_point_after(instruction* i);
|
||||
void set_insert_point(basic_block* block);
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
type *get_int8_ty();
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
template<typename InstTy>
|
||||
InstTy* insert(InstTy *inst){
|
||||
assert(block_);
|
||||
block_->get_inst_list().insert(insert_point_, inst);
|
||||
inst->set_parent(block_);
|
||||
// for(ir::value* op: inst->ops())
|
||||
// op->add_use(inst);
|
||||
return inst;
|
||||
}
|
||||
// terminator instructions
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
value* create_fp_to_si(value *src, type *dst_ty);
|
||||
value* create_fp_to_ui(value *src, type *dst_ty);
|
||||
value* create_fp_ext(value *src, type *dst_ty);
|
||||
value* create_fp_trunc(value *src, type *dst_ty);
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||
value *create_downcast(value *arg);
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||
// Binary instructions
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs);
|
||||
value *create_fdiv(value *lhs, value *rhs);
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
// GEP
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
|
||||
// Comparison (int)
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_icmpSLE(value *lhs, value *rhs);
|
||||
value *create_icmpSLT(value *lhs, value *rhs);
|
||||
value *create_icmpSGE(value *lhs, value *rhs);
|
||||
value *create_icmpSGT(value *lhs, value *rhs);
|
||||
value *create_icmpULE(value *lhs, value *rhs);
|
||||
value *create_icmpULT(value *lhs, value *rhs);
|
||||
value *create_icmpUGE(value *lhs, value *rhs);
|
||||
value *create_icmpUGT(value *lhs, value *rhs);
|
||||
value *create_icmpEQ(value *lhs, value *rhs);
|
||||
value *create_icmpNE(value *lhs, value *rhs);
|
||||
// Comparison (float)
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_fcmpOLT(value *lhs, value *rhs);
|
||||
value *create_fcmpOGT(value *lhs, value *rhs);
|
||||
value *create_fcmpOLE(value *lhs, value *rhs);
|
||||
value *create_fcmpOGE(value *lhs, value *rhs);
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpONE(value *lhs, value *rhs);
|
||||
value *create_fcmpULT(value *lhs, value *rhs);
|
||||
value *create_fcmpUGT(value *lhs, value *rhs);
|
||||
value *create_fcmpULE(value *lhs, value *rhs);
|
||||
value *create_fcmpUGE(value *lhs, value *rhs);
|
||||
value *create_fcmpUEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpUNE(value *lhs, value *rhs);
|
||||
// Logical
|
||||
value *create_and(value *lhs, value *rhs);
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
value *create_prefetch_s(value *arg, int inc);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
basic_block *block_;
|
||||
iterator insert_point_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONSTANT_H_
|
||||
#define _TRITON_IR_CONSTANT_H_
|
||||
|
||||
#include "enums.h"
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context;
|
||||
|
||||
/* Constant */
|
||||
class constant: public user{
|
||||
protected:
|
||||
using user::user;
|
||||
|
||||
public:
|
||||
static constant* get_all_ones_value(type *ty);
|
||||
static constant* get_null_value(type *ty);
|
||||
virtual std::string repr() const = 0;
|
||||
};
|
||||
|
||||
/* Undef value */
|
||||
class undef_value: public constant{
|
||||
private:
|
||||
undef_value(type *ty);
|
||||
|
||||
public:
|
||||
static undef_value* get(type* ty);
|
||||
std::string repr() const { return "undef"; }
|
||||
void accept(visitor* vst) { vst->visit_undef_value(this); }
|
||||
};
|
||||
|
||||
|
||||
/* Constant int */
|
||||
class constant_int: public constant{
|
||||
protected:
|
||||
constant_int(type *ty, uint64_t value);
|
||||
|
||||
public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_int(this); }
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Constant fp */
|
||||
class constant_fp: public constant{
|
||||
constant_fp(type *ty, double value);
|
||||
|
||||
public:
|
||||
double get_value() { return value_; }
|
||||
static constant* get_negative_zero(type *ty);
|
||||
static constant* get_zero_value_for_negation(type *ty);
|
||||
static constant* get(context &ctx, double v);
|
||||
static constant* get(type *ty, double v);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_fp(this); }
|
||||
|
||||
private:
|
||||
double value_;
|
||||
};
|
||||
|
||||
|
||||
/* Global Value */
|
||||
class global_value: public constant {
|
||||
public:
|
||||
enum linkage_types_t {
|
||||
external
|
||||
};
|
||||
|
||||
public:
|
||||
global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
private:
|
||||
linkage_types_t linkage_;
|
||||
};
|
||||
|
||||
/* global object */
|
||||
class global_object: public global_value {
|
||||
public:
|
||||
global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space = 0);
|
||||
std::string repr() const { return get_name(); }
|
||||
};
|
||||
|
||||
/* global variable */
|
||||
class alloc_const: public global_object {
|
||||
public:
|
||||
alloc_const(type *ty, constant_int *size,
|
||||
const std::string &name = "");
|
||||
std::string repr() const { return get_name(); }
|
||||
void accept(visitor* vst) { vst->visit_alloc_const(this); }
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_H_
|
||||
#define _TRITON_IR_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
/* Context */
|
||||
class context {
|
||||
public:
|
||||
context();
|
||||
context(const context&) = delete;
|
||||
context& operator=(const context&) = delete;
|
||||
|
||||
public:
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,46 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
|
||||
#define _TRITON_IR_CONTEXT_IMPL_H_
|
||||
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
public:
|
||||
// constructors
|
||||
context_impl(context &ctx);
|
||||
|
||||
public:
|
||||
// non-numeric types
|
||||
type void_ty, label_ty;
|
||||
// floating point types
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
||||
|
||||
// Int constants
|
||||
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
||||
// Float constants
|
||||
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,175 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_ENUMS_H_
|
||||
#define _TRITON_IR_ENUMS_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
enum binary_op_t: unsigned int{
|
||||
Add,
|
||||
FAdd,
|
||||
Sub,
|
||||
FSub,
|
||||
Mul,
|
||||
FMul,
|
||||
UDiv,
|
||||
SDiv,
|
||||
FDiv,
|
||||
URem,
|
||||
SRem,
|
||||
FRem,
|
||||
Shl,
|
||||
LShr,
|
||||
AShr,
|
||||
And,
|
||||
Or,
|
||||
Xor
|
||||
};
|
||||
|
||||
enum class atomic_rmw_op_t: unsigned int{
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Add,
|
||||
Max,
|
||||
Min,
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
Xchg,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
SExt,
|
||||
FPTrunc,
|
||||
FPExt,
|
||||
UIToFP,
|
||||
SIToFP,
|
||||
FPToUI,
|
||||
FPToSI,
|
||||
PtrToInt,
|
||||
IntToPtr,
|
||||
BitCast,
|
||||
AddrSpaceCast
|
||||
};
|
||||
|
||||
enum cmp_pred_t: unsigned int {
|
||||
FIRST_FCMP_PREDICATE,
|
||||
FCMP_FALSE,
|
||||
FCMP_OEQ,
|
||||
FCMP_OGT,
|
||||
FCMP_OGE,
|
||||
FCMP_OLT,
|
||||
FCMP_OLE,
|
||||
FCMP_ONE,
|
||||
FCMP_ORD,
|
||||
FCMP_UNO,
|
||||
FCMP_UEQ,
|
||||
FCMP_UGT,
|
||||
FCMP_UGE,
|
||||
FCMP_ULT,
|
||||
FCMP_ULE,
|
||||
FCMP_UNE,
|
||||
FCMP_TRUE,
|
||||
LAST_FCMP_PREDICATE,
|
||||
FIRST_ICMP_PREDICATE,
|
||||
ICMP_EQ,
|
||||
ICMP_NE,
|
||||
ICMP_UGT,
|
||||
ICMP_UGE,
|
||||
ICMP_ULT,
|
||||
ICMP_ULE,
|
||||
ICMP_SGT,
|
||||
ICMP_SGE,
|
||||
ICMP_SLT,
|
||||
ICMP_SLE,
|
||||
LAST_ICMP_PREDICATE
|
||||
};
|
||||
|
||||
enum value_id_t: unsigned {
|
||||
/* ------------ *
|
||||
INSTRUCTIONS
|
||||
* ------------ */
|
||||
INST_BEGIN,
|
||||
// phi
|
||||
INST_PHI,
|
||||
// arithmetic
|
||||
INST_BINOP,
|
||||
INST_GETELEMENTPTR,
|
||||
INST_SELECT,
|
||||
INST_SQRT,
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
INST_CAST_SEXT,
|
||||
INST_CAST_FP_TRUNC,
|
||||
INST_CAST_FP_EXT,
|
||||
INST_CAST_UI_TO_FP,
|
||||
INST_CAST_SI_TO_FP,
|
||||
INST_CAST_FP_TO_UI,
|
||||
INST_CAST_FP_TO_SI,
|
||||
INST_CAST_PTR_TO_INT,
|
||||
INST_CAST_INT_TO_PTR,
|
||||
INST_CAST_BIT_CAST,
|
||||
INST_CAST_ADDR_SPACE_CAST,
|
||||
// terminators
|
||||
INST_RETURN,
|
||||
INST_COND_BRANCH,
|
||||
INST_UNCOND_BRANCH,
|
||||
// io
|
||||
INST_UNMASKED_LOAD,
|
||||
INST_MASKED_LOAD,
|
||||
INST_MASKED_LOAD_ASYNC,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_CAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
INST_GET_PROGRAM_ID,
|
||||
INST_GET_NUM_PROGRAMS,
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_UMULHI,
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
INST_DOT,
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_CVT_LAYOUT,
|
||||
INST_CVT_SCANLINE,
|
||||
INST_DECOALESCE,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_ASYNC_WAIT,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,142 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_FUNCTION_H_
|
||||
#define _TRITON_IR_FUNCTION_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "value.h"
|
||||
#include "constant.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class function;
|
||||
class function_type;
|
||||
class module;
|
||||
class basic_block;
|
||||
|
||||
/* Argument */
|
||||
class argument: public value{
|
||||
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
|
||||
|
||||
public:
|
||||
static argument* create(type *ty, const std::string &name,
|
||||
function *parent = nullptr, unsigned arg_no = 0);
|
||||
function* get_parent() const;
|
||||
unsigned get_arg_no() const;
|
||||
|
||||
void accept(visitor *v);
|
||||
|
||||
private:
|
||||
function *parent_;
|
||||
unsigned arg_no_;
|
||||
};
|
||||
|
||||
/* Attribute */
|
||||
enum attribute_kind_t {
|
||||
readonly = 0,
|
||||
writeonly,
|
||||
noalias,
|
||||
aligned,
|
||||
multiple_of,
|
||||
retune,
|
||||
not_implemented
|
||||
};
|
||||
|
||||
class attribute {
|
||||
public:
|
||||
attribute(attribute_kind_t kind, unsigned value = 0):
|
||||
kind_(kind), value_(value){}
|
||||
|
||||
bool operator<(const attribute& other) const {
|
||||
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
|
||||
}
|
||||
|
||||
attribute_kind_t get_kind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
unsigned get_value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
bool is_llvm_attr() const {
|
||||
return kind_ != multiple_of;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(kind_){
|
||||
case readonly: return ".readonly";
|
||||
case writeonly: return ".writeonly";
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
|
||||
case retune: return ".retunr";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
private:
|
||||
attribute_kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
/* Function */
|
||||
class function: public global_object{
|
||||
typedef std::vector<argument*> args_t;
|
||||
typedef args_t::iterator arg_iterator;
|
||||
typedef args_t::const_iterator const_arg_iterator;
|
||||
|
||||
typedef std::vector<basic_block*> blocks_t;
|
||||
typedef blocks_t::iterator block_iterator;
|
||||
typedef blocks_t::const_iterator const_block_iterator;
|
||||
|
||||
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
|
||||
|
||||
private:
|
||||
function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name = "", module *parent = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const args_t &args() const { return args_; }
|
||||
function_type* get_fn_type() { return fn_ty_; }
|
||||
const function_type* get_fn_type() const { return fn_ty_; }
|
||||
module *get_parent() { return parent_; }
|
||||
const module *get_parent() const { return parent_; }
|
||||
|
||||
// factory methods
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_function(this); }
|
||||
|
||||
private:
|
||||
module *parent_;
|
||||
bool init_;
|
||||
function_type *fn_ty_;
|
||||
args_t args_;
|
||||
blocks_t blocks_;
|
||||
attr_map_t attrs_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,978 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_INSTRUCTIONS_H_
|
||||
#define _TRITON_IR_INSTRUCTIONS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
|
||||
#define _TRITON_DEFINE_CLONE(name) \
|
||||
ir::instruction* clone_impl() const { return new name(*this); }
|
||||
|
||||
#define _TRITON_DEFINE_ACCEPT(name) \
|
||||
void accept(visitor* v) { v->visit_ ## name (this); }
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class constant_int;
|
||||
class constant;
|
||||
class make_range;
|
||||
class basic_block;
|
||||
class context;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class result_reference;
|
||||
|
||||
|
||||
class instruction: public user{
|
||||
public:
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
private:
|
||||
virtual ir::instruction* clone_impl() const = 0;
|
||||
|
||||
protected:
|
||||
// constructors
|
||||
instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// parent
|
||||
void set_parent(basic_block *block) { parent_ = block; }
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
std::string repr() const { return repr_impl(); }
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
// for(auto it = op_begin(); it != op_end(); it++)
|
||||
// (*it)->add_use(res);
|
||||
res->parent_ = nullptr;
|
||||
res->users_.clear();
|
||||
return res;
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
value_id_t id_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class phi_node: public instruction {
|
||||
private:
|
||||
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "phi"; }
|
||||
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
void set_incoming_block(unsigned i, basic_block *block);
|
||||
value *get_value_for_block(basic_block *block);
|
||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||
unsigned get_num_incoming() { return get_num_operands(); }
|
||||
void add_incoming(value *v, basic_block *block);
|
||||
|
||||
// Type
|
||||
void set_type(type *ty) { ty_ = ty; }
|
||||
|
||||
// Factory methods
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(phi_node)
|
||||
_TRITON_DEFINE_ACCEPT(phi_node)
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
std::vector<basic_block*> blocks_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
// Constructors
|
||||
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// Get operand
|
||||
binary_op_t get_op() const { return op_; }
|
||||
|
||||
// Bool
|
||||
bool is_terminator() const;
|
||||
bool is_binary_op() const;
|
||||
bool is_int_div_rem() const;
|
||||
bool is_shift() const;
|
||||
bool is_cast() const;
|
||||
bool is_int_mult() const;
|
||||
bool is_int_add_sub() const;
|
||||
bool is_int_div() const;
|
||||
bool is_int_rem() const;
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Approx
|
||||
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
|
||||
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
|
||||
// Factory methods
|
||||
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
_TRITON_DEFINE_ACCEPT(binary_operator)
|
||||
|
||||
public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
|
||||
bool fdiv_ieee_rnd_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cmp_inst: public instruction{
|
||||
public:
|
||||
typedef cmp_pred_t pred_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(cmp_pred_t pred);
|
||||
static bool is_int_predicate(cmp_pred_t pred);
|
||||
static type* make_cmp_result_type(type *ty);
|
||||
|
||||
public:
|
||||
cmp_pred_t get_pred() const { return pred_; }
|
||||
|
||||
private:
|
||||
cmp_pred_t pred_;
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst {
|
||||
icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(icmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(icmp_inst)
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst {
|
||||
fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(fcmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(fcmp_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class unary_inst: public instruction {
|
||||
protected:
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cast_inst: public unary_inst{
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, id, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(cast_op_t op, value *arg, type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
cast_op_t get_op() const { return op_; }
|
||||
|
||||
// factory methods
|
||||
static cast_inst *create(cast_op_t op, value *arg, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_ACCEPT(cast_inst)
|
||||
|
||||
private:
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name) \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, id, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class terminator_inst: public instruction{
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
// return instruction
|
||||
class return_inst: public terminator_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "ret"; }
|
||||
return_inst(context &ctx, value *ret_val, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_return_value()
|
||||
{ return get_num_operands() ? get_operand(0) : nullptr; }
|
||||
|
||||
unsigned get_num_successors() const { return 0; }
|
||||
|
||||
// factory methods
|
||||
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(return_inst)
|
||||
_TRITON_DEFINE_ACCEPT(return_inst)
|
||||
};
|
||||
|
||||
// base branch instruction
|
||||
class branch_inst: public terminator_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "br"; }
|
||||
|
||||
protected:
|
||||
using terminator_inst::terminator_inst;
|
||||
|
||||
public:
|
||||
static branch_inst* create(basic_block *dest,
|
||||
instruction *next = nullptr);
|
||||
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// conditional branch
|
||||
class cond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
|
||||
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
|
||||
value *get_cond() { return get_operand(2); }
|
||||
_TRITON_DEFINE_CLONE(cond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
|
||||
};
|
||||
|
||||
// unconditional branch
|
||||
class uncond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
uncond_branch_inst(basic_block *dst, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
_TRITON_DEFINE_CLONE(uncond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class getelementptr_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "getelementptr"; }
|
||||
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
|
||||
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
|
||||
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
type *get_source_elt_ty() { return source_elt_ty; }
|
||||
op_iterator idx_begin() { return op_begin() + 1; }
|
||||
op_iterator idx_end() { return op_end(); }
|
||||
value *get_pointer_operand() { return *op_begin(); }
|
||||
|
||||
// factory methods
|
||||
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(getelementptr_inst)
|
||||
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
|
||||
|
||||
private:
|
||||
type *source_elt_ty;
|
||||
type *res_elt_ty;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
};
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
public:
|
||||
enum CACHE_MODIFIER : uint32_t {
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
return is_volatile_ ? ".volatile" : "";
|
||||
}
|
||||
bool is_volatile_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
|
||||
};
|
||||
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||
};
|
||||
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||
};
|
||||
|
||||
|
||||
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
};
|
||||
|
||||
// unmasked_store
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
class cat_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cat"; }
|
||||
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||
};
|
||||
|
||||
// retile
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
// reshape
|
||||
|
||||
class reshape_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "reshape"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reshape_inst)
|
||||
};
|
||||
|
||||
// splat
|
||||
|
||||
class splat_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "splat"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(splat_inst)
|
||||
};
|
||||
|
||||
// broadcast
|
||||
|
||||
class broadcast_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "broadcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(broadcast_inst)
|
||||
};
|
||||
|
||||
|
||||
// downcast
|
||||
|
||||
class downcast_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "downcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(downcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(downcast_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class builtin_inst: public instruction{
|
||||
protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_program_id_inst: public builtin_inst {
|
||||
private:
|
||||
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_program_id_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_num_programs_inst: public builtin_inst {
|
||||
private:
|
||||
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_programs_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_rmw"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_rmw_op_t get_op() { return op_; }
|
||||
|
||||
private:
|
||||
atomic_rmw_op_t op_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_cas_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class umulhi_inst: public builtin_inst {
|
||||
private:
|
||||
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "umulhi"; }
|
||||
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "exp"; }
|
||||
_TRITON_DEFINE_CLONE(exp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(exp_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class cos_inst: public builtin_inst {
|
||||
private:
|
||||
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "cos"; }
|
||||
_TRITON_DEFINE_CLONE(cos_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cos_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sin_inst: public builtin_inst {
|
||||
private:
|
||||
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "sin"; }
|
||||
_TRITON_DEFINE_CLONE(sin_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sin_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "log"; }
|
||||
_TRITON_DEFINE_CLONE(log_inst)
|
||||
_TRITON_DEFINE_ACCEPT(log_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
enum DataType {
|
||||
FP8, FP16, BF16, TF32, FP32,
|
||||
INT1, INT4, INT8, INT32,
|
||||
UNKNOWN,
|
||||
};
|
||||
|
||||
private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
|
||||
public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||
|
||||
private:
|
||||
bool is_prefetched_ = false;
|
||||
bool allow_tf32_ = false;
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
//private:
|
||||
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
//public:
|
||||
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
|
||||
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<int> get_perm() const;
|
||||
_TRITON_DEFINE_CLONE(trans_inst)
|
||||
_TRITON_DEFINE_ACCEPT(trans_inst)
|
||||
|
||||
private:
|
||||
std::vector<int> perm_;
|
||||
};
|
||||
|
||||
class sqrt_inst: public builtin_inst {
|
||||
private:
|
||||
sqrt_inst(value *arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "sqrt"; }
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(sqrt_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sqrt_inst)
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
private:
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
static std::string to_str(op_t op);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
_TRITON_DEFINE_CLONE(reduce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reduce_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
_TRITON_DEFINE_CLONE(select_inst)
|
||||
_TRITON_DEFINE_ACCEPT(select_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
value* get_pred_op() { return get_operand(0); }
|
||||
value* get_if_value_op() { return get_operand(1); }
|
||||
value* get_else_value_op() { return get_operand(2); }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsics classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
class copy_to_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_to_shared"; }
|
||||
|
||||
public:
|
||||
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class copy_from_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_from_shared"; }
|
||||
|
||||
public:
|
||||
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class cvt_layout_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "cvt_layout_inst"; }
|
||||
|
||||
public:
|
||||
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cvt_layout_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "barrier"; }
|
||||
_TRITON_DEFINE_CLONE(barrier_inst)
|
||||
_TRITON_DEFINE_ACCEPT(barrier_inst)
|
||||
|
||||
public:
|
||||
static barrier_inst* create(context &ctx, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class async_wait_inst: public instruction{
|
||||
private:
|
||||
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||
|
||||
public:
|
||||
static async_wait_inst* create(context &ctx, int N,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
int get_N() { return N_; }
|
||||
void set_N(int n) { N_ = n; }
|
||||
|
||||
private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
class prefetch_s_inst : public instruction {
|
||||
std::string repr_impl() const { return "prefetch_s"; }
|
||||
_TRITON_DEFINE_CLONE(prefetch_s_inst)
|
||||
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
|
||||
|
||||
/// inc_: 0->first, 1->latch
|
||||
int inc_ = 0;
|
||||
public:
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
int get_inc() const { return inc_; }
|
||||
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
|
||||
_TRITON_DEFINE_CLONE(make_range)
|
||||
_TRITON_DEFINE_ACCEPT(make_range)
|
||||
|
||||
public:
|
||||
static make_range *create(constant_int *first, constant_int *last);
|
||||
const constant_int* get_first() const;
|
||||
const constant_int* get_last() const;
|
||||
|
||||
private:
|
||||
constant_int* first_;
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,32 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_METADATA_H_
|
||||
#define _TRITON_IR_METADATA_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Metadata */
|
||||
class metadata{
|
||||
public:
|
||||
enum kind_t{
|
||||
multiple_of,
|
||||
max_contiguous
|
||||
};
|
||||
|
||||
private:
|
||||
metadata(kind_t kind, unsigned value);
|
||||
|
||||
public:
|
||||
static metadata* get(kind_t kind, unsigned value);
|
||||
|
||||
private:
|
||||
kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,92 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_MODULE_H_
|
||||
#define _TRITON_IR_MODULE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/context.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace lang{
|
||||
|
||||
class iteration_statement;
|
||||
class compound_statement;
|
||||
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class phi_node;
|
||||
class value;
|
||||
class context;
|
||||
class function;
|
||||
class attribute;
|
||||
class function_type;
|
||||
class constant;
|
||||
class global_value;
|
||||
class alloc_const;
|
||||
|
||||
/* Module */
|
||||
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
struct current_iteration_info_t{
|
||||
lang::iteration_statement *statement;
|
||||
basic_block *block;
|
||||
};
|
||||
|
||||
private:
|
||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
// Const allocation
|
||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||
// Register global
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder &builder_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_IR_PRINT_H_
|
||||
#define _TRITON_IR_PRINT_H_
|
||||
|
||||
#include "builder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
|
||||
void print(module &mod, std::ostream& os);
|
||||
void print(function &func, std::ostream& os);
|
||||
void print(basic_block &bb, std::ostream& os);
|
||||
void print(instruction &instr, std::ostream& os);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,239 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_TYPE_H_
|
||||
#define _TRITON_IR_TYPE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
typedef std::vector<unsigned> block_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
|
||||
|
||||
public:
|
||||
enum id_t {
|
||||
// primitive types
|
||||
VoidTyID = 0, ///< type with no size
|
||||
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
|
||||
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
|
||||
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
|
||||
FP32TyID, ///< 32-bit floating point type
|
||||
FP64TyID, ///< 64-bit floating point type
|
||||
LabelTyID, ///< Labels
|
||||
MetadataTyID, ///< Metadata
|
||||
TokenTyID, ///< Token
|
||||
// derived types
|
||||
IntegerTyID, ///< Arbitrary bit width integers
|
||||
FunctionTyID, ///< Functions
|
||||
PointerTyID, ///< Pointers
|
||||
StructTyID, ///< Struct
|
||||
BlockTyID, ///< Block
|
||||
};
|
||||
|
||||
public:
|
||||
//constructors
|
||||
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
|
||||
|
||||
//destructor
|
||||
virtual ~type(){}
|
||||
|
||||
// accessors
|
||||
context &get_context() const { return ctx_; }
|
||||
id_t get_type_id() const { return id_; }
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
block_shapes_t get_block_shapes() const;
|
||||
const size_t get_tile_rank() const;
|
||||
const size_t get_tile_ranks1() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
type *get_pointer_element_ty() const;
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
bool is_integer_ty(unsigned width) const;
|
||||
bool is_floating_point_ty() const;
|
||||
bool is_sized() const ;
|
||||
|
||||
// Factory methods
|
||||
// primitive types
|
||||
static type *get_void_ty(context &ctx);
|
||||
static type *get_label_ty(context &ctx);
|
||||
// half
|
||||
static type *get_fp8_ty(context &ctx);
|
||||
static type *get_fp16_ty(context &ctx);
|
||||
static type *get_bf16_ty(context &ctx);
|
||||
static type *get_fp32_ty(context &ctx);
|
||||
static type *get_fp64_ty(context &ctx);
|
||||
// integer types
|
||||
static integer_type *get_int1_ty(context &ctx);
|
||||
static integer_type *get_int8_ty(context &ctx);
|
||||
static integer_type *get_int16_ty(context &ctx);
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
std::string res = get_tile_element_ty()->repr();
|
||||
auto shapes = get_block_shapes();
|
||||
res += "<";
|
||||
for(size_t i = 0; i < shapes.size(); i++){
|
||||
if(i > 0)
|
||||
res += ", ";
|
||||
res += std::to_string(shapes[i]);
|
||||
}
|
||||
res+= ">";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case FP8TyID: return "fp8";
|
||||
case FP16TyID: return "f16";
|
||||
case FP32TyID: return "f32";
|
||||
case FP64TyID: return "f64";
|
||||
case BF16TyID: return "bf16";
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case BlockTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
|
||||
};
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
|
||||
protected:
|
||||
contained_tys_vec_t contained_tys_;
|
||||
};
|
||||
|
||||
class integer_type: public type {
|
||||
friend class context_impl;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
protected:
|
||||
using type::type;
|
||||
|
||||
public:
|
||||
bool index_valid(value *idx) const;
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class block_type: public composite_type {
|
||||
private:
|
||||
block_type(type *ty, const block_shapes_t &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const block_shapes_t& get_shapes() const { return shapes_; }
|
||||
unsigned get_num_elements() const;
|
||||
unsigned get_bitwidth() const;
|
||||
|
||||
// factory methods
|
||||
static block_type* get(type *ty, const block_shapes_t &shapes);
|
||||
static block_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
private:
|
||||
block_shapes_t shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
private:
|
||||
pointer_type(type *ty, unsigned address_space);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_address_space() const { return address_space_; }
|
||||
type *get_element_ty() const { return contained_tys_[0]; }
|
||||
// factory methods
|
||||
static pointer_type* get(type *ty, unsigned address_space);
|
||||
|
||||
private:
|
||||
unsigned address_space_;
|
||||
};
|
||||
|
||||
class function_type: public type {
|
||||
private:
|
||||
function_type(type *ret_ty, const std::vector<type *> ¶m_tys);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_num_params() const { return contained_tys_.size() - 1; }
|
||||
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
|
||||
const_ty_iterator params_end() const { return contained_tys_.end(); }
|
||||
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
|
||||
ty_iterator params_end() { return contained_tys_.end(); }
|
||||
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
||||
type* get_return_ty() const { return contained_tys_.at(0); }
|
||||
// factory methods
|
||||
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CFG_H_
|
||||
#define _TRITON_IR_CFG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class value;
|
||||
|
||||
class cfg {
|
||||
public:
|
||||
static std::vector<basic_block *> post_order(function* fn);
|
||||
static std::vector<basic_block *> reverse_post_order(function* fn);
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,95 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VALUE_H_
|
||||
#define _TRITON_IR_VALUE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class use;
|
||||
class user;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class value {
|
||||
public:
|
||||
typedef std::set<user*> users_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
value(type *ty, const std::string &name = "");
|
||||
virtual ~value(){ }
|
||||
// uses
|
||||
void add_use(user* arg);
|
||||
users_t::iterator erase_use(user* arg);
|
||||
const std::set<user*> &get_users() { return users_; }
|
||||
void replace_all_uses_with(value *target);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
const std::string &get_name() const { return name_; }
|
||||
bool has_name() const { return !name_.empty(); }
|
||||
type* get_type() const { return ty_; }
|
||||
// visitor
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
||||
protected:
|
||||
type *ty_;
|
||||
users_t users_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// user class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class user: public value{
|
||||
public:
|
||||
typedef std::vector<value*> ops_t;
|
||||
typedef ops_t::iterator op_iterator;
|
||||
typedef ops_t::const_iterator const_op_iterator;
|
||||
|
||||
protected:
|
||||
void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; }
|
||||
void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; }
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
user(type *ty, unsigned num_ops, const std::string &name = "")
|
||||
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
|
||||
}
|
||||
virtual ~user() { }
|
||||
|
||||
// Operands
|
||||
const ops_t& ops() { return ops_; }
|
||||
const ops_t& ops() const { return ops_; }
|
||||
op_iterator op_begin() { return ops_.begin(); }
|
||||
op_iterator op_end() { return ops_.end(); }
|
||||
void set_operand(unsigned i, value *x);
|
||||
value *get_operand(unsigned i) const;
|
||||
unsigned get_num_operands() const ;
|
||||
unsigned get_num_hidden() const;
|
||||
|
||||
// Utils
|
||||
value::users_t::iterator replace_uses_of_with(value *before, value *after);
|
||||
|
||||
|
||||
private:
|
||||
ops_t ops_;
|
||||
unsigned num_ops_;
|
||||
unsigned num_hidden_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,170 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VISITOR_H_
|
||||
#define _TRITON_IR_VISITOR_H_
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class value;
|
||||
|
||||
class instruction;
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
class s_ext_inst;
|
||||
class fp_trunc_inst;
|
||||
class fp_ext_inst;
|
||||
class ui_to_fp_inst;
|
||||
class si_to_fp_inst;
|
||||
class fp_to_ui_inst;
|
||||
class fp_to_si_inst;
|
||||
class ptr_to_int_inst;
|
||||
class int_to_ptr_inst;
|
||||
class bit_cast_inst;
|
||||
class addr_space_cast_inst;
|
||||
|
||||
class return_inst;
|
||||
class cond_branch_inst;
|
||||
class uncond_branch_inst;
|
||||
|
||||
|
||||
class unmasked_load_inst;
|
||||
class masked_load_inst;
|
||||
class unmasked_store_inst;
|
||||
class masked_store_inst;
|
||||
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class cat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class umulhi_inst;
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class cvt_layout_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class masked_load_async_inst;
|
||||
class barrier_inst;
|
||||
class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class function;
|
||||
|
||||
class basic_block;
|
||||
|
||||
class argument;
|
||||
|
||||
class visitor {
|
||||
public:
|
||||
virtual ~visitor() {}
|
||||
|
||||
virtual void visit_value(ir::value*);
|
||||
|
||||
virtual void visit_basic_block(basic_block*) = 0;
|
||||
virtual void visit_argument(argument*) = 0;
|
||||
virtual void visit_phi_node(phi_node*) = 0;
|
||||
virtual void visit_binary_operator(binary_operator*) = 0;
|
||||
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
|
||||
virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0;
|
||||
virtual void visit_masked_load_inst(masked_load_inst*) = 0;
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
12
lib/ir/CMakeLists.txt
Normal file
12
lib/ir/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_mlir_dialect_library(TRITONIR
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithmetic
|
||||
MLIRControlFlow
|
||||
MLIRFunc
|
||||
MLIRTensor
|
||||
)
|
71
lib/ir/Dialect.cpp
Normal file
71
lib/ir/Dialect.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void TritonDialect::initialize() {
|
||||
registerTypes();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Ops.cpp.inc"
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pointer-type ::= `!triton.ptr<` element-type ` >`
|
||||
static Type parsePointerType(TritonDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
|
||||
Type pointeeType;
|
||||
if (parser.parseType(pointeeType))
|
||||
return Type();
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
return PointerType::get(pointeeType);
|
||||
}
|
||||
|
||||
// trtion-type ::= pointer-type
|
||||
Type TritonDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
if (keyword == "ptr")
|
||||
return parsePointerType(*this, parser);
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Printing
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void print(PointerType type, DialectAsmPrinter &os) {
|
||||
os << "ptr<" << type.getPointeeType() << ">";
|
||||
}
|
||||
|
||||
void TritonDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<PointerType>( [&](auto type) { print(type, os); })
|
||||
.Default([](Type) { llvm_unreachable("unhandled Triton type"); });
|
||||
}
|
63
lib/ir/Ops.cpp
Normal file
63
lib/ir/Ops.cpp
Normal file
@@ -0,0 +1,63 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.cpp.inc"
|
||||
|
||||
// enum attribute definitions
|
||||
#include "triton/OpsEnums.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
// other
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
resultType,
|
||||
DenseElementsAttr::get(
|
||||
resultType, builder.getZeroAttr(elementType)
|
||||
)
|
||||
);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addOperands(other);
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
55
lib/ir/Types.cpp
Normal file
55
lib/ir/Types.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
// F8 & BF8
|
||||
Float8Type Float8Type::get(MLIRContext *context) {
|
||||
return Base::get(context);
|
||||
}
|
||||
|
||||
BFloat8Type BFloat8Type::get(MLIRContext *context) {
|
||||
return Base::get(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PointerType
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct triton::detail::PointerTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::pair<Type, unsigned>;
|
||||
|
||||
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<PointerTypeStorage>()) PointerTypeStorage(key);
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(pointeeType, addressSpace);
|
||||
}
|
||||
|
||||
PointerTypeStorage(const KeyTy &key)
|
||||
: pointeeType(key.first), addressSpace(key.second) {}
|
||||
|
||||
Type pointeeType;
|
||||
unsigned addressSpace;
|
||||
};
|
||||
|
||||
PointerType PointerType::get(Type pointeeType) {
|
||||
return Base::get(pointeeType.getContext(), pointeeType, 0);
|
||||
}
|
||||
|
||||
PointerType PointerType::get(Type pointeeType, unsigned addressSpace) {
|
||||
return Base::get(pointeeType.getContext(), pointeeType, addressSpace);
|
||||
}
|
||||
|
||||
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
||||
|
||||
unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Triton Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
void TritonDialect::registerTypes() {
|
||||
addTypes<Float8Type, BFloat8Type, PointerType>();
|
||||
}
|
@@ -1,41 +0,0 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
|
||||
class phi_node;
|
||||
|
||||
|
||||
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
|
||||
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
|
||||
if(parent_)
|
||||
parent_->insert_block(this);
|
||||
}
|
||||
|
||||
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
|
||||
return new basic_block(ctx, name, parent);
|
||||
}
|
||||
|
||||
void basic_block::add_predecessor(basic_block *pred) {
|
||||
preds_.push_back(pred);
|
||||
if(pred)
|
||||
pred->succs_.push_back(this);
|
||||
}
|
||||
|
||||
|
||||
|
||||
basic_block::iterator basic_block::get_first_non_phi(){
|
||||
auto it = begin();
|
||||
for(; it != end(); it++)
|
||||
if(!dynamic_cast<phi_node*>(*it)){
|
||||
return it;
|
||||
}
|
||||
return it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
@@ -1,434 +0,0 @@
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
builder::builder(context &ctx):
|
||||
ctx_(ctx), block_(nullptr) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
void builder::set_insert_point(basic_block::iterator it){
|
||||
block_ = (*it)->get_parent();
|
||||
insert_point_ = it;
|
||||
}
|
||||
|
||||
void builder::set_insert_point(instruction* i){
|
||||
block_ = i->get_parent();
|
||||
auto it = std::find(block_->begin(), block_->end(), i);
|
||||
set_insert_point(it);
|
||||
}
|
||||
|
||||
|
||||
void builder::set_insert_point_after(instruction* i){
|
||||
block_ = i->get_parent();
|
||||
auto it = std::find(block_->begin(), block_->end(), i);
|
||||
set_insert_point(++it);
|
||||
}
|
||||
|
||||
|
||||
void builder::set_insert_point(basic_block *block){
|
||||
block_ = block;
|
||||
insert_point_ = block->end();
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// convenience functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::get_int1(bool val)
|
||||
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_int32(uint32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(uint64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
|
||||
return insert(make_range::create(lo, hi));
|
||||
}
|
||||
|
||||
type *builder::get_void_ty()
|
||||
{ return type::get_void_ty(ctx_); }
|
||||
|
||||
type *builder::get_int1_ty()
|
||||
{ return type::get_int1_ty(ctx_); }
|
||||
|
||||
type *builder::get_int8_ty()
|
||||
{ return type::get_int8_ty(ctx_); }
|
||||
|
||||
type *builder::get_int16_ty()
|
||||
{ return type::get_int16_ty(ctx_); }
|
||||
|
||||
type *builder::get_int32_ty()
|
||||
{ return type::get_int32_ty(ctx_); }
|
||||
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_fp8_ty()
|
||||
{ return type::get_fp8_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
type *builder::get_bf16_ty()
|
||||
{ return type::get_bf16_ty(ctx_); }
|
||||
|
||||
type *builder::get_float_ty()
|
||||
{ return type::get_fp32_ty(ctx_); }
|
||||
|
||||
type *builder::get_double_ty()
|
||||
{ return type::get_fp64_ty(ctx_); }
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_br(basic_block *dest){
|
||||
dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(dest));
|
||||
}
|
||||
|
||||
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
|
||||
if_dest->add_predecessor(block_);
|
||||
else_dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(cond, if_dest, else_dest));
|
||||
}
|
||||
|
||||
value *builder::create_ret_void() {
|
||||
return insert(return_inst::create(ctx_));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\
|
||||
return create_cast(OPCODE, src, dst_ty);\
|
||||
}
|
||||
|
||||
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
|
||||
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
|
||||
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
|
||||
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
||||
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
||||
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
|
||||
DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
|
||||
DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
|
||||
DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
|
||||
|
||||
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){
|
||||
return insert(cast_inst::create(op, v, dst_ty));
|
||||
}
|
||||
|
||||
value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){
|
||||
return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node* builder::create_phi(type *ty, unsigned num_reserved){
|
||||
return insert(phi_node::create(ty, num_reserved));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary float instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
|
||||
return insert(binary_operator::create(OPCODE, lhs, rhs));\
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary int instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
value *rhs,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs));
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
|
||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\
|
||||
}\
|
||||
|
||||
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Binary
|
||||
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
|
||||
DEFINE_NOWRAP_BINARY(add, binary_op_t::Add)
|
||||
DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub)
|
||||
DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl)
|
||||
DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr)
|
||||
DEFINE_NOWRAP_BINARY(lshr, binary_op_t::LShr)
|
||||
DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv)
|
||||
DEFINE_BINARY_INT(udiv, binary_op_t::UDiv)
|
||||
DEFINE_BINARY_INT(srem, binary_op_t::SRem)
|
||||
DEFINE_BINARY_INT(urem, binary_op_t::URem)
|
||||
DEFINE_BINARY_INT(and, binary_op_t::And)
|
||||
DEFINE_BINARY_INT(or, binary_op_t::Or)
|
||||
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list){
|
||||
return insert(getelementptr_inst::create(ptr, idx_list));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// icmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){
|
||||
return insert(icmp_inst::create(pred, lhs, rhs));
|
||||
}
|
||||
|
||||
#define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_icmp(OPCODE, lhs, rhs);\
|
||||
}
|
||||
|
||||
// Signed
|
||||
DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE)
|
||||
DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT)
|
||||
DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE)
|
||||
DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT)
|
||||
// Unsigned
|
||||
DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE)
|
||||
DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT)
|
||||
DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE)
|
||||
DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT)
|
||||
// General
|
||||
DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ)
|
||||
DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// fcmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){
|
||||
return insert(fcmp_inst::create(pred, lhs, rhs));
|
||||
}
|
||||
|
||||
#define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_fcmp(OPCODE, lhs, rhs);\
|
||||
}
|
||||
|
||||
// Ordered
|
||||
DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE)
|
||||
DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT)
|
||||
DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE)
|
||||
DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
|
||||
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
|
||||
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE)
|
||||
DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT)
|
||||
DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE)
|
||||
DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT)
|
||||
DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ)
|
||||
DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
return insert(masked_store_inst::create(ptr, val, mask));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// block instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(reshape_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_cat(value *lhs, value *rhs) {
|
||||
return insert(cat_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(splat_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(broadcast_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_downcast(value *arg) {
|
||||
return insert(downcast_inst::create(arg));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
|
||||
return create_atomic_rmw(OPCODE, ptr, val, mask);\
|
||||
}
|
||||
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_program_id(unsigned axis) {
|
||||
return insert(get_program_id_inst::create(ctx_, axis));
|
||||
}
|
||||
|
||||
value *builder::create_get_num_programs(unsigned axis) {
|
||||
return insert(get_num_programs_inst::create(ctx_, axis));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
return insert(exp_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_cos(value *arg){
|
||||
return insert(cos_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_sin(value *arg){
|
||||
return insert(sin_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_log(value *arg){
|
||||
return insert(log_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
|
||||
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
||||
return insert(trans_inst::create(A, perm));
|
||||
}
|
||||
|
||||
value *builder::create_sqrt(value *A) {
|
||||
return insert(sqrt_inst::create(A));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) {
|
||||
return insert(reduce_inst::create(A, op, axis));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value){
|
||||
return insert(select_inst::create(pred, if_value, else_value));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||
return insert(umulhi_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_copy_to_shared(value *arg) {
|
||||
return insert(copy_to_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_));
|
||||
}
|
||||
|
||||
value *builder::create_async_wait(int N) {
|
||||
return insert(async_wait_inst::create(ctx_, N));
|
||||
}
|
||||
|
||||
value *builder::create_prefetch_s(value *arg, int inc) {
|
||||
return insert(prefetch_s_inst::create(ctx_, arg, inc));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,118 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
// constant
|
||||
|
||||
constant *constant::get_null_value(type *ty) {
|
||||
context &ctx = ty->get_context();
|
||||
switch (ty->get_scalar_ty()->get_type_id()) {
|
||||
case type::IntegerTyID:
|
||||
return constant_int::get(ty, 0);
|
||||
case type::FP16TyID:
|
||||
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||
case type::FP32TyID:
|
||||
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||
case type::FP64TyID:
|
||||
return constant_fp::get(type::get_fp64_ty(ctx), 0);
|
||||
default:
|
||||
throw std::runtime_error("Cannot create a null constant of that type!");
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME
|
||||
|
||||
constant *constant::get_all_ones_value(type *ty) {
|
||||
if(ty->is_integer_ty())
|
||||
return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF);
|
||||
if(ty->is_floating_point_ty())
|
||||
return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF);
|
||||
throw std::runtime_error("Cannot create all ones value for that type!");
|
||||
}
|
||||
|
||||
// constant_int
|
||||
// FIXME use something like APInt
|
||||
|
||||
constant_int::constant_int(type *ty, uint64_t value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||
if (!ty->is_integer_ty())
|
||||
throw std::runtime_error("Cannot create constant_int with non integer ty");
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<constant_int> &cst = impl->int_constants_[std::make_pair(ty, value)];
|
||||
if(!cst)
|
||||
cst.reset(new constant_int(ty, value));
|
||||
return cst.get();
|
||||
}
|
||||
|
||||
|
||||
// constant_fp
|
||||
// FIXME use something like APFloat
|
||||
|
||||
constant_fp::constant_fp(type *ty, double value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant *constant_fp::get_negative_zero(type *ty){
|
||||
double neg_zero = 0;
|
||||
return get(ty, neg_zero);
|
||||
}
|
||||
|
||||
constant *constant_fp::get_zero_value_for_negation(type *ty) {
|
||||
if(ty->get_scalar_ty()->is_floating_point_ty())
|
||||
return constant_fp::get(ty, 0);
|
||||
return constant::get_null_value(ty);
|
||||
}
|
||||
|
||||
constant *constant_fp::get(type *ty, double v){
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<constant_fp> &result = impl->fp_constants_[std::make_pair(ty, v)];
|
||||
if(!result)
|
||||
result.reset(new constant_fp(ty, v));
|
||||
return result.get();
|
||||
}
|
||||
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
: constant(ty, 0) { }
|
||||
|
||||
undef_value *undef_value::get(type *ty) {
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<undef_value> &result = impl->uv_constants_[ty];
|
||||
if(!result)
|
||||
result.reset(new undef_value(ty));
|
||||
return result.get();
|
||||
}
|
||||
|
||||
/* global value */
|
||||
global_value::global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage,
|
||||
const std::string &name, unsigned addr_space)
|
||||
: constant(pointer_type::get(ty, addr_space), num_ops, name),
|
||||
linkage_(linkage) { }
|
||||
|
||||
|
||||
/* global object */
|
||||
global_object::global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage,
|
||||
const std::string &name, unsigned addr_space)
|
||||
: global_value(ty, num_ops, linkage, name, addr_space) { }
|
||||
|
||||
|
||||
/* alloc const */
|
||||
alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name)
|
||||
: global_object(ty, 1, global_value::external, name, 4) {
|
||||
set_operand(0, size);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,40 +0,0 @@
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
context_impl::context_impl(context &ctx)
|
||||
: void_ty(ctx, type::VoidTyID),
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
// floating point
|
||||
fp8_ty(ctx, type::FP8TyID),
|
||||
fp16_ty(ctx, type::FP16TyID),
|
||||
bf16_ty(ctx, type::BF16TyID),
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
context::context():
|
||||
p_impl(std::make_shared<context_impl>(*this)) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,66 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Argument */
|
||||
|
||||
argument::argument(type *ty, const std::string &name, function *parent, unsigned arg_no)
|
||||
: value(ty, name), parent_(parent), arg_no_(arg_no) { }
|
||||
|
||||
argument *argument::create(type *ty, const std::string &name,
|
||||
function *parent, unsigned arg_no) {
|
||||
return new argument(ty, name, parent, arg_no);
|
||||
}
|
||||
|
||||
function* argument::get_parent() const {
|
||||
return parent_;
|
||||
}
|
||||
|
||||
unsigned argument::get_arg_no() const {
|
||||
return arg_no_;
|
||||
}
|
||||
|
||||
void argument::accept(visitor *v) {
|
||||
v->visit_argument(this);
|
||||
}
|
||||
|
||||
|
||||
/* function */
|
||||
function::function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *parent)
|
||||
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
|
||||
unsigned num_params = fn_ty_->get_num_params();
|
||||
// skip if no parameter
|
||||
if(num_params == 0)
|
||||
return;
|
||||
// create arguments
|
||||
args_.resize(num_params);
|
||||
for(unsigned i = 0; i < num_params; i++){
|
||||
type *param_ty = fn_ty_->get_param_ty(i);
|
||||
args_[i] = argument::create(param_ty, "", this, i);
|
||||
}
|
||||
if(parent)
|
||||
parent->push_function(this);
|
||||
}
|
||||
|
||||
/* basic block */
|
||||
void function::insert_block(basic_block *block, basic_block *next) {
|
||||
auto it = std::find(blocks_.begin(), blocks_.end(), next);
|
||||
blocks_.insert(it, block);
|
||||
}
|
||||
|
||||
|
||||
function *function::create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod) {
|
||||
return new function(ty, linkage, name, mod);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,928 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name, instruction *next)
|
||||
: user(ty, num_ops, name), id_(ity) {
|
||||
if(next){
|
||||
basic_block *block = next->get_parent();
|
||||
assert(block && "Next instruction is not in a basic block!");
|
||||
auto it = std::find(block->begin(), block->end(), next);
|
||||
block->get_inst_list().insert(it, next);
|
||||
}
|
||||
}
|
||||
|
||||
void instruction::erase_from_parent() {
|
||||
parent_->erase(this);
|
||||
for(ir::value* op: ops())
|
||||
op->erase_use(this);
|
||||
}
|
||||
|
||||
bool instruction::has_tile_result_or_op() {
|
||||
bool result = get_type()->is_block_ty();
|
||||
for(unsigned i = 0; i < get_num_operands(); i++)
|
||||
result |= get_operand(i)->get_type()->is_block_ty();
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
|
||||
: instruction(ty, INST_PHI, 0, name, next) {
|
||||
blocks_.reserve(num_reserved);
|
||||
}
|
||||
|
||||
value* phi_node::get_value_for_block(basic_block * block) {
|
||||
auto it = std::find(blocks_.begin(), blocks_.end(), block);
|
||||
size_t n = std::distance(blocks_.begin(), it);
|
||||
return get_incoming_value(n);
|
||||
}
|
||||
|
||||
// Set incoming value
|
||||
void phi_node::set_incoming_value(unsigned i, value *v){
|
||||
assert(v && "PHI node got a null value!");
|
||||
assert(get_type() == v->get_type() &&
|
||||
"All operands to PHI node must be the same type as the PHI node!");
|
||||
set_operand(i, v);
|
||||
}
|
||||
|
||||
// Set incoming block
|
||||
void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
||||
assert(block && "PHI node got a null basic block!");
|
||||
blocks_[i] = block;
|
||||
}
|
||||
|
||||
// Add incoming
|
||||
void phi_node::add_incoming(value *v, basic_block *block){
|
||||
resize_ops(get_num_operands() + 1);
|
||||
blocks_.resize(get_num_operands() + 1);
|
||||
set_incoming_value(get_num_operands() - 1, v);
|
||||
set_incoming_block(get_num_operands() - 1, block);
|
||||
}
|
||||
|
||||
// Factory methods
|
||||
phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
|
||||
return new phi_node(ty, num_reserved, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string binary_operator::repr_impl() const {
|
||||
switch(op_) {
|
||||
case Add : return "add";
|
||||
case FAdd : return "fadd";
|
||||
case Sub : return "sub";
|
||||
case FSub : return "fsub";
|
||||
case Mul : return "mul";
|
||||
case FMul : return "fmul";
|
||||
case UDiv : return "udiv";
|
||||
case SDiv : return "sdiv";
|
||||
case FDiv : return "fdiv";
|
||||
case URem : return "urem";
|
||||
case SRem : return "srem";
|
||||
case FRem : return "frem";
|
||||
case Shl : return "shl";
|
||||
case LShr : return "lshr";
|
||||
case AShr : return "ashr";
|
||||
case And : return "and";
|
||||
case Or : return "or";
|
||||
case Xor : return "xor";
|
||||
default: throw std::runtime_error("unknown binary operator");
|
||||
}
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_div() const {
|
||||
return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_rem() const {
|
||||
return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shl() const {
|
||||
return op_ == binary_op_t::Shl;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shr() const {
|
||||
return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_mult() const {
|
||||
return op_ == binary_op_t::Mul;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_add_sub() const {
|
||||
return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
|
||||
}
|
||||
|
||||
|
||||
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(lhs->get_type() == rhs->get_type() &&
|
||||
"Cannot create binary operator with two operands of differing type!");
|
||||
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
|
||||
}
|
||||
|
||||
//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
|
||||
// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->is_integer_ty());
|
||||
// constant *mask = constant::get_all_ones_value(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
||||
//}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
|
||||
// cmp_inst
|
||||
std::string cmp_inst::repr_impl() const {
|
||||
switch (pred_) {
|
||||
case FCMP_FALSE : return "false";
|
||||
case FCMP_OEQ : return "fcmp_oeq";
|
||||
case FCMP_OGT : return "fcmp_ogt";
|
||||
case FCMP_OGE : return "fcmp_oge";
|
||||
case FCMP_OLT : return "fcmp_olt";
|
||||
case FCMP_OLE : return "fcmp_ole";
|
||||
case FCMP_ONE : return "fcmp_one";
|
||||
case FCMP_ORD : return "fcmp_ord";
|
||||
case FCMP_UNO : return "fcmp_uno";
|
||||
case FCMP_UEQ : return "fcmp_ueq";
|
||||
case FCMP_UGT : return "fcmp_ugt";
|
||||
case FCMP_UGE : return "fcmp_uge";
|
||||
case FCMP_ULT : return "fcmp_ult";
|
||||
case FCMP_ULE : return "fcmp_ule";
|
||||
case FCMP_UNE : return "fcmp_une";
|
||||
case FCMP_TRUE : return "true";
|
||||
case ICMP_EQ : return "icmp_eq";
|
||||
case ICMP_NE : return "icmp_ne";
|
||||
case ICMP_UGT : return "icmp_ugt";
|
||||
case ICMP_UGE : return "icmp_uge";
|
||||
case ICMP_ULT : return "icmp_ult";
|
||||
case ICMP_ULE : return "icmp_ule";
|
||||
case ICMP_SGT : return "icmp_sgt";
|
||||
case ICMP_SGE : return "icmp_sge";
|
||||
case ICMP_SLT : return "icmp_slt";
|
||||
case ICMP_SLE : return "icmp_sle";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 2, name, next), pred_(pred) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
type* cmp_inst::make_cmp_result_type(type *ty){
|
||||
type* int1_ty = type::get_int1_ty(ty->get_context());
|
||||
if (block_type* tile_ty = dynamic_cast<block_type*>(ty))
|
||||
return block_type::get_same_shapes(int1_ty, tile_ty);
|
||||
return int1_ty;
|
||||
}
|
||||
|
||||
|
||||
bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
|
||||
}
|
||||
|
||||
bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
|
||||
}
|
||||
|
||||
|
||||
// icmp_inst
|
||||
icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_int_predicate(pred));
|
||||
assert(lhs->get_type() == rhs->get_type());
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// fcmp_inst
|
||||
fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_fp_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 1, name, next) {
|
||||
set_operand(0, v);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string cast_inst::repr_impl() const {
|
||||
switch (op_){
|
||||
case cast_op_t::Trunc: return "trunc";
|
||||
case cast_op_t::ZExt: return "zext";
|
||||
case cast_op_t::SExt: return "sext";
|
||||
case cast_op_t::FPTrunc: return "fp_trunc";
|
||||
case cast_op_t::FPExt: return "fp_ext";
|
||||
case cast_op_t::UIToFP: return "ui_to_fp";
|
||||
case cast_op_t::SIToFP: return "si_to_fp";
|
||||
case cast_op_t::FPToUI: return "fp_to_ui";
|
||||
case cast_op_t::FPToSI: return "fp_to_si";
|
||||
case cast_op_t::PtrToInt: return "ptr_to_int";
|
||||
case cast_op_t::IntToPtr: return "int_to_ptr";
|
||||
case cast_op_t::BitCast: return "bitcast";
|
||||
case cast_op_t::AddrSpaceCast: return "addr_space_cast";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
// TODO
|
||||
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
|
||||
assert(arg->get_type()->is_block_ty() == ty->is_block_ty());
|
||||
return true;
|
||||
}
|
||||
|
||||
cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
|
||||
assert(is_valid(op, arg, ty) && "Invalid cast!");
|
||||
// Construct and return the appropriate CastInst subclass
|
||||
switch (op) {
|
||||
case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
|
||||
case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
|
||||
case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
|
||||
case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
|
||||
case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
|
||||
type *arg_ty = arg->get_type();
|
||||
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
||||
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
|
||||
(arg_bits > dst_bits ? cast_op_t::Trunc :
|
||||
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
|
||||
return create(op, arg, ty, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// return_inst
|
||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||
if(ret_val)
|
||||
set_operand(0, ret_val);
|
||||
}
|
||||
|
||||
return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){
|
||||
return new return_inst(ctx, ret_val, next);
|
||||
}
|
||||
|
||||
|
||||
// branch_inst
|
||||
branch_inst* branch_inst::create(basic_block *dst, instruction *next) {
|
||||
assert(dst && "Branch destination may not be null!");
|
||||
return new uncond_branch_inst(dst, next);
|
||||
}
|
||||
|
||||
branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) {
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
return new cond_branch_inst(if_dst, else_dst, cond, next);
|
||||
}
|
||||
|
||||
// uncond_branch_inst
|
||||
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
|
||||
: branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
|
||||
set_operand(0, dst);
|
||||
}
|
||||
|
||||
// cond_branch_inst
|
||||
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
|
||||
: branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
set_operand(0, if_dst);
|
||||
set_operand(1, else_dst);
|
||||
set_operand(2, cond);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
|
||||
: instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
|
||||
source_elt_ty(pointee_ty),
|
||||
res_elt_ty(get_indexed_type(pointee_ty, idx)){
|
||||
// sanity check
|
||||
type *expected_ty = get_type()->get_scalar_ty();
|
||||
expected_ty = ((pointer_type*)expected_ty)->get_element_ty();
|
||||
assert(res_elt_ty == expected_ty);
|
||||
// set operands
|
||||
set_operand(0, ptr);
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
set_operand(1 + i, idx[i]);
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
|
||||
// result pointer type
|
||||
type *ty = x->get_type();
|
||||
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
|
||||
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
|
||||
// Tile GEP
|
||||
if(ty->is_block_ty())
|
||||
return block_type::get_same_shapes(ptr_ty, ty);
|
||||
for(value *idx : idx_list)
|
||||
if (idx->get_type()->is_block_ty())
|
||||
return block_type::get_same_shapes(ptr_ty, ty);
|
||||
// Scalar GEP
|
||||
return ptr_ty;
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector<value *> &idx_list) {
|
||||
if(idx_list.empty())
|
||||
return ty;
|
||||
if(!ty->is_sized())
|
||||
return nullptr;
|
||||
unsigned cur_idx = 1;
|
||||
for(; cur_idx != idx_list.size(); cur_idx++){
|
||||
composite_type *cty = dynamic_cast<composite_type*>(ty);
|
||||
if(!cty || cty->is_pointer_ty())
|
||||
break;
|
||||
value *idx = idx_list[cur_idx];
|
||||
if(!cty->index_valid(idx))
|
||||
break;
|
||||
ty = cty->get_type_at_index(idx);
|
||||
}
|
||||
return (cur_idx == idx_list.size())? ty : nullptr;
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_indexed_type(type *ty, const std::vector<value *> &idx_list) {
|
||||
type *result = get_indexed_type_impl(ty, idx_list);
|
||||
assert(result && "invalid GEP type!");
|
||||
return result;
|
||||
}
|
||||
|
||||
getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next) {
|
||||
type *pointee_ty = ((pointer_type*)(ptr->get_type()->get_scalar_ty()))->get_element_ty();
|
||||
return new getelementptr_inst(pointee_ty, ptr, idx, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
type *load_inst::get_pointee_type(type *ty) {
|
||||
type *scalar_ty = ty->get_scalar_ty();
|
||||
type *pointee_ty = scalar_ty->get_pointer_element_ty();
|
||||
if(ty->is_block_ty())
|
||||
return block_type::get_same_shapes(pointee_ty, ty);
|
||||
return pointee_ty;
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||
{x->get_type()->get_block_shapes()[0] +
|
||||
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||
set_operand(0, x);
|
||||
set_operand(1, y);
|
||||
}
|
||||
|
||||
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new cat_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// retile
|
||||
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
|
||||
}
|
||||
|
||||
|
||||
// splat
|
||||
|
||||
instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
|
||||
instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
|
||||
}
|
||||
|
||||
// downcast
|
||||
|
||||
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
allow_tf32_ = allow_tf32;
|
||||
}
|
||||
|
||||
instruction *dot_inst::create(value *A, value *B, value *C,
|
||||
bool AT, bool BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
TransT OPA = AT ? Trans : NoTrans;
|
||||
TransT OPB = BT ? Trans : NoTrans;
|
||||
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// trans instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
|
||||
// get argument shapes
|
||||
ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes();
|
||||
// permutate argument shapes
|
||||
perm = init_perm(ty, perm);
|
||||
ir::block_type::block_shapes_t res_shapes = arg_shapes;
|
||||
for(size_t i = 0; i < perm.size(); i++)
|
||||
res_shapes[i] = arg_shapes[perm[i]];
|
||||
// construct type
|
||||
return block_type::get(ty->get_scalar_ty(), res_shapes);
|
||||
}
|
||||
|
||||
std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
|
||||
if(!perm.empty())
|
||||
return perm;
|
||||
auto size = ty->get_block_shapes().size();
|
||||
std::vector<int> result;
|
||||
result.push_back(size - 1);
|
||||
for(size_t i = 0; i < size - 1; i++)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
||||
// sanity check
|
||||
perm_ = init_perm(arg->get_type(), perm);
|
||||
//auto size = arg->get_type()->get_tile_shapes().size();
|
||||
//assert(perm_.size() == size);
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
|
||||
return new trans_inst(arg, perm, name, next);
|
||||
}
|
||||
|
||||
const std::vector<int> trans_inst::get_perm() const {
|
||||
return perm_;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// sqrt instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new sqrt_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string reduce_inst::to_str(op_t op) {
|
||||
switch (op) {
|
||||
case ADD: return "+";
|
||||
case SUB: return "-";
|
||||
case MAX: return "imax";
|
||||
case MIN: return "imin";
|
||||
case FADD: return "+";
|
||||
case FSUB: return "-";
|
||||
case FMAX: return "fmax";
|
||||
case FMIN: return "fmin";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
type *scalar_ty = arg->get_type()->get_scalar_ty();
|
||||
if(shapes.empty())
|
||||
// shapes.push_back(1);
|
||||
return scalar_ty;
|
||||
return block_type::get(scalar_ty, shapes);
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
|
||||
op_(op),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, op, axis, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// select instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
|
||||
: builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
|
||||
set_operand(0, pred);
|
||||
set_operand(1, if_value);
|
||||
set_operand(2, else_value);
|
||||
}
|
||||
|
||||
instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
|
||||
return new select_inst(pred, if_value, else_value, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// get_program_id
|
||||
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic_rmw
|
||||
|
||||
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, cmp);
|
||||
set_operand(2, val);
|
||||
}
|
||||
|
||||
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// umulhi
|
||||
|
||||
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new umulhi_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_EXP, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* exp_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new exp_inst(val, name, next);
|
||||
}
|
||||
|
||||
// cos
|
||||
cos_inst::cos_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_COS, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* cos_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new cos_inst(val, name, next);
|
||||
}
|
||||
|
||||
// sin
|
||||
sin_inst::sin_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_SIN, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* sin_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new sin_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// log
|
||||
|
||||
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new log_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cvt_scanline
|
||||
cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next);
|
||||
}
|
||||
|
||||
// copy to shared
|
||||
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// copy from shared
|
||||
copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
|
||||
|
||||
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
|
||||
return new async_wait_inst(ctx, N, name, next);
|
||||
}
|
||||
|
||||
// prefetch_s
|
||||
prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
|
||||
return new prefetch_s_inst(ctx, arg, inc, name, next);
|
||||
}
|
||||
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
|
||||
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
// return new make_range_dyn(ty, name, next);
|
||||
//}
|
||||
|
||||
//// nv_static_program_idx
|
||||
//make_range_sta::make_range_sta(make_range *range)
|
||||
// : constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
//make_range* make_range_sta::get_range() const
|
||||
//{ return range_; }
|
||||
|
||||
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||
// static std::map<make_range*, make_range_sta*> cache;
|
||||
// if(cache.find(range) == cache.end())
|
||||
// cache.insert({range, new make_range_sta(range)});
|
||||
// return cache.at(range);
|
||||
//}
|
||||
|
||||
|
||||
// make_range
|
||||
make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
: instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
|
||||
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
// assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||
return new make_range(ty, first, last);
|
||||
}
|
||||
|
||||
const constant_int* make_range::get_first() const {
|
||||
return first_;
|
||||
}
|
||||
|
||||
const constant_int* make_range::get_last() const {
|
||||
return last_;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,14 +0,0 @@
|
||||
#include "triton/ir/metadata.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
metadata::metadata(kind_t kind, unsigned value)
|
||||
: kind_(kind), value_(value) { }
|
||||
|
||||
metadata* metadata::get(kind_t kind, unsigned value) {
|
||||
return new metadata(kind, value);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,22 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* functions */
|
||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||
function *&fn = (function*&)symbols_[name];
|
||||
if(fn == nullptr)
|
||||
return fn = function::create(ty, global_value::external, name, this);
|
||||
return fn;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
450
lib/ir/print.cc
450
lib/ir/print.cc
@@ -1,450 +0,0 @@
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/print.h"
|
||||
|
||||
#include <map>
|
||||
#include <iomanip>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
namespace {
|
||||
class SlotTracker {
|
||||
// A mapping of values to slot numbers.
|
||||
using value_map = std::map<const value*, unsigned>;
|
||||
|
||||
// The module for which we are holding slot numbers.
|
||||
const module *mod_;
|
||||
bool module_processed = false;
|
||||
|
||||
// The function for which we are holding slot numbers.
|
||||
const function *func_ = nullptr;
|
||||
bool function_processed = false;
|
||||
|
||||
// m_map - The slot map for the module level data.
|
||||
value_map m_map;
|
||||
unsigned m_next = 0;
|
||||
|
||||
// f_map - The slot map for the function level data.
|
||||
value_map f_map;
|
||||
unsigned f_next = 0;
|
||||
|
||||
public:
|
||||
// Construct from a module
|
||||
explicit SlotTracker(const module *mod) : mod_(mod) {}
|
||||
|
||||
// Construct from a function
|
||||
explicit SlotTracker(const function *f)
|
||||
: mod_(f? f->get_parent() : nullptr), func_(f) {}
|
||||
|
||||
// Return the slot number of the specified value. If something is not in
|
||||
// the SlotTracker, return -1
|
||||
int get_local_slot(const value *v);
|
||||
|
||||
void initialize_if_needed();
|
||||
|
||||
// If you'd like to deal with a function instead of just a module, use
|
||||
// this method to get its data into the SlotTracker
|
||||
void incorporate_function(const function *f) {
|
||||
func_ = f;
|
||||
function_processed = false;
|
||||
}
|
||||
|
||||
private:
|
||||
// Add all of the module level global variables (and their initializers)
|
||||
// and function declarations, but not contents of those functions.
|
||||
void process_module();
|
||||
|
||||
// Add all of the functions arguments, basic blocks, and instructions.
|
||||
void process_function();
|
||||
|
||||
// Insert specified value* into the slot table
|
||||
void create_function_slot(const value *v);
|
||||
};
|
||||
|
||||
class AssemblyWriter {
|
||||
std::ostream &os;
|
||||
SlotTracker &slot_tracker;
|
||||
|
||||
public:
|
||||
AssemblyWriter(std::ostream &os, SlotTracker &slot_tracker)
|
||||
: os(os), slot_tracker(slot_tracker) {}
|
||||
|
||||
void print_module(const module *mod);
|
||||
void print_function(const function *f);
|
||||
void print_argument(const argument *arg);
|
||||
void print_basic_block(const basic_block *bb);
|
||||
void print_instruction(const instruction *instr);
|
||||
void print_value(const value *v);
|
||||
|
||||
void write_operand(const value *op, bool print_type = false);
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
//-------------------------
|
||||
// SlotTracker
|
||||
//-------------------------
|
||||
void SlotTracker::process_module() {
|
||||
// Nothing to do at the moment.
|
||||
// Create slots for global variable & unamed functions & ...
|
||||
module_processed = true;
|
||||
}
|
||||
|
||||
void SlotTracker::process_function() {
|
||||
f_next = 0;
|
||||
|
||||
// Add all the function arguments with no names.
|
||||
for (const argument *arg : func_->args())
|
||||
if (!arg->has_name())
|
||||
create_function_slot(arg);
|
||||
|
||||
// Add all of the basic blocks and instructions with no names.
|
||||
for (const basic_block *bb : func_->blocks()) {
|
||||
if (!bb->has_name())
|
||||
create_function_slot(bb);
|
||||
|
||||
for (const instruction *instr : bb->get_inst_list()) {
|
||||
if (!instr->get_type()->is_void_ty() && !instr->has_name())
|
||||
create_function_slot(instr);
|
||||
}
|
||||
}
|
||||
|
||||
function_processed = true;
|
||||
}
|
||||
|
||||
void SlotTracker::create_function_slot(const value *v) {
|
||||
assert(!v->get_type()->is_void_ty() && !v->has_name() && "Doesn't need a slot");
|
||||
|
||||
unsigned dst_slot = f_next++;
|
||||
f_map[v] = dst_slot;
|
||||
}
|
||||
|
||||
int SlotTracker::get_local_slot(const value *v) {
|
||||
assert(dynamic_cast<const constant*>(v) == nullptr && "Can't get a constant slot");
|
||||
|
||||
// Check for uninitialized state and do lazy initialization.
|
||||
initialize_if_needed();
|
||||
|
||||
value_map::iterator f_iter = f_map.find(v);
|
||||
return f_iter == f_map.end() ? -1 : (int)f_iter->second;
|
||||
}
|
||||
|
||||
void SlotTracker::initialize_if_needed() {
|
||||
if (mod_ && !module_processed)
|
||||
process_module();
|
||||
|
||||
if (func_ && !function_processed)
|
||||
process_function();
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------
|
||||
// AssemblyWriter
|
||||
//-------------------------------
|
||||
void AssemblyWriter::write_operand(const value *operand, bool print_type) {
|
||||
if (!operand) {
|
||||
os << "<null operand!>";
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto *c = dynamic_cast<const ir::constant*>(operand)) {
|
||||
os << c->repr();
|
||||
return;
|
||||
}
|
||||
|
||||
if (operand->has_name()) {
|
||||
os << operand->get_name();
|
||||
return;
|
||||
}
|
||||
|
||||
// Print the normal way
|
||||
int slot_num = slot_tracker.get_local_slot(operand);
|
||||
|
||||
if (slot_num != -1)
|
||||
os << "%" << slot_num;
|
||||
else
|
||||
os << "<badref>";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_module(const module *mod) {
|
||||
slot_tracker.initialize_if_needed();
|
||||
// ;ModuleID = ...
|
||||
// source_filename = ...
|
||||
|
||||
// Print all of the functions.
|
||||
for (function *f : mod->get_function_list()) {
|
||||
os << "\n";
|
||||
print_function(f);
|
||||
}
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_function(const function *f) {
|
||||
// Annotation & Attributes
|
||||
|
||||
slot_tracker.incorporate_function(f);
|
||||
|
||||
os << "def ";
|
||||
ir::type *rt_type = f->get_fn_type()->get_return_ty();
|
||||
// Functions must have names.
|
||||
os << rt_type->repr() << " " << f->get_name() << "(";
|
||||
// Print arguments
|
||||
for (ir::argument *arg : f->args()) {
|
||||
if (arg->get_arg_no() > 0)
|
||||
os << ", ";
|
||||
print_argument(arg);
|
||||
}
|
||||
os << ")";
|
||||
|
||||
// Print function body
|
||||
os << "{";
|
||||
for (const basic_block *bb : f->blocks())
|
||||
print_basic_block(bb);
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_argument(const argument *arg) {
|
||||
// Print type
|
||||
os << arg->get_type()->repr();
|
||||
|
||||
// Print name, if available.
|
||||
if (arg->has_name())
|
||||
os << " " << arg->get_name();
|
||||
else {
|
||||
int slot_num = slot_tracker.get_local_slot(arg);
|
||||
assert(slot_num != -1 && "expect argument in function here");
|
||||
os << " %" << slot_num;
|
||||
}
|
||||
|
||||
// Print attributes
|
||||
std::set<attribute> attrs = arg->get_parent()->get_attributes(arg);
|
||||
for (attribute attr : attrs)
|
||||
os << " " << attr.repr();
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_basic_block(const basic_block *bb) {
|
||||
// bb label
|
||||
if (bb->has_name()) {
|
||||
os << "\n";
|
||||
os << bb->get_name() << ":";
|
||||
} else {
|
||||
os << "\n";
|
||||
int slot_num = slot_tracker.get_local_slot(bb);
|
||||
if (slot_num != -1)
|
||||
os << slot_num << ":";
|
||||
else
|
||||
os << "<badref>:";
|
||||
}
|
||||
|
||||
// Print predecessors for the block
|
||||
auto const &predecessors = bb->get_predecessors();
|
||||
if (!predecessors.empty()) {
|
||||
os << std::setw(50) << std::setfill(' ')
|
||||
<< "; preds = ";
|
||||
for (size_t i=0; i<predecessors.size(); ++i) {
|
||||
if (i)
|
||||
os << ", ";
|
||||
write_operand(predecessors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
|
||||
// Annotation?
|
||||
|
||||
// Print all of the instructions in the basic block
|
||||
for (const ir::instruction *instr : bb->get_inst_list())
|
||||
print_instruction(instr);
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_instruction(const instruction *instr) {
|
||||
// Print out indentation for an instruction.
|
||||
os << " ";
|
||||
|
||||
ir::type *type = instr->get_type();
|
||||
if (instr->has_name()) {
|
||||
os << instr->get_name();
|
||||
os << " = ";
|
||||
} else if (!type->is_void_ty()) {
|
||||
// Print out the def slot taken.
|
||||
int slot_num = slot_tracker.get_local_slot(instr);
|
||||
if (slot_num == -1)
|
||||
os << "<badref> = ";
|
||||
else
|
||||
os << "%" << slot_num << " = ";
|
||||
}
|
||||
|
||||
// Print out opcode
|
||||
os << instr->repr() << " " << type->repr();
|
||||
|
||||
size_t num_ops = instr->get_num_operands();
|
||||
if (num_ops > 0)
|
||||
os << " ";
|
||||
ir::instruction::ops_t ops = instr->ops();
|
||||
for (unsigned i = 0; i < num_ops; ++i) {
|
||||
if (i)
|
||||
os << ", ";
|
||||
write_operand(ops[i]);
|
||||
}
|
||||
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_value(const value *v) {
|
||||
// Not implemented
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------
|
||||
// External interface
|
||||
//-------------------------------
|
||||
void module::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this);
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_module(this);
|
||||
}
|
||||
|
||||
void function::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this);
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_function(this);
|
||||
}
|
||||
|
||||
void basic_block::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this->get_parent());
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_basic_block(this);
|
||||
}
|
||||
|
||||
void instruction::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this->get_parent()->get_parent());
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_instruction(this);
|
||||
}
|
||||
|
||||
//-------------------------------
|
||||
// legacy print interface
|
||||
//-------------------------------
|
||||
std::string get_name(ir::value *v, unsigned i) {
|
||||
if(v->get_name().empty()){
|
||||
std::string name = "%" + std::to_string(i);
|
||||
v->set_name(name);
|
||||
}
|
||||
return v->get_name();
|
||||
}
|
||||
|
||||
|
||||
void print(module &mod, std::ostream& os) {
|
||||
unsigned cnt = 0;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ;
|
||||
for(ir::argument* arg: fn->args()) {
|
||||
if(arg->get_arg_no() > 0)
|
||||
os << ", ";
|
||||
os << arg->get_type()->repr() << " " << arg->get_name();
|
||||
auto attrs = fn->get_attributes(arg);
|
||||
if(attrs.size() > 0)
|
||||
os << " ";
|
||||
for(ir::attribute attr: attrs)
|
||||
os << attr.repr() << " ";
|
||||
}
|
||||
os << ")" << std::endl;
|
||||
os << "{" << std::endl;
|
||||
for(ir::basic_block *block: fn->blocks()){
|
||||
auto const &predecessors = block->get_predecessors();
|
||||
os << block->get_name() << ":";
|
||||
if(!predecessors.empty()){
|
||||
os << " ";
|
||||
os << "; preds = ";
|
||||
auto const &predecessors = block->get_predecessors();
|
||||
for(ir::basic_block *pred: predecessors)
|
||||
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
|
||||
}
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: block->get_inst_list()){
|
||||
os << " ";
|
||||
if(!inst->get_type()->is_void_ty()){
|
||||
os << get_name(inst, cnt++);
|
||||
os << " = ";
|
||||
}
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
size_t num_ops = inst->get_num_operands();
|
||||
if(num_ops > 0)
|
||||
os << " ";;
|
||||
for(unsigned i = 0; i < num_ops; i++){
|
||||
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||
os << x->repr();
|
||||
else
|
||||
os << get_name(ops[i], cnt++);
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void print(function &fn, std::ostream &os) {
|
||||
//
|
||||
}
|
||||
|
||||
void print(basic_block &bb, std::ostream &os) {
|
||||
auto const &predecessors = bb.get_predecessors();
|
||||
os << bb.get_name() << ":";
|
||||
if(!predecessors.empty()){
|
||||
os << " ";
|
||||
os << "; preds = ";
|
||||
auto const &predecessors = bb.get_predecessors();
|
||||
for(ir::basic_block *pred: predecessors)
|
||||
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
|
||||
}
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: bb.get_inst_list()){
|
||||
print(*inst, os);
|
||||
}
|
||||
}
|
||||
|
||||
void print(instruction &instr, std::ostream &os) {
|
||||
instruction *inst = &instr;
|
||||
os << " ";
|
||||
if(!inst->get_type()->is_void_ty()){
|
||||
os << instr.get_name();
|
||||
os << " = ";
|
||||
}
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
size_t num_ops = inst->get_num_operands();
|
||||
if(num_ops > 0)
|
||||
os << " ";;
|
||||
for(unsigned i = 0; i < num_ops; i++){
|
||||
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||
os << x->repr();
|
||||
else
|
||||
os << ops[i]->get_name();
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
233
lib/ir/type.cc
233
lib/ir/type.cc
@@ -1,233 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/constant.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// attributes
|
||||
type *type::get_scalar_ty() const {
|
||||
if(is_block_ty())
|
||||
return get_tile_element_ty();
|
||||
return const_cast<type*>(this);
|
||||
}
|
||||
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case FP16TyID: return 16;
|
||||
case BF16TyID: return 16;
|
||||
case FP32TyID: return 32;
|
||||
case FP64TyID: return 64;
|
||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == FP16TyID) return 10;
|
||||
if (id == BF16TyID) return 7;
|
||||
if (id == FP32TyID) return 23;
|
||||
if (id == FP64TyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
type* type::get_tile_element_ty() const {
|
||||
assert(is_block_ty());
|
||||
return contained_tys_[0];
|
||||
}
|
||||
|
||||
unsigned type::get_pointer_address_space() const {
|
||||
assert(is_pointer_ty());
|
||||
return ((pointer_type*)this)->get_address_space();
|
||||
}
|
||||
|
||||
type * type::get_pointer_element_ty() const {
|
||||
type *ptr_ty = get_scalar_ty();
|
||||
assert(ptr_ty->is_pointer_ty());
|
||||
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
|
||||
if(is_block_ty())
|
||||
return block_type::get_same_shapes(scalar_ty, (type*)this);
|
||||
return scalar_ty;
|
||||
}
|
||||
|
||||
|
||||
type::block_shapes_t type::get_block_shapes() const {
|
||||
assert(is_block_ty());
|
||||
return ((block_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_rank() const {
|
||||
return get_block_shapes().size();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_ranks1() const {
|
||||
int ret = 0;
|
||||
for(int s: get_block_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
unsigned type::get_tile_num_elements() const {
|
||||
const block_shapes_t& shapes = get_block_shapes();
|
||||
unsigned result = 1;
|
||||
for(auto shape: shapes)
|
||||
result *= shape;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// composite predicates
|
||||
bool type::is_int_or_tileint_ty()
|
||||
{ return get_scalar_ty()->is_integer_ty(); }
|
||||
|
||||
bool type::is_integer_ty(unsigned width) const
|
||||
{ return is_integer_ty() && get_integer_bitwidth()== width; }
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
if(is_integer_ty() || is_floating_point_ty() ||
|
||||
is_pointer_ty()){
|
||||
return true;
|
||||
}
|
||||
// tile types are sizes
|
||||
if(is_block_ty())
|
||||
return get_scalar_ty()->is_sized();
|
||||
return false;
|
||||
}
|
||||
|
||||
// primitive types
|
||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
|
||||
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
|
||||
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
|
||||
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
|
||||
// integer types
|
||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||
integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
|
||||
|
||||
|
||||
pointer_type::pointer_type(type *ty, unsigned address_space)
|
||||
: type(ty->get_context(), PointerTyID), address_space_(address_space){
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool pointer_type::is_valid_elt_ty(type *ty){
|
||||
return !ty->is_void_ty() && !ty->is_label_ty() &&
|
||||
!ty->is_metadata_ty() && !ty->is_token_ty();
|
||||
}
|
||||
|
||||
pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
|
||||
assert(elt_ty && "Can't get a pointer to <null> type!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
|
||||
if(!entry)
|
||||
entry.reset(new pointer_type(elt_ty, address_space));
|
||||
return entry.get();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// composite_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
type* composite_type::get_type_at_index(value *) const{
|
||||
assert(is_block_ty());
|
||||
return get_scalar_ty();
|
||||
}
|
||||
|
||||
bool composite_type::index_valid(value *idx) const{
|
||||
assert(is_block_ty());
|
||||
return idx->get_type()->is_int_or_tileint_ty();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
||||
: composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool block_type::is_valid_elt_ty(type *ty) {
|
||||
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
|
||||
}
|
||||
|
||||
unsigned block_type::get_num_elements() const {
|
||||
unsigned res = 1;
|
||||
for(auto shape: shapes_)
|
||||
res *= shape;
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned block_type::get_bitwidth() const {
|
||||
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
||||
}
|
||||
|
||||
block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
|
||||
if(!entry)
|
||||
entry.reset(new block_type(elt_ty, shapes));
|
||||
return entry.get();
|
||||
}
|
||||
|
||||
block_type* block_type::get_same_shapes(type *ty, type *ref){
|
||||
assert(ref->is_block_ty());
|
||||
return get(ty, ref->get_block_shapes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// function_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
function_type::function_type(type *ret_ty, const std::vector<type*> ¶m_tys):
|
||||
type(ret_ty->get_context(), FunctionTyID) {
|
||||
contained_tys_.push_back(ret_ty);
|
||||
for(type *ty: param_tys)
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
function_type* function_type::get(type *ret_ty, const std::vector<type *> ¶m_tys) {
|
||||
return new function_type(ret_ty, param_tys);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,68 +0,0 @@
|
||||
#include <stack>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
std::vector<basic_block*> cfg::post_order(function* fn) {
|
||||
std::stack<basic_block*> stack;
|
||||
std::set<basic_block*> visited;
|
||||
std::vector<basic_block*> result;
|
||||
// initialize stack
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
if(block->get_predecessors().empty()){
|
||||
stack.push(block);
|
||||
visited.insert(block);
|
||||
}
|
||||
// DFS
|
||||
while(!stack.empty()) {
|
||||
basic_block* current = stack.top();
|
||||
bool tail = true;
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end()){
|
||||
stack.push(succ);
|
||||
visited.insert(succ);
|
||||
tail = false;
|
||||
break;
|
||||
}
|
||||
if(tail){
|
||||
stack.pop();
|
||||
result.push_back(current);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
auto result = post_order(fn);
|
||||
std::reverse(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
do_work(i);
|
||||
}
|
||||
|
||||
void for_each_value(module &mod, const std::function<void (value *)> &do_work) {
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
for(ir::value *op: i->ops()){
|
||||
if(seen.insert(op).second)
|
||||
do_work(op);
|
||||
}
|
||||
if(seen.insert(i).second)
|
||||
do_work(i);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,81 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value::value(type *ty, const std::string &name): ty_(ty){
|
||||
set_name(name);
|
||||
}
|
||||
|
||||
void value::add_use(user *arg) {
|
||||
users_.insert(arg);
|
||||
}
|
||||
|
||||
value::users_t::iterator value::erase_use(user *arg){
|
||||
auto it = users_.find(arg);
|
||||
if(it == users_.end())
|
||||
return it;
|
||||
return users_.erase(it);
|
||||
}
|
||||
|
||||
// TODO: automatic naming scheme + update symbol table
|
||||
void value::set_name(const std::string &name){
|
||||
name_ = name;
|
||||
}
|
||||
|
||||
void value::replace_all_uses_with(value *target){
|
||||
for (auto it = users_.begin(); it != users_.end(); ) {
|
||||
it = (*it)->replace_uses_of_with(this, target);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void visitor::visit_value(ir::value* v) {
|
||||
v->accept(this);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// user class
|
||||
//===----------------------------------------------------------------------===//
|
||||
void user::set_operand(unsigned i, value *x) {
|
||||
assert(i < ops_.size() && "set_operand() out of range!");
|
||||
ops_[i] = x;
|
||||
x->add_use(this);
|
||||
}
|
||||
|
||||
value* user::get_operand(unsigned i) const {
|
||||
assert(i < ops_.size() && "get_operand() out of range!");
|
||||
return ops_[i];
|
||||
}
|
||||
|
||||
unsigned user::get_num_operands() const {
|
||||
return num_ops_;
|
||||
}
|
||||
|
||||
unsigned user::get_num_hidden() const {
|
||||
return num_hidden_;
|
||||
}
|
||||
|
||||
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
||||
for(size_t i = 0; i < ops_.size(); i++)
|
||||
if(ops_[i] == before){
|
||||
ops_[i] = after;
|
||||
after->add_use(this);
|
||||
}
|
||||
return before->erase_use(this);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -906,12 +906,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
|
||||
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
|
||||
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
|
||||
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
|
||||
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
@@ -819,6 +819,49 @@ class Kernel:
|
||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# TODO: share code with _compile & __call__
|
||||
|
||||
# preparing args
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in self.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
return generator.module
|
||||
|
||||
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
|
Reference in New Issue
Block a user