[CODEGEN][TRANSFORM][PEEPHOLE] Fixed bug in *1 multiplication
This commit is contained in:
		
				
					committed by
					
						
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							9e54a03006
						
					
				
				
					commit
					4181f9f2af
				
			@@ -104,11 +104,17 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
 | 
				
			|||||||
      ir::value *lhs = binop->get_operand(0);
 | 
					      ir::value *lhs = binop->get_operand(0);
 | 
				
			||||||
      ir::value *rhs = binop->get_operand(1);
 | 
					      ir::value *rhs = binop->get_operand(1);
 | 
				
			||||||
      ir::constant_int *_1_lhs = nullptr;
 | 
					      ir::constant_int *_1_lhs = nullptr;
 | 
				
			||||||
      if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
 | 
					      if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
 | 
				
			||||||
        _1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
 | 
					        auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
 | 
				
			||||||
 | 
					        if(cst && cst->get_value() == 1)
 | 
				
			||||||
 | 
					          _1_lhs = cst;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      ir::constant_int *_1_rhs = nullptr;
 | 
					      ir::constant_int *_1_rhs = nullptr;
 | 
				
			||||||
      if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
 | 
					      if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
 | 
				
			||||||
        _1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
 | 
					        auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
 | 
				
			||||||
 | 
					        if(cst && cst->get_value() == 1)
 | 
				
			||||||
 | 
					          _1_rhs = cst;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      if(_1_lhs){
 | 
					      if(_1_lhs){
 | 
				
			||||||
        binop->replace_all_uses_with(rhs);
 | 
					        binop->replace_all_uses_with(rhs);
 | 
				
			||||||
        return true;
 | 
					        return true;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -236,6 +236,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
 | 
				
			|||||||
  dce.run(module);
 | 
					  dce.run(module);
 | 
				
			||||||
  peephole.run(module);
 | 
					  peephole.run(module);
 | 
				
			||||||
  dce.run(module);
 | 
					  dce.run(module);
 | 
				
			||||||
 | 
					//  ir::print(module, std::cout);
 | 
				
			||||||
 | 
					//  exit(EXIT_FAILURE);
 | 
				
			||||||
  align.run(module);
 | 
					  align.run(module);
 | 
				
			||||||
  cts.run(module);
 | 
					  cts.run(module);
 | 
				
			||||||
  axes.run(module);
 | 
					  axes.run(module);
 | 
				
			||||||
@@ -255,7 +257,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
 | 
				
			|||||||
  if(allocation.allocated_size() > context->device()->max_shared_memory())
 | 
					  if(allocation.allocated_size() > context->device()->max_shared_memory())
 | 
				
			||||||
    return std::unique_ptr<driver::module>();
 | 
					    return std::unique_ptr<driver::module>();
 | 
				
			||||||
  barriers.run(module);
 | 
					  barriers.run(module);
 | 
				
			||||||
//  ir::print(module, std::cout);
 | 
					 | 
				
			||||||
  isel.visit(module, *llvm);
 | 
					  isel.visit(module, *llvm);
 | 
				
			||||||
  // return binary
 | 
					  // return binary
 | 
				
			||||||
  std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
 | 
					  std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -321,7 +321,7 @@ __global__ void {name}(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #print(src)
 | 
					        # print(src)
 | 
				
			||||||
        ret = triton.kernel(src)
 | 
					        ret = triton.kernel(src)
 | 
				
			||||||
        if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
 | 
					        if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
 | 
				
			||||||
            ret.set_constant('AD', delta_a)
 | 
					            ret.set_constant('AD', delta_a)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user