some reassociation
This commit is contained in:
@@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
@@ -98,8 +98,8 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 8192, 8192, 8192},
|
||||
{false, true, 32768, 256, 512}
|
||||
{false, true, 8192, 8192, 8192}
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
};
|
||||
|
35
include/triton/codegen/reassociate.h
Normal file
35
include/triton/codegen/reassociate.h
Normal file
@@ -0,0 +1,35 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class builder;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class reassociate {
|
||||
private:
|
||||
ir::instruction* is_bin_add(ir::value *x);
|
||||
ir::value *reorder_op(ir::value *value, ir::builder &builder, std::vector<ir::instruction*>& to_delete, ir::value *&noncst, ir::value *&cst);
|
||||
|
||||
public:
|
||||
reassociate();
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -99,10 +99,10 @@ inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
// size_t TM = 128;
|
||||
// size_t TN = 128;
|
||||
return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
// return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
// return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -57,8 +57,8 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Tile-level control flow
|
||||
value *create_mask(value *pred, const std::string &name = "");
|
||||
value *create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name = "");
|
||||
// value *create_mask(value *pred, const std::string &name = "");
|
||||
// value *create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name = "");
|
||||
// Cast instructions
|
||||
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
|
@@ -20,10 +20,10 @@ class context;
|
||||
class result_reference;
|
||||
class instruction: public user{
|
||||
public:
|
||||
struct mask_info_t {
|
||||
value *pred;
|
||||
value *else_value;
|
||||
};
|
||||
// struct mask_info_t {
|
||||
// value *pred;
|
||||
// value *else_value;
|
||||
// };
|
||||
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
@@ -37,11 +37,11 @@ public:
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// mask
|
||||
void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); }
|
||||
value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); }
|
||||
void set_mask_else(value *x) { resize_hidden(2); set_operand(get_num_operands() + 1, x); }
|
||||
value* get_mask_else() const { if(get_num_hidden() < 2) return nullptr; return get_operand(get_num_operands() + 1); }
|
||||
// // mask
|
||||
// void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); }
|
||||
// value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); }
|
||||
// void set_mask_else(value *x) { resize_hidden(2); set_operand(get_num_operands() + 1, x); }
|
||||
// value* get_mask_else() const { if(get_num_hidden() < 2) return nullptr; return get_operand(get_num_operands() + 1); }
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
@@ -55,8 +55,8 @@ public:
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
private:
|
||||
basic_block *parent_;
|
||||
value *pred_;
|
||||
value *mask_pred_;
|
||||
// value *pred_;
|
||||
// value *mask_pred_;
|
||||
std::vector<value*> results_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
};
|
||||
@@ -335,34 +335,34 @@ public:
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// mask
|
||||
class mask_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "mask"; }
|
||||
mask_inst(ir::value *pred, const std::string &name, instruction *next);
|
||||
//// mask
|
||||
//class mask_inst: public instruction {
|
||||
//private:
|
||||
// std::string repr_impl() const { return "mask"; }
|
||||
// mask_inst(ir::value *pred, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static mask_inst* create(ir::value *pred, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
//public:
|
||||
// static mask_inst* create(ir::value *pred, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
// merge
|
||||
class psi_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "merge"; }
|
||||
psi_inst(ir::value *mask_true, ir::value *value_true,
|
||||
ir::value *mask_false, ir::value *value_false,
|
||||
const std::string &name, instruction *next);
|
||||
//// merge
|
||||
//class psi_inst: public instruction {
|
||||
//private:
|
||||
// std::string repr_impl() const { return "merge"; }
|
||||
// psi_inst(ir::value *mask_true, ir::value *value_true,
|
||||
// ir::value *mask_false, ir::value *value_false,
|
||||
// const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static psi_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
ir::value *mask_false, ir::value *value_false,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
ir::value *get_mask_true() { return get_operand(0); }
|
||||
ir::value *get_value_true() { return get_operand(1); }
|
||||
ir::value *get_mask_false() { return get_operand(2); }
|
||||
ir::value *get_value_false() { return get_operand(3); }
|
||||
//public:
|
||||
// static psi_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
// ir::value *mask_false, ir::value *value_false,
|
||||
// const std::string &name = "", instruction *next = nullptr);
|
||||
// ir::value *get_mask_true() { return get_operand(0); }
|
||||
// ir::value *get_value_true() { return get_operand(1); }
|
||||
// ir::value *get_mask_false() { return get_operand(2); }
|
||||
// ir::value *get_value_false() { return get_operand(3); }
|
||||
|
||||
};
|
||||
//};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
@@ -408,9 +408,14 @@ private:
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_mask() const;
|
||||
value *set_mask(value *mask);
|
||||
// factory method
|
||||
static load_inst* create(value *ptr, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
value *mask_;
|
||||
};
|
||||
|
||||
class store_inst: public instruction{
|
||||
@@ -421,9 +426,14 @@ private:
|
||||
public:
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
value *get_mask() const;
|
||||
value *set_mask(value *mask);
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
ir::value *mask_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -18,6 +18,7 @@
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/codegen/reassociate.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
@@ -70,10 +71,12 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
ir::print(module, std::cout);
|
||||
reassociate_.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
ir::print(module, std::cout);
|
||||
alignment_info.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
@@ -95,6 +98,7 @@ public:
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::alignment_info alignment_info;
|
||||
codegen::reassociate reassociate_;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
|
@@ -72,11 +72,11 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
cst_info value_true = populate_is_constant(x->get_value_true());
|
||||
cst_info value_false = populate_is_constant(x->get_value_false());
|
||||
return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// cst_info value_true = populate_is_constant(x->get_value_true());
|
||||
// cst_info value_false = populate_is_constant(x->get_value_false());
|
||||
// return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
|
||||
// }
|
||||
if(v->get_type()->is_tile_ty())
|
||||
return cache({0, 0});
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
@@ -144,11 +144,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
int value_true = populate_max_contiguous(x->get_value_true());
|
||||
int value_false = populate_max_contiguous(x->get_value_false());
|
||||
return cache(std::min(value_true, value_false));
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// int value_true = populate_max_contiguous(x->get_value_true());
|
||||
// int value_false = populate_max_contiguous(x->get_value_false());
|
||||
// return cache(std::min(value_true, value_false));
|
||||
// }
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
@@ -240,11 +240,11 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){
|
||||
return cache(v->get_type()->get_tile_shapes()[0]->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
int value_true = populate_starting_multiple(x->get_value_true());
|
||||
int value_false = populate_starting_multiple(x->get_value_false());
|
||||
return cache(gcd(value_true, value_false));
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// int value_true = populate_starting_multiple(x->get_value_true());
|
||||
// int value_false = populate_starting_multiple(x->get_value_false());
|
||||
// return cache(gcd(value_true, value_false));
|
||||
// }
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
|
185
lib/codegen/reassociate.cpp
Normal file
185
lib/codegen/reassociate.cpp
Normal file
@@ -0,0 +1,185 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/reassociate.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
//inline Constant *get_gep_cst_offset(GetElementPtrInst *gep){
|
||||
// std::vector<Value*> idx_vals;
|
||||
// std::transform(gep->idx_begin(), gep->idx_end(),
|
||||
// std::back_inserter(idx_vals),
|
||||
// [](Value* x){ return x;});
|
||||
// if(idx_vals.size() > 1)
|
||||
// return nullptr;
|
||||
// Value *idx = idx_vals[0];
|
||||
// if(isa<Constant>(idx))
|
||||
// return idx;
|
||||
// if(Instruction *BinOp = is_bin_add(idx)){
|
||||
// Value *LHS = BinOp->getOperand(0);
|
||||
// Value *RHS = BinOp->getOperand(1);
|
||||
// if(Constant* Res = dyn_cast<Constant>(LHS))
|
||||
// return Res;
|
||||
// if(Constant* Res = dyn_cast<Constant>(RHS))
|
||||
// return Res;
|
||||
// }
|
||||
// return nullptr;
|
||||
//}
|
||||
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()==llvm::Instruction::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
inline ir::value *reassociate::reorder_op(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
std::vector<ir::instruction*>& to_delete,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
ir::value* new_value = old_value;
|
||||
cst = nullptr;
|
||||
noncst = old_value;
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reorder_op(old_arg, builder, to_delete, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
to_delete.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
// handle binary addition
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reorder_op(op->get_operand (0), builder, to_delete, noncst, cst);
|
||||
ir::value *rhs = reorder_op(op->get_operand(1), builder, to_delete, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
ir::value *llhs = bin_lhs->get_operand(0);
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
if(new_value != op)
|
||||
to_delete.push_back(bin_lhs);
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
if(new_value != op)
|
||||
to_delete.push_back(bin_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
// extract constant and non-constant
|
||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
||||
ir::value *new_lhs = bin_add->get_operand(0);
|
||||
ir::value *new_rhs = bin_add->get_operand(1);
|
||||
if(is_cst(new_lhs)){
|
||||
cst = new_lhs;
|
||||
noncst = new_rhs;
|
||||
}
|
||||
if(is_cst(new_rhs)){
|
||||
cst = new_rhs;
|
||||
noncst = new_lhs;
|
||||
}
|
||||
}
|
||||
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value){
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(old_value))
|
||||
to_delete.push_back(x);
|
||||
}
|
||||
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate() {
|
||||
|
||||
}
|
||||
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
bool done = false;
|
||||
do{
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *gep = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
std::vector<ir::value*> idxs(gep->idx_begin(), gep->idx_end());
|
||||
ir::value *cst = nullptr;
|
||||
ir::value *noncst = idxs[0];
|
||||
reorder_op(noncst, builder, to_delete, noncst, cst);
|
||||
// std::cout << gep->get_name() << " " << noncst << " " << cst << std::endl;
|
||||
}
|
||||
}
|
||||
done = true;
|
||||
}
|
||||
}while(!done);
|
||||
}
|
||||
// erase dead code
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -236,39 +236,6 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::constant to Constant");
|
||||
}
|
||||
|
||||
inline Value *Reassociate(Value *V, IRBuilder<> &Builder){
|
||||
BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
|
||||
if(BinOp)
|
||||
if(BinOp->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LHS = Reassociate(BinOp->getOperand(0), Builder);
|
||||
Value *RHS = Reassociate(BinOp->getOperand(1), Builder);
|
||||
if(BinaryOperator *BinLHS = dyn_cast<BinaryOperator>(LHS))
|
||||
if(BinLHS->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LLHS = BinLHS->getOperand(0);
|
||||
Value *RLHS = BinLHS->getOperand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(isa<Constant>(LLHS))
|
||||
return Builder.CreateAdd(LLHS, Builder.CreateAdd(RLHS, RHS));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(isa<Constant>(RLHS))
|
||||
return Builder.CreateAdd(RLHS, Builder.CreateAdd(LLHS, RHS));
|
||||
}
|
||||
if(BinaryOperator *BinRHS = dyn_cast<BinaryOperator>(RHS))
|
||||
if(BinRHS->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LRHS = BinRHS->getOperand(0);
|
||||
Value *RRHS = BinRHS->getOperand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(isa<Constant>(LRHS))
|
||||
return Builder.CreateAdd(LRHS, Builder.CreateAdd(RRHS, LHS));
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(isa<Constant>(LRHS))
|
||||
return Builder.CreateAdd(RRHS, Builder.CreateAdd(LRHS, LHS));
|
||||
}
|
||||
return BinOp;
|
||||
}
|
||||
return V;
|
||||
}
|
||||
|
||||
/* convert ir::instruction to llvm::Instruction */
|
||||
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) {
|
||||
LLVMContext & ctx = builder.getContext();
|
||||
@@ -320,13 +287,14 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
|
||||
}
|
||||
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
|
||||
// get pointer
|
||||
Value *ptr = value(ii->get_operand(0));
|
||||
// reassociate first index
|
||||
std::vector<Value*> idx_vals;
|
||||
std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals),
|
||||
[&value](ir::value* x){ return value(x);});
|
||||
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
|
||||
idx_vals[0] = Reassociate(idx_vals[0], builder);
|
||||
Value *arg = value(ii->get_operand(0));
|
||||
return builder.Insert(GetElementPtrInst::CreateInBounds(source_ty, arg, idx_vals));
|
||||
return builder.Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals));
|
||||
}
|
||||
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
@@ -612,7 +580,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second || dynamic_cast<ir::mask_inst*>(v))
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
// recurse
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
@@ -767,40 +735,32 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Module *module = block->getModule();
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
Function *fn = block->getParent();
|
||||
ir::value *mask = ins->get_mask_pred();
|
||||
BasicBlock *last_block = nullptr;
|
||||
auto set_mask_insert_pt = [&](indices_t idx){
|
||||
if(mask){
|
||||
distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
|
||||
BasicBlock *block = pmap_.at({mask_tile, idx});
|
||||
builder.SetInsertPoint(block->getTerminator());
|
||||
last_block = last_block_.at({mask_tile, idx});
|
||||
}
|
||||
};
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
distributed_tile *mask_tile;
|
||||
if(mask)
|
||||
mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
Value *ptr_value = ptr->get_value(idx);
|
||||
Value *value_value = value->get_value(idx);
|
||||
Instruction *store;
|
||||
// if(mask){
|
||||
// Value *pred_value = mask_tile->get_value(idx);
|
||||
// value_value = builder.CreateVectorSplat(1, value_value);
|
||||
// pred_value = builder.CreateVectorSplat(1, pred_value);
|
||||
// Type *ptr_ty = PointerType::get(value_value->getType(), ptr_value->getType()->getPointerAddressSpace());
|
||||
// ptr_value = builder.CreateBitCast(ptr_value, ptr_ty);
|
||||
// store = builder.CreateMaskedStore(value_value, ptr_value, 1, pred_value);
|
||||
// }
|
||||
// else
|
||||
store = new StoreInst(value_value, ptr_value);
|
||||
builder.Insert(store);
|
||||
});
|
||||
ir::value *mask = x->get_mask();
|
||||
if(mask) {
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptr->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(preds->get_value(idx), mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else {
|
||||
ptr->for_each([&](indices_t idx){
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr->get_value(idx)))
|
||||
if(BinaryOperator *binop = dyn_cast<BinaryOperator>(*gep->idx_begin())){
|
||||
std::cout << isa<Constant>(binop->getOperand(0)) << " " << isa<Constant>(binop->getOperand(1)) << std::endl;
|
||||
}
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
else {
|
||||
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
|
||||
@@ -875,49 +835,49 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
});
|
||||
}
|
||||
// mask
|
||||
else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
pred->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_else_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
});
|
||||
}
|
||||
// merge
|
||||
else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
result->for_each([&](indices_t idx){
|
||||
BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
Value *value_true = value_tile_true->get_value(idx);
|
||||
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
Value *value_false = value_tile_false->get_value(idx);
|
||||
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
if(block_done->getTerminator())
|
||||
builder.SetInsertPoint(block_done->getTerminator());
|
||||
else
|
||||
builder.SetInsertPoint(block_done);
|
||||
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
phi->addIncoming(value_true, block_true);
|
||||
phi->addIncoming(value_false,block_false);
|
||||
result->set_value(idx, phi);
|
||||
});
|
||||
}
|
||||
// // mask
|
||||
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_then_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
// pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
// pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
// last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
// last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
// });
|
||||
// }
|
||||
// // merge
|
||||
// else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
// distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
// distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
// result->for_each([&](indices_t idx){
|
||||
// BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
// Value *value_true = value_tile_true->get_value(idx);
|
||||
// BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
// Value *value_false = value_tile_false->get_value(idx);
|
||||
// BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
// if(block_done->getTerminator())
|
||||
// builder.SetInsertPoint(block_done->getTerminator());
|
||||
// else
|
||||
// builder.SetInsertPoint(block_done);
|
||||
// PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
// phi->addIncoming(value_true, block_true);
|
||||
// phi->addIncoming(value_false,block_false);
|
||||
// result->set_value(idx, phi);
|
||||
// });
|
||||
// }
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
@@ -934,7 +894,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
||||
result->for_each([&](indices_t idx) {
|
||||
set_mask_insert_pt(idx);
|
||||
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
|
||||
});
|
||||
}
|
||||
@@ -1163,12 +1122,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
// vector_size = result->axis(0).contiguous;
|
||||
// vector_size = 1;
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
result->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0){
|
||||
@@ -1189,20 +1145,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
};
|
||||
set_mask_insert_pt(idx);
|
||||
result->set_value(idx, llvm_inst(ins, value, builder));
|
||||
});
|
||||
}
|
||||
}
|
||||
if(mask){
|
||||
builder.SetInsertPoint(block);
|
||||
if(last_block)
|
||||
builder.SetInsertPoint(last_block);
|
||||
}
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
if(src->has_tile_result_or_op() || (src->get_mask_pred() && src->get_mask_pred()->get_type()->is_tile_ty())) {
|
||||
if(src->has_tile_result_or_op()) {
|
||||
lower_tile_instruction(src, builder);
|
||||
}
|
||||
else {
|
||||
@@ -1310,7 +1260,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::psi_inst*>(i)) && !current->empty();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i)) && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
lower_instruction(i, dst_builder);
|
||||
|
@@ -60,6 +60,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
||||
}
|
||||
else {
|
||||
params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -123,8 +123,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
@@ -132,15 +132,17 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy*TN + (0 ... TN);
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
@checkc *pc = c;
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -90,13 +90,13 @@ value *builder::create_ret_void() {
|
||||
// tile-level control-flow instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_mask(value *pred, const std::string &name){
|
||||
return insert(mask_inst::create(pred, name));
|
||||
}
|
||||
//value *builder::create_mask(value *pred, const std::string &name){
|
||||
// return insert(mask_inst::create(pred, name));
|
||||
//}
|
||||
|
||||
value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
|
||||
return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name));
|
||||
}
|
||||
//value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
|
||||
// return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name));
|
||||
//}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -349,31 +349,31 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
|
||||
}
|
||||
|
||||
// mask_inst
|
||||
mask_inst::mask_inst(value *pred, const std::string &name, instruction *next)
|
||||
: instruction(pred->get_type(), 1, 2, name, next) {
|
||||
set_operand(0, pred);
|
||||
}
|
||||
//mask_inst::mask_inst(value *pred, const std::string &name, instruction *next)
|
||||
// : instruction(pred->get_type(), 1, 2, name, next) {
|
||||
// set_operand(0, pred);
|
||||
//}
|
||||
|
||||
mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *next) {
|
||||
return new mask_inst(pred, name, next);
|
||||
}
|
||||
//mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *next) {
|
||||
// return new mask_inst(pred, name, next);
|
||||
//}
|
||||
|
||||
// merge_inst
|
||||
psi_inst::psi_inst(value *mask_true, value *value_true,
|
||||
value *mask_false, value *value_false,
|
||||
const std::string &name, instruction *next)
|
||||
: instruction(value_true->get_type(), 4, 1, name, next) {
|
||||
set_operand(0, mask_true);
|
||||
set_operand(1, value_true);
|
||||
set_operand(2, mask_false);
|
||||
set_operand(3, value_false);
|
||||
}
|
||||
//// merge_inst
|
||||
//psi_inst::psi_inst(value *mask_true, value *value_true,
|
||||
// value *mask_false, value *value_false,
|
||||
// const std::string &name, instruction *next)
|
||||
// : instruction(value_true->get_type(), 4, 1, name, next) {
|
||||
// set_operand(0, mask_true);
|
||||
// set_operand(1, value_true);
|
||||
// set_operand(2, mask_false);
|
||||
// set_operand(3, value_false);
|
||||
//}
|
||||
|
||||
psi_inst* psi_inst::create(value *mask_true, value *value_true,
|
||||
value *mask_false, value *value_false,
|
||||
const std::string &name, instruction *next) {
|
||||
return new psi_inst(mask_true, value_true, mask_false, value_false, name, next);
|
||||
}
|
||||
//psi_inst* psi_inst::create(value *mask_true, value *value_true,
|
||||
// value *mask_false, value *value_false,
|
||||
// const std::string &name, instruction *next) {
|
||||
// return new psi_inst(mask_true, value_true, mask_false, value_false, name, next);
|
||||
//}
|
||||
|
||||
|
||||
|
||||
@@ -449,7 +449,16 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
}
|
||||
|
||||
load_inst::load_inst(value *ptr, const std::string &name, instruction *next)
|
||||
: unary_inst(get_pointee_type(ptr->get_type()), ptr, name, next) {
|
||||
: unary_inst(get_pointee_type(ptr->get_type()), ptr, name, next), mask_(nullptr){
|
||||
}
|
||||
|
||||
value *load_inst::get_mask() const {
|
||||
return mask_;
|
||||
}
|
||||
|
||||
value *load_inst::set_mask(value *mask) {
|
||||
mask_ = mask;
|
||||
return this;
|
||||
}
|
||||
|
||||
load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
@@ -458,11 +467,20 @@ load_inst* load_inst::create(value *ptr, const std::string &name, instruction *n
|
||||
|
||||
// store
|
||||
store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, 1, name, next) {
|
||||
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, 1, name, next), mask_(nullptr) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, v);
|
||||
}
|
||||
|
||||
value *store_inst::get_mask() const {
|
||||
return mask_;
|
||||
}
|
||||
|
||||
value *store_inst::set_mask(value *mask) {
|
||||
mask_ = mask;
|
||||
return this;
|
||||
}
|
||||
|
||||
store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) {
|
||||
return new store_inst(ptr, v, name, next);
|
||||
}
|
||||
|
@@ -35,8 +35,12 @@ void print(module &mod, std::ostream& os) {
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: block->get_inst_list()){
|
||||
os << " ";
|
||||
if(ir::value *pred = inst->get_mask_pred())
|
||||
os << "@" << get_name(pred, cnt++) << " ";
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(inst))
|
||||
if(ir::value *mask = x->get_mask())
|
||||
os << "@" << get_name(mask, cnt++) << " ";
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(inst))
|
||||
if(ir::value *mask = x->get_mask())
|
||||
os << "@" << get_name(mask, cnt++) << " ";
|
||||
unsigned num_results = inst->get_num_results();
|
||||
for(unsigned i = 0; i < num_results; i++){
|
||||
os << get_name(inst->get_result(i), cnt++);
|
||||
|
@@ -255,37 +255,38 @@ ir::value* cast_expression::codegen(ir::module *mod) const{
|
||||
|
||||
/* Conditional expression */
|
||||
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
|
||||
ir::value *pred = cond_->codegen(mod);
|
||||
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||
/* true value */
|
||||
ir::value *true_mask = mask->get_result(0);
|
||||
auto it_true_begin = instructions.end();
|
||||
it_true_begin--;
|
||||
ir::value *true_value = true_value_->codegen(mod);
|
||||
implicit_broadcast(mod, pred, true_value);
|
||||
it_true_begin++;
|
||||
auto it_true_end = instructions.end();
|
||||
for(auto it = it_true_begin; it != it_true_end; it++)
|
||||
// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
(*it)->set_mask_pred(true_mask);
|
||||
/* false value */
|
||||
ir::value *false_mask = mask->get_result(1);
|
||||
auto it_false_begin = instructions.end();
|
||||
it_false_begin--;
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
implicit_broadcast(mod, pred, false_value);
|
||||
bool is_float, is_ptr, is_int, is_signed;
|
||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
it_false_begin++;
|
||||
auto it_false_end = instructions.end();
|
||||
for(auto it = it_false_begin; it != it_false_end; it++)
|
||||
// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
(*it)->set_mask_pred(false_mask);
|
||||
/* psi */
|
||||
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||
return result;
|
||||
throw std::runtime_error("not implemented");
|
||||
// ir::builder &builder = mod->get_builder();
|
||||
// ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
|
||||
// ir::value *pred = cond_->codegen(mod);
|
||||
// ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||
// /* true value */
|
||||
// ir::value *true_mask = mask->get_result(0);
|
||||
// auto it_true_begin = instructions.end();
|
||||
// it_true_begin--;
|
||||
// ir::value *true_value = true_value_->codegen(mod);
|
||||
// implicit_broadcast(mod, pred, true_value);
|
||||
// it_true_begin++;
|
||||
// auto it_true_end = instructions.end();
|
||||
// for(auto it = it_true_begin; it != it_true_end; it++)
|
||||
//// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
// (*it)->set_mask_pred(true_mask);
|
||||
// /* false value */
|
||||
// ir::value *false_mask = mask->get_result(1);
|
||||
// auto it_false_begin = instructions.end();
|
||||
// it_false_begin--;
|
||||
// ir::value *false_value = false_value_->codegen(mod);
|
||||
// implicit_broadcast(mod, pred, false_value);
|
||||
// bool is_float, is_ptr, is_int, is_signed;
|
||||
// implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
// it_false_begin++;
|
||||
// auto it_false_end = instructions.end();
|
||||
// for(auto it = it_false_begin; it != it_false_end; it++)
|
||||
//// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
// (*it)->set_mask_pred(false_mask);
|
||||
// /* psi */
|
||||
// ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||
// return result;
|
||||
}
|
||||
|
||||
/* Assignment expression */
|
||||
|
@@ -29,34 +29,22 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
/* Expression statement */
|
||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::basic_block *block = builder.get_insert_block();
|
||||
if(pred_) {
|
||||
// generate mask
|
||||
ir::value *pred = pred_->codegen(mod);
|
||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||
// generate expression
|
||||
unsigned szbegin = block->get_inst_list().size();
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
ir::basic_block::iterator begin = block->begin();
|
||||
std::advance(begin, szbegin);
|
||||
// set mask
|
||||
ir::type *ty = expr->get_type();
|
||||
for(auto it = begin; it != builder.get_insert_point(); it++)
|
||||
(*it)->set_mask_pred(mask->get_result(0));
|
||||
// if(auto *itn = dynamic_cast<ir::instruction*>(expr))
|
||||
// itn->set_mask_pred(mask->get_result(0));
|
||||
if(ty->is_void_ty())
|
||||
return expr;
|
||||
// merge with psi
|
||||
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
||||
mask->get_result(1), ir::undef_value::get(ty));
|
||||
if(assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_)){
|
||||
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
||||
mod->set_value(name, psi);
|
||||
}
|
||||
return psi;
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
if(pred_ == nullptr)
|
||||
return expr;
|
||||
ir::value *pred = pred_->codegen(mod);
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(expr))
|
||||
x->set_mask(pred);
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(expr))
|
||||
x->set_mask(pred);
|
||||
else
|
||||
expr = builder.create_select(pred, expr, ir::undef_value::get(expr->get_type()));
|
||||
if(assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_))
|
||||
if(auto *named = dynamic_cast<named_expression*>(assignment)){
|
||||
std::string name = named->lvalue()->id()->name();
|
||||
mod->set_value(name, expr);
|
||||
}
|
||||
return expr_->codegen(mod);
|
||||
return expr;
|
||||
}
|
||||
|
||||
/* For statement */
|
||||
|
Reference in New Issue
Block a user