Init commit
This commit is contained in:
12
lib/ir/CMakeLists.txt
Normal file
12
lib/ir/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_mlir_dialect_library(TRITONIR
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithmetic
|
||||
MLIRControlFlow
|
||||
MLIRFunc
|
||||
MLIRTensor
|
||||
)
|
71
lib/ir/Dialect.cpp
Normal file
71
lib/ir/Dialect.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void TritonDialect::initialize() {
|
||||
registerTypes();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Ops.cpp.inc"
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pointer-type ::= `!triton.ptr<` element-type ` >`
|
||||
static Type parsePointerType(TritonDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
|
||||
Type pointeeType;
|
||||
if (parser.parseType(pointeeType))
|
||||
return Type();
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
return PointerType::get(pointeeType);
|
||||
}
|
||||
|
||||
// trtion-type ::= pointer-type
|
||||
Type TritonDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
if (keyword == "ptr")
|
||||
return parsePointerType(*this, parser);
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Printing
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void print(PointerType type, DialectAsmPrinter &os) {
|
||||
os << "ptr<" << type.getPointeeType() << ">";
|
||||
}
|
||||
|
||||
void TritonDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<PointerType>( [&](auto type) { print(type, os); })
|
||||
.Default([](Type) { llvm_unreachable("unhandled Triton type"); });
|
||||
}
|
63
lib/ir/Ops.cpp
Normal file
63
lib/ir/Ops.cpp
Normal file
@@ -0,0 +1,63 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.cpp.inc"
|
||||
|
||||
// enum attribute definitions
|
||||
#include "triton/OpsEnums.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true
|
||||
)
|
||||
);
|
||||
// other
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(),
|
||||
resultType,
|
||||
DenseElementsAttr::get(
|
||||
resultType, builder.getZeroAttr(elementType)
|
||||
)
|
||||
);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addOperands(other);
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
55
lib/ir/Types.cpp
Normal file
55
lib/ir/Types.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
// F8 & BF8
|
||||
Float8Type Float8Type::get(MLIRContext *context) {
|
||||
return Base::get(context);
|
||||
}
|
||||
|
||||
BFloat8Type BFloat8Type::get(MLIRContext *context) {
|
||||
return Base::get(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PointerType
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct triton::detail::PointerTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::pair<Type, unsigned>;
|
||||
|
||||
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<PointerTypeStorage>()) PointerTypeStorage(key);
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(pointeeType, addressSpace);
|
||||
}
|
||||
|
||||
PointerTypeStorage(const KeyTy &key)
|
||||
: pointeeType(key.first), addressSpace(key.second) {}
|
||||
|
||||
Type pointeeType;
|
||||
unsigned addressSpace;
|
||||
};
|
||||
|
||||
PointerType PointerType::get(Type pointeeType) {
|
||||
return Base::get(pointeeType.getContext(), pointeeType, 0);
|
||||
}
|
||||
|
||||
PointerType PointerType::get(Type pointeeType, unsigned addressSpace) {
|
||||
return Base::get(pointeeType.getContext(), pointeeType, addressSpace);
|
||||
}
|
||||
|
||||
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
||||
|
||||
unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Triton Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
void TritonDialect::registerTypes() {
|
||||
addTypes<Float8Type, BFloat8Type, PointerType>();
|
||||
}
|
@@ -1,41 +0,0 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
|
||||
class phi_node;
|
||||
|
||||
|
||||
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
|
||||
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
|
||||
if(parent_)
|
||||
parent_->insert_block(this);
|
||||
}
|
||||
|
||||
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
|
||||
return new basic_block(ctx, name, parent);
|
||||
}
|
||||
|
||||
void basic_block::add_predecessor(basic_block *pred) {
|
||||
preds_.push_back(pred);
|
||||
if(pred)
|
||||
pred->succs_.push_back(this);
|
||||
}
|
||||
|
||||
|
||||
|
||||
basic_block::iterator basic_block::get_first_non_phi(){
|
||||
auto it = begin();
|
||||
for(; it != end(); it++)
|
||||
if(!dynamic_cast<phi_node*>(*it)){
|
||||
return it;
|
||||
}
|
||||
return it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
@@ -1,434 +0,0 @@
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
builder::builder(context &ctx):
|
||||
ctx_(ctx), block_(nullptr) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
void builder::set_insert_point(basic_block::iterator it){
|
||||
block_ = (*it)->get_parent();
|
||||
insert_point_ = it;
|
||||
}
|
||||
|
||||
void builder::set_insert_point(instruction* i){
|
||||
block_ = i->get_parent();
|
||||
auto it = std::find(block_->begin(), block_->end(), i);
|
||||
set_insert_point(it);
|
||||
}
|
||||
|
||||
|
||||
void builder::set_insert_point_after(instruction* i){
|
||||
block_ = i->get_parent();
|
||||
auto it = std::find(block_->begin(), block_->end(), i);
|
||||
set_insert_point(++it);
|
||||
}
|
||||
|
||||
|
||||
void builder::set_insert_point(basic_block *block){
|
||||
block_ = block;
|
||||
insert_point_ = block->end();
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// convenience functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::get_int1(bool val)
|
||||
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_int32(uint32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(uint64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
|
||||
return insert(make_range::create(lo, hi));
|
||||
}
|
||||
|
||||
type *builder::get_void_ty()
|
||||
{ return type::get_void_ty(ctx_); }
|
||||
|
||||
type *builder::get_int1_ty()
|
||||
{ return type::get_int1_ty(ctx_); }
|
||||
|
||||
type *builder::get_int8_ty()
|
||||
{ return type::get_int8_ty(ctx_); }
|
||||
|
||||
type *builder::get_int16_ty()
|
||||
{ return type::get_int16_ty(ctx_); }
|
||||
|
||||
type *builder::get_int32_ty()
|
||||
{ return type::get_int32_ty(ctx_); }
|
||||
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_fp8_ty()
|
||||
{ return type::get_fp8_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
type *builder::get_bf16_ty()
|
||||
{ return type::get_bf16_ty(ctx_); }
|
||||
|
||||
type *builder::get_float_ty()
|
||||
{ return type::get_fp32_ty(ctx_); }
|
||||
|
||||
type *builder::get_double_ty()
|
||||
{ return type::get_fp64_ty(ctx_); }
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_br(basic_block *dest){
|
||||
dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(dest));
|
||||
}
|
||||
|
||||
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
|
||||
if_dest->add_predecessor(block_);
|
||||
else_dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(cond, if_dest, else_dest));
|
||||
}
|
||||
|
||||
value *builder::create_ret_void() {
|
||||
return insert(return_inst::create(ctx_));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\
|
||||
return create_cast(OPCODE, src, dst_ty);\
|
||||
}
|
||||
|
||||
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
|
||||
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
|
||||
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
|
||||
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
||||
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
||||
DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
|
||||
DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
|
||||
DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
|
||||
DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
|
||||
|
||||
value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){
|
||||
return insert(cast_inst::create(op, v, dst_ty));
|
||||
}
|
||||
|
||||
value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){
|
||||
return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node* builder::create_phi(type *ty, unsigned num_reserved){
|
||||
return insert(phi_node::create(ty, num_reserved));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary float instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
|
||||
return insert(binary_operator::create(OPCODE, lhs, rhs));\
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary int instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
value *rhs,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs));
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
|
||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\
|
||||
}\
|
||||
|
||||
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Binary
|
||||
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
|
||||
DEFINE_NOWRAP_BINARY(add, binary_op_t::Add)
|
||||
DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub)
|
||||
DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl)
|
||||
DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr)
|
||||
DEFINE_NOWRAP_BINARY(lshr, binary_op_t::LShr)
|
||||
DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv)
|
||||
DEFINE_BINARY_INT(udiv, binary_op_t::UDiv)
|
||||
DEFINE_BINARY_INT(srem, binary_op_t::SRem)
|
||||
DEFINE_BINARY_INT(urem, binary_op_t::URem)
|
||||
DEFINE_BINARY_INT(and, binary_op_t::And)
|
||||
DEFINE_BINARY_INT(or, binary_op_t::Or)
|
||||
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list){
|
||||
return insert(getelementptr_inst::create(ptr, idx_list));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// icmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){
|
||||
return insert(icmp_inst::create(pred, lhs, rhs));
|
||||
}
|
||||
|
||||
#define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_icmp(OPCODE, lhs, rhs);\
|
||||
}
|
||||
|
||||
// Signed
|
||||
DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE)
|
||||
DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT)
|
||||
DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE)
|
||||
DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT)
|
||||
// Unsigned
|
||||
DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE)
|
||||
DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT)
|
||||
DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE)
|
||||
DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT)
|
||||
// General
|
||||
DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ)
|
||||
DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// fcmp instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){
|
||||
return insert(fcmp_inst::create(pred, lhs, rhs));
|
||||
}
|
||||
|
||||
#define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\
|
||||
return create_fcmp(OPCODE, lhs, rhs);\
|
||||
}
|
||||
|
||||
// Ordered
|
||||
DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE)
|
||||
DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT)
|
||||
DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE)
|
||||
DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
|
||||
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
|
||||
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE)
|
||||
DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT)
|
||||
DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE)
|
||||
DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT)
|
||||
DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ)
|
||||
DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
return insert(masked_store_inst::create(ptr, val, mask));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// block instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(reshape_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_cat(value *lhs, value *rhs) {
|
||||
return insert(cat_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(splat_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(broadcast_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_downcast(value *arg) {
|
||||
return insert(downcast_inst::create(arg));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
|
||||
return create_atomic_rmw(OPCODE, ptr, val, mask);\
|
||||
}
|
||||
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_program_id(unsigned axis) {
|
||||
return insert(get_program_id_inst::create(ctx_, axis));
|
||||
}
|
||||
|
||||
value *builder::create_get_num_programs(unsigned axis) {
|
||||
return insert(get_num_programs_inst::create(ctx_, axis));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
return insert(exp_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_cos(value *arg){
|
||||
return insert(cos_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_sin(value *arg){
|
||||
return insert(sin_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_log(value *arg){
|
||||
return insert(log_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
|
||||
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
||||
return insert(trans_inst::create(A, perm));
|
||||
}
|
||||
|
||||
value *builder::create_sqrt(value *A) {
|
||||
return insert(sqrt_inst::create(A));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) {
|
||||
return insert(reduce_inst::create(A, op, axis));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value){
|
||||
return insert(select_inst::create(pred, if_value, else_value));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||
return insert(umulhi_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_copy_to_shared(value *arg) {
|
||||
return insert(copy_to_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_));
|
||||
}
|
||||
|
||||
value *builder::create_async_wait(int N) {
|
||||
return insert(async_wait_inst::create(ctx_, N));
|
||||
}
|
||||
|
||||
value *builder::create_prefetch_s(value *arg, int inc) {
|
||||
return insert(prefetch_s_inst::create(ctx_, arg, inc));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,118 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
// constant
|
||||
|
||||
constant *constant::get_null_value(type *ty) {
|
||||
context &ctx = ty->get_context();
|
||||
switch (ty->get_scalar_ty()->get_type_id()) {
|
||||
case type::IntegerTyID:
|
||||
return constant_int::get(ty, 0);
|
||||
case type::FP16TyID:
|
||||
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||
case type::FP32TyID:
|
||||
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||
case type::FP64TyID:
|
||||
return constant_fp::get(type::get_fp64_ty(ctx), 0);
|
||||
default:
|
||||
throw std::runtime_error("Cannot create a null constant of that type!");
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME
|
||||
|
||||
constant *constant::get_all_ones_value(type *ty) {
|
||||
if(ty->is_integer_ty())
|
||||
return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF);
|
||||
if(ty->is_floating_point_ty())
|
||||
return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF);
|
||||
throw std::runtime_error("Cannot create all ones value for that type!");
|
||||
}
|
||||
|
||||
// constant_int
|
||||
// FIXME use something like APInt
|
||||
|
||||
constant_int::constant_int(type *ty, uint64_t value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||
if (!ty->is_integer_ty())
|
||||
throw std::runtime_error("Cannot create constant_int with non integer ty");
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<constant_int> &cst = impl->int_constants_[std::make_pair(ty, value)];
|
||||
if(!cst)
|
||||
cst.reset(new constant_int(ty, value));
|
||||
return cst.get();
|
||||
}
|
||||
|
||||
|
||||
// constant_fp
|
||||
// FIXME use something like APFloat
|
||||
|
||||
constant_fp::constant_fp(type *ty, double value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant *constant_fp::get_negative_zero(type *ty){
|
||||
double neg_zero = 0;
|
||||
return get(ty, neg_zero);
|
||||
}
|
||||
|
||||
constant *constant_fp::get_zero_value_for_negation(type *ty) {
|
||||
if(ty->get_scalar_ty()->is_floating_point_ty())
|
||||
return constant_fp::get(ty, 0);
|
||||
return constant::get_null_value(ty);
|
||||
}
|
||||
|
||||
constant *constant_fp::get(type *ty, double v){
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<constant_fp> &result = impl->fp_constants_[std::make_pair(ty, v)];
|
||||
if(!result)
|
||||
result.reset(new constant_fp(ty, v));
|
||||
return result.get();
|
||||
}
|
||||
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
: constant(ty, 0) { }
|
||||
|
||||
undef_value *undef_value::get(type *ty) {
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
std::unique_ptr<undef_value> &result = impl->uv_constants_[ty];
|
||||
if(!result)
|
||||
result.reset(new undef_value(ty));
|
||||
return result.get();
|
||||
}
|
||||
|
||||
/* global value */
|
||||
global_value::global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage,
|
||||
const std::string &name, unsigned addr_space)
|
||||
: constant(pointer_type::get(ty, addr_space), num_ops, name),
|
||||
linkage_(linkage) { }
|
||||
|
||||
|
||||
/* global object */
|
||||
global_object::global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage,
|
||||
const std::string &name, unsigned addr_space)
|
||||
: global_value(ty, num_ops, linkage, name, addr_space) { }
|
||||
|
||||
|
||||
/* alloc const */
|
||||
alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name)
|
||||
: global_object(ty, 1, global_value::external, name, 4) {
|
||||
set_operand(0, size);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,40 +0,0 @@
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
context_impl::context_impl(context &ctx)
|
||||
: void_ty(ctx, type::VoidTyID),
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
// floating point
|
||||
fp8_ty(ctx, type::FP8TyID),
|
||||
fp16_ty(ctx, type::FP16TyID),
|
||||
bf16_ty(ctx, type::BF16TyID),
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
context::context():
|
||||
p_impl(std::make_shared<context_impl>(*this)) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,66 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Argument */
|
||||
|
||||
argument::argument(type *ty, const std::string &name, function *parent, unsigned arg_no)
|
||||
: value(ty, name), parent_(parent), arg_no_(arg_no) { }
|
||||
|
||||
argument *argument::create(type *ty, const std::string &name,
|
||||
function *parent, unsigned arg_no) {
|
||||
return new argument(ty, name, parent, arg_no);
|
||||
}
|
||||
|
||||
function* argument::get_parent() const {
|
||||
return parent_;
|
||||
}
|
||||
|
||||
unsigned argument::get_arg_no() const {
|
||||
return arg_no_;
|
||||
}
|
||||
|
||||
void argument::accept(visitor *v) {
|
||||
v->visit_argument(this);
|
||||
}
|
||||
|
||||
|
||||
/* function */
|
||||
function::function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *parent)
|
||||
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
|
||||
unsigned num_params = fn_ty_->get_num_params();
|
||||
// skip if no parameter
|
||||
if(num_params == 0)
|
||||
return;
|
||||
// create arguments
|
||||
args_.resize(num_params);
|
||||
for(unsigned i = 0; i < num_params; i++){
|
||||
type *param_ty = fn_ty_->get_param_ty(i);
|
||||
args_[i] = argument::create(param_ty, "", this, i);
|
||||
}
|
||||
if(parent)
|
||||
parent->push_function(this);
|
||||
}
|
||||
|
||||
/* basic block */
|
||||
void function::insert_block(basic_block *block, basic_block *next) {
|
||||
auto it = std::find(blocks_.begin(), blocks_.end(), next);
|
||||
blocks_.insert(it, block);
|
||||
}
|
||||
|
||||
|
||||
function *function::create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod) {
|
||||
return new function(ty, linkage, name, mod);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,928 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name, instruction *next)
|
||||
: user(ty, num_ops, name), id_(ity) {
|
||||
if(next){
|
||||
basic_block *block = next->get_parent();
|
||||
assert(block && "Next instruction is not in a basic block!");
|
||||
auto it = std::find(block->begin(), block->end(), next);
|
||||
block->get_inst_list().insert(it, next);
|
||||
}
|
||||
}
|
||||
|
||||
void instruction::erase_from_parent() {
|
||||
parent_->erase(this);
|
||||
for(ir::value* op: ops())
|
||||
op->erase_use(this);
|
||||
}
|
||||
|
||||
bool instruction::has_tile_result_or_op() {
|
||||
bool result = get_type()->is_block_ty();
|
||||
for(unsigned i = 0; i < get_num_operands(); i++)
|
||||
result |= get_operand(i)->get_type()->is_block_ty();
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
|
||||
: instruction(ty, INST_PHI, 0, name, next) {
|
||||
blocks_.reserve(num_reserved);
|
||||
}
|
||||
|
||||
value* phi_node::get_value_for_block(basic_block * block) {
|
||||
auto it = std::find(blocks_.begin(), blocks_.end(), block);
|
||||
size_t n = std::distance(blocks_.begin(), it);
|
||||
return get_incoming_value(n);
|
||||
}
|
||||
|
||||
// Set incoming value
|
||||
void phi_node::set_incoming_value(unsigned i, value *v){
|
||||
assert(v && "PHI node got a null value!");
|
||||
assert(get_type() == v->get_type() &&
|
||||
"All operands to PHI node must be the same type as the PHI node!");
|
||||
set_operand(i, v);
|
||||
}
|
||||
|
||||
// Set incoming block
|
||||
void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
||||
assert(block && "PHI node got a null basic block!");
|
||||
blocks_[i] = block;
|
||||
}
|
||||
|
||||
// Add incoming
|
||||
void phi_node::add_incoming(value *v, basic_block *block){
|
||||
resize_ops(get_num_operands() + 1);
|
||||
blocks_.resize(get_num_operands() + 1);
|
||||
set_incoming_value(get_num_operands() - 1, v);
|
||||
set_incoming_block(get_num_operands() - 1, block);
|
||||
}
|
||||
|
||||
// Factory methods
|
||||
phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
|
||||
return new phi_node(ty, num_reserved, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string binary_operator::repr_impl() const {
|
||||
switch(op_) {
|
||||
case Add : return "add";
|
||||
case FAdd : return "fadd";
|
||||
case Sub : return "sub";
|
||||
case FSub : return "fsub";
|
||||
case Mul : return "mul";
|
||||
case FMul : return "fmul";
|
||||
case UDiv : return "udiv";
|
||||
case SDiv : return "sdiv";
|
||||
case FDiv : return "fdiv";
|
||||
case URem : return "urem";
|
||||
case SRem : return "srem";
|
||||
case FRem : return "frem";
|
||||
case Shl : return "shl";
|
||||
case LShr : return "lshr";
|
||||
case AShr : return "ashr";
|
||||
case And : return "and";
|
||||
case Or : return "or";
|
||||
case Xor : return "xor";
|
||||
default: throw std::runtime_error("unknown binary operator");
|
||||
}
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_div() const {
|
||||
return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_rem() const {
|
||||
return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shl() const {
|
||||
return op_ == binary_op_t::Shl;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shr() const {
|
||||
return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_mult() const {
|
||||
return op_ == binary_op_t::Mul;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_add_sub() const {
|
||||
return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
|
||||
}
|
||||
|
||||
|
||||
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(lhs->get_type() == rhs->get_type() &&
|
||||
"Cannot create binary operator with two operands of differing type!");
|
||||
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
|
||||
}
|
||||
|
||||
//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
|
||||
// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->is_integer_ty());
|
||||
// constant *mask = constant::get_all_ones_value(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
||||
//}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
|
||||
// cmp_inst
|
||||
std::string cmp_inst::repr_impl() const {
|
||||
switch (pred_) {
|
||||
case FCMP_FALSE : return "false";
|
||||
case FCMP_OEQ : return "fcmp_oeq";
|
||||
case FCMP_OGT : return "fcmp_ogt";
|
||||
case FCMP_OGE : return "fcmp_oge";
|
||||
case FCMP_OLT : return "fcmp_olt";
|
||||
case FCMP_OLE : return "fcmp_ole";
|
||||
case FCMP_ONE : return "fcmp_one";
|
||||
case FCMP_ORD : return "fcmp_ord";
|
||||
case FCMP_UNO : return "fcmp_uno";
|
||||
case FCMP_UEQ : return "fcmp_ueq";
|
||||
case FCMP_UGT : return "fcmp_ugt";
|
||||
case FCMP_UGE : return "fcmp_uge";
|
||||
case FCMP_ULT : return "fcmp_ult";
|
||||
case FCMP_ULE : return "fcmp_ule";
|
||||
case FCMP_UNE : return "fcmp_une";
|
||||
case FCMP_TRUE : return "true";
|
||||
case ICMP_EQ : return "icmp_eq";
|
||||
case ICMP_NE : return "icmp_ne";
|
||||
case ICMP_UGT : return "icmp_ugt";
|
||||
case ICMP_UGE : return "icmp_uge";
|
||||
case ICMP_ULT : return "icmp_ult";
|
||||
case ICMP_ULE : return "icmp_ule";
|
||||
case ICMP_SGT : return "icmp_sgt";
|
||||
case ICMP_SGE : return "icmp_sge";
|
||||
case ICMP_SLT : return "icmp_slt";
|
||||
case ICMP_SLE : return "icmp_sle";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 2, name, next), pred_(pred) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
type* cmp_inst::make_cmp_result_type(type *ty){
|
||||
type* int1_ty = type::get_int1_ty(ty->get_context());
|
||||
if (block_type* tile_ty = dynamic_cast<block_type*>(ty))
|
||||
return block_type::get_same_shapes(int1_ty, tile_ty);
|
||||
return int1_ty;
|
||||
}
|
||||
|
||||
|
||||
bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
|
||||
}
|
||||
|
||||
bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
|
||||
}
|
||||
|
||||
|
||||
// icmp_inst
|
||||
icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_int_predicate(pred));
|
||||
assert(lhs->get_type() == rhs->get_type());
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// fcmp_inst
|
||||
fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_fp_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 1, name, next) {
|
||||
set_operand(0, v);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string cast_inst::repr_impl() const {
|
||||
switch (op_){
|
||||
case cast_op_t::Trunc: return "trunc";
|
||||
case cast_op_t::ZExt: return "zext";
|
||||
case cast_op_t::SExt: return "sext";
|
||||
case cast_op_t::FPTrunc: return "fp_trunc";
|
||||
case cast_op_t::FPExt: return "fp_ext";
|
||||
case cast_op_t::UIToFP: return "ui_to_fp";
|
||||
case cast_op_t::SIToFP: return "si_to_fp";
|
||||
case cast_op_t::FPToUI: return "fp_to_ui";
|
||||
case cast_op_t::FPToSI: return "fp_to_si";
|
||||
case cast_op_t::PtrToInt: return "ptr_to_int";
|
||||
case cast_op_t::IntToPtr: return "int_to_ptr";
|
||||
case cast_op_t::BitCast: return "bitcast";
|
||||
case cast_op_t::AddrSpaceCast: return "addr_space_cast";
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
// TODO
|
||||
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
|
||||
assert(arg->get_type()->is_block_ty() == ty->is_block_ty());
|
||||
return true;
|
||||
}
|
||||
|
||||
cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
|
||||
assert(is_valid(op, arg, ty) && "Invalid cast!");
|
||||
// Construct and return the appropriate CastInst subclass
|
||||
switch (op) {
|
||||
case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
|
||||
case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
|
||||
case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
|
||||
case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
|
||||
case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
|
||||
case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
|
||||
case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
|
||||
type *arg_ty = arg->get_type();
|
||||
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
||||
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
|
||||
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
|
||||
(arg_bits > dst_bits ? cast_op_t::Trunc :
|
||||
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
|
||||
return create(op, arg, ty, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// return_inst
|
||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||
if(ret_val)
|
||||
set_operand(0, ret_val);
|
||||
}
|
||||
|
||||
return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){
|
||||
return new return_inst(ctx, ret_val, next);
|
||||
}
|
||||
|
||||
|
||||
// branch_inst
|
||||
branch_inst* branch_inst::create(basic_block *dst, instruction *next) {
|
||||
assert(dst && "Branch destination may not be null!");
|
||||
return new uncond_branch_inst(dst, next);
|
||||
}
|
||||
|
||||
branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) {
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
return new cond_branch_inst(if_dst, else_dst, cond, next);
|
||||
}
|
||||
|
||||
// uncond_branch_inst
|
||||
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
|
||||
: branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
|
||||
set_operand(0, dst);
|
||||
}
|
||||
|
||||
// cond_branch_inst
|
||||
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
|
||||
: branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
set_operand(0, if_dst);
|
||||
set_operand(1, else_dst);
|
||||
set_operand(2, cond);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
|
||||
: instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
|
||||
source_elt_ty(pointee_ty),
|
||||
res_elt_ty(get_indexed_type(pointee_ty, idx)){
|
||||
// sanity check
|
||||
type *expected_ty = get_type()->get_scalar_ty();
|
||||
expected_ty = ((pointer_type*)expected_ty)->get_element_ty();
|
||||
assert(res_elt_ty == expected_ty);
|
||||
// set operands
|
||||
set_operand(0, ptr);
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
set_operand(1 + i, idx[i]);
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
|
||||
// result pointer type
|
||||
type *ty = x->get_type();
|
||||
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
|
||||
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
|
||||
// Tile GEP
|
||||
if(ty->is_block_ty())
|
||||
return block_type::get_same_shapes(ptr_ty, ty);
|
||||
for(value *idx : idx_list)
|
||||
if (idx->get_type()->is_block_ty())
|
||||
return block_type::get_same_shapes(ptr_ty, ty);
|
||||
// Scalar GEP
|
||||
return ptr_ty;
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector<value *> &idx_list) {
|
||||
if(idx_list.empty())
|
||||
return ty;
|
||||
if(!ty->is_sized())
|
||||
return nullptr;
|
||||
unsigned cur_idx = 1;
|
||||
for(; cur_idx != idx_list.size(); cur_idx++){
|
||||
composite_type *cty = dynamic_cast<composite_type*>(ty);
|
||||
if(!cty || cty->is_pointer_ty())
|
||||
break;
|
||||
value *idx = idx_list[cur_idx];
|
||||
if(!cty->index_valid(idx))
|
||||
break;
|
||||
ty = cty->get_type_at_index(idx);
|
||||
}
|
||||
return (cur_idx == idx_list.size())? ty : nullptr;
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_indexed_type(type *ty, const std::vector<value *> &idx_list) {
|
||||
type *result = get_indexed_type_impl(ty, idx_list);
|
||||
assert(result && "invalid GEP type!");
|
||||
return result;
|
||||
}
|
||||
|
||||
getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next) {
|
||||
type *pointee_ty = ((pointer_type*)(ptr->get_type()->get_scalar_ty()))->get_element_ty();
|
||||
return new getelementptr_inst(pointee_ty, ptr, idx, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
type *load_inst::get_pointee_type(type *ty) {
|
||||
type *scalar_ty = ty->get_scalar_ty();
|
||||
type *pointee_ty = scalar_ty->get_pointer_element_ty();
|
||||
if(ty->is_block_ty())
|
||||
return block_type::get_same_shapes(pointee_ty, ty);
|
||||
return pointee_ty;
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||
{x->get_type()->get_block_shapes()[0] +
|
||||
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||
set_operand(0, x);
|
||||
set_operand(1, y);
|
||||
}
|
||||
|
||||
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new cat_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// retile
|
||||
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
|
||||
}
|
||||
|
||||
|
||||
// splat
|
||||
|
||||
instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
|
||||
instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
|
||||
}
|
||||
|
||||
// downcast
|
||||
|
||||
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
allow_tf32_ = allow_tf32;
|
||||
}
|
||||
|
||||
instruction *dot_inst::create(value *A, value *B, value *C,
|
||||
bool AT, bool BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
TransT OPA = AT ? Trans : NoTrans;
|
||||
TransT OPB = BT ? Trans : NoTrans;
|
||||
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// trans instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
|
||||
// get argument shapes
|
||||
ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes();
|
||||
// permutate argument shapes
|
||||
perm = init_perm(ty, perm);
|
||||
ir::block_type::block_shapes_t res_shapes = arg_shapes;
|
||||
for(size_t i = 0; i < perm.size(); i++)
|
||||
res_shapes[i] = arg_shapes[perm[i]];
|
||||
// construct type
|
||||
return block_type::get(ty->get_scalar_ty(), res_shapes);
|
||||
}
|
||||
|
||||
std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
|
||||
if(!perm.empty())
|
||||
return perm;
|
||||
auto size = ty->get_block_shapes().size();
|
||||
std::vector<int> result;
|
||||
result.push_back(size - 1);
|
||||
for(size_t i = 0; i < size - 1; i++)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
||||
// sanity check
|
||||
perm_ = init_perm(arg->get_type(), perm);
|
||||
//auto size = arg->get_type()->get_tile_shapes().size();
|
||||
//assert(perm_.size() == size);
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
|
||||
return new trans_inst(arg, perm, name, next);
|
||||
}
|
||||
|
||||
const std::vector<int> trans_inst::get_perm() const {
|
||||
return perm_;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// sqrt instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new sqrt_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string reduce_inst::to_str(op_t op) {
|
||||
switch (op) {
|
||||
case ADD: return "+";
|
||||
case SUB: return "-";
|
||||
case MAX: return "imax";
|
||||
case MIN: return "imin";
|
||||
case FADD: return "+";
|
||||
case FSUB: return "-";
|
||||
case FMAX: return "fmax";
|
||||
case FMIN: return "fmin";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
type *scalar_ty = arg->get_type()->get_scalar_ty();
|
||||
if(shapes.empty())
|
||||
// shapes.push_back(1);
|
||||
return scalar_ty;
|
||||
return block_type::get(scalar_ty, shapes);
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
|
||||
op_(op),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, op, axis, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// select instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
|
||||
: builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
|
||||
set_operand(0, pred);
|
||||
set_operand(1, if_value);
|
||||
set_operand(2, else_value);
|
||||
}
|
||||
|
||||
instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
|
||||
return new select_inst(pred, if_value, else_value, name, next);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// get_program_id
|
||||
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic_rmw
|
||||
|
||||
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, cmp);
|
||||
set_operand(2, val);
|
||||
}
|
||||
|
||||
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// umulhi
|
||||
|
||||
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new umulhi_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_EXP, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* exp_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new exp_inst(val, name, next);
|
||||
}
|
||||
|
||||
// cos
|
||||
cos_inst::cos_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_COS, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* cos_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new cos_inst(val, name, next);
|
||||
}
|
||||
|
||||
// sin
|
||||
sin_inst::sin_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_SIN, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* sin_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new sin_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// log
|
||||
|
||||
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new log_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cvt_scanline
|
||||
cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next);
|
||||
}
|
||||
|
||||
// copy to shared
|
||||
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// copy from shared
|
||||
copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
|
||||
|
||||
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
|
||||
return new async_wait_inst(ctx, N, name, next);
|
||||
}
|
||||
|
||||
// prefetch_s
|
||||
prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
|
||||
return new prefetch_s_inst(ctx, arg, inc, name, next);
|
||||
}
|
||||
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
|
||||
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
// return new make_range_dyn(ty, name, next);
|
||||
//}
|
||||
|
||||
//// nv_static_program_idx
|
||||
//make_range_sta::make_range_sta(make_range *range)
|
||||
// : constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
//make_range* make_range_sta::get_range() const
|
||||
//{ return range_; }
|
||||
|
||||
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||
// static std::map<make_range*, make_range_sta*> cache;
|
||||
// if(cache.find(range) == cache.end())
|
||||
// cache.insert({range, new make_range_sta(range)});
|
||||
// return cache.at(range);
|
||||
//}
|
||||
|
||||
|
||||
// make_range
|
||||
make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
: instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
|
||||
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
// assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||
return new make_range(ty, first, last);
|
||||
}
|
||||
|
||||
const constant_int* make_range::get_first() const {
|
||||
return first_;
|
||||
}
|
||||
|
||||
const constant_int* make_range::get_last() const {
|
||||
return last_;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,14 +0,0 @@
|
||||
#include "triton/ir/metadata.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
metadata::metadata(kind_t kind, unsigned value)
|
||||
: kind_(kind), value_(value) { }
|
||||
|
||||
metadata* metadata::get(kind_t kind, unsigned value) {
|
||||
return new metadata(kind, value);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,22 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* functions */
|
||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||
function *&fn = (function*&)symbols_[name];
|
||||
if(fn == nullptr)
|
||||
return fn = function::create(ty, global_value::external, name, this);
|
||||
return fn;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
450
lib/ir/print.cc
450
lib/ir/print.cc
@@ -1,450 +0,0 @@
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/print.h"
|
||||
|
||||
#include <map>
|
||||
#include <iomanip>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
namespace {
|
||||
class SlotTracker {
|
||||
// A mapping of values to slot numbers.
|
||||
using value_map = std::map<const value*, unsigned>;
|
||||
|
||||
// The module for which we are holding slot numbers.
|
||||
const module *mod_;
|
||||
bool module_processed = false;
|
||||
|
||||
// The function for which we are holding slot numbers.
|
||||
const function *func_ = nullptr;
|
||||
bool function_processed = false;
|
||||
|
||||
// m_map - The slot map for the module level data.
|
||||
value_map m_map;
|
||||
unsigned m_next = 0;
|
||||
|
||||
// f_map - The slot map for the function level data.
|
||||
value_map f_map;
|
||||
unsigned f_next = 0;
|
||||
|
||||
public:
|
||||
// Construct from a module
|
||||
explicit SlotTracker(const module *mod) : mod_(mod) {}
|
||||
|
||||
// Construct from a function
|
||||
explicit SlotTracker(const function *f)
|
||||
: mod_(f? f->get_parent() : nullptr), func_(f) {}
|
||||
|
||||
// Return the slot number of the specified value. If something is not in
|
||||
// the SlotTracker, return -1
|
||||
int get_local_slot(const value *v);
|
||||
|
||||
void initialize_if_needed();
|
||||
|
||||
// If you'd like to deal with a function instead of just a module, use
|
||||
// this method to get its data into the SlotTracker
|
||||
void incorporate_function(const function *f) {
|
||||
func_ = f;
|
||||
function_processed = false;
|
||||
}
|
||||
|
||||
private:
|
||||
// Add all of the module level global variables (and their initializers)
|
||||
// and function declarations, but not contents of those functions.
|
||||
void process_module();
|
||||
|
||||
// Add all of the functions arguments, basic blocks, and instructions.
|
||||
void process_function();
|
||||
|
||||
// Insert specified value* into the slot table
|
||||
void create_function_slot(const value *v);
|
||||
};
|
||||
|
||||
class AssemblyWriter {
|
||||
std::ostream &os;
|
||||
SlotTracker &slot_tracker;
|
||||
|
||||
public:
|
||||
AssemblyWriter(std::ostream &os, SlotTracker &slot_tracker)
|
||||
: os(os), slot_tracker(slot_tracker) {}
|
||||
|
||||
void print_module(const module *mod);
|
||||
void print_function(const function *f);
|
||||
void print_argument(const argument *arg);
|
||||
void print_basic_block(const basic_block *bb);
|
||||
void print_instruction(const instruction *instr);
|
||||
void print_value(const value *v);
|
||||
|
||||
void write_operand(const value *op, bool print_type = false);
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
//-------------------------
|
||||
// SlotTracker
|
||||
//-------------------------
|
||||
void SlotTracker::process_module() {
|
||||
// Nothing to do at the moment.
|
||||
// Create slots for global variable & unamed functions & ...
|
||||
module_processed = true;
|
||||
}
|
||||
|
||||
void SlotTracker::process_function() {
|
||||
f_next = 0;
|
||||
|
||||
// Add all the function arguments with no names.
|
||||
for (const argument *arg : func_->args())
|
||||
if (!arg->has_name())
|
||||
create_function_slot(arg);
|
||||
|
||||
// Add all of the basic blocks and instructions with no names.
|
||||
for (const basic_block *bb : func_->blocks()) {
|
||||
if (!bb->has_name())
|
||||
create_function_slot(bb);
|
||||
|
||||
for (const instruction *instr : bb->get_inst_list()) {
|
||||
if (!instr->get_type()->is_void_ty() && !instr->has_name())
|
||||
create_function_slot(instr);
|
||||
}
|
||||
}
|
||||
|
||||
function_processed = true;
|
||||
}
|
||||
|
||||
void SlotTracker::create_function_slot(const value *v) {
|
||||
assert(!v->get_type()->is_void_ty() && !v->has_name() && "Doesn't need a slot");
|
||||
|
||||
unsigned dst_slot = f_next++;
|
||||
f_map[v] = dst_slot;
|
||||
}
|
||||
|
||||
int SlotTracker::get_local_slot(const value *v) {
|
||||
assert(dynamic_cast<const constant*>(v) == nullptr && "Can't get a constant slot");
|
||||
|
||||
// Check for uninitialized state and do lazy initialization.
|
||||
initialize_if_needed();
|
||||
|
||||
value_map::iterator f_iter = f_map.find(v);
|
||||
return f_iter == f_map.end() ? -1 : (int)f_iter->second;
|
||||
}
|
||||
|
||||
void SlotTracker::initialize_if_needed() {
|
||||
if (mod_ && !module_processed)
|
||||
process_module();
|
||||
|
||||
if (func_ && !function_processed)
|
||||
process_function();
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------
|
||||
// AssemblyWriter
|
||||
//-------------------------------
|
||||
void AssemblyWriter::write_operand(const value *operand, bool print_type) {
|
||||
if (!operand) {
|
||||
os << "<null operand!>";
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto *c = dynamic_cast<const ir::constant*>(operand)) {
|
||||
os << c->repr();
|
||||
return;
|
||||
}
|
||||
|
||||
if (operand->has_name()) {
|
||||
os << operand->get_name();
|
||||
return;
|
||||
}
|
||||
|
||||
// Print the normal way
|
||||
int slot_num = slot_tracker.get_local_slot(operand);
|
||||
|
||||
if (slot_num != -1)
|
||||
os << "%" << slot_num;
|
||||
else
|
||||
os << "<badref>";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_module(const module *mod) {
|
||||
slot_tracker.initialize_if_needed();
|
||||
// ;ModuleID = ...
|
||||
// source_filename = ...
|
||||
|
||||
// Print all of the functions.
|
||||
for (function *f : mod->get_function_list()) {
|
||||
os << "\n";
|
||||
print_function(f);
|
||||
}
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_function(const function *f) {
|
||||
// Annotation & Attributes
|
||||
|
||||
slot_tracker.incorporate_function(f);
|
||||
|
||||
os << "def ";
|
||||
ir::type *rt_type = f->get_fn_type()->get_return_ty();
|
||||
// Functions must have names.
|
||||
os << rt_type->repr() << " " << f->get_name() << "(";
|
||||
// Print arguments
|
||||
for (ir::argument *arg : f->args()) {
|
||||
if (arg->get_arg_no() > 0)
|
||||
os << ", ";
|
||||
print_argument(arg);
|
||||
}
|
||||
os << ")";
|
||||
|
||||
// Print function body
|
||||
os << "{";
|
||||
for (const basic_block *bb : f->blocks())
|
||||
print_basic_block(bb);
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_argument(const argument *arg) {
|
||||
// Print type
|
||||
os << arg->get_type()->repr();
|
||||
|
||||
// Print name, if available.
|
||||
if (arg->has_name())
|
||||
os << " " << arg->get_name();
|
||||
else {
|
||||
int slot_num = slot_tracker.get_local_slot(arg);
|
||||
assert(slot_num != -1 && "expect argument in function here");
|
||||
os << " %" << slot_num;
|
||||
}
|
||||
|
||||
// Print attributes
|
||||
std::set<attribute> attrs = arg->get_parent()->get_attributes(arg);
|
||||
for (attribute attr : attrs)
|
||||
os << " " << attr.repr();
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_basic_block(const basic_block *bb) {
|
||||
// bb label
|
||||
if (bb->has_name()) {
|
||||
os << "\n";
|
||||
os << bb->get_name() << ":";
|
||||
} else {
|
||||
os << "\n";
|
||||
int slot_num = slot_tracker.get_local_slot(bb);
|
||||
if (slot_num != -1)
|
||||
os << slot_num << ":";
|
||||
else
|
||||
os << "<badref>:";
|
||||
}
|
||||
|
||||
// Print predecessors for the block
|
||||
auto const &predecessors = bb->get_predecessors();
|
||||
if (!predecessors.empty()) {
|
||||
os << std::setw(50) << std::setfill(' ')
|
||||
<< "; preds = ";
|
||||
for (size_t i=0; i<predecessors.size(); ++i) {
|
||||
if (i)
|
||||
os << ", ";
|
||||
write_operand(predecessors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
|
||||
// Annotation?
|
||||
|
||||
// Print all of the instructions in the basic block
|
||||
for (const ir::instruction *instr : bb->get_inst_list())
|
||||
print_instruction(instr);
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_instruction(const instruction *instr) {
|
||||
// Print out indentation for an instruction.
|
||||
os << " ";
|
||||
|
||||
ir::type *type = instr->get_type();
|
||||
if (instr->has_name()) {
|
||||
os << instr->get_name();
|
||||
os << " = ";
|
||||
} else if (!type->is_void_ty()) {
|
||||
// Print out the def slot taken.
|
||||
int slot_num = slot_tracker.get_local_slot(instr);
|
||||
if (slot_num == -1)
|
||||
os << "<badref> = ";
|
||||
else
|
||||
os << "%" << slot_num << " = ";
|
||||
}
|
||||
|
||||
// Print out opcode
|
||||
os << instr->repr() << " " << type->repr();
|
||||
|
||||
size_t num_ops = instr->get_num_operands();
|
||||
if (num_ops > 0)
|
||||
os << " ";
|
||||
ir::instruction::ops_t ops = instr->ops();
|
||||
for (unsigned i = 0; i < num_ops; ++i) {
|
||||
if (i)
|
||||
os << ", ";
|
||||
write_operand(ops[i]);
|
||||
}
|
||||
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
void AssemblyWriter::print_value(const value *v) {
|
||||
// Not implemented
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------
|
||||
// External interface
|
||||
//-------------------------------
|
||||
void module::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this);
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_module(this);
|
||||
}
|
||||
|
||||
void function::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this);
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_function(this);
|
||||
}
|
||||
|
||||
void basic_block::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this->get_parent());
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_basic_block(this);
|
||||
}
|
||||
|
||||
void instruction::print(std::ostream &os) {
|
||||
SlotTracker slot_tracker(this->get_parent()->get_parent());
|
||||
AssemblyWriter writer(os, slot_tracker);
|
||||
writer.print_instruction(this);
|
||||
}
|
||||
|
||||
//-------------------------------
|
||||
// legacy print interface
|
||||
//-------------------------------
|
||||
std::string get_name(ir::value *v, unsigned i) {
|
||||
if(v->get_name().empty()){
|
||||
std::string name = "%" + std::to_string(i);
|
||||
v->set_name(name);
|
||||
}
|
||||
return v->get_name();
|
||||
}
|
||||
|
||||
|
||||
void print(module &mod, std::ostream& os) {
|
||||
unsigned cnt = 0;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ;
|
||||
for(ir::argument* arg: fn->args()) {
|
||||
if(arg->get_arg_no() > 0)
|
||||
os << ", ";
|
||||
os << arg->get_type()->repr() << " " << arg->get_name();
|
||||
auto attrs = fn->get_attributes(arg);
|
||||
if(attrs.size() > 0)
|
||||
os << " ";
|
||||
for(ir::attribute attr: attrs)
|
||||
os << attr.repr() << " ";
|
||||
}
|
||||
os << ")" << std::endl;
|
||||
os << "{" << std::endl;
|
||||
for(ir::basic_block *block: fn->blocks()){
|
||||
auto const &predecessors = block->get_predecessors();
|
||||
os << block->get_name() << ":";
|
||||
if(!predecessors.empty()){
|
||||
os << " ";
|
||||
os << "; preds = ";
|
||||
auto const &predecessors = block->get_predecessors();
|
||||
for(ir::basic_block *pred: predecessors)
|
||||
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
|
||||
}
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: block->get_inst_list()){
|
||||
os << " ";
|
||||
if(!inst->get_type()->is_void_ty()){
|
||||
os << get_name(inst, cnt++);
|
||||
os << " = ";
|
||||
}
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
size_t num_ops = inst->get_num_operands();
|
||||
if(num_ops > 0)
|
||||
os << " ";;
|
||||
for(unsigned i = 0; i < num_ops; i++){
|
||||
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||
os << x->repr();
|
||||
else
|
||||
os << get_name(ops[i], cnt++);
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void print(function &fn, std::ostream &os) {
|
||||
//
|
||||
}
|
||||
|
||||
void print(basic_block &bb, std::ostream &os) {
|
||||
auto const &predecessors = bb.get_predecessors();
|
||||
os << bb.get_name() << ":";
|
||||
if(!predecessors.empty()){
|
||||
os << " ";
|
||||
os << "; preds = ";
|
||||
auto const &predecessors = bb.get_predecessors();
|
||||
for(ir::basic_block *pred: predecessors)
|
||||
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
|
||||
}
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: bb.get_inst_list()){
|
||||
print(*inst, os);
|
||||
}
|
||||
}
|
||||
|
||||
void print(instruction &instr, std::ostream &os) {
|
||||
instruction *inst = &instr;
|
||||
os << " ";
|
||||
if(!inst->get_type()->is_void_ty()){
|
||||
os << instr.get_name();
|
||||
os << " = ";
|
||||
}
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
size_t num_ops = inst->get_num_operands();
|
||||
if(num_ops > 0)
|
||||
os << " ";;
|
||||
for(unsigned i = 0; i < num_ops; i++){
|
||||
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||
os << x->repr();
|
||||
else
|
||||
os << ops[i]->get_name();
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
233
lib/ir/type.cc
233
lib/ir/type.cc
@@ -1,233 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/constant.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// attributes
|
||||
type *type::get_scalar_ty() const {
|
||||
if(is_block_ty())
|
||||
return get_tile_element_ty();
|
||||
return const_cast<type*>(this);
|
||||
}
|
||||
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case FP16TyID: return 16;
|
||||
case BF16TyID: return 16;
|
||||
case FP32TyID: return 32;
|
||||
case FP64TyID: return 64;
|
||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == FP16TyID) return 10;
|
||||
if (id == BF16TyID) return 7;
|
||||
if (id == FP32TyID) return 23;
|
||||
if (id == FP64TyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
type* type::get_tile_element_ty() const {
|
||||
assert(is_block_ty());
|
||||
return contained_tys_[0];
|
||||
}
|
||||
|
||||
unsigned type::get_pointer_address_space() const {
|
||||
assert(is_pointer_ty());
|
||||
return ((pointer_type*)this)->get_address_space();
|
||||
}
|
||||
|
||||
type * type::get_pointer_element_ty() const {
|
||||
type *ptr_ty = get_scalar_ty();
|
||||
assert(ptr_ty->is_pointer_ty());
|
||||
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
|
||||
if(is_block_ty())
|
||||
return block_type::get_same_shapes(scalar_ty, (type*)this);
|
||||
return scalar_ty;
|
||||
}
|
||||
|
||||
|
||||
type::block_shapes_t type::get_block_shapes() const {
|
||||
assert(is_block_ty());
|
||||
return ((block_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_rank() const {
|
||||
return get_block_shapes().size();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_ranks1() const {
|
||||
int ret = 0;
|
||||
for(int s: get_block_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
unsigned type::get_tile_num_elements() const {
|
||||
const block_shapes_t& shapes = get_block_shapes();
|
||||
unsigned result = 1;
|
||||
for(auto shape: shapes)
|
||||
result *= shape;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// composite predicates
|
||||
bool type::is_int_or_tileint_ty()
|
||||
{ return get_scalar_ty()->is_integer_ty(); }
|
||||
|
||||
bool type::is_integer_ty(unsigned width) const
|
||||
{ return is_integer_ty() && get_integer_bitwidth()== width; }
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
if(is_integer_ty() || is_floating_point_ty() ||
|
||||
is_pointer_ty()){
|
||||
return true;
|
||||
}
|
||||
// tile types are sizes
|
||||
if(is_block_ty())
|
||||
return get_scalar_ty()->is_sized();
|
||||
return false;
|
||||
}
|
||||
|
||||
// primitive types
|
||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
|
||||
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
|
||||
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
|
||||
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
|
||||
// integer types
|
||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||
integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
|
||||
|
||||
|
||||
pointer_type::pointer_type(type *ty, unsigned address_space)
|
||||
: type(ty->get_context(), PointerTyID), address_space_(address_space){
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool pointer_type::is_valid_elt_ty(type *ty){
|
||||
return !ty->is_void_ty() && !ty->is_label_ty() &&
|
||||
!ty->is_metadata_ty() && !ty->is_token_ty();
|
||||
}
|
||||
|
||||
pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
|
||||
assert(elt_ty && "Can't get a pointer to <null> type!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
|
||||
if(!entry)
|
||||
entry.reset(new pointer_type(elt_ty, address_space));
|
||||
return entry.get();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// composite_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
type* composite_type::get_type_at_index(value *) const{
|
||||
assert(is_block_ty());
|
||||
return get_scalar_ty();
|
||||
}
|
||||
|
||||
bool composite_type::index_valid(value *idx) const{
|
||||
assert(is_block_ty());
|
||||
return idx->get_type()->is_int_or_tileint_ty();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
||||
: composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool block_type::is_valid_elt_ty(type *ty) {
|
||||
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
|
||||
}
|
||||
|
||||
unsigned block_type::get_num_elements() const {
|
||||
unsigned res = 1;
|
||||
for(auto shape: shapes_)
|
||||
res *= shape;
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned block_type::get_bitwidth() const {
|
||||
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
||||
}
|
||||
|
||||
block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
|
||||
if(!entry)
|
||||
entry.reset(new block_type(elt_ty, shapes));
|
||||
return entry.get();
|
||||
}
|
||||
|
||||
block_type* block_type::get_same_shapes(type *ty, type *ref){
|
||||
assert(ref->is_block_ty());
|
||||
return get(ty, ref->get_block_shapes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// function_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
function_type::function_type(type *ret_ty, const std::vector<type*> ¶m_tys):
|
||||
type(ret_ty->get_context(), FunctionTyID) {
|
||||
contained_tys_.push_back(ret_ty);
|
||||
for(type *ty: param_tys)
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
function_type* function_type::get(type *ret_ty, const std::vector<type *> ¶m_tys) {
|
||||
return new function_type(ret_ty, param_tys);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,68 +0,0 @@
|
||||
#include <stack>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
std::vector<basic_block*> cfg::post_order(function* fn) {
|
||||
std::stack<basic_block*> stack;
|
||||
std::set<basic_block*> visited;
|
||||
std::vector<basic_block*> result;
|
||||
// initialize stack
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
if(block->get_predecessors().empty()){
|
||||
stack.push(block);
|
||||
visited.insert(block);
|
||||
}
|
||||
// DFS
|
||||
while(!stack.empty()) {
|
||||
basic_block* current = stack.top();
|
||||
bool tail = true;
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end()){
|
||||
stack.push(succ);
|
||||
visited.insert(succ);
|
||||
tail = false;
|
||||
break;
|
||||
}
|
||||
if(tail){
|
||||
stack.pop();
|
||||
result.push_back(current);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
auto result = post_order(fn);
|
||||
std::reverse(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
do_work(i);
|
||||
}
|
||||
|
||||
void for_each_value(module &mod, const std::function<void (value *)> &do_work) {
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
for(ir::value *op: i->ops()){
|
||||
if(seen.insert(op).second)
|
||||
do_work(op);
|
||||
}
|
||||
if(seen.insert(i).second)
|
||||
do_work(i);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,81 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value::value(type *ty, const std::string &name): ty_(ty){
|
||||
set_name(name);
|
||||
}
|
||||
|
||||
void value::add_use(user *arg) {
|
||||
users_.insert(arg);
|
||||
}
|
||||
|
||||
value::users_t::iterator value::erase_use(user *arg){
|
||||
auto it = users_.find(arg);
|
||||
if(it == users_.end())
|
||||
return it;
|
||||
return users_.erase(it);
|
||||
}
|
||||
|
||||
// TODO: automatic naming scheme + update symbol table
|
||||
void value::set_name(const std::string &name){
|
||||
name_ = name;
|
||||
}
|
||||
|
||||
void value::replace_all_uses_with(value *target){
|
||||
for (auto it = users_.begin(); it != users_.end(); ) {
|
||||
it = (*it)->replace_uses_of_with(this, target);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void visitor::visit_value(ir::value* v) {
|
||||
v->accept(this);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// user class
|
||||
//===----------------------------------------------------------------------===//
|
||||
void user::set_operand(unsigned i, value *x) {
|
||||
assert(i < ops_.size() && "set_operand() out of range!");
|
||||
ops_[i] = x;
|
||||
x->add_use(this);
|
||||
}
|
||||
|
||||
value* user::get_operand(unsigned i) const {
|
||||
assert(i < ops_.size() && "get_operand() out of range!");
|
||||
return ops_[i];
|
||||
}
|
||||
|
||||
unsigned user::get_num_operands() const {
|
||||
return num_ops_;
|
||||
}
|
||||
|
||||
unsigned user::get_num_hidden() const {
|
||||
return num_hidden_;
|
||||
}
|
||||
|
||||
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
||||
for(size_t i = 0; i < ops_.size(); i++)
|
||||
if(ops_[i] == before){
|
||||
ops_[i] = after;
|
||||
after->add_use(this);
|
||||
}
|
||||
return before->erase_use(this);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user