[CODEGEN][TRANSFORM][PEEPHOLE] Fixed bug in *1 multiplication
This commit is contained in:
committed by
Philippe Tillet
parent
9e54a03006
commit
4181f9f2af
@@ -104,11 +104,17 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
|
||||
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
|
||||
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
|
@@ -236,6 +236,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
@@ -255,7 +257,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -321,7 +321,7 @@ __global__ void {name}(
|
||||
}
|
||||
"""
|
||||
|
||||
#print(src)
|
||||
# print(src)
|
||||
ret = triton.kernel(src)
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
|
Reference in New Issue
Block a user