2021-07-27 12:38:38 -07:00
|
|
|
#include <algorithm>
|
2021-04-20 22:29:40 -04:00
|
|
|
#include <iostream>
|
2021-07-27 12:38:38 -07:00
|
|
|
#include "triton/ir/context.h"
|
|
|
|
#include "triton/ir/basic_block.h"
|
|
|
|
#include "triton/ir/instructions.h"
|
|
|
|
#include "triton/ir/constant.h"
|
|
|
|
#include "triton/ir/type.h"
|
2022-04-03 20:58:16 -07:00
|
|
|
#include "triton/ir/function.h"
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
namespace triton{
|
|
|
|
namespace ir{
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// instruction classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
|
|
|
|
const std::string &name, instruction *next)
|
|
|
|
: user(ty, num_ops, name), id_(ity) {
|
|
|
|
if(next){
|
|
|
|
basic_block *block = next->get_parent();
|
|
|
|
assert(block && "Next instruction is not in a basic block!");
|
|
|
|
auto it = std::find(block->begin(), block->end(), next);
|
|
|
|
block->get_inst_list().insert(it, next);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void instruction::erase_from_parent() {
|
|
|
|
parent_->erase(this);
|
|
|
|
for(ir::value* op: ops())
|
|
|
|
op->erase_use(this);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool instruction::has_tile_result_or_op() {
|
2021-04-20 22:29:40 -04:00
|
|
|
bool result = get_type()->is_block_ty();
|
2021-07-27 12:38:38 -07:00
|
|
|
for(unsigned i = 0; i < get_num_operands(); i++)
|
2021-04-20 22:29:40 -04:00
|
|
|
result |= get_operand(i)->get_type()->is_block_ty();
|
2021-07-27 12:38:38 -07:00
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// phi_node classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
|
|
|
|
: instruction(ty, INST_PHI, 0, name, next) {
|
|
|
|
blocks_.reserve(num_reserved);
|
|
|
|
}
|
|
|
|
|
2021-02-21 15:19:39 -08:00
|
|
|
value* phi_node::get_value_for_block(basic_block * block) {
|
|
|
|
auto it = std::find(blocks_.begin(), blocks_.end(), block);
|
|
|
|
size_t n = std::distance(blocks_.begin(), it);
|
|
|
|
return get_incoming_value(n);
|
|
|
|
}
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
// Set incoming value
|
|
|
|
void phi_node::set_incoming_value(unsigned i, value *v){
|
|
|
|
assert(v && "PHI node got a null value!");
|
|
|
|
assert(get_type() == v->get_type() &&
|
|
|
|
"All operands to PHI node must be the same type as the PHI node!");
|
|
|
|
set_operand(i, v);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set incoming block
|
|
|
|
void phi_node::set_incoming_block(unsigned i, basic_block *block){
|
|
|
|
assert(block && "PHI node got a null basic block!");
|
|
|
|
blocks_[i] = block;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add incoming
|
|
|
|
void phi_node::add_incoming(value *v, basic_block *block){
|
2022-06-27 11:49:19 -07:00
|
|
|
assert(v && "PHI node got a null value!!");
|
2021-07-27 12:38:38 -07:00
|
|
|
resize_ops(get_num_operands() + 1);
|
|
|
|
blocks_.resize(get_num_operands() + 1);
|
|
|
|
set_incoming_value(get_num_operands() - 1, v);
|
|
|
|
set_incoming_block(get_num_operands() - 1, block);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Factory methods
|
|
|
|
phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
|
|
|
|
return new phi_node(ty, num_reserved, name, next);
|
|
|
|
}
|
|
|
|
|
2022-04-03 20:58:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// call_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
|
|
|
|
|
|
|
|
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
|
|
|
|
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
|
|
|
|
for(size_t i = 0; i < values.size(); i++)
|
|
|
|
set_operand(i, values.at(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
|
|
|
|
return new call_inst(fn, values, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// launch
|
|
|
|
|
|
|
|
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
|
|
|
|
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
|
|
|
|
int k = 0;
|
|
|
|
if(grid.size() != 3)
|
|
|
|
throw std::runtime_error("grid must have 3 elements");
|
|
|
|
set_operand(k++, fn);
|
|
|
|
val_begin = k;
|
|
|
|
for(ir::value* v: values)
|
|
|
|
set_operand(k++, v);
|
|
|
|
val_end = k;
|
|
|
|
grid_begin = k;
|
|
|
|
for(ir::value* g: grid)
|
|
|
|
set_operand(k++, g);
|
|
|
|
grid_end = k;
|
|
|
|
set_operand(k++, num_warps);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ir::function* launch_inst::get_fn() {
|
|
|
|
return (ir::function*)get_operand(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<ir::value*> launch_inst::get_values() {
|
|
|
|
std::vector<ir::value*> ret;
|
|
|
|
for(int i = val_begin; i < val_end; i++)
|
|
|
|
ret.push_back(get_operand(i));
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<ir::value*> launch_inst::get_grid() {
|
|
|
|
std::vector<ir::value*> ret;
|
|
|
|
for(int i = grid_begin; i < grid_end; i++)
|
|
|
|
ret.push_back(get_operand(i));
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
ir::value* launch_inst::get_num_warps() {
|
|
|
|
return get_operand(grid_end);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
|
|
|
|
return new launch_inst(fn, values, grid, num_warps, name, next);
|
|
|
|
}
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// binary_operator classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
std::string binary_operator::repr_impl() const {
|
|
|
|
switch(op_) {
|
|
|
|
case Add : return "add";
|
|
|
|
case FAdd : return "fadd";
|
|
|
|
case Sub : return "sub";
|
|
|
|
case FSub : return "fsub";
|
|
|
|
case Mul : return "mul";
|
|
|
|
case FMul : return "fmul";
|
|
|
|
case UDiv : return "udiv";
|
|
|
|
case SDiv : return "sdiv";
|
|
|
|
case FDiv : return "fdiv";
|
|
|
|
case URem : return "urem";
|
|
|
|
case SRem : return "srem";
|
|
|
|
case FRem : return "frem";
|
|
|
|
case Shl : return "shl";
|
|
|
|
case LShr : return "lshr";
|
|
|
|
case AShr : return "ashr";
|
|
|
|
case And : return "and";
|
|
|
|
case Or : return "or";
|
|
|
|
case Xor : return "xor";
|
|
|
|
default: throw std::runtime_error("unknown binary operator");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_int_div() const {
|
|
|
|
return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_int_rem() const {
|
|
|
|
return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_shl() const {
|
|
|
|
return op_ == binary_op_t::Shl;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_shr() const {
|
|
|
|
return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_int_mult() const {
|
|
|
|
return op_ == binary_op_t::Mul;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool binary_operator::is_int_add_sub() const {
|
|
|
|
return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
2022-01-29 18:29:29 -08:00
|
|
|
: instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, lhs);
|
|
|
|
set_operand(1, rhs);
|
|
|
|
}
|
|
|
|
|
|
|
|
binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
|
|
|
|
assert(lhs->get_type() == rhs->get_type() &&
|
|
|
|
"Cannot create binary operator with two operands of differing type!");
|
|
|
|
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
|
|
|
// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
|
|
|
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
|
|
|
// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
|
|
|
//}
|
|
|
|
|
|
|
|
//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
|
|
|
// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
|
|
|
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
|
|
|
|
// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
|
|
|
//}
|
|
|
|
|
|
|
|
//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
|
|
|
// assert(arg->get_type()->is_integer_ty());
|
|
|
|
// constant *mask = constant::get_all_ones_value(arg->get_type());
|
|
|
|
// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
|
|
|
//}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// cmp_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// cmp_inst
|
|
|
|
std::string cmp_inst::repr_impl() const {
|
|
|
|
switch (pred_) {
|
|
|
|
case FCMP_FALSE : return "false";
|
|
|
|
case FCMP_OEQ : return "fcmp_oeq";
|
|
|
|
case FCMP_OGT : return "fcmp_ogt";
|
|
|
|
case FCMP_OGE : return "fcmp_oge";
|
|
|
|
case FCMP_OLT : return "fcmp_olt";
|
|
|
|
case FCMP_OLE : return "fcmp_ole";
|
|
|
|
case FCMP_ONE : return "fcmp_one";
|
|
|
|
case FCMP_ORD : return "fcmp_ord";
|
|
|
|
case FCMP_UNO : return "fcmp_uno";
|
|
|
|
case FCMP_UEQ : return "fcmp_ueq";
|
|
|
|
case FCMP_UGT : return "fcmp_ugt";
|
|
|
|
case FCMP_UGE : return "fcmp_uge";
|
|
|
|
case FCMP_ULT : return "fcmp_ult";
|
|
|
|
case FCMP_ULE : return "fcmp_ule";
|
|
|
|
case FCMP_UNE : return "fcmp_une";
|
|
|
|
case FCMP_TRUE : return "true";
|
|
|
|
case ICMP_EQ : return "icmp_eq";
|
|
|
|
case ICMP_NE : return "icmp_ne";
|
|
|
|
case ICMP_UGT : return "icmp_ugt";
|
|
|
|
case ICMP_UGE : return "icmp_uge";
|
|
|
|
case ICMP_ULT : return "icmp_ult";
|
|
|
|
case ICMP_ULE : return "icmp_ule";
|
|
|
|
case ICMP_SGT : return "icmp_sgt";
|
|
|
|
case ICMP_SGE : return "icmp_sge";
|
|
|
|
case ICMP_SLT : return "icmp_slt";
|
|
|
|
case ICMP_SLE : return "icmp_sle";
|
|
|
|
default: throw std::runtime_error("unreachable");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
|
|
|
: instruction(ty, id, 2, name, next), pred_(pred) {
|
|
|
|
set_operand(0, lhs);
|
|
|
|
set_operand(1, rhs);
|
|
|
|
}
|
|
|
|
|
|
|
|
type* cmp_inst::make_cmp_result_type(type *ty){
|
|
|
|
type* int1_ty = type::get_int1_ty(ty->get_context());
|
2021-04-20 22:29:40 -04:00
|
|
|
if (block_type* tile_ty = dynamic_cast<block_type*>(ty))
|
|
|
|
return block_type::get_same_shapes(int1_ty, tile_ty);
|
2021-07-27 12:38:38 -07:00
|
|
|
return int1_ty;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
|
|
|
|
return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
|
|
|
|
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// icmp_inst
|
|
|
|
icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
|
|
|
|
value *lhs, value *rhs, const std::string &name, instruction *next)
|
|
|
|
: cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
|
|
|
|
|
|
|
|
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
|
|
|
assert(is_int_predicate(pred));
|
2021-12-23 17:01:17 -08:00
|
|
|
assert(lhs->get_type() == rhs->get_type());
|
2021-07-27 12:38:38 -07:00
|
|
|
type *res_ty = make_cmp_result_type(lhs->get_type());
|
|
|
|
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// fcmp_inst
|
|
|
|
fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
|
|
|
|
value *lhs, value *rhs, const std::string &name, instruction *next)
|
|
|
|
: cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
|
|
|
|
|
|
|
|
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
|
|
|
assert(is_fp_predicate(pred));
|
|
|
|
type *res_ty = make_cmp_result_type(lhs->get_type());
|
|
|
|
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// unary_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
|
|
|
|
: instruction(ty, id, 1, name, next) {
|
|
|
|
set_operand(0, v);
|
|
|
|
}
|
|
|
|
|
2022-10-12 14:14:45 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// dequantize_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next)
|
|
|
|
: instruction(ty, INST_DEQUANTIZE, 3, name, next) {
|
|
|
|
set_operand(0, v);
|
|
|
|
set_operand(1, scale);
|
|
|
|
set_operand(2, shift);
|
|
|
|
}
|
|
|
|
|
|
|
|
dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){
|
|
|
|
return new dequantize_inst(ty, arg, scale, shift, name, next);
|
|
|
|
}
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// cast_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
std::string cast_inst::repr_impl() const {
|
|
|
|
switch (op_){
|
|
|
|
case cast_op_t::Trunc: return "trunc";
|
|
|
|
case cast_op_t::ZExt: return "zext";
|
|
|
|
case cast_op_t::SExt: return "sext";
|
|
|
|
case cast_op_t::FPTrunc: return "fp_trunc";
|
|
|
|
case cast_op_t::FPExt: return "fp_ext";
|
|
|
|
case cast_op_t::UIToFP: return "ui_to_fp";
|
|
|
|
case cast_op_t::SIToFP: return "si_to_fp";
|
|
|
|
case cast_op_t::FPToUI: return "fp_to_ui";
|
|
|
|
case cast_op_t::FPToSI: return "fp_to_si";
|
|
|
|
case cast_op_t::PtrToInt: return "ptr_to_int";
|
|
|
|
case cast_op_t::IntToPtr: return "int_to_ptr";
|
|
|
|
case cast_op_t::BitCast: return "bitcast";
|
|
|
|
case cast_op_t::AddrSpaceCast: return "addr_space_cast";
|
|
|
|
default: throw std::runtime_error("unreachable");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// TODO
|
|
|
|
bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
|
2021-04-20 22:29:40 -04:00
|
|
|
assert(arg->get_type()->is_block_ty() == ty->is_block_ty());
|
2021-07-27 12:38:38 -07:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
|
|
|
|
assert(is_valid(op, arg, ty) && "Invalid cast!");
|
|
|
|
// Construct and return the appropriate CastInst subclass
|
|
|
|
switch (op) {
|
|
|
|
case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
|
|
|
|
case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
|
|
|
|
default: throw std::runtime_error("unreachable");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
|
|
|
|
type *arg_ty = arg->get_type();
|
|
|
|
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
|
|
|
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
|
|
|
|
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
|
|
|
|
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
|
2022-04-06 16:13:53 -07:00
|
|
|
(arg_bits > dst_bits ? cast_op_t::Trunc :
|
|
|
|
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
|
2021-07-27 12:38:38 -07:00
|
|
|
return create(op, arg, ty, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// terminator_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
|
|
// return_inst
|
|
|
|
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
2022-04-03 20:58:16 -07:00
|
|
|
: terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
2021-07-27 12:38:38 -07:00
|
|
|
if(ret_val)
|
|
|
|
set_operand(0, ret_val);
|
|
|
|
}
|
|
|
|
|
|
|
|
return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){
|
|
|
|
return new return_inst(ctx, ret_val, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// branch_inst
|
|
|
|
branch_inst* branch_inst::create(basic_block *dst, instruction *next) {
|
|
|
|
assert(dst && "Branch destination may not be null!");
|
|
|
|
return new uncond_branch_inst(dst, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) {
|
|
|
|
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
|
|
|
return new cond_branch_inst(if_dst, else_dst, cond, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// uncond_branch_inst
|
|
|
|
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
|
|
|
|
: branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
|
|
|
|
set_operand(0, dst);
|
|
|
|
}
|
|
|
|
|
|
|
|
// cond_branch_inst
|
|
|
|
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
|
|
|
|
: branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
|
|
|
|
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
|
|
|
set_operand(0, if_dst);
|
|
|
|
set_operand(1, else_dst);
|
|
|
|
set_operand(2, cond);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// getelementptr_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
|
|
|
|
: instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
|
|
|
|
source_elt_ty(pointee_ty),
|
|
|
|
res_elt_ty(get_indexed_type(pointee_ty, idx)){
|
|
|
|
// sanity check
|
|
|
|
type *expected_ty = get_type()->get_scalar_ty();
|
|
|
|
expected_ty = ((pointer_type*)expected_ty)->get_element_ty();
|
|
|
|
assert(res_elt_ty == expected_ty);
|
|
|
|
// set operands
|
|
|
|
set_operand(0, ptr);
|
|
|
|
for(size_t i = 0; i < idx.size(); i++)
|
|
|
|
set_operand(1 + i, idx[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
|
|
|
|
// result pointer type
|
|
|
|
type *ty = x->get_type();
|
|
|
|
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
|
|
|
|
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
|
|
|
|
// Tile GEP
|
2021-04-20 22:29:40 -04:00
|
|
|
if(ty->is_block_ty())
|
|
|
|
return block_type::get_same_shapes(ptr_ty, ty);
|
2021-07-27 12:38:38 -07:00
|
|
|
for(value *idx : idx_list)
|
2021-04-20 22:29:40 -04:00
|
|
|
if (idx->get_type()->is_block_ty())
|
|
|
|
return block_type::get_same_shapes(ptr_ty, ty);
|
2021-07-27 12:38:38 -07:00
|
|
|
// Scalar GEP
|
|
|
|
return ptr_ty;
|
|
|
|
}
|
|
|
|
|
|
|
|
type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector<value *> &idx_list) {
|
|
|
|
if(idx_list.empty())
|
|
|
|
return ty;
|
|
|
|
if(!ty->is_sized())
|
|
|
|
return nullptr;
|
|
|
|
unsigned cur_idx = 1;
|
|
|
|
for(; cur_idx != idx_list.size(); cur_idx++){
|
|
|
|
composite_type *cty = dynamic_cast<composite_type*>(ty);
|
|
|
|
if(!cty || cty->is_pointer_ty())
|
|
|
|
break;
|
|
|
|
value *idx = idx_list[cur_idx];
|
|
|
|
if(!cty->index_valid(idx))
|
|
|
|
break;
|
|
|
|
ty = cty->get_type_at_index(idx);
|
|
|
|
}
|
|
|
|
return (cur_idx == idx_list.size())? ty : nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
type *getelementptr_inst::get_indexed_type(type *ty, const std::vector<value *> &idx_list) {
|
|
|
|
type *result = get_indexed_type_impl(ty, idx_list);
|
|
|
|
assert(result && "invalid GEP type!");
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next) {
|
|
|
|
type *pointee_ty = ((pointer_type*)(ptr->get_type()->get_scalar_ty()))->get_element_ty();
|
|
|
|
return new getelementptr_inst(pointee_ty, ptr, idx, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// load_inst/store_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// io_inst
|
2022-06-27 11:49:19 -07:00
|
|
|
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
|
|
|
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
|
2021-07-27 12:38:38 -07:00
|
|
|
{ }
|
|
|
|
|
|
|
|
// load_inst
|
2022-02-24 14:56:24 -08:00
|
|
|
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
2022-06-27 11:49:19 -07:00
|
|
|
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
|
2021-07-27 12:38:38 -07:00
|
|
|
{ }
|
|
|
|
|
|
|
|
// load
|
|
|
|
type *load_inst::get_pointee_type(type *ty) {
|
|
|
|
type *scalar_ty = ty->get_scalar_ty();
|
|
|
|
type *pointee_ty = scalar_ty->get_pointer_element_ty();
|
2021-04-20 22:29:40 -04:00
|
|
|
if(ty->is_block_ty())
|
|
|
|
return block_type::get_same_shapes(pointee_ty, ty);
|
2021-07-27 12:38:38 -07:00
|
|
|
return pointee_ty;
|
|
|
|
}
|
|
|
|
|
|
|
|
// unmasked_load
|
2022-02-24 14:56:24 -08:00
|
|
|
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
|
|
|
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, ptr);
|
|
|
|
}
|
|
|
|
|
2022-02-24 14:56:24 -08:00
|
|
|
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
|
|
|
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// masked load
|
2022-02-24 14:56:24 -08:00
|
|
|
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
|
|
|
bool is_volatile,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next)
|
2022-02-24 14:56:24 -08:00
|
|
|
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, mask);
|
|
|
|
set_operand(2, false_value);
|
|
|
|
}
|
|
|
|
|
|
|
|
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
2022-02-24 14:56:24 -08:00
|
|
|
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
|
|
|
bool is_volatile,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-02-24 14:56:24 -08:00
|
|
|
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2021-01-11 19:20:34 -05:00
|
|
|
// masked load async
|
|
|
|
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
2022-02-24 14:56:24 -08:00
|
|
|
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
2021-01-11 19:20:34 -05:00
|
|
|
const std::string &name, instruction *next)
|
2022-02-24 14:56:24 -08:00
|
|
|
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
2021-01-11 19:20:34 -05:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, mask);
|
|
|
|
set_operand(2, false_value);
|
|
|
|
}
|
|
|
|
|
|
|
|
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
2022-02-24 14:56:24 -08:00
|
|
|
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
2021-01-11 19:20:34 -05:00
|
|
|
const std::string &name, instruction *next) {
|
2022-02-24 14:56:24 -08:00
|
|
|
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
2021-01-11 19:20:34 -05:00
|
|
|
}
|
|
|
|
|
2020-11-19 18:19:55 -05:00
|
|
|
// store
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2022-06-27 11:49:19 -07:00
|
|
|
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
|
|
|
|
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
|
2021-07-27 12:38:38 -07:00
|
|
|
{ }
|
|
|
|
|
|
|
|
// unmasked_store
|
2022-06-27 11:49:19 -07:00
|
|
|
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next)
|
2022-06-27 11:49:19 -07:00
|
|
|
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, val);
|
|
|
|
}
|
|
|
|
|
2022-06-27 11:49:19 -07:00
|
|
|
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-06-27 11:49:19 -07:00
|
|
|
return new unmasked_store_inst(ptr, val, eviction, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// masked store
|
2022-06-27 11:49:19 -07:00
|
|
|
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next)
|
2022-06-27 11:49:19 -07:00
|
|
|
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, val);
|
|
|
|
set_operand(2, mask);
|
|
|
|
}
|
|
|
|
|
2022-10-12 14:14:45 -07:00
|
|
|
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
|
2022-06-27 11:49:19 -07:00
|
|
|
const std::string &name, instruction *next) {
|
|
|
|
return new masked_store_inst(ptr, val, mask, eviction, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
2022-04-03 20:58:16 -07:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// struct classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// insert value
|
|
|
|
|
|
|
|
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
|
|
|
|
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
|
|
|
|
set_operand(0, val);
|
|
|
|
set_operand(1, elt);
|
|
|
|
}
|
|
|
|
|
|
|
|
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
|
|
|
|
return new insert_value_inst(val, elt, idx, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// extract value
|
|
|
|
|
|
|
|
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
|
|
|
|
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
|
|
|
|
set_operand(0, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
|
|
|
|
return new extract_value_inst(val, idx, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// retile_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-10-24 02:30:46 -07:00
|
|
|
// cat
|
|
|
|
|
|
|
|
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
|
|
|
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
|
|
|
{x->get_type()->get_block_shapes()[0] +
|
|
|
|
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
|
|
|
set_operand(0, x);
|
|
|
|
set_operand(1, y);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
|
|
|
return new cat_inst(lhs, rhs, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// retile
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next)
|
2021-04-20 22:29:40 -04:00
|
|
|
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
|
2021-10-24 02:30:46 -07:00
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
// reshape
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
|
|
|
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// splat
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
|
|
|
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// broadcast
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
|
|
|
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// downcast
|
|
|
|
|
|
|
|
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
|
|
|
|
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
|
|
|
}
|
|
|
|
|
2022-04-03 20:58:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// matmul_inst classes
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-01-12 02:20:31 +08:00
|
|
|
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next)
|
2022-06-27 11:49:19 -07:00
|
|
|
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, A);
|
|
|
|
set_operand(1, B);
|
|
|
|
set_operand(2, C);
|
2022-01-12 02:20:31 +08:00
|
|
|
allow_tf32_ = allow_tf32;
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
instruction *dot_inst::create(value *A, value *B, value *C,
|
2022-01-12 02:20:31 +08:00
|
|
|
bool AT, bool BT, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
|
|
|
TransT OPA = AT ? Trans : NoTrans;
|
|
|
|
TransT OPB = BT ? Trans : NoTrans;
|
2022-01-12 02:20:31 +08:00
|
|
|
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2022-01-12 02:20:31 +08:00
|
|
|
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-01-12 02:20:31 +08:00
|
|
|
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2022-01-12 02:20:31 +08:00
|
|
|
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-01-12 02:20:31 +08:00
|
|
|
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2022-01-12 02:20:31 +08:00
|
|
|
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-01-12 02:20:31 +08:00
|
|
|
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2022-01-12 02:20:31 +08:00
|
|
|
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
|
2021-07-27 12:38:38 -07:00
|
|
|
const std::string &name, instruction *next) {
|
2022-01-12 02:20:31 +08:00
|
|
|
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// trans instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
|
|
|
|
// get argument shapes
|
2021-04-20 22:29:40 -04:00
|
|
|
ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes();
|
2021-07-27 12:38:38 -07:00
|
|
|
// permutate argument shapes
|
|
|
|
perm = init_perm(ty, perm);
|
2021-04-20 22:29:40 -04:00
|
|
|
ir::block_type::block_shapes_t res_shapes = arg_shapes;
|
2021-07-27 12:38:38 -07:00
|
|
|
for(size_t i = 0; i < perm.size(); i++)
|
|
|
|
res_shapes[i] = arg_shapes[perm[i]];
|
|
|
|
// construct type
|
2021-04-20 22:29:40 -04:00
|
|
|
return block_type::get(ty->get_scalar_ty(), res_shapes);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
|
|
|
|
if(!perm.empty())
|
|
|
|
return perm;
|
2021-04-20 22:29:40 -04:00
|
|
|
auto size = ty->get_block_shapes().size();
|
2021-07-27 12:38:38 -07:00
|
|
|
std::vector<int> result;
|
|
|
|
result.push_back(size - 1);
|
|
|
|
for(size_t i = 0; i < size - 1; i++)
|
|
|
|
result.push_back(i);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
|
|
|
// sanity check
|
|
|
|
perm_ = init_perm(arg->get_type(), perm);
|
|
|
|
//auto size = arg->get_type()->get_tile_shapes().size();
|
|
|
|
//assert(perm_.size() == size);
|
|
|
|
set_operand(0, arg);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
|
|
|
|
return new trans_inst(arg, perm, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
const std::vector<int> trans_inst::get_perm() const {
|
|
|
|
return perm_;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// sqrt instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
|
|
|
|
set_operand(0, arg);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
|
|
|
|
return new sqrt_inst(arg, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// reduce instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
std::string reduce_inst::to_str(op_t op) {
|
|
|
|
switch (op) {
|
|
|
|
case ADD: return "+";
|
|
|
|
case SUB: return "-";
|
|
|
|
case MAX: return "imax";
|
|
|
|
case MIN: return "imin";
|
|
|
|
case FADD: return "+";
|
|
|
|
case FSUB: return "-";
|
|
|
|
case FMAX: return "fmax";
|
|
|
|
case FMIN: return "fmin";
|
|
|
|
default: break;
|
|
|
|
}
|
|
|
|
assert(false);
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
|
|
|
|
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
2021-04-20 22:29:40 -04:00
|
|
|
ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes();
|
2021-07-27 12:38:38 -07:00
|
|
|
shapes.erase(shapes.begin() + axis);
|
|
|
|
type *scalar_ty = arg->get_type()->get_scalar_ty();
|
|
|
|
if(shapes.empty())
|
|
|
|
// shapes.push_back(1);
|
|
|
|
return scalar_ty;
|
2021-04-20 22:29:40 -04:00
|
|
|
return block_type::get(scalar_ty, shapes);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
|
|
|
|
op_(op),
|
|
|
|
axis_(axis){
|
|
|
|
set_operand(0, arg);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
|
|
|
|
return new reduce_inst(arg, op, axis, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// select instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
|
|
|
|
set_operand(0, pred);
|
|
|
|
set_operand(1, if_value);
|
|
|
|
set_operand(2, else_value);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
|
|
|
|
return new select_inst(pred, if_value, else_value, name, next);
|
|
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// builtin instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
|
|
// get_program_id
|
|
|
|
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
|
|
|
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// get_num_program
|
2021-04-20 22:29:40 -04:00
|
|
|
get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
2021-07-27 12:38:38 -07:00
|
|
|
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
|
|
|
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
2021-05-25 18:31:48 -04:00
|
|
|
// atomic_rmw
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2021-05-25 18:31:48 -04:00
|
|
|
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
|
|
|
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
2021-05-19 21:40:41 -04:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, val);
|
|
|
|
set_operand(2, msk);
|
|
|
|
}
|
|
|
|
|
2021-05-25 18:31:48 -04:00
|
|
|
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
|
|
|
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
2021-05-19 21:40:41 -04:00
|
|
|
}
|
2021-05-25 18:31:48 -04:00
|
|
|
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
// atomic cas
|
|
|
|
|
|
|
|
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
2021-05-19 21:40:41 -04:00
|
|
|
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
2021-07-27 12:38:38 -07:00
|
|
|
set_operand(0, ptr);
|
|
|
|
set_operand(1, cmp);
|
|
|
|
set_operand(2, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
|
|
|
|
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-10-24 02:30:46 -07:00
|
|
|
// umulhi
|
|
|
|
|
|
|
|
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
|
|
|
set_operand(0, lhs);
|
|
|
|
set_operand(1, rhs);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
|
|
|
return new umulhi_inst(lhs, rhs, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-03-31 18:55:31 -04:00
|
|
|
// exp
|
|
|
|
|
|
|
|
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(val->get_type(), INST_EXP, 1, name, next) {
|
|
|
|
set_operand(0, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* exp_inst::create(value *val, const std::string& name, instruction *next) {
|
|
|
|
return new exp_inst(val, name, next);
|
|
|
|
}
|
|
|
|
|
2021-07-14 17:16:48 -07:00
|
|
|
// cos
|
|
|
|
cos_inst::cos_inst(value *val, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(val->get_type(), INST_COS, 1, name, next) {
|
|
|
|
set_operand(0, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* cos_inst::create(value *val, const std::string& name, instruction *next) {
|
|
|
|
return new cos_inst(val, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// sin
|
|
|
|
sin_inst::sin_inst(value *val, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(val->get_type(), INST_SIN, 1, name, next) {
|
|
|
|
set_operand(0, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* sin_inst::create(value *val, const std::string& name, instruction *next) {
|
|
|
|
return new sin_inst(val, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-11-03 15:50:11 -05:00
|
|
|
// log
|
|
|
|
|
|
|
|
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
|
|
|
: builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
|
|
|
|
set_operand(0, val);
|
|
|
|
}
|
|
|
|
|
|
|
|
instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
|
|
|
|
return new log_inst(val, name, next);
|
|
|
|
}
|
|
|
|
|
2020-03-31 18:55:31 -04:00
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// intrinsic instructions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-08-30 11:50:35 -07:00
|
|
|
// cvt_scanline
|
|
|
|
cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) {
|
|
|
|
return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next);
|
|
|
|
}
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
// copy to shared
|
|
|
|
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
|
|
|
instruction *next) {
|
|
|
|
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// copy from shared
|
|
|
|
copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
|
|
|
|
instruction *next) {
|
|
|
|
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
|
|
|
|
}
|
|
|
|
|
|
|
|
// barrier
|
2022-03-28 16:15:43 -07:00
|
|
|
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
|
2021-07-27 12:38:38 -07:00
|
|
|
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
|
|
|
|
|
|
|
|
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
|
|
|
return new barrier_inst(ctx, name, next);
|
|
|
|
}
|
|
|
|
|
2021-02-21 15:19:39 -08:00
|
|
|
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
|
|
|
|
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
|
2021-01-11 19:20:34 -05:00
|
|
|
|
2021-02-21 15:19:39 -08:00
|
|
|
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
|
|
|
|
return new async_wait_inst(ctx, N, name, next);
|
2021-01-11 19:20:34 -05:00
|
|
|
}
|
|
|
|
|
2021-05-13 10:42:18 +08:00
|
|
|
// prefetch_s
|
|
|
|
prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
|
|
|
|
return new prefetch_s_inst(ctx, arg, inc, name, next);
|
|
|
|
}
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2022-03-28 16:15:43 -07:00
|
|
|
// global timer
|
|
|
|
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
|
|
|
|
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2022-03-28 16:15:43 -07:00
|
|
|
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
|
|
|
|
return new globaltimer_inst(ctx, name, next);
|
|
|
|
}
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2022-07-13 15:52:21 -07:00
|
|
|
// extern elementwise
|
|
|
|
extern_elementwise_inst::extern_elementwise_inst(
|
|
|
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
|
|
|
const std::string &lib_name, const std::string &lib_path,
|
2022-10-13 17:18:16 -07:00
|
|
|
const std::string &symbol_name, const std::string &name, instruction *next)
|
|
|
|
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), name, next),
|
2022-07-13 15:52:21 -07:00
|
|
|
lib_name_(lib_name),
|
2022-10-13 17:18:16 -07:00
|
|
|
lib_path_(lib_path),
|
|
|
|
symbol_name_(symbol_name) {
|
2022-07-13 15:52:21 -07:00
|
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
|
|
set_operand(i, args[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
extern_elementwise_inst *extern_elementwise_inst::create(
|
|
|
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
|
|
|
const std::string &lib_name, const std::string &lib_path,
|
2022-10-13 17:18:16 -07:00
|
|
|
const std::string &symbol_name, const std::string &name,
|
|
|
|
instruction *next) {
|
2022-07-13 15:52:21 -07:00
|
|
|
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
|
2022-10-13 17:18:16 -07:00
|
|
|
symbol_name, name, next);
|
2022-07-13 15:52:21 -07:00
|
|
|
}
|
|
|
|
|
2022-03-28 16:15:43 -07:00
|
|
|
// clock
|
|
|
|
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
|
|
|
|
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
|
2021-07-27 12:38:38 -07:00
|
|
|
|
2022-03-28 16:15:43 -07:00
|
|
|
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
|
|
|
|
return new clock_inst(ctx, name, next);
|
|
|
|
}
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
|
|
|
|
// make_range
|
|
|
|
make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
|
|
|
: instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
|
|
|
|
|
|
|
|
make_range *make_range::create(constant_int *first, constant_int *last) {
|
|
|
|
assert(first->get_type()->is_integer_ty());
|
|
|
|
assert(first->get_type() == last->get_type());
|
2021-10-24 02:30:46 -07:00
|
|
|
// assert(((constant_int*)first)->get_value() == 0);
|
2021-10-04 18:39:40 -07:00
|
|
|
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
2021-07-27 12:38:38 -07:00
|
|
|
return new make_range(ty, first, last);
|
|
|
|
}
|
|
|
|
|
|
|
|
const constant_int* make_range::get_first() const {
|
|
|
|
return first_;
|
|
|
|
}
|
|
|
|
|
|
|
|
const constant_int* make_range::get_last() const {
|
|
|
|
return last_;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|