[codegen] adding visitor
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
|
||||
@@ -24,6 +25,7 @@ namespace llvm{
|
||||
class Function;
|
||||
}
|
||||
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -145,6 +147,82 @@ private:
|
||||
};
|
||||
|
||||
|
||||
class generator: public ir::visitor {
|
||||
private:
|
||||
Type *type(ir::type *ty);
|
||||
|
||||
private:
|
||||
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
|
||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
||||
void get_value(ir::value *x, const indices_t& idx);
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
|
||||
public:
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
|
||||
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
|
||||
void visit_retile_inst(ir::retile_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Function *fn_;
|
||||
Builder *builder_;
|
||||
|
||||
std::map<ir::value *, Value *> vmap_;
|
||||
std::map<ir::value *, tile *> tmap_;
|
||||
target *tgt_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *sh_mem_ptr_;
|
||||
Value *offset_a_i_, *offset_a_k_;
|
||||
Value *offset_b_j_, *offset_b_k_;
|
||||
unsigned num_packs_0_, num_packs_1_;
|
||||
unsigned pack_size_0_, pack_size_1_;
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
@@ -178,7 +256,7 @@ private:
|
||||
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||
|
||||
// lower scalar instruction
|
||||
void lower_instruction(ir::instruction *src, Builder &builder);
|
||||
void lower_value(ir::value *src, Builder &builder, std::set<ir::value*>& seen);
|
||||
// lower tile instruction
|
||||
void lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
|
@@ -10,10 +10,13 @@
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
|
||||
#define _TRITON_DEFINE_CLONE(name) \
|
||||
ir::instruction* clone_impl() const { return new name(*this); }
|
||||
|
||||
#define _TRITON_DEFINE_ACCEPT(name) \
|
||||
void accept(visitor* v) { v->visit_ ## name (this); }
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -23,6 +26,7 @@ class constant;
|
||||
class make_range;
|
||||
class basic_block;
|
||||
class context;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
@@ -99,6 +103,7 @@ public:
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(phi_node)
|
||||
_TRITON_DEFINE_ACCEPT(phi_node)
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
@@ -148,6 +153,7 @@ public:
|
||||
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
_TRITON_DEFINE_ACCEPT(binary_operator)
|
||||
|
||||
public:
|
||||
binary_op_t op_;
|
||||
@@ -189,6 +195,7 @@ public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(icmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(icmp_inst)
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst {
|
||||
@@ -199,6 +206,7 @@ public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(fcmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(fcmp_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -236,13 +244,15 @@ public:
|
||||
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_ACCEPT(cast_inst)
|
||||
|
||||
private:
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name); \
|
||||
_TRITON_DEFINE_CLONE(name) \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, id, v, name, next, op){ } \
|
||||
@@ -287,6 +297,7 @@ public:
|
||||
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(return_inst)
|
||||
_TRITON_DEFINE_ACCEPT(return_inst)
|
||||
};
|
||||
|
||||
// base branch instruction
|
||||
@@ -315,6 +326,7 @@ public:
|
||||
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
|
||||
value *get_cond() { return get_operand(2); }
|
||||
_TRITON_DEFINE_CLONE(cond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
|
||||
};
|
||||
|
||||
// unconditional branch
|
||||
@@ -326,6 +338,7 @@ private:
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
_TRITON_DEFINE_CLONE(uncond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
|
||||
};
|
||||
|
||||
|
||||
@@ -354,6 +367,7 @@ public:
|
||||
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(getelementptr_inst)
|
||||
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
|
||||
|
||||
private:
|
||||
type *source_elt_ty;
|
||||
@@ -395,6 +409,7 @@ public:
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
|
||||
};
|
||||
|
||||
// masked load
|
||||
@@ -413,6 +428,7 @@ public:
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||
};
|
||||
|
||||
// store
|
||||
@@ -437,6 +453,7 @@ public:
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
@@ -453,6 +470,7 @@ public:
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -477,6 +495,7 @@ public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reshape_inst)
|
||||
};
|
||||
|
||||
// splat
|
||||
@@ -490,6 +509,7 @@ public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(splat_inst)
|
||||
};
|
||||
|
||||
// broadcast
|
||||
@@ -503,6 +523,7 @@ public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(broadcast_inst)
|
||||
};
|
||||
|
||||
|
||||
@@ -516,6 +537,7 @@ private:
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(downcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(downcast_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -536,6 +558,7 @@ public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_program_id_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
@@ -550,6 +573,7 @@ public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_program_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_program_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
@@ -560,6 +584,7 @@ private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_cas_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -570,6 +595,7 @@ private:
|
||||
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_exch"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_exch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_exch_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -580,6 +606,7 @@ private:
|
||||
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -600,6 +627,7 @@ public:
|
||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
@@ -622,6 +650,7 @@ public:
|
||||
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<int> get_perm() const;
|
||||
_TRITON_DEFINE_CLONE(trans_inst)
|
||||
_TRITON_DEFINE_ACCEPT(trans_inst)
|
||||
|
||||
private:
|
||||
std::vector<int> perm_;
|
||||
@@ -634,6 +663,7 @@ private:
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(sqrt_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sqrt_inst)
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
@@ -644,6 +674,7 @@ private:
|
||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
_TRITON_DEFINE_CLONE(reduce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reduce_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -658,6 +689,7 @@ private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
_TRITON_DEFINE_CLONE(select_inst)
|
||||
_TRITON_DEFINE_ACCEPT(select_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -676,6 +708,7 @@ public:
|
||||
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class copy_from_shared_inst: public unary_inst{
|
||||
@@ -687,6 +720,7 @@ public:
|
||||
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
@@ -694,6 +728,7 @@ private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "barrier"; }
|
||||
_TRITON_DEFINE_CLONE(barrier_inst)
|
||||
_TRITON_DEFINE_ACCEPT(barrier_inst)
|
||||
|
||||
public:
|
||||
static barrier_inst* create(context &ctx, const std::string &name = "",
|
||||
@@ -708,6 +743,7 @@ private:
|
||||
make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
_TRITON_DEFINE_CLONE(make_range_dyn)
|
||||
_TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||
|
||||
public:
|
||||
static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -732,6 +768,7 @@ class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
|
||||
_TRITON_DEFINE_CLONE(make_range)
|
||||
_TRITON_DEFINE_ACCEPT(make_range)
|
||||
|
||||
public:
|
||||
static make_range *create(constant_int *first, constant_int *last);
|
||||
|
116
include/triton/ir/visitor.h
Normal file
116
include/triton/ir/visitor.h
Normal file
@@ -0,0 +1,116 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VISITOR_H_
|
||||
#define _TRITON_IR_VISITOR_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
class s_ext_inst;
|
||||
class fp_trunc_inst;
|
||||
class fp_ext_inst;
|
||||
class ui_to_fp_inst;
|
||||
class si_to_fp_inst;
|
||||
class fp_to_ui_inst;
|
||||
class fp_to_si_inst;
|
||||
class ptr_to_int_inst;
|
||||
class int_to_ptr_inst;
|
||||
class bit_cast_inst;
|
||||
class addr_space_cast_inst;
|
||||
|
||||
class return_inst;
|
||||
class cond_branch_inst;
|
||||
class uncond_branch_inst;
|
||||
|
||||
|
||||
class unmasked_load_inst;
|
||||
class masked_load_inst;
|
||||
class unmasked_store_inst;
|
||||
class masked_store_inst;
|
||||
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_program_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_exch_inst;
|
||||
class atomic_add_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class barrier_inst;
|
||||
class make_range_dyn;
|
||||
class make_range_sta;
|
||||
class make_range;
|
||||
|
||||
|
||||
|
||||
class visitor {
|
||||
public:
|
||||
virtual ~visitor() {}
|
||||
|
||||
virtual void visit_phi_node(phi_node*) = 0;
|
||||
virtual void visit_binary_operator(binary_operator*) = 0;
|
||||
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_cast_inst(trunc_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
|
||||
virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0;
|
||||
virtual void visit_masked_load_inst(masked_load_inst*) = 0;
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_retile_inst(retile_inst*) = 0;
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_program_inst(get_num_program_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -759,24 +759,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
}
|
||||
}
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false);
|
||||
bool is_inserted = tmap_.insert({v, T}).second;
|
||||
// constant range
|
||||
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
if(is_inserted && dynamic_cast<ir::make_range_sta*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
tmap_.insert({v, T});
|
||||
}
|
||||
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
@@ -1408,14 +1391,56 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
lower_elementwise(ins, ctx, fn, builder);
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
if(src->has_tile_result_or_op()) {
|
||||
lower_tile_instruction(src, builder);
|
||||
void selection::lower_value(ir::value *src, IRBuilder<> &builder, std::set<ir::value*>& seen) {
|
||||
if(!seen.insert(src).second)
|
||||
return;
|
||||
|
||||
auto *inst = dynamic_cast<ir::instruction*>(src);
|
||||
if(inst && !dynamic_cast<ir::phi_node*>(src))
|
||||
for(ir::value *op: inst->ops())
|
||||
lower_value(op, builder, seen);
|
||||
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
auto *phi = dynamic_cast<ir::phi_node*>(src);
|
||||
bool phi_inserted = phi && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
|
||||
|
||||
if(dynamic_cast<ir::make_range*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
else {
|
||||
Instruction *i = (Instruction*)llvm_value(src, builder);
|
||||
else if(dynamic_cast<ir::make_range_sta*>(src)){
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(src);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else if(inst && inst->has_tile_result_or_op()) {
|
||||
lower_tile_instruction(inst, builder);
|
||||
}
|
||||
else if(inst){
|
||||
Instruction *i = (Instruction*)llvm_value(inst, builder);
|
||||
vmap_[src] = i;
|
||||
}
|
||||
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
builder.SetInsertPoint(current);
|
||||
|
||||
// if(dynamic_cast<ir::phi_node*>(src))
|
||||
// for(ir::value *op: inst->ops())
|
||||
// lower_value(op, builder, seen);
|
||||
|
||||
|
||||
}
|
||||
|
||||
/* ----------------------------
|
||||
@@ -1508,29 +1533,29 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
||||
|
||||
// iterate over functions
|
||||
std::set<ir::value*> seen;
|
||||
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
|
||||
// create LLVM function
|
||||
llvm_fn(fn, dst_builder, dst);
|
||||
|
||||
// allocate shared memory
|
||||
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||
|
||||
// initialize layouts
|
||||
init_layouts(fn, dst_builder, sh_mem_ptr_);
|
||||
|
||||
// generate LLVM-IR code
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i)) && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
lower_instruction(i, dst_builder);
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(current);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
lower_value(i, dst_builder, seen);
|
||||
last_block[block] = dst_builder.GetInsertBlock();
|
||||
}
|
||||
}
|
||||
|
||||
// finalize double-buffering
|
||||
for(const auto& x: layouts_->get_all()) {
|
||||
if(x.second->double_buffer) {
|
||||
@@ -1588,5 +1613,646 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
}
|
||||
|
||||
|
||||
/* -----------------------------------------------------
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
* ------------------------------------------------------ */
|
||||
|
||||
|
||||
|
||||
void generator::visit_phi_node(ir::phi_node* phi) {
|
||||
Type *ty = type(phi->get_type()->get_scalar_ty());
|
||||
unsigned num_ops = phi->get_num_operands();
|
||||
for_each(phi, [&](indices_t idx){
|
||||
set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops)));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_binary_operator(ir::binary_operator*binop) {
|
||||
for_each(binop, [&](indices_t idx){
|
||||
Value *lhs = get_value(binop->get_operand(0), idx);
|
||||
Value *rhs = get_value(binop->get_operand(1), idx);
|
||||
Value *ret = builder_->Insert(BinaryOperator::Create(llvm_op(binop->get_op()), lhs, rhs));
|
||||
set_value(binop, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) {
|
||||
for_each(gep, [&](indices_t idx){
|
||||
Value *ptr = get_value(gep->get_operand(0), idx);
|
||||
std::vector<Value*> idx_vals;
|
||||
std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals),
|
||||
[&](ir::value* x){ return get_value(x, idx);});
|
||||
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty());
|
||||
Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals));
|
||||
set_value(gep, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_icmp_inst(ir::icmp_inst* icmp) {
|
||||
for_each(icmp, [&](indices_t idx){
|
||||
ir::cmp_pred_t pred = icmp->get_pred();
|
||||
Value *lhs = get_value(icmp->get_operand(0), idx);
|
||||
Value *rhs = get_value(icmp->get_operand(1), idx);
|
||||
Value *ret = builder_->Insert(CmpInst::Create(Instruction::ICmp, llvm_pred(pred), lhs, rhs));
|
||||
set_value(icmp, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) {
|
||||
for_each(fcmp, [&](indices_t idx){
|
||||
ir::cmp_pred_t pred = fcmp->get_pred();
|
||||
Value *lhs = get_value(fcmp->get_operand(0), idx);
|
||||
Value *rhs = get_value(fcmp->get_operand(1), idx);
|
||||
Value *ret = builder_->Insert(FCmpInst::Create(Instruction::FCmp, llvm_pred(pred), lhs, rhs));
|
||||
set_value(fcmp, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_cast_inst(ir::cast_inst* cast) {
|
||||
for_each(cast, [&](indices_t idx){
|
||||
Value *arg = get_value(cast->get_operand(0), idx);
|
||||
Type *dst_ty = type(cast->get_type()->get_scalar_ty());
|
||||
Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty));
|
||||
set_value(cast, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_return_inst(ir::return_inst* rr) {
|
||||
ir::value *ret_val = rr->get_return_value();
|
||||
builder_->Insert(ReturnInst::Create(*ctx_, ret_val ? ret_val : nullptr));
|
||||
}
|
||||
|
||||
void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) {
|
||||
BasicBlock *true_dest = vmap_.at(br->get_true_dest());
|
||||
BasicBlock *false_dest = vmap_.at(br->get_false_dest());
|
||||
Value *cond = vmap_.at(br->get_cond());
|
||||
builder_->Insert(BranchInst::Create(true_dest, false_dest, cond));
|
||||
}
|
||||
|
||||
void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
|
||||
BasicBlock *dest = vmap_.at(br->get_dest());
|
||||
builder_->Insert(BranchInst::Create(dest));
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
// vector loads
|
||||
std::map<unsigned, Value*> packets;
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
packets[id] = builder_->CreateLoad(ptr);
|
||||
}
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
result->set_value(idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// find vector size
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
|
||||
std::map<unsigned, Value*> packets;
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
|
||||
|
||||
ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
Value *mask = masks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
const Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
Value *result_then = builder_->CreateLoad(ptr);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
Value *current_result = nullptr;
|
||||
if(false_values){
|
||||
current_result = builder_->CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(result_then->getType()->isVectorTy())
|
||||
result_false = builder_->CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
|
||||
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
current_result = result_then;
|
||||
|
||||
// ConstantInt *cst = nullptr;
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
// llvm::Value* mask = masks->get_value(idx);
|
||||
// std::string offset = "";
|
||||
// if(cst)
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
|
||||
// Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
|
||||
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
|
||||
// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_values)
|
||||
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
// Value *current_result = builder_->CreateCall(iasm, {mask, ptr});
|
||||
|
||||
packets[id] = current_result;
|
||||
}
|
||||
});
|
||||
// extract result element
|
||||
result->for_each([&](indices_t idx){
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
// Value *res = builder_->CreateExtractElement(tmp, (linear % vector_size) % 2);
|
||||
// result->set_value(idx, res);
|
||||
result->set_value(idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
|
||||
for_each(st->get_pointer_operand(), [&](indices_t idx){
|
||||
Value *ptr = get_value(st->get_pointer_operand(), idx);
|
||||
Value *val = get_value(st->get_value_operand(), idx);
|
||||
builder_->CreateStore(val, ptr);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
|
||||
distributed_tile* scalars = (distributed_tile*)tmap_.at(st->get_value_operand());
|
||||
ir::value *mask = st->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
const Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateStore(scalar, ptr);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
|
||||
// }
|
||||
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_reshape_inst(ir::reshape_inst* reshape) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(reshape);
|
||||
ir::value* in = reshape->get_operand(0);
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
for_each(reshape, [&](indices_t out_idx){
|
||||
unsigned pos = result->get_linear_index(out_idx);
|
||||
indices_t in_idx = in_tile->get_ordered_indices(pos);
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_splat_inst(ir::splat_inst* splat) {
|
||||
Value *in = get_value(splat->get_operand(0), {});
|
||||
for_each(splat, [&](indices_t idx){
|
||||
set_value(splat, idx, in);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_broadcast_inst(ir::broadcast_inst* bcast) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(bcast);
|
||||
ir::value* in = bcast->get_operand(0);
|
||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx = out_idx;
|
||||
for(size_t k = 0; k < in_idx.size(); k++){
|
||||
if(in_shapes[k] == 1)
|
||||
in_idx[k] = builder_->getInt32(0);
|
||||
}
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_downcast_inst(ir::downcast_inst* x) {
|
||||
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder_->getInt32(0)});
|
||||
}
|
||||
|
||||
void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
|
||||
Module &module = builder_->GetInsertBlock()->getModule();
|
||||
Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis());
|
||||
vmap_[pid] = ret;
|
||||
}
|
||||
|
||||
void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
|
||||
Module &module = builder_->GetInsertBlock()->getModule();
|
||||
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
|
||||
vmap_[np] = ret;
|
||||
}
|
||||
|
||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(cas)));
|
||||
ptr = builder_->CreateBitCast(ptr, PointerType::get(builder_->getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vmap_.at(cas->get_operand(0));
|
||||
Value *cas_cmp = vmap_.at(cas->get_operand(1));
|
||||
Value *cas_val = vmap_.at(cas->get_operand(2));
|
||||
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = builder_->CreateExtractValue(old, {0});
|
||||
builder_->CreateStore(old, ptr);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
Value *res = builder_->CreateLoad(ptr);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
|
||||
void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(xchg->get_operand(0));
|
||||
Value *rmw_val = vmap_.at(xchg->get_operand(1));
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *res = builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
||||
throw std::runtime_error("unsupported");
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||
|
||||
TA->set_vector_size(4*pack_size_0_);
|
||||
TB->set_vector_size(4*pack_size_1_);
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
TC->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
fcs[key].push_back(TD->get_value(idx));
|
||||
});
|
||||
|
||||
Type *fp32_ty = builder_->getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
|
||||
Value *offset_a_i = offset_a_i_;
|
||||
Value *offset_a_k = offset_a_k_;
|
||||
Value *offset_b_j = offset_b_j_;
|
||||
Value *offset_b_k = offset_b_k_;
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||
|
||||
auto ord_a = layouts_->get(dot->get_operand(0))->order;
|
||||
auto ord_b = layouts_->get(dot->get_operand(1))->order;
|
||||
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1);
|
||||
bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1);
|
||||
|
||||
|
||||
if(is_a_row){
|
||||
offset_a_i = builder_->CreateAdd(offset_a_i, builder_->CreateURem(u_thread_id, builder_->getInt32(4)));
|
||||
offset_a_k = builder_->getInt32(0);
|
||||
}
|
||||
if(!is_b_row){
|
||||
offset_b_j = builder_->CreateAdd(offset_b_j, builder_->CreateURem(u_thread_id, builder_->getInt32(4)));
|
||||
offset_b_k = builder_->getInt32(0);
|
||||
}
|
||||
|
||||
std::string op_a = is_a_row ? "row" : "col";
|
||||
std::string op_b = is_b_row ? "row" : "col";
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = layouts_->get(dot)->fpw.at(0);
|
||||
unsigned fpw_1 = layouts_->get(dot)->fpw.at(1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = layouts_->get(dot)->wpt.at(0);
|
||||
unsigned wpt_1 = layouts_->get(dot)->wpt.at(1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
unsigned ld_fc = num_rep_i * 2;
|
||||
|
||||
|
||||
for(auto& x: fcs){
|
||||
std::vector<Value *>& fc = x.second;
|
||||
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
|
||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder_->getInt32(K);
|
||||
Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder_->CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder_->CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder_->CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder_->CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder_->CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder_->CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder_->CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder_->CreateExtractValue(nc, {7});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
TC->for_each([&](indices_t idx){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
if(i >= fcs.at(key).size())
|
||||
i = 0;
|
||||
TC->set_value(idx, fcs.at(key)[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
||||
TB->set_return_mode(false);
|
||||
|
||||
}
|
||||
void generator::visit_scanline_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
TC->for_each([&](indices_t idx){
|
||||
Value *res = TD->get_value(idx);
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
// input indices
|
||||
indices_t a_idx = {idx[0], builder_->getInt32(K)};
|
||||
indices_t b_idx = {builder_->getInt32(K), idx[1]};
|
||||
// add batching dimension
|
||||
for(size_t i = 2; i < idx.size(); i++){
|
||||
a_idx.insert(a_idx.end(), idx[i]);
|
||||
b_idx.insert(b_idx.end(), idx[i]);
|
||||
}
|
||||
// load value
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
a = builder_->CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder_->CreateFPCast(b, c_ty);
|
||||
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
TC->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_outer_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add) {
|
||||
TC->for_each([&](indices_t idx){
|
||||
Value *res = TD->get_value(idx);
|
||||
indices_t a_idx = {idx[0], builder_->getInt32(0)};
|
||||
indices_t b_idx = {builder_->getInt32(0), idx[1]};
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
if(a->getType() != c_ty)
|
||||
a = builder_->CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder_->CreateFPCast(b, c_ty);
|
||||
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
||||
TC->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
Function *fn = builder_->GetInsertBlock()->getParent();
|
||||
|
||||
distributed_tile* TC = (distributed_tile*)tmap_.at(dot);
|
||||
Module *module = fn->getParent();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
|
||||
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
||||
Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(layouts_->get(dot)->type == analysis::HMMA_884)
|
||||
visit_hmma_dot(dot, TC, TA, TB, TD, NK);
|
||||
else
|
||||
visit_scanline_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
}
|
||||
else {
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
|
||||
visit_outer_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||
shared_tile* in = (shared_tile*)tmap_.at(trans->get_operand(0));
|
||||
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), *builder_, in->get_offset(), trans->get_perm());
|
||||
tmap_[trans] = out;
|
||||
}
|
||||
|
||||
void generator::visit_sqrt_inst(ir::sqrt_inst* sqrt) {
|
||||
for_each(sqrt, [&](indices_t idx){
|
||||
Value *val = get_value(sqrt->get_operand(0), idx);
|
||||
Module* module = builder_->GetInsertBlock()->getModule();
|
||||
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()});
|
||||
Value *ret = builder_->CreateCall(sqrt, {val});
|
||||
set_value(sqrt, idx, ret);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_reduce_inst(ir::reduce_inst*) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
void generator::visit_select_inst(ir::select_inst* select) {
|
||||
for_each(select, [&](indices_t idx){
|
||||
Value *pred = get_value(select->get_operand(0), idx);
|
||||
Value *if_value = get_value(select->get_operand(1), idx);
|
||||
Value *else_value = get_value(select->get_operand(2), idx);
|
||||
Value *ret = builder_->Insert(SelectInst::Create(pred, if_value, else_value));
|
||||
set_value(select, idx, ret);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
unsigned vector_size = 1;
|
||||
auto x_order = layouts_->get(cts)->order;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
auto arg_order = layouts_->get(arg)->order;
|
||||
// tiles
|
||||
shared_tile* result = (shared_tile*)tmap_.at(cts);
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
if(x_order == arg_order){
|
||||
size_t ld = arg_order[0];
|
||||
vector_size = layouts_->get(arg)->nts.at(ld);
|
||||
}
|
||||
|
||||
std::map<unsigned, Value*> packets;
|
||||
in->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
Value *in_value = in->get_value(idx);
|
||||
if(linear % vector_size == 0)
|
||||
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
|
||||
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
|
||||
});
|
||||
in->for_each([&](indices_t idx){
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0)
|
||||
result->set_value(idx, packets[id]);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst* cfs) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(cfs);
|
||||
shared_tile* arg = (shared_tile*)tmap_.at(cfs->get_operand(0));
|
||||
result->for_each([&](indices_t idx){
|
||||
result->set_value(idx, arg->get_value(idx));
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_barrier_inst(ir::barrier_inst*) {
|
||||
Module *module = builder_->GetInsertBlock()->getModule();
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
result->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(0);
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(x);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
void generator::visit_make_range(ir::make_range* x) {
|
||||
distributed_tile *T = (distributed_tile *)tmap_.at(x);
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user