[codegen] [selection] merged selection into generator visit
This commit is contained in:
@@ -1,346 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_SELECTION_H
|
||||
#define TDL_INCLUDE_CODEGEN_SELECTION_H
|
||||
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layout;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
class target;
|
||||
|
||||
typedef std::vector<Value*> indices_t;
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class tile {
|
||||
protected:
|
||||
typedef std::vector<unsigned> shapes_t;
|
||||
|
||||
public:
|
||||
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
||||
virtual void set_value(indices_t idx, Value *v) = 0;
|
||||
virtual Value* get_value(indices_t idx) = 0;
|
||||
Type *get_ty() const { return ty_; }
|
||||
shapes_t get_shapes() const { return shapes_; }
|
||||
|
||||
protected:
|
||||
Type *ty_;
|
||||
shapes_t shapes_;
|
||||
};
|
||||
|
||||
class shared_tile: public tile {
|
||||
private:
|
||||
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
|
||||
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
||||
|
||||
|
||||
public:
|
||||
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
|
||||
void set_vector_size(unsigned vector_size);
|
||||
void set_return_mode(bool return_vector);
|
||||
void set_value(indices_t, Value *);
|
||||
Value* get_ptr_to(indices_t idx);
|
||||
Value* get_value(indices_t idx);
|
||||
Value* get_pointer() { return ptr_; }
|
||||
Value* get_offset() { return offset_; }
|
||||
const std::vector<int>& get_perm() { return perm_; }
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
|
||||
|
||||
private:
|
||||
Value *ptr_;
|
||||
bool return_vector_;
|
||||
Builder &builder_;
|
||||
Value *offset_;
|
||||
std::map<indices_t, Value*> ptr_cache_;
|
||||
unsigned vector_size_;
|
||||
std::vector<int> order_;
|
||||
std::vector<int> perm_;
|
||||
};
|
||||
|
||||
// Distribtued tile
|
||||
class distributed_tile: public tile{
|
||||
typedef std::vector<distributed_axis> axes_t;
|
||||
typedef std::vector<indices_t> ordered_indices_vec_t;
|
||||
typedef std::map<indices_t, unsigned> indices_map_t;
|
||||
typedef std::map<indices_t, Value*> values_map_t;
|
||||
|
||||
private:
|
||||
void init_indices();
|
||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||
|
||||
public:
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
|
||||
void set_value(indices_t idx, Value *v);
|
||||
Value* get_value(indices_t idx);
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn);
|
||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
||||
|
||||
private:
|
||||
axes_t axes_;
|
||||
std::vector<int> order_;
|
||||
indices_map_t indices_;
|
||||
values_map_t values_;
|
||||
ordered_indices_vec_t ordered_indices_;
|
||||
size_t vector_size_;
|
||||
Builder &builder_;
|
||||
};
|
||||
|
||||
class machine_layout_t {
|
||||
public:
|
||||
virtual tile* create(ir::value *v) = 0;
|
||||
};
|
||||
|
||||
class machine_layout_shared_t: public machine_layout_t {
|
||||
public:
|
||||
machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap);
|
||||
|
||||
tile* create(ir::value *v);
|
||||
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
analysis::allocation* alloc_;
|
||||
Value *&sh_mem_ptr_;
|
||||
analysis::layout_t* layout_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
|
||||
Value *offset_;
|
||||
Value *ptr_;
|
||||
Value *pre_ptr_;
|
||||
Value *next_ptr_;
|
||||
|
||||
};
|
||||
|
||||
class machine_layout_distributed_t: public machine_layout_t {
|
||||
public:
|
||||
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_t* layout);
|
||||
|
||||
tile* create(ir::value *v);
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
Type *ty_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
analysis::layout_t* layout_;
|
||||
};
|
||||
|
||||
|
||||
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_hmma_884_t* layout);
|
||||
Value *offset_a_i_, *offset_a_k_;
|
||||
Value *offset_b_j_, *offset_b_k_;
|
||||
unsigned pack_size_0_;
|
||||
unsigned pack_size_1_;
|
||||
unsigned num_packs_0_;
|
||||
unsigned num_packs_1_;
|
||||
};
|
||||
|
||||
class machine_layout_scanline_t: public machine_layout_distributed_t {
|
||||
public:
|
||||
machine_layout_scanline_t(Module *mod, Builder *builder,
|
||||
target *tgt, Type *ty,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::layout_scanline_t* layout);
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
||||
Value* get_value(ir::value *x, const indices_t& idx);
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
|
||||
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
|
||||
void finalize_shared_layout(analysis::layout_shared_t*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
public:
|
||||
generator(Module *dst,
|
||||
analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
|
||||
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
|
||||
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
|
||||
void visit_layout_scanline(analysis::layout_scanline_t*);
|
||||
void visit_layout_shared(analysis::layout_shared_t*);
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
std::unique_ptr<Builder> builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
std::map<ir::value *, Value *> vmap_;
|
||||
std::map<ir::value *, tile *> tmap_;
|
||||
target *tgt_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *sh_mem_ptr_;
|
||||
unsigned num_warps_;
|
||||
|
||||
std::set<ir::value*> seen_;
|
||||
};
|
||||
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::layout *layouts, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
target *tgt_;
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -79,12 +79,11 @@ private:
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
public:
|
||||
generator(Module *dst,
|
||||
analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
target *tgt,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
@@ -143,6 +142,8 @@ public:
|
||||
void visit_layout_scanline(analysis::layout_scanline_t*);
|
||||
void visit_layout_shared(analysis::layout_shared_t*);
|
||||
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
|
@@ -1,70 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_SELECTION_H_
|
||||
#define _TRITON_SELECTION_SELECTION_H_
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
class Value;
|
||||
}
|
||||
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
// typedef
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Value Value;
|
||||
// forward
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class align;
|
||||
class allocation;
|
||||
class axes;
|
||||
class layout;
|
||||
}
|
||||
class target;
|
||||
class tile;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::layout *layouts, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
target *tgt_;
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -9,7 +9,7 @@
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
// codegen
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
|
@@ -109,18 +109,18 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
|
||||
}
|
||||
|
||||
|
||||
inline Type *type(ir::type *ty, LLVMContext &ctx) {
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = type(tt->get_return_ty(), ctx);
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return type(t, ctx);});
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = type(ty->get_pointer_element_ty(), ctx);
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
@@ -173,29 +173,14 @@ inline bool is_trans(ir::value *v) {
|
||||
|
||||
|
||||
|
||||
generator::generator(Module *dst,
|
||||
analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
unsigned num_warps)
|
||||
: ctx_(&dst->getContext()), mod_(dst),
|
||||
builder_(new Builder(dst->getContext())),
|
||||
a_axes_(a_axes), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc),
|
||||
num_warps_(num_warps) {
|
||||
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
generator::generator(analysis::axes *a_axes,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
target *tgt,
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc),
|
||||
tgt_(tgt), num_warps_(num_warps) {
|
||||
|
||||
}
|
||||
|
||||
@@ -226,7 +211,7 @@ void generator::visit_value(ir::value* v) {
|
||||
}
|
||||
|
||||
void generator::visit_phi_node(ir::phi_node* phi) {
|
||||
Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_);
|
||||
Type *ty = llvm_type(phi->get_type()->get_scalar_ty(), *ctx_);
|
||||
unsigned num_ops = phi->get_num_operands();
|
||||
for_each(phi, [&](indices_t idx){
|
||||
set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops)));
|
||||
@@ -248,7 +233,7 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) {
|
||||
std::vector<Value*> idx_vals;
|
||||
std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals),
|
||||
[&](ir::value* x){ return get_value(x, idx);});
|
||||
Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_);
|
||||
Type *source_ty = llvm_type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_);
|
||||
Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals));
|
||||
set_value(gep, idx, ret);
|
||||
});
|
||||
@@ -277,7 +262,7 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) {
|
||||
void generator::visit_cast_inst(ir::cast_inst* cast) {
|
||||
for_each(cast, [&](indices_t idx){
|
||||
Value *arg = get_value(cast->get_operand(0), idx);
|
||||
Type *dst_ty = type(cast->get_type()->get_scalar_ty(), *ctx_);
|
||||
Type *dst_ty = llvm_type(cast->get_type()->get_scalar_ty(), *ctx_);
|
||||
Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty));
|
||||
set_value(cast, idx, ret);
|
||||
});
|
||||
@@ -726,7 +711,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
ir::value *D = dot->get_operand(2);
|
||||
|
||||
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
||||
Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_);
|
||||
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = 1;
|
||||
@@ -851,22 +836,22 @@ void generator::visit_make_range(ir::make_range* x) {
|
||||
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *ud) {
|
||||
vmap_[ud] = llvm::UndefValue::get(type(ud->get_type(), *ctx_));
|
||||
vmap_[ud] = llvm::UndefValue::get(llvm_type(ud->get_type(), *ctx_));
|
||||
}
|
||||
|
||||
void generator::visit_constant_int(ir::constant_int *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
vmap_[cst] = ConstantInt::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *cst){
|
||||
Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_);
|
||||
vmap_[cst] = ConstantFP::get(ty, cst->get_value());
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
|
||||
Type *element_ty = type(alloc->get_type()->get_pointer_element_ty(), *ctx_);
|
||||
Type *element_ty = llvm_type(alloc->get_type()->get_pointer_element_ty(), *ctx_);
|
||||
Type *array_ty = llvm::ArrayType::get(element_ty, size);
|
||||
Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
|
||||
nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
|
||||
@@ -876,7 +861,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_);
|
||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), *ctx_);
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> fn_args_ty;
|
||||
@@ -925,11 +910,11 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
@@ -1026,5 +1011,29 @@ void generator::finalize_phi_node(ir::phi_node *phi) {
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
mod_ = &dst;
|
||||
ctx_ = &dst.getContext();
|
||||
builder_ = new Builder(*ctx_);
|
||||
// allocate shared memory
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
// allocate constant memory
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
visit_alloc_const(x);
|
||||
// visit functions
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
visit_function(fn);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
@@ -13,18 +14,18 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *type(ir::type *ty, LLVMContext &ctx) {
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = type(tt->get_return_ty(), ctx);
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return type(t, ctx);});
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = type(ty->get_pointer_element_ty(), ctx);
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
@@ -80,7 +81,7 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
|
||||
auto shapes = layout_->shapes;
|
||||
shapes[order[0]] += layout_->pad;
|
||||
|
||||
Type* ty = type(layout_->ty, builder_->getContext());
|
||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
||||
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
@@ -113,7 +114,7 @@ tile* machine_layout_shared_t::create(ir::value *v) {
|
||||
auto order = layout_->order;
|
||||
auto shapes = layout_->shapes;
|
||||
shapes[order[0]] += layout_->pad;
|
||||
Type* ty = type(layout_->ty, builder_->getContext());
|
||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
||||
// double-buffered
|
||||
if(layout_->double_buffer) {
|
||||
if(v == layout_->double_buffer->phi)
|
||||
@@ -135,7 +136,7 @@ machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder
|
||||
}
|
||||
|
||||
tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
|
@@ -1,20 +0,0 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
void selection::run(ir::module &src, Module &dst) {
|
||||
generator gen(&dst, a_axes_, tgt_, layouts_, alignment_, alloc_, num_warps_ );
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
gen.visit_alloc_const(x);
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
gen.visit_function(fn);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -13,7 +13,7 @@
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
@@ -217,7 +217,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
// ir::print(module, std::cout);
|
||||
peephole.run(module);
|
||||
@@ -243,7 +243,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
selection.run(module, *llvm);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
|
Reference in New Issue
Block a user