[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)
This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
committed by
Philippe Tillet
parent
e16bee1a27
commit
840140bf26
@@ -23,6 +23,63 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
|
||||
// (x + cst) + y -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// (x + (y + cst)) -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(y))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
|
||||
// default
|
||||
return (*builder_)->CreateAdd(x, y, name);
|
||||
}
|
||||
|
||||
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
|
||||
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1)))
|
||||
if(dyn_cast<Constant>(y)){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
|
||||
(*builder_)->CreateMul(bin->getOperand(1), y));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateMul(x, y, name);
|
||||
}
|
||||
|
||||
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
|
||||
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
|
||||
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
|
||||
return (*builder_)->CreateGEP(gep->getPointerOperand(),
|
||||
(*builder_)->CreateAdd(cst1, cst2));
|
||||
}
|
||||
// ptr + (off + cst) -> (ptr + off) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(off))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateGEP(ptr, off, name);
|
||||
}
|
||||
|
||||
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
|
||||
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||
//}
|
||||
|
||||
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
@@ -34,7 +91,6 @@ using namespace llvm;
|
||||
// constants
|
||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||
// ops
|
||||
#define add(...) builder_->CreateAdd(__VA_ARGS__)
|
||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||
@@ -52,7 +108,6 @@ using namespace llvm;
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
|
||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||
@@ -64,7 +119,6 @@ using namespace llvm;
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||
tgt_(tgt), num_warps_(num_warps) {
|
||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
|
||||
}
|
||||
|
||||
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
* \brief Code Generation for `binary_operator`
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
|
||||
std::vector<Value*> vals;
|
||||
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
|
||||
vals.push_back(vals_[*it][idx]);
|
||||
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
|
||||
vals_[x][idx] = gep(ty, ptr, vals);
|
||||
assert(vals.size() == 1);
|
||||
vals_[x][idx] = gep(ptr, vals[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
// output ptr info
|
||||
Value* out_base = shared[i].first;
|
||||
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0))
|
||||
vals_[x][idx] = idx[0];
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
vals_[x][idx] = bin_add->getOperand(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0))
|
||||
// vals_[x][idx] = idx[0];
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// vals_[x][idx] = bin_add->getOperand(0);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0)){
|
||||
vals_[x][idx] = idx[0];
|
||||
}
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *cst = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(cst));
|
||||
vals_[x][idx] = cst;
|
||||
}
|
||||
};
|
||||
}
|
||||
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0)){
|
||||
// vals_[x][idx] = idx[0];
|
||||
// }
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// Value *cst = bin_add->getOperand(1);
|
||||
// assert(isa<Constant>(cst));
|
||||
// vals_[x][idx] = cst;
|
||||
// }
|
||||
// };
|
||||
//}
|
||||
|
||||
void generator::visit_make_range(ir::make_range* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
|
Reference in New Issue
Block a user