[LANG] Minor semantic changes (#121)

* Now using unordered instead of ordered float (fixes NaN issues)
* Bool -> int32 now converts to 1 rather than -1
* Reduce extend arguments to 32-bits if possible
This commit is contained in:
Philippe Tillet
2021-06-01 21:13:21 -04:00
committed by Philippe Tillet
parent 0274429429
commit 80c86ecf4a
5 changed files with 31 additions and 15 deletions

View File

@@ -54,7 +54,6 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
dce.run(ir);
//ir::print(ir, std::cout);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);

View File

@@ -1554,7 +1554,7 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
std::map<indices_t, Value*> partial;
ir::value *arg = x->get_operand(0);
Type *ty = cvt(x->get_type()->get_scalar_ty());
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
Value *acc = nullptr;
// reduce within thread
@@ -1563,13 +1563,13 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
acc = !acc ? val : do_acc(acc, val);
}
// reduce within wrap
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false),
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
for(int i = 16; i > 0; i >>= 1)
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
// pointers
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ty, addr_space));
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
Value* lane = urem(thread, i32(32));
@@ -1688,10 +1688,10 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
// neutral element
Value *neutral;
switch(op) {
case ir::reduce_inst::ADD: neutral = i32(0); break;
case ir::reduce_inst::SUB: neutral = i32(0); break;
case ir::reduce_inst::MAX: neutral = i32(INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = i32(INT32_MAX); break;
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;