[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)

This massively simplifies implementation of `reassociate` and also fixes
a bunch of bug. The pass could still be improved, but can already be used
to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
Philippe Tillet
2021-05-07 17:54:37 -04:00
committed by Philippe Tillet
parent e16bee1a27
commit 840140bf26
12 changed files with 204 additions and 667 deletions

View File

@@ -23,6 +23,63 @@ namespace codegen{
using namespace llvm;
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
// (x + cst) + y -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
bin->getOperand(1));
}
// (x + (y + cst)) -> (x + y) + cst
if(auto* bin = dyn_cast<BinaryOperator>(y))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1))){
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
bin->getOperand(1));
}
// default
return (*builder_)->CreateAdd(x, y, name);
}
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
if(auto* bin = dyn_cast<BinaryOperator>(x))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(dyn_cast<Constant>(bin->getOperand(1)))
if(dyn_cast<Constant>(y)){
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
(*builder_)->CreateMul(bin->getOperand(1), y));
}
// default
return (*builder_)->CreateMul(x, y, name);
}
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
return (*builder_)->CreateGEP(gep->getPointerOperand(),
(*builder_)->CreateAdd(cst1, cst2));
}
// ptr + (off + cst) -> (ptr + off) + cst
if(auto* bin = dyn_cast<BinaryOperator>(off))
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
bin->getOperand(1));
}
// default
return (*builder_)->CreateGEP(ptr, off, name);
}
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
//}
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
@@ -34,7 +91,6 @@ using namespace llvm;
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define add(...) builder_->CreateAdd(__VA_ARGS__)
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
@@ -52,7 +108,6 @@ using namespace llvm;
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
@@ -64,7 +119,6 @@ using namespace llvm;
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define mul(...) builder_->CreateMul(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
target *tgt,
unsigned num_warps)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps) {
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
}
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
* \brief Code Generation for `binary_operator`
*/
void generator::visit_binary_operator(ir::binary_operator*x) {
using ll = llvm::Instruction::BinaryOps;
auto cvt = [](ir::binary_op_t op){
using ll = llvm::Instruction::BinaryOps;
using tt = ir::binary_op_t;
switch(op) {
case tt::Add: return ll::Add;
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
}
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
std::vector<Value*> vals;
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
vals.push_back(vals_[*it][idx]);
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
vals_[x][idx] = gep(ty, ptr, vals);
assert(vals.size() == 1);
vals_[x][idx] = gep(ptr, vals[0]);
}
}
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
std::vector<llvm::Type*> tys = {f32_ty};
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
call(iasm);
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
for(indices_t idx: idxs_.at(x)){
assert(idx.size() == 1);
if(idx[0] == i32(0))
vals_[x][idx] = idx[0];
else{
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
vals_[x][idx] = bin_add->getOperand(0);
}
}
}
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0))
// vals_[x][idx] = idx[0];
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// vals_[x][idx] = bin_add->getOperand(0);
// }
// }
//}
void generator::visit_make_range_sta(ir::make_range_sta* x) {
for(indices_t idx: idxs_.at(x)){
assert(idx.size() == 1);
if(idx[0] == i32(0)){
vals_[x][idx] = idx[0];
}
else{
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *cst = bin_add->getOperand(1);
assert(isa<Constant>(cst));
vals_[x][idx] = cst;
}
};
}
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
// if(idx[0] == i32(0)){
// vals_[x][idx] = idx[0];
// }
// else{
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
// assert(bin_add);
// Value *cst = bin_add->getOperand(1);
// assert(isa<Constant>(cst));
// vals_[x][idx] = cst;
// }
// };
//}
void generator::visit_make_range(ir::make_range* x) {
for(indices_t idx: idxs_.at(x)){