From 297d1a99d1efa23ecf72e931358591e81478fcee Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 7 Jan 2019 22:49:37 -0500 Subject: [PATCH] [code generation] adding missing files --- lib/codegen/allocation.cpp | 129 ++++++++++ lib/codegen/layout.cpp | 55 +++++ lib/codegen/liveness.cpp | 42 ++++ lib/codegen/loop_info.cpp | 0 lib/codegen/selection.cpp | 189 +++++++++++++++ lib/codegen/tune.cpp | 468 +++++++++++++++++++++++++++++++++++++ 6 files changed, 883 insertions(+) create mode 100644 lib/codegen/allocation.cpp create mode 100644 lib/codegen/layout.cpp create mode 100644 lib/codegen/liveness.cpp create mode 100644 lib/codegen/loop_info.cpp create mode 100644 lib/codegen/selection.cpp create mode 100644 lib/codegen/tune.cpp diff --git a/lib/codegen/allocation.cpp b/lib/codegen/allocation.cpp new file mode 100644 index 000000000..2dcc4fbc9 --- /dev/null +++ b/lib/codegen/allocation.cpp @@ -0,0 +1,129 @@ +#include "codegen/allocation.h" +#include "codegen/liveness.h" +#include "codegen/layout.h" +#include "codegen/loop_info.h" +#include "ir/basic_block.h" +#include "ir/type.h" +#include "ir/value.h" +#include "ir/function.h" +#include "ir/instructions.h" + +namespace tdl{ +namespace codegen{ + +unsigned allocation::get_num_bytes(ir::value *x) const { + ir::type *ty = x->get_type(); + unsigned num_elements = ty->get_tile_num_elements(); + if(has_double_buffer(x)) + num_elements *= 2; + return num_elements * ty->get_scalar_ty()->get_size_in_bits(); +} + + +void allocation::run(ir::function &fn){ + using std::max; + using std::min; + typedef std::multimap triples_map_type; + + // Fill double buffering info + for(ir::basic_block *block: fn.blocks()) + for(ir::instruction *v: block->get_inst_list()) + // If requires shared memory + if(layout_->get_num_shared_views(v) && + loop_info_->get_loop_for(block)) + double_buffer_.insert(v); + + std::vector I; + for(auto x: liveness_->intervals()) + I.push_back(x.first); + std::vector J = I; + + triples_map_type H; + H.insert({0, segment{0, 100}}); + + std::vector V; + std::map starts; + while(!J.empty()){ + auto h_it = H.begin(); + unsigned w = h_it->first; + segment xh = h_it->second; + H.erase(h_it); + auto j_it = std::find_if(J.begin(), J.end(), [&](ir::value *JJ){ + segment xj = liveness_->get_interval(JJ); + bool res = xj.intersect(xh); + for(auto val: H) + res = res && !val.second.intersect(xj); + return res; + }); + if(j_it != J.end()){ + unsigned size = get_num_bytes(*j_it); + segment xj = liveness_->get_interval(*j_it); + starts[*j_it] = w; + H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); + if(xh.start < xj.start) + H.insert({w, segment{xh.start, xj.end}}); + if(xj.end < xh.end) + H.insert({w, segment{xj.start, xh.end}}); + V.push_back(*j_it); + J.erase(j_it); + } + } + + + // Build interference graph + std::map> interferences; + for(ir::value *x: V) + for(ir::value *y: V){ + if(x == y) + continue; + unsigned X0 = starts[x], Y0 = starts[y]; + unsigned NX = get_num_bytes(x); + unsigned NY = get_num_bytes(y); + segment XS = {X0, X0 + NX}; + segment YS = {Y0, Y0 + NY}; + if(liveness_->get_interval(x).intersect(liveness_->get_interval(y)) + && XS.intersect(YS)) + interferences[x].insert(y); + } + + // Initialize colors + std::map colors; + for(ir::value *X: V) + colors[X] = (X==V[0])?0:-1; + + // First-fit coloring + std::vector available(V.size()); + for(ir::value *x: V){ + // Non-neighboring colors are available + std::fill(available.begin(), available.end(), true); + for(ir::value *Y: interferences[x]){ + int color = colors[Y]; + if(color >= 0) + available[color] = false; + } + // Assigns first available color + auto It = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), It); + } + + // Finalize allocation + for(ir::value *x: V){ + unsigned Adj = 0; + for(ir::value *y: interferences[x]) + Adj = std::max(Adj, starts[y] + get_num_bytes(y)); + offsets_[x] = starts[x] + colors[x] * Adj; + if(auto *phi = dynamic_cast(x)) + for(ir::value *px: phi->ops()){ + if(offsets_.find(px) == offsets_.end()) + offsets_[px] = offsets_[x]; + } + } + + // Save maximum size of induced memory space + allocated_size_ = 0; + for(auto &x: offsets_) + allocated_size_ = std::max(allocated_size_, x.second + get_num_bytes(x.first)); +} + +} +} diff --git a/lib/codegen/layout.cpp b/lib/codegen/layout.cpp new file mode 100644 index 000000000..cdddb1d17 --- /dev/null +++ b/lib/codegen/layout.cpp @@ -0,0 +1,55 @@ +#include "codegen/layout.h" +#include "ir/function.h" +#include "ir/basic_block.h" +#include "ir/instructions.h" + +namespace tdl{ +namespace codegen{ + + +shared_view_info layout::get_shared_view(ir::value *v, unsigned idx){ + return shared_views_.at(v)[idx]; +} + +unsigned layout::get_num_shared_views(ir::value *v){ + return shared_views_.at(v).size(); +} + +// Phi node +void layout::add_phi_nodes(ir::value *v){ + if(ir::phi_node *phi = dynamic_cast(v)) + if(shared_views_.find(phi) != shared_views_.end()) + for(ir::value *v: phi->ops()){ + shared_views_[v] = shared_views_[phi]; + for(shared_view_info &info: shared_views_[v]) + info.has_dedicated_storage = false; + } +} + +// Memory Layout +void layout::add_shared_views(ir::value *v){ + // GEMM has shared inputs + if(dynamic_cast(v)) + shared_views_[v].push_back({v, true}); + if(dynamic_cast(v)) + shared_views_[v].push_back({v, true}); +} + +// Entry point +bool layout::run(ir::function &fn) { + // Non-phis + for(ir::basic_block *block: fn.blocks()) + for(ir::instruction *instr: block->get_inst_list()) { + add_shared_views(instr); + } + // Phi nodes + for(ir::basic_block *block: fn.blocks()) + for(ir::instruction *instr: block->get_inst_list()) { + add_phi_nodes(instr); + } + // Done + return false; +} + +} +} diff --git a/lib/codegen/liveness.cpp b/lib/codegen/liveness.cpp new file mode 100644 index 000000000..0e56aac03 --- /dev/null +++ b/lib/codegen/liveness.cpp @@ -0,0 +1,42 @@ +#include "codegen/liveness.h" +#include "codegen/layout.h" +#include "ir/basic_block.h" +#include "ir/function.h" +#include "ir/instructions.h" +#include "ir/value.h" + +namespace tdl{ +namespace codegen{ + + +// Entry point +void liveness::run(ir::function *fn) { + + // Assigns index to each instruction + slot_index index = 0; + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *instr: block->get_inst_list()){ + index += 1; + indices_.insert({instr, index}); + } + + // Liveness analysis + // Creates live intervals + for(auto i: indices_){ + ir::value *v = i.first; + if(!layouts_->get_num_shared_views(v)) + continue; + if(!layouts_->get_shared_view(v, 0).has_dedicated_storage) + continue; + unsigned start = i.second; + unsigned end = start; + for(ir::value *u: v->get_users()){ + start = std::min(start, indices_.at(u)); + end = std::max(end, indices_.at(u)); + } + intervals_[v] = segment{start, end}; + } +} + +} +} diff --git a/lib/codegen/loop_info.cpp b/lib/codegen/loop_info.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp new file mode 100644 index 000000000..edf48262c --- /dev/null +++ b/lib/codegen/selection.cpp @@ -0,0 +1,189 @@ +#include "codegen/selection.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/IRBuilder.h" +#include "ir/context.h" +#include "ir/module.h" +#include "ir/function.h" +#include "ir/type.h" + + +namespace tdl{ +namespace codegen{ + +using namespace llvm; + + +/* convert ir::type to Type */ +Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { + // function + if(auto* tt = dynamic_cast(ty)){ + Type *return_ty = llvm_type(tt->get_return_ty(), ctx); + std::vector param_tys; + std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), + [this,&ctx](ir::type* t){ return llvm_type(t, ctx);}); + return FunctionType::get(return_ty, param_tys, false); + } + // pointer + if(ty->is_pointer_ty()){ + Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); + unsigned addr_space = ty->get_pointer_address_space(); + return PointerType::get(elt_ty, addr_space); + } + // integer + if(ty->is_integer_ty()){ + unsigned bitwidth = ty->get_integer_bitwidth(); + return IntegerType::get(ctx, bitwidth); + } + // primitive types + switch(ty->get_type_id()){ + case ir::type::VoidTyID: return Type::getVoidTy(ctx); + case ir::type::HalfTyID: return Type::getHalfTy(ctx); + case ir::type::FloatTyID: return Type::getFloatTy(ctx); + case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); + case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); + case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); + case ir::type::LabelTyID: return Type::getLabelTy(ctx); + case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); + case ir::type::TokenTyID: return Type::getTokenTy(ctx); + default: break; + } + // unknown type + throw std::runtime_error("unknown conversion from ir::type to Type"); +} + +/* convert ir::constant to Constant */ +Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) { + Type *dst_ty = llvm_type(cst->get_type(), ctx); + if(auto* cc = dynamic_cast(cst)) + return ConstantInt::get(dst_ty, cc->get_value()); + if(auto* cc = dynamic_cast(cst)) + return ConstantFP::get(dst_ty, cc->get_value()); + // unknown constant + throw std::runtime_error("unknown conversion from ir::constant to Constant"); +} + + +/* convert ir::instruction to Instruction */ +Instruction *selection::llvm_inst(ir::instruction *inst, LLVMContext & ctx) { + auto value = [&](ir::value *x) { return llvm_value(x, ctx); }; + auto block = [&](ir::basic_block *x) { return bmap_.at(x); }; + auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; + if(auto* ii = dynamic_cast(inst)){ + BasicBlock *true_dest = block(ii->get_true_dest()); + BasicBlock *false_dest = block(ii->get_false_dest()); + Value *cond = value(ii->get_cond()); + return BranchInst::Create(true_dest, false_dest, cond); + } + if(auto* ii = dynamic_cast(inst)){ + BasicBlock *dest = block(ii->get_dest()); + return BranchInst::Create(dest); + } + if(auto* ii = dynamic_cast(inst)){ + Type *ty = type(ii->get_type()); + unsigned num_ops = ii->get_num_operands(); + return PHINode::Create(ty, num_ops, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + ir::value *ret_val = ii->get_return_value(); + return ReturnInst::Create(ctx, ret_val?value(ret_val):nullptr); + } + if(auto* ii = dynamic_cast(inst)){ + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + CmpInst::Predicate pred = ii->get_pred(); + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + CmpInst::Predicate pred = ii->get_pred(); + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + Value *arg = value(ii->get_operand(0)); + Type *dst_ty = type(ii->get_type()); + return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + std::vector idx_vals; + std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals), + [&value](ir::value* x){ return value(x);}); + Type *source_ty = type(ii->get_source_elt_ty()); + Value *arg = value(ii->get_operand(0)); + return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name()); + } + if(ir::load_inst* ii = dynamic_cast(inst)){ + Value *ptr = value(ii->get_pointer_operand()); + return new LoadInst(ptr, ii->get_name()); + } + // unknown instruction + throw std::runtime_error("unknown conversion from ir::type to Type"); +} + +Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) { + if(vmap_.find(v) != vmap_.end()) + return vmap_.at(v); + // create operands + if(auto *uu = dynamic_cast(v)) + for(ir::value* u: uu->ops()) + vmap_[u] = llvm_value(u, ctx); + if(auto *cc = dynamic_cast(v)) + return llvm_constant(cc, ctx); + // instruction + if(auto *ii = dynamic_cast(v)) + return llvm_inst(ii, ctx); + // unknown value + throw std::runtime_error("unknown conversion from ir::value to Value"); +} + +void selection::run(ir::module &src, Module &dst){ + vmap_.clear(); + bmap_.clear(); + LLVMContext &dst_ctx = dst.getContext(); + IRBuilder<> dst_builder(dst_ctx); + // iterate over functions + for(ir::function *fn: src.get_function_list()) { + // create LLVM function + FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx); + Function *dst_fn = Function::Create(fn_ty, Function::ExternalLinkage, "kernel", &dst); + // map parameters + for(unsigned i = 0; i < fn->args().size(); i++) + vmap_[fn->args()[i]] = &*(dst_fn->arg_begin() + i); + // create blocks + for(ir::basic_block *block: fn->blocks()) { + BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn); + bmap_[block] = dst_block; + } + // iterate through block + for(ir::basic_block *block: fn->blocks()) { + dst_builder.SetInsertPoint(bmap_[block]); + for(ir::instruction *inst: block->get_inst_list()) { + Instruction *dst_inst = llvm_inst(inst, dst_ctx); + vmap_[inst] = dst_inst; + dst_builder.Insert(dst_inst); + } + } + // add phi operands + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *inst: block->get_inst_list()) + if(auto *phi = dynamic_cast(inst)){ + PHINode *dst_phi = (PHINode*)vmap_.at(phi); + for(unsigned i = 0; i < phi->get_num_incoming(); i++){ + ir::value *inc_val = phi->get_incoming_value(i); + ir::basic_block *inc_block = phi->get_incoming_block(i); + Value *llvm_inc_val = llvm_value(inc_val, dst_ctx); + BasicBlock *llvm_block = bmap_[inc_block]; + dst_phi->addIncoming(llvm_inc_val, llvm_block); + } + } + } +} + + +} +} diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp new file mode 100644 index 000000000..6e646a4e3 --- /dev/null +++ b/lib/codegen/tune.cpp @@ -0,0 +1,468 @@ +//#include "codegen/tune.h" + +//namespace tdl{ +//namespace codegen{ + + +//// Layout binding pass +//class TLVMAddTunerConstraints: public FunctionPass { +//public: +// static char ID; +// TLVMAddTunerConstraints(): FunctionPass(ID){ } + +// void getAnalysisUsage(AnalysisUsage & AU) const override; +// bool runOnFunction(Function &F) override; +//}; + +//// Initialization +//char TLVMAddTunerConstraints::ID = 0; +//INITIALIZE_PASS_BEGIN(TLVMAddTunerConstraints, "tlvm-add-tuner-constraints", +// "Add Tuner Constraints (TLVM)", false, true) +//INITIALIZE_PASS_END(TLVMAddTunerConstraints, "tlvm-add-tuner-constraints", +// "Add Tuner Constraints (TLVM)", false, true) +//FunctionPass *llvm::createTLVMAddTunerConstraintsPass() { return new TLVMAddTunerConstraints(); } + +//// Analysis usage +//void TLVMAddTunerConstraints::getAnalysisUsage(AnalysisUsage &AU) const { +// AU.setPreservesAll(); +// FunctionPass::getAnalysisUsage(AU); +//} + + +//inline unsigned MDRead(MDNode* Node){ +// Metadata *MD = Node->getOperand(0).get(); +// Constant *Cst = ((ConstantAsMetadata*)MD)->getValue(); +// unsigned Result = Cst->getUniqueInteger().getZExtValue(); +// return Result; +//} + +//inline unsigned getNumGT1Dim(Instruction &I){ +// unsigned Res = 0; +// for(unsigned K = 0; K < I.getType()->getTileNumDimensions(); K++) +// if(MDRead(I.getMetadata("nvvm.param.shape.d" + itostr(K))) > 1) +// Res++; +// return Res; +//} +//// Run +//bool TLVMAddTunerConstraints::runOnFunction(Function &F) { +// LLVMContext &Ctx = F.getContext(); + +// DenseMap Refs; +// for(Function::iterator::value_type &BB: F) +// for(Instruction &I : BB) +// if(isTLVMValue(&I)){ +// SmallVector, 4> MDs; +// I.getAllMetadata(MDs); +// for(auto &X: MDs){ +// if(MDRead(X.second)==1) +// continue; +// Instruction *&Ref = Refs[X.second]; +// if(!Ref || getNumGT1Dim(I) > getNumGT1Dim(*Ref)) +// Ref = &I; +// } +// } +// SmallVector Grids; +// for(auto &R: Refs) +// if(std::find(Grids.begin(), Grids.end(), R.second) == Grids.end()) +// Grids.push_back(R.second); + + +// Instruction *FirstTile = Grids.front(); +// for(Instruction *I: Grids){ +// Type *Ty = I->getType(); +// size_t NumDim = Ty->getTileNumDimensions(); + +// // For each dimension, the product of layout components +// // must divide shape +// for(size_t K = 0; K < NumDim; K++){ +// unsigned Shape = MDRead(I->getMetadata("nvvm.param.shape.d" + itostr(K))); +// unsigned S0 = MDRead(I->getMetadata("nvvm.param.layout.p0.d" + itostr(K))); +// unsigned S1 = MDRead(I->getMetadata("nvvm.param.layout.p1.d" + itostr(K))); +// unsigned S2 = MDRead(I->getMetadata("nvvm.param.layout.p2.d" + itostr(K))); +// bool Constraint = Shape % (S0*S1*S2)== 0; +// Constant *Cst = Constraint?ConstantInt::getTrue(Ctx):ConstantInt::getFalse(Ctx); +// I->setMetadata("nvvm.constraint.shape.d" + itostr(K), MDNode::get(Ctx, ConstantAsMetadata::get(Cst))); +// }; +// // The number of threads per warp is 32 +// { +// int NumThreads = 1; +// for(size_t K = 0; K < NumDim; K++){ +// unsigned PC = MDRead(I->getMetadata("nvvm.param.layout.p1.d" + itostr(K))); +// NumThreads *= PC; +// } +// bool Constraint = NumThreads==32; +// Constant *Cst = Constraint?ConstantInt::getTrue(Ctx):ConstantInt::getFalse(Ctx); +// I->setMetadata("nvvm.constraint.threads", MDNode::get(Ctx, ConstantAsMetadata::get(Cst))); +// } +// // The number of warps required by the layout is the same +// // for all tiles in the function +// { +// int NumWarps = 1; +// int RefNumWarps = 1; +// for(size_t K = 0; K < NumDim; K++){ +// unsigned PC = MDRead(I->getMetadata("nvvm.param.layout.p2.d" + itostr(K))); +// unsigned PR = MDRead(FirstTile->getMetadata("nvvm.param.layout.p2.d" + itostr(K))); +// NumWarps *= PC; +// RefNumWarps *= PR; +// } +// bool Constraint = NumWarps==RefNumWarps; +// Constant *Cst = Constraint?ConstantInt::getTrue(Ctx):ConstantInt::getFalse(Ctx); +// I->setMetadata("nvvm.constraint.warps", MDNode::get(Ctx, ConstantAsMetadata::get(Cst))); +// }; +// } +// return true; +//} + + +//// Layout binding pass +//class TLVMAddTunerParams: public FunctionPass { +//private: +// enum CType{ +// Layout = 0, Shape = 1 +// }; +// // Params pool +// SmallVector LParamsPool; +// // Constraints +// typedef std::pair CNodeType; +// typedef DenseMap> CGraphType; +// // Layout constraints +// CGraphType LCGraph; +// DenseSet LCNodes; +// // Shape constraints +// CGraphType SCGraph; +// DenseSet SCNodes; +// // Relational +// std::map, std::function> ExtraParams; +// DenseSet Constants; + +// void addConstraint(CNodeType X, CNodeType Y, CType CT); +// void initCPhi(Instruction *I); +// void initCGraph(Instruction *V); +// void connectedComponents(CNodeType X, ArrayRef Vals, CType CT, DenseSet &Nodes, CGraphType &Graph); + +//public: +// static char ID; +// TLVMAddTunerParams(): FunctionPass(ID){ } + +// void getAnalysisUsage(AnalysisUsage & AU) const override; +// bool runOnFunction(Function &F) override; + +//private: +// std::map, Constant*> KnownParams; +//}; + +//// Initialization +//char TLVMAddTunerParams::ID = 0; +//INITIALIZE_PASS_BEGIN(TLVMAddTunerParams, "tlvm-add-tuner-parameters", +// "Add Tuner Parameters (TLVM)", false, true) +//INITIALIZE_PASS_END(TLVMAddTunerParams, "tlvm-add-tuner-parameters", +// "Add Tuner Parameters (TLVM)", false, true) +//FunctionPass *llvm::createTLVMAddTunerParamsPass() { return new TLVMAddTunerParams(); } + +//// Analysis usage +//void TLVMAddTunerParams::getAnalysisUsage(AnalysisUsage &AU) const { +// AU.setPreservesAll(); +// FunctionPass::getAnalysisUsage(AU); +//} + +//void TLVMAddTunerParams::addConstraint(CNodeType X, CNodeType Y, CType CT){ +// // Layout Constraint +// if(CT == Layout){ +// LCGraph[X].insert(Y); +// LCGraph[Y].insert(X); +// LCNodes.insert(X); +// LCNodes.insert(Y); +// } +// if(CT == Shape || CT == Layout){ +// SCGraph[X].insert(Y); +// SCGraph[Y].insert(X); +// SCNodes.insert(X); +// SCNodes.insert(Y); +// } +//} + +//void TLVMAddTunerParams::initCPhi(Instruction *I){ +// unsigned NumDim = 0; +// // Phi Nodes: all the incoming value share the result layout +// if(PHINode *Phi = dyn_cast(I)){ +// Type *Ty = Phi->getType(); +// NumDim = Ty->getTileNumDimensions(); +// unsigned NumInc = Phi->getNumIncomingValues(); +// for(unsigned PI = 0; PI < NumInc; PI++){ +// Value *Inc = Phi->getIncomingValue(PI); +// for(unsigned K = 0; K < NumDim; K++){ +// CType CT = (LCGraph.find({Inc,K}) != LCGraph.end() || +// LCGraph.find({Phi,K}) != LCGraph.end())?Layout:Shape; +// addConstraint({Phi, K}, {Inc, K}, CT); +// } +// } +// } +//} + +//void TLVMAddTunerParams::initCGraph(Instruction *I) { +// unsigned NumDim = 0; +// LLVMContext &Context = I->getContext(); +// Constant *_1 = ConstantInt::get(Type::getInt32Ty(Context), 1); +// // Function call +// if(CallInst *Call = dyn_cast(I)) +// if(Function *Callee = Call->getCalledFunction()){ +// Intrinsic::ID IntrinsicID = Callee->getIntrinsicID(); +// switch (IntrinsicID) { +// // Outer +// case Intrinsic::tlvm_outer_add: LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_outer_and: { +// addConstraint({Call, 0}, {Call->getOperand(0), 0}, Layout); +// addConstraint({Call, 1}, {Call->getOperand(1), 0}, Layout); +// break; +// } +// // Slice +// case Intrinsic::tlvm_read_slice_x: LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_read_slice_y: { +// addConstraint({Call, 0}, {Call->getOperand(0), 0}, Shape); +// break; +// } +// // Range +// case Intrinsic::tlvm_range: { +// addConstraint({Call, 0}, {Call->getOperand(1), 0}, Shape); +// break; +// } +// // GetTilePtr +// case Intrinsic::tlvm_gtp_2d: NumDim++; LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_gtp_1d: NumDim++; { +// Value *Offset = Call->getOperand(1); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({Call, K}, {Offset, K}, Layout); +// } +// break; +// } +// // SlideTilePtr: Pointer shares result layout +// case Intrinsic::tlvm_stp_2d: NumDim++; LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_stp_1d: NumDim++; { +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({Call, K}, {Call->getOperand(0), K}, Layout); +// addConstraint({Call, K}, {Call->getOperand(1), K}, Layout); +// } +// break; +// } +// // Transpose +// case Intrinsic::tlvm_transpose_2d: NumDim++; NumDim++; { +// Value *Op = Call->getOperand(0); +// addConstraint({Call, 0}, {Op, 1}, Shape); +// addConstraint({Call, 1}, {Op, 0}, Shape); +// break; +// } +// // Reshape +// case Intrinsic::tlvm_reshape_2d: NumDim++; NumDim++; { +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Call, K}, {Call->getOperand(1 + K), 0}, Shape); +// break; +// } +// // Reshape distributed +// case Intrinsic::tlvm_reshape_2d_1d: NumDim++; NumDim++; { +// size_t Current = 0; +// for(unsigned K = 0; K < NumDim; K++){ +// if(Call->getOperand(1 + K) == _1) +// addConstraint({Call, K}, {_1, 0}, Layout); +// else +// addConstraint({Call, K}, {Call->getOperand(0), Current++}, Layout); +// } +// break; +// } +// // Broadcast +// case Intrinsic::tlvm_broadcast_2d: NumDim++; LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_broadcast_1d: NumDim++; { +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Call, K}, {Call->getOperand(1 + K), 0}, Shape); +// break; +// } +// // Splat +// case Intrinsic::tlvm_splat_2d: NumDim++; LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_splat_1d: NumDim++; { +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Call, K}, {Call->getOperand(K), 0}, Shape); +// break; +// } + +// case Intrinsic::tlvm_load:{ +// NumDim = Call->getType()->getTileNumDimensions(); +// Value *Ptr = Call->getOperand(0); +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Call, K}, {Ptr, K}, Layout); +// break; +// } + +// // Masked Load +// case Intrinsic::tlvm_masked_load: { +// NumDim = Call->getType()->getTileNumDimensions(); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({Call, K}, {Call->getOperand(0), K}, Layout); +// addConstraint({Call, K}, {Call->getOperand(1), K}, Layout); +// } +// break; +// } +// // Masked store +// case Intrinsic::tlvm_atomic_load_add_f32: LLVM_FALLTHROUGH; +// case Intrinsic::tlvm_masked_store: { +// Value *Val = Call->getOperand(0); +// Value *Ptr = Call->getOperand(1); +// Value *Mask = Call->getOperand(2); +// NumDim = Val->getType()->getTileNumDimensions(); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({Val, K}, {Ptr, K}, Layout); +// addConstraint({Val, K}, {Mask, K}, Layout); +// } +// break; +// } +// // Set Mask +// case Intrinsic::tlvm_set_mask_2d: NumDim++; NumDim++; { +// for(unsigned K = 0; K < NumDim; K++){ +// Value *Op = Call->getOperand(NumDim + K); +// addConstraint({Call, K}, {Op, 0}, Layout); +// } +// break; +// } +// // MMA +// // A shares first axis with C +// // B shares last axis with C +// case Intrinsic::tlvm_mma_nn: +// case Intrinsic::tlvm_mma_nt: +// case Intrinsic::tlvm_mma_tn: +// case Intrinsic::tlvm_mma_tt:{ +// bool AT = IntrinsicID == Intrinsic::tlvm_mma_tn || IntrinsicID == Intrinsic::tlvm_mma_tt; +// bool BT = IntrinsicID == Intrinsic::tlvm_mma_nt || IntrinsicID == Intrinsic::tlvm_mma_tt; +// Value *A = Call->getOperand(0); +// Value *B = Call->getOperand(1); +// Value *D = Call->getOperand(2); +// size_t AOuter = 0, AInner = 1; +// size_t BOuter = 1, BInner = 0; +// if(AT) std::swap(AOuter, AInner); +// if(BT) std::swap(BOuter, BInner); +// addConstraint({Call, 0}, {A, AOuter}, Shape); +// addConstraint({Call, 1}, {B, BOuter}, Shape); +// addConstraint({A, AInner}, {B, BInner}, Shape); +// addConstraint({Call, 0}, {D, 0}, Layout); +// addConstraint({Call, 1}, {D, 1}, Layout); +// break; +// } +// default: +// break; +// } +// } +// // LoadInst: Pointer shares the result layout +// if(LoadInst *Load = dyn_cast(I)){ +// NumDim = Load->getType()->getTileNumDimensions(); +// Value *Ptr = Load->getPointerOperand(); +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Load, K}, {Ptr, K}, Layout); +// } +// // StoreInst: Pointer shares the value layout +// if(StoreInst *Store = dyn_cast(I)){ +// Value *Ptr = Store->getPointerOperand(); +// Value *Val = Store->getValueOperand(); +// NumDim = Val->getType()->getTileNumDimensions(); +// for(unsigned K = 0; K < NumDim; K++) +// addConstraint({Ptr, K}, {Val, K}, Layout); +// } +// // SelectInst: Selected tensor share layout +// if(SelectInst *Select = dyn_cast(I)){ +// NumDim = Select->getType()->getTileNumDimensions(); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({Select->getTrueValue(), K}, {Select, K}, Layout); +// addConstraint({Select->getFalseValue(), K}, {Select, K}, Layout); +// } +// } +// if(isa(I)){ +// NumDim = I->getType()->getTileNumDimensions(); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({I->getOperand(0), K}, {I, K}, Layout); +// } +// } +// // Phi Nodes: all the incoming value share the result layout +// if(PHINode *Phi = dyn_cast(I)){ +// Type *Ty = Phi->getType(); +// NumDim = Ty->getTileNumDimensions(); +// unsigned NumInc = Phi->getNumIncomingValues(); +// for(unsigned PI = 0; PI < NumInc; PI++){ +// Value *Inc = Phi->getIncomingValue(PI); +// for(unsigned K = 0; K < NumDim; K++){ +// CType CT = (LCGraph.find({Inc,K}) != LCGraph.end() || +// LCGraph.find({Phi,K}) != LCGraph.end())?Layout:Shape; +// addConstraint({Phi, K}, {Inc, K}, CT); +// } +// } +// } +// // Binary op: All the arguments share the result layout +// Instruction *BinOp = static_cast(I); +// if(isa(BinOp) || isa(BinOp)){ +// NumDim = BinOp->getType()->getTileNumDimensions(); +// Value *A = BinOp->getOperand(0); +// Value *B = BinOp->getOperand(1); +// for(unsigned K = 0; K < NumDim; K++){ +// addConstraint({BinOp, K}, {A, K}, Layout); +// addConstraint({BinOp, K}, {B, K}, Layout); +// } +// } +//} + +//void TLVMAddTunerParams::connectedComponents(CNodeType X, ArrayRef Vals, CType CT, +// DenseSet &Nodes, CGraphType &Graph){ +// if(Nodes.find(X) != Nodes.end()){ +// Nodes.erase(X); +// std::string Suffix = ".d" + itostr(X.second); +// if(Instruction *Instr = dyn_cast(X.first)){ +// if(CT==Shape){ +// Instr->setMetadata("nvvm.param.shape" + Suffix, Vals[0]); +// } +// if(CT==Layout){ +// Instr->setMetadata("nvvm.param.layout.p0" + Suffix, Vals[0]); +// Instr->setMetadata("nvvm.param.layout.p1" + Suffix, Vals[1]); +// Instr->setMetadata("nvvm.param.layout.p2" + Suffix, Vals[2]); +// } +// } +// if(ConstantInt *Cst = dyn_cast(X.first)){ +// Metadata *CstMD = ConstantAsMetadata::get(Cst); +// if(CT==Shape){ +// Vals[0]->replaceOperandWith(0, CstMD); +// } +// if(CT==Layout){ +// Vals[0]->replaceOperandWith(0, CstMD); +// Vals[1]->replaceOperandWith(0, CstMD); +// Vals[2]->replaceOperandWith(0, CstMD); +// } +// } +// for(CNodeType &E: Graph[X]) +// connectedComponents(E, Vals, CT, Nodes, Graph); +// } +//} + +//// Run +//bool TLVMAddTunerParams::runOnFunction(Function &F) { +// // Build constraints graph +// for(Function::iterator::value_type &BB: F) +// for(BasicBlock::iterator::value_type &I : BB) +// if(isTLVMValue(&I)) +// initCGraph(&I); +// for(Function::iterator::value_type &BB: F) +// for(BasicBlock::iterator::value_type &I : BB) +// if(isTLVMValue(&I)) +// initCPhi(&I); +// // Add parameters +// LLVMContext &Ctx = F.getContext(); +// Metadata *UndefMD = ConstantAsMetadata::get(UndefValue::get(Type::getInt32Ty(Ctx))); +// // Shape parameters +// while(!SCNodes.empty()){ +// MDNode *V0 = MDNode::getTemporary(Ctx, UndefMD).release(); +// connectedComponents(*SCNodes.begin(), {V0}, Shape, SCNodes, SCGraph); +// } +// // Layout parameters +// while(!LCNodes.empty()){ +// MDNode *V0 = MDNode::getTemporary(Ctx, UndefMD).release(); +// MDNode *V1 = MDNode::getTemporary(Ctx, UndefMD).release(); +// MDNode *V2 = MDNode::getTemporary(Ctx, UndefMD).release(); +// connectedComponents(*LCNodes.begin(), {V0, V1, V2}, Layout, LCNodes, LCGraph); +// } +// return true; +//} + +//} +//}