[CODEGEN][TRANSFORM][PEEPHOLE] Fixed bug in *1 multiplication

This commit is contained in:
Philippe Tillet
2020-02-18 23:04:46 -05:00
committed by Philippe Tillet
parent 9e54a03006
commit 4181f9f2af
3 changed files with 13 additions and 6 deletions

View File

@@ -104,11 +104,17 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;

View File

@@ -236,6 +236,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
peephole.run(module);
dce.run(module);
// ir::print(module, std::cout);
// exit(EXIT_FAILURE);
align.run(module);
cts.run(module);
axes.run(module);
@@ -255,7 +257,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -321,7 +321,7 @@ __global__ void {name}(
}
"""
#print(src)
# print(src)
ret = triton.kernel(src)
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)