From 4181f9f2afba73f2ae4786d2f568c785b9b97ec0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 18 Feb 2020 23:04:46 -0500 Subject: [PATCH] [CODEGEN][TRANSFORM][PEEPHOLE] Fixed bug in *1 multiplication --- lib/codegen/transform/peephole.cc | 14 ++++++++++---- lib/runtime/function.cc | 3 ++- python/triton/ops/einsum.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 98b70e4ab..0720318db 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -104,11 +104,17 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) { ir::value *lhs = binop->get_operand(0); ir::value *rhs = binop->get_operand(1); ir::constant_int *_1_lhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(lhs)) - _1_lhs = dynamic_cast(splat->get_operand(0)); + if(ir::splat_inst *splat = dynamic_cast(lhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_lhs = cst; + } ir::constant_int *_1_rhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(rhs)) - _1_rhs = dynamic_cast(splat->get_operand(0)); + if(ir::splat_inst *splat = dynamic_cast(rhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_rhs = cst; + } if(_1_lhs){ binop->replace_all_uses_with(rhs); return true; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index e77e5772b..25d355f2c 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -236,6 +236,8 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); peephole.run(module); dce.run(module); +// ir::print(module, std::cout); +// exit(EXIT_FAILURE); align.run(module); cts.run(module); axes.run(module); @@ -255,7 +257,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); -// ir::print(module, std::cout); isel.visit(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 39f0c78b5..3f49ae009 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -321,7 +321,7 @@ __global__ void {name}( } """ - #print(src) + # print(src) ret = triton.kernel(src) if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a)