[ir] deleted mask/merge instructions; will be replaced by masked_load/store and select
This commit is contained in:
@@ -48,14 +48,14 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
cublasGemmAlgo_t fastest;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
@@ -109,6 +109,6 @@ int main() {
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -144,6 +144,6 @@ int main() {
|
||||
for(config_t c: configs){
|
||||
std::string repr = c.repr();
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
std::cout << "// " << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -14,9 +14,9 @@ namespace ir {
|
||||
namespace codegen{
|
||||
class tune;
|
||||
|
||||
class optimize_cse {
|
||||
class optimize_dce {
|
||||
public:
|
||||
optimize_cse() {}
|
||||
optimize_dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
@@ -19,7 +19,7 @@ namespace codegen{
|
||||
|
||||
class optimize_trans {
|
||||
private:
|
||||
ir::value *replace_phi(ir::value* value, std::vector<ir::instruction*>& to_delete, ir::builder &builder);
|
||||
ir::value *replace_phi(ir::value* value, ir::builder &builder);
|
||||
|
||||
public:
|
||||
optimize_trans() {}
|
||||
|
@@ -104,19 +104,10 @@ private:
|
||||
};
|
||||
|
||||
|
||||
// Fragmented tile
|
||||
class fragmented_tile: public tile{
|
||||
public:
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
typedef std::map<std::pair<tile*, indices_t>, llvm::BasicBlock*> pmap_t;
|
||||
|
||||
private:
|
||||
// utils
|
||||
@@ -152,8 +143,6 @@ public:
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
pmap_t pmap_;
|
||||
pmap_t last_block_;
|
||||
shmem_allocation *alloc_;
|
||||
tune *params_;
|
||||
target *tgt_;
|
||||
|
@@ -101,7 +101,7 @@ inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
// return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
// return {4, 4, 128, 8, 4, 128, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
|
@@ -50,15 +50,14 @@ public:
|
||||
block_->get_inst_list().insert(insert_point_, inst);
|
||||
inst->set_parent(block_);
|
||||
inst->set_name(name);
|
||||
// for(ir::value* op: inst->ops())
|
||||
// op->add_use(inst);
|
||||
return inst;
|
||||
}
|
||||
// terminator instructions
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Tile-level control flow
|
||||
// value *create_mask(value *pred, const std::string &name = "");
|
||||
// value *create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name = "");
|
||||
// Cast instructions
|
||||
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
@@ -120,6 +119,8 @@ public:
|
||||
// Input/Output
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||
value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = "");
|
||||
// Tile instruction
|
||||
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
|
@@ -21,11 +21,6 @@ class context;
|
||||
class result_reference;
|
||||
class instruction: public user{
|
||||
public:
|
||||
// struct mask_info_t {
|
||||
// value *pred;
|
||||
// value *else_value;
|
||||
// };
|
||||
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
protected:
|
||||
@@ -38,11 +33,6 @@ public:
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// // mask
|
||||
// void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); }
|
||||
// value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); }
|
||||
// void set_mask_else(value *x) { resize_hidden(2); set_operand(get_num_operands() + 1, x); }
|
||||
// value* get_mask_else() const { if(get_num_hidden() < 2) return nullptr; return get_operand(get_num_operands() + 1); }
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
@@ -56,8 +46,6 @@ public:
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
private:
|
||||
basic_block *parent_;
|
||||
// value *pred_;
|
||||
// value *mask_pred_;
|
||||
std::vector<value*> results_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
};
|
||||
@@ -336,35 +324,6 @@ public:
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//// mask
|
||||
//class mask_inst: public instruction {
|
||||
//private:
|
||||
// std::string repr_impl() const { return "mask"; }
|
||||
// mask_inst(ir::value *pred, const std::string &name, instruction *next);
|
||||
|
||||
//public:
|
||||
// static mask_inst* create(ir::value *pred, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
//// merge
|
||||
//class psi_inst: public instruction {
|
||||
//private:
|
||||
// std::string repr_impl() const { return "merge"; }
|
||||
// psi_inst(ir::value *mask_true, ir::value *value_true,
|
||||
// ir::value *mask_false, ir::value *value_false,
|
||||
// const std::string &name, instruction *next);
|
||||
|
||||
//public:
|
||||
// static psi_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
// ir::value *mask_false, ir::value *value_false,
|
||||
// const std::string &name = "", instruction *next = nullptr);
|
||||
// ir::value *get_mask_true() { return get_operand(0); }
|
||||
// ir::value *get_value_true() { return get_operand(1); }
|
||||
// ir::value *get_mask_false() { return get_operand(2); }
|
||||
// ir::value *get_value_false() { return get_operand(3); }
|
||||
|
||||
//};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -399,43 +358,78 @@ private:
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class load_inst: public unary_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "load"; }
|
||||
load_inst(value *ptr, const std::string &name, instruction *next);
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
|
||||
public:
|
||||
// value *get_mask() const;
|
||||
// value *get_false_value() const;
|
||||
};
|
||||
|
||||
class load_inst: public io_inst{
|
||||
protected:
|
||||
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
std::string repr_impl() const { return "load"; }
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_mask() const;
|
||||
value *set_mask(value *mask);
|
||||
// factory method
|
||||
static load_inst* create(value *ptr, const std::string &name = "",
|
||||
static load_inst* create(value *ptr,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
value *mask_;
|
||||
};
|
||||
|
||||
class store_inst: public instruction{
|
||||
class masked_load_inst: public load_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "store"; }
|
||||
store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "masked_load"; }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
value *get_mask() const;
|
||||
value *set_mask(value *mask);
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class store_inst: public io_inst{
|
||||
protected:
|
||||
store_inst(value *ptr, value *v, unsigned num_extra_ops,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
ir::value *mask_;
|
||||
std::string repr_impl() const { return "store"; }
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -507,21 +501,6 @@ protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_global_range_inst: public builtin_inst {
|
||||
private:
|
||||
get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_range_id_inst: public builtin_inst {
|
||||
private:
|
||||
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
|
@@ -71,16 +71,6 @@ private:
|
||||
const constant* size_;
|
||||
};
|
||||
|
||||
class get_global_range_expression: public builtin_expression{
|
||||
public:
|
||||
get_global_range_expression(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const constant* size_;
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class get_range_id_expression: public builtin_expression{
|
||||
public:
|
||||
get_range_id_expression(node *axis): axis_((constant*)axis) { }
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
%token GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -120,8 +120,7 @@ identifier
|
||||
|
||||
/* Built-in */
|
||||
builtin_expression
|
||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
|
||||
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
|
@@ -44,7 +44,6 @@ using triton::lang::return_void;
|
||||
"fp32" { return return_impl(FP32, yytext); }
|
||||
"fp64" { return return_impl(FP64, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||
|
@@ -11,7 +11,7 @@
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
#include "triton/codegen/optimize_dce.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
@@ -63,7 +63,7 @@ public:
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_dce(),
|
||||
optimize_trans(),
|
||||
alignment_info(),
|
||||
reassociate(&tune, &alignment_info),
|
||||
@@ -72,14 +72,11 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
@@ -87,6 +84,8 @@ public:
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
@@ -97,7 +96,7 @@ public:
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_dce optimize_dce;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::alignment_info alignment_info;
|
||||
codegen::reassociate reassociate;
|
||||
|
@@ -109,8 +109,6 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::get_global_range_inst*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(dynamic_cast<ir::constant_range*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
@@ -243,14 +241,6 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(v)){
|
||||
return cache(v->get_type()->get_tile_shapes()[0]->get_value());
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// int value_true = populate_starting_multiple(x->get_value_true());
|
||||
// int value_false = populate_starting_multiple(x->get_value_false());
|
||||
// return cache(gcd(value_true, value_false));
|
||||
// }
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
@@ -313,7 +303,6 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,14 +0,0 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
void optimize_cse::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
}
|
||||
}
|
60
lib/codegen/optimize_dce.cpp
Normal file
60
lib/codegen/optimize_dce.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/codegen/optimize_dce.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
void optimize_dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|
||||
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)){
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mark -- ignore branches
|
||||
while(!work_list.empty()){
|
||||
ir::instruction* current = work_list.back();
|
||||
work_list.pop_back();
|
||||
// mark instruction operands
|
||||
for(ir::value* op: current->ops()) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(op))
|
||||
if(marked.insert(i).second)
|
||||
work_list.push_back(i);
|
||||
}
|
||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||
}
|
||||
|
||||
// sweep -- delete non-branch unmarked instructions
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(marked.find(i) == marked.end())
|
||||
to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -7,20 +7,18 @@ namespace codegen{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
std::vector<ir::instruction*>& to_delete,
|
||||
ir::builder& builder){
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder));
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), builder));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
phi->replace_all_uses_with(result);
|
||||
to_delete.push_back(phi);
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
@@ -39,7 +37,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
|
||||
void optimize_trans::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -56,15 +53,11 @@ void optimize_trans::run(ir::module &mod) {
|
||||
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, to_delete, builder);
|
||||
to_delete.push_back(trans);
|
||||
ir::value* new_phi = replace_phi(op, builder);
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
// erase dead code
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -189,8 +189,6 @@ void reassociate::run(ir::module &mod) {
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
std::map<ir::basic_block*, std::set<ir::value*>> re_ordered;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
@@ -259,11 +257,6 @@ void reassociate::run(ir::module &mod) {
|
||||
params_->copy(new_pz, pz);
|
||||
align_->copy(new_pz, pz);
|
||||
}
|
||||
|
||||
// // reassociate pointer
|
||||
// reassociate_ptr(pz, builder, offsets);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -32,7 +32,7 @@ void distributed_tile::init_indices() {
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = UndefValue::get(ty_);
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[0]++;
|
||||
while(id[k] == axes_[k].values.size()){
|
||||
@@ -57,12 +57,17 @@ distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *v) {
|
||||
values_[idx] = v;
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
return values_[idx];
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
@@ -688,15 +693,15 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
bool is_inserted = tmap_.insert({v, T}).second;
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant_range*>(v)){
|
||||
if(is_inserted && dynamic_cast<ir::constant_range*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
if(dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
@@ -746,31 +751,41 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
Function *fn = block->getParent();
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask();
|
||||
if(mask) {
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptr->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(preds->get_value(idx), mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else {
|
||||
ptr->for_each([&](indices_t idx){
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr->get_value(idx)))
|
||||
if(BinaryOperator *binop = dyn_cast<BinaryOperator>(*gep->idx_begin())){
|
||||
std::cout << isa<Constant>(binop->getOperand(0)) << " " << isa<Constant>(binop->getOperand(1)) << std::endl;
|
||||
}
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
|
||||
// }
|
||||
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx));
|
||||
});
|
||||
}
|
||||
else {
|
||||
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
|
||||
@@ -837,14 +852,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
return;
|
||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||
// global_range
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||
Value *offset = tgt_->get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis());
|
||||
result->for_each([&](indices_t idx){
|
||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
});
|
||||
}
|
||||
// nv_dynamic_range_idx_inst
|
||||
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
|
||||
result->for_each([&](indices_t idx){
|
||||
@@ -855,49 +862,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
// // mask
|
||||
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_then_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
// pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
// pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
// last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
// last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
// });
|
||||
// }
|
||||
// // merge
|
||||
// else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
|
||||
// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
// distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
// distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
// result->for_each([&](indices_t idx){
|
||||
// BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
// Value *value_true = value_tile_true->get_value(idx);
|
||||
// BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
// Value *value_false = value_tile_false->get_value(idx);
|
||||
// BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
// if(block_done->getTerminator())
|
||||
// builder.SetInsertPoint(block_done->getTerminator());
|
||||
// else
|
||||
// builder.SetInsertPoint(block_done);
|
||||
// PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
// phi->addIncoming(value_true, block_true);
|
||||
// phi->addIncoming(value_false,block_false);
|
||||
// result->set_value(idx, phi);
|
||||
// });
|
||||
// }
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
@@ -939,9 +903,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
in->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
Value *in_value = in->get_value(idx);
|
||||
if(linear % vector_size == 0)
|
||||
packets[id] = result->get_value(idx);
|
||||
packets[id] = builder.CreateInsertElement(packets.at(id), in->get_value(idx), linear % vector_size);
|
||||
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
|
||||
packets[id] = builder.CreateInsertElement(packets.at(id), in_value, linear % vector_size);
|
||||
});
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
@@ -1017,8 +982,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::vector<Value *> fc;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
fc.push_back(result->get_value(idx));
|
||||
fc.push_back(TC->get_value(idx));
|
||||
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
|
||||
});
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
@@ -1076,10 +1043,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
|
||||
Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1));
|
||||
Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0));
|
||||
Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1));
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
@@ -1136,24 +1103,106 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
});
|
||||
}
|
||||
}
|
||||
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
|
||||
else if(auto *ld = dynamic_cast<ir::masked_load_inst*>(ins)){
|
||||
// find vector size
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(ld->get_mask_operand());
|
||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(ld->get_false_value_operand());
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0){
|
||||
Value *ptr = TP->get_value(idx);
|
||||
ptr= builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ConstantInt *cst = nullptr;
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1){
|
||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
}
|
||||
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
Value *mask = masks->get_value(idx);
|
||||
BasicBlock *current_bb = builder.GetInsertBlock();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(mask, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
Value *result_then = builder.CreateLoad(ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
Value *result = nullptr;
|
||||
if(false_values){
|
||||
result = builder.CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(vector_size > 1)
|
||||
result_false = builder.CreateVectorSplat(vector_size, result_false);
|
||||
((PHINode*)result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
result = result_then;
|
||||
|
||||
// std::string offset = "";
|
||||
// if(cst)
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
|
||||
// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
|
||||
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_value)
|
||||
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
// Value *result = builder.CreateCall(iasm, {mask, ptr});
|
||||
|
||||
packets[id] = result;
|
||||
}
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder.CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
// Value *res = builder.CreateExtractElement(tmp, (linear % vector_size) % 2);
|
||||
// result->set_value(idx, res);
|
||||
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
else if(auto *ld = dynamic_cast<ir::load_inst*>(ins)){
|
||||
// find vector size
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
// vector loads
|
||||
std::map<unsigned, Value*> packets;
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ConstantInt *cst = nullptr;
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1){
|
||||
cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
}
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
packets[id] = builder.CreateLoad(ptr);
|
||||
}
|
||||
result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
// element-wise
|
||||
|
@@ -106,9 +106,9 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned k = 0; k < v->get_num_results(); k++){
|
||||
ir::value *result = v->get_result(k);
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
for(ir::value* op: v->ops()){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({result, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -123,14 +123,16 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = *pb;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
int1 checka[)" + AS + R"(] = k > TK;
|
||||
int1 checkb[)" + BS + R"(] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||
@@ -138,11 +140,10 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
*pc = c;
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -85,20 +85,6 @@ value *builder::create_ret_void() {
|
||||
return insert(return_inst::create(ctx_));
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile-level control-flow instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//value *builder::create_mask(value *pred, const std::string &name){
|
||||
// return insert(mask_inst::create(pred, name));
|
||||
//}
|
||||
|
||||
//value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
|
||||
// return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name));
|
||||
//}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -264,14 +250,22 @@ DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE)
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *arg, const std::string &name){
|
||||
return insert(load_inst::create(arg, name));
|
||||
value *builder::create_load(value *ptr, const std::string &name){
|
||||
return insert(load_inst::create(ptr, name));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
return insert(store_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, name));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
|
||||
return insert(masked_store_inst::create(ptr, val, mask, name));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -296,10 +290,6 @@ value *builder::create_downcast(value *arg, const std::string &name) {
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) {
|
||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||
}
|
||||
|
||||
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
|
||||
return insert(get_range_id_inst::create(ctx_, axis, name));
|
||||
}
|
||||
|
@@ -270,6 +270,7 @@ std::string cast_inst::repr_impl() const {
|
||||
}
|
||||
// TODO
|
||||
bool cast_inst::is_valid(op_t op, value *arg, type *ty) {
|
||||
assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty());
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -348,34 +349,6 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
|
||||
set_operand(2, cond);
|
||||
}
|
||||
|
||||
// mask_inst
|
||||
//mask_inst::mask_inst(value *pred, const std::string &name, instruction *next)
|
||||
// : instruction(pred->get_type(), 1, 2, name, next) {
|
||||
// set_operand(0, pred);
|
||||
//}
|
||||
|
||||
//mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *next) {
|
||||
// return new mask_inst(pred, name, next);
|
||||
//}
|
||||
|
||||
//// merge_inst
|
||||
//psi_inst::psi_inst(value *mask_true, value *value_true,
|
||||
// value *mask_false, value *value_false,
|
||||
// const std::string &name, instruction *next)
|
||||
// : instruction(value_true->get_type(), 4, 1, name, next) {
|
||||
// set_operand(0, mask_true);
|
||||
// set_operand(1, value_true);
|
||||
// set_operand(2, mask_false);
|
||||
// set_operand(3, value_false);
|
||||
//}
|
||||
|
||||
//psi_inst* psi_inst::create(value *mask_true, value *value_true,
|
||||
// value *mask_false, value *value_false,
|
||||
// const std::string &name, instruction *next) {
|
||||
// return new psi_inst(mask_true, value_true, mask_false, value_false, name, next);
|
||||
//}
|
||||
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
@@ -440,6 +413,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
|
||||
: instruction(ty, num_ops, num_results, name, next)
|
||||
{ }
|
||||
|
||||
// load
|
||||
type *load_inst::get_pointee_type(type *ty) {
|
||||
type *scalar_ty = ty->get_scalar_ty();
|
||||
type *pointee_ty = scalar_ty->get_pointer_element_ty();
|
||||
@@ -448,43 +428,52 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
return pointee_ty;
|
||||
}
|
||||
|
||||
load_inst::load_inst(value *ptr, const std::string &name, instruction *next)
|
||||
: unary_inst(get_pointee_type(ptr->get_type()), ptr, name, next), mask_(nullptr){
|
||||
}
|
||||
|
||||
value *load_inst::get_mask() const {
|
||||
return mask_;
|
||||
}
|
||||
|
||||
value *load_inst::set_mask(value *mask) {
|
||||
mask_ = mask;
|
||||
return this;
|
||||
load_inst::load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), 1 + num_extra_ops, 1, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
return new load_inst(ptr, name, next);
|
||||
return new load_inst(ptr, 0, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, 2, name, next) {
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
|
||||
// store
|
||||
store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, 1, name, next), mask_(nullptr) {
|
||||
store_inst::store_inst(value *ptr, value *val, unsigned num_extra_ops,
|
||||
const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), 2 + num_extra_ops, 1, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, v);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
value *store_inst::get_mask() const {
|
||||
return mask_;
|
||||
store_inst* store_inst::create(value *ptr, value *val,
|
||||
const std::string &name, instruction *next) {
|
||||
return new store_inst(ptr, val, 0, name, next);
|
||||
}
|
||||
|
||||
value *store_inst::set_mask(value *mask) {
|
||||
mask_ = mask;
|
||||
return this;
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, val, 1, name, next) {
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) {
|
||||
return new store_inst(ptr, v, name, next);
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -636,19 +625,6 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
|
||||
// builtin instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// get_global_range
|
||||
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis) {
|
||||
|
||||
}
|
||||
|
||||
instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||
const std::string &name, instruction *next) {
|
||||
type *int_ty = type::get_int32_ty(ctx);
|
||||
type *tile_ty = tile_type::get(int_ty, {size});
|
||||
return new get_global_range_inst(tile_ty, axis, name, next);
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
|
@@ -35,12 +35,6 @@ void print(module &mod, std::ostream& os) {
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: block->get_inst_list()){
|
||||
os << " ";
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(inst))
|
||||
if(ir::value *mask = x->get_mask())
|
||||
os << "@" << get_name(mask, cnt++) << " ";
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(inst))
|
||||
if(ir::value *mask = x->get_mask())
|
||||
os << "@" << get_name(mask, cnt++) << " ";
|
||||
unsigned num_results = inst->get_num_results();
|
||||
for(unsigned i = 0; i < num_results; i++){
|
||||
os << get_name(inst->get_result(i), cnt++);
|
||||
|
@@ -151,7 +151,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
}
|
||||
else if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, ty, value);
|
||||
}
|
||||
value->set_name(name);
|
||||
|
@@ -115,12 +115,6 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const {
|
||||
return res;
|
||||
}
|
||||
|
||||
// get_global_range
|
||||
ir::value* get_global_range_expression::codegen(ir::module *mod) const {
|
||||
ir::builder &builder = mod->get_builder();
|
||||
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
ir::value* get_range_id_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_range_id(axis_->value());
|
||||
@@ -254,39 +248,24 @@ ir::value* cast_expression::codegen(ir::module *mod) const{
|
||||
}
|
||||
|
||||
/* Conditional expression */
|
||||
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
ir::value *conditional_expression::codegen(ir::module *mod) const {
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::value *mask = cond_->codegen(mod);
|
||||
ir::value *true_value = true_value_->codegen(mod);
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
bool is_float, is_ptr, is_int, is_signed;
|
||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, mask, true_value);
|
||||
implicit_broadcast(mod, mask, false_value);
|
||||
if(ir::load_inst* load = dynamic_cast<ir::load_inst*>(true_value)){
|
||||
load->erase_from_parent();
|
||||
return builder.create_masked_load(load->get_pointer_operand(), mask, false_value);
|
||||
}
|
||||
if(ir::load_inst* load = dynamic_cast<ir::load_inst*>(false_value)){
|
||||
load->erase_from_parent();
|
||||
return builder.create_masked_load(load->get_pointer_operand(), mask, true_value);
|
||||
}
|
||||
throw std::runtime_error("not implemented");
|
||||
// ir::builder &builder = mod->get_builder();
|
||||
// ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list();
|
||||
// ir::value *pred = cond_->codegen(mod);
|
||||
// ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||
// /* true value */
|
||||
// ir::value *true_mask = mask->get_result(0);
|
||||
// auto it_true_begin = instructions.end();
|
||||
// it_true_begin--;
|
||||
// ir::value *true_value = true_value_->codegen(mod);
|
||||
// implicit_broadcast(mod, pred, true_value);
|
||||
// it_true_begin++;
|
||||
// auto it_true_end = instructions.end();
|
||||
// for(auto it = it_true_begin; it != it_true_end; it++)
|
||||
//// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
// (*it)->set_mask_pred(true_mask);
|
||||
// /* false value */
|
||||
// ir::value *false_mask = mask->get_result(1);
|
||||
// auto it_false_begin = instructions.end();
|
||||
// it_false_begin--;
|
||||
// ir::value *false_value = false_value_->codegen(mod);
|
||||
// implicit_broadcast(mod, pred, false_value);
|
||||
// bool is_float, is_ptr, is_int, is_signed;
|
||||
// implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
// it_false_begin++;
|
||||
// auto it_false_end = instructions.end();
|
||||
// for(auto it = it_false_begin; it != it_false_end; it++)
|
||||
//// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||
// (*it)->set_mask_pred(false_mask);
|
||||
// /* psi */
|
||||
// ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||
// return result;
|
||||
}
|
||||
|
||||
/* Assignment expression */
|
||||
|
@@ -29,21 +29,35 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
/* Expression statement */
|
||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
if(pred_ == nullptr)
|
||||
return expr;
|
||||
ir::value *pred = pred_->codegen(mod);
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(expr))
|
||||
x->set_mask(pred);
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(expr))
|
||||
x->set_mask(pred);
|
||||
else
|
||||
expr = builder.create_select(pred, expr, ir::undef_value::get(expr->get_type()));
|
||||
// get name if applicable
|
||||
std::string name = "";
|
||||
ir::value *current = nullptr;
|
||||
if(assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_))
|
||||
if(auto *named = dynamic_cast<named_expression*>(assignment)){
|
||||
std::string name = named->lvalue()->id()->name();
|
||||
mod->set_value(name, expr);
|
||||
if(const named_expression* named = dynamic_cast<const named_expression*>(assignment->lvalue())){
|
||||
name = named->id()->name();
|
||||
current = mod->get_value(name);
|
||||
}
|
||||
// lower expression
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
// modify expression if predicated
|
||||
if(pred_) {
|
||||
ir::value *pred = pred_->codegen(mod);
|
||||
if(!current)
|
||||
current = ir::undef_value::get(expr->get_type());
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(expr)){
|
||||
x->erase_from_parent();
|
||||
expr = builder.create_masked_load(x->get_pointer_operand(), pred, current);
|
||||
}
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(expr)){
|
||||
x->erase_from_parent();
|
||||
expr =builder.create_masked_store(x->get_pointer_operand(), x->get_value_operand(), pred);
|
||||
}
|
||||
else
|
||||
expr = builder.create_select(pred, expr, current);
|
||||
}
|
||||
// update symbols table
|
||||
if(!name.empty())
|
||||
mod->set_value(name, expr);
|
||||
return expr;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user