From 5b9afaa6887235f7d7c79297545466210cf15c38 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 7 Mar 2021 14:53:48 -0500 Subject: [PATCH] [CODEGEN] Fixed bug that caused conditional operator to not always properly mask load operations Also includes minor improvement to benchmarking infrastructure --- include/triton/codegen/transform/peephole.h | 1 + include/triton/ir/instructions.h | 3 + lib/codegen/selection/generator.cc | 22 +++--- lib/codegen/transform/peephole.cc | 17 +++++ lib/lang/code_gen.cc | 55 ++++++++++++--- lib/runtime/function.cc | 1 + python/bench/bench_matmul.py | 34 ++++++---- python/triton/ops/cross_entropy.c | 2 +- python/triton/testing.py | 75 +++++++++++++-------- 9 files changed, 146 insertions(+), 64 deletions(-) diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index d8a21e6cc..c14c74702 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -32,6 +32,7 @@ private: bool rewrite_mult(ir::instruction *value, ir::builder& builder); bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); + bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); private: diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 4718e7d9f..6971a751b 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -748,6 +748,9 @@ private: public: static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr); + value* get_pred_op() { return get_operand(0); } + value* get_if_value_op() { return get_operand(1); } + value* get_else_value_op() { return get_operand(2); } }; //===----------------------------------------------------------------------===// diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a51ffc645..8dd6864ed 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1616,20 +1616,22 @@ void generator::visit_make_range(ir::make_range* x) { } } - - -void generator::visit_undef_value(ir::undef_value *ud) { - vals_[ud][{}] = llvm::UndefValue::get(cvt(ud->get_type())); +void generator::visit_undef_value(ir::undef_value *x) { + Type* ty = cvt(x->get_type()->get_scalar_ty()); + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = llvm::UndefValue::get(ty); } -void generator::visit_constant_int(ir::constant_int *cst){ - Type *ty = cvt(cst->get_type()->get_scalar_ty()); - vals_[cst][{}] = ConstantInt::get(ty, cst->get_value()); +void generator::visit_constant_int(ir::constant_int *x){ + Type *ty = cvt(x->get_type()->get_scalar_ty()); + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = ConstantInt::get(ty, x->get_value()); } -void generator::visit_constant_fp(ir::constant_fp *cst){ - Type *ty = cvt(cst->get_type()->get_scalar_ty()); - vals_[cst][{}] = ConstantFP::get(ty, cst->get_value()); +void generator::visit_constant_fp(ir::constant_fp *x){ + Type *ty = cvt(x->get_type()->get_scalar_ty()); + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = ConstantFP::get(ty, x->get_value()); } void generator::visit_alloc_const(ir::alloc_const *alloc) { diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 2855674a9..392f6ea94 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -193,6 +193,22 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil return false; } +bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){ + auto select = dynamic_cast(value); + if(!select) + return false; + auto if_value = dynamic_cast(select->get_if_value_op()); + if(!if_value) + return false; + if(select->get_pred_op() != if_value->get_mask_operand()) + return false; + builder.set_insert_point(select); + ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), + if_value->get_mask_operand(), + select->get_else_value_op()); + select->replace_all_uses_with(new_load); + return true; +} void peephole::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); @@ -230,6 +246,7 @@ void peephole::run(ir::module &mod) { // was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); + was_modified = was_modified || rewrite_select_masked_load(i, builder); if(tgt_->as_nvidia()->sm() >= 80) was_modified = was_modified || rewrite_load_to_shared(i, builder); if(was_modified) diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 0a1a15eab..77568014f 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -267,23 +267,56 @@ void Generator::VisitTransOp(TransOp *trans) { } void Generator::VisitConditionalOp(ConditionalOp* condOp) { -// auto &instructions = bld_->get_insert_block()->get_inst_list(); + auto &instructions = bld_->get_insert_block()->get_inst_list(); VisitExpr(condOp->cond_); - ir::value* cond = ret_; + ir::value* true_cond = ret_; + ir::instruction* start = instructions.back(); VisitExpr(condOp->exprTrue_); ir::value* true_val = ret_; VisitExpr(condOp->exprFalse_); ir::value* false_val = ret_; - if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { - if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) - false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); - ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val); - ld->replace_all_uses_with(new_ld); - ld->erase_from_parent(); - return set_ret(new_ld); + auto begin = std::find(instructions.begin(), instructions.end(), start); + bool is_in_true_cond = true; + for(auto it = begin; it != instructions.end(); it++){ + ir::instruction* instr = *it; + // we mask load with `cond` when used to compute true_value + // we mask load with `!cond` when used to compute false_value + if(auto ld = dynamic_cast(instr)){ + bld_->set_insert_point(ld); + ir::type* ty = ld->get_type(); + ir::value* cond = is_in_true_cond ? true_cond : true_cond; + ir::value* ptr = ld->get_pointer_operand(); + ir::value* else_val = ir::undef_value::get(ty); + ir::value* masked_ld = bld_->create_masked_load(ptr, cond, else_val); + ld->replace_all_uses_with(masked_ld); + ld->erase_from_parent(); + if(true_val == ld) + true_val = masked_ld; + if(false_val == ld) + false_val = masked_ld; + it = std::find(instructions.begin(), instructions.end(), masked_ld); + } + if(instr == true_val) + is_in_true_cond = false; } - return set_ret(bld_->create_select(cond, true_val, false_val)); -// return error_not_implemented(); + bld_->set_insert_point(bld_->get_insert_block()); + return set_ret(bld_->create_select(true_cond, true_val, false_val)); + +// VisitExpr(condOp->cond_); +// ir::value* cond = ret_; +// VisitExpr(condOp->exprTrue_); +// ir::value* true_val = ret_; +// VisitExpr(condOp->exprFalse_); +// ir::value* false_val = ret_; +// if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { +// if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) +// false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); +// ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val); +// ld->replace_all_uses_with(new_ld); +// ld->erase_from_parent(); +// return set_ret(new_ld); +// } +// return set_ret(bld_->create_select(cond, true_val, false_val)); } void Generator::VisitFuncCall(FuncCall* funcCall) { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 69799e557..962832fa6 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -156,6 +156,7 @@ std::tuple, layouts.run(ir); peephole.run(ir); dce.run(ir); +// ir::print(ir, std::cout); if(target->is_gpu()) cts.run(ir); align.run(ir); diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index b79030c40..d83c17c06 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -2,33 +2,39 @@ import triton import torch import os + def rounded_linspace(low, high, steps, div): ret = torch.linspace(low, high, steps) ret = (ret.int() + div - 1) // div * div ret = torch.unique(ret) return list(map(int, ret)) + # Square benchmarks nt = {False: "n", True: "t"} square_confs = [ triton.testing.Benchmark( x_names=["M", "N", "K"], - x_vals=rounded_linspace(512, 8192, 17, 128), + x_vals=rounded_linspace(512, 8192, 32, 128), y_name="provider", y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], ylabel="TFLOPS", loglog=False, plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", - args={"AT": AT, "BT": BT, "dtype": torch.float16}, - ) for AT in [False, True] for BT in [False, True] + args={ + "AT": AT, + "BT": BT, + "dtype": torch.float16 + }, + ) for AT in [False] for BT in [False] ] # Transformer training benchmarks transformer_confs = [ triton.testing.Benchmark( x_names=[x], - x_vals = rounded_linspace(NK//16, NK, 33, 128), + x_vals = rounded_linspace(NK//16, NK, 32, 128), y_name="provider", y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], @@ -41,21 +47,21 @@ transformer_confs = [ for M in [2048] ] -@triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40): + +@triton.testing.perf_report(transformer_confs) +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) if AT: a = a.t() if BT: b = b.t() num_flops = 2 * M * N * K + tflops = lambda ms: 2. * M * N * K / ms * 1e-9 if provider == "torch": - torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) - torch_tflops = num_flops / torch_ms * 1e-9 - return torch_tflops + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) + return tflops(ms), tflops(max_ms), tflops(min_ms) if provider == "triton": - triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) - triton_tflops = num_flops / triton_ms * 1e-9 - return triton_tflops + ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) + return tflops(ms), tflops(max_ms), tflops(min_ms) if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ: import subprocess import tempfile @@ -87,6 +93,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40): subprocess.run(cmd, stdout=subprocess.PIPE) # read CSV output df_c = pd.read_csv(f"{fname}.gemm.csv") - cutlass_tflops = max(df_c["GFLOPs"]) / 1e3 - return cutlass_tflops + tflops = max(df_c["GFLOPs"]) / 1e3 + return tflops return None diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c index 2de793448..b906c8a05 100644 --- a/python/triton/ops/cross_entropy.c +++ b/python/triton/ops/cross_entropy.c @@ -26,7 +26,7 @@ __global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, TYPE local_dn = *(dneg_logprobs + row); // We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] // and we have -log(p[k]) stored, so this is easy - TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0; + TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0; // selected_logit_idx is selected logit index for our token bool find_one[TILE] = ((0 ... TILE) == local_ind); intermediate = intermediate - ((TYPE[TILE])find_one); diff --git a/python/triton/testing.py b/python/triton/testing.py index c31ebe4dd..22c97cc7f 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -26,35 +26,34 @@ def allclose(x, y): return err < tol -def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - # warmup to put the clock in a stable regime - ret = fn() - for i in range(warmup): - fn() - torch.cuda.synchronize() - total_ms = 0 +def do_bench(fn, warmup=10, rep=50, grad_to_none=None): # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 # doesn't contain any input data before the run + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') - for i in range(rep): + for i in range(warmup + rep): # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the # provided gradients if grad_to_none is not None: grad_to_none.grad = None - # reset L2 + # we clear the L2 cache before each run cache.zero_() # record time of `fn` - start_event.record() + if i >= warmup: + start_event[i - warmup].record() fn() - end_event.record() - torch.cuda.synchronize() - total_ms += start_event.elapsed_time(end_event) - # return the average runtime of `fn` - return total_ms / rep + if i >= warmup: + end_event[i - warmup].record() + torch.cuda.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) + q = torch.quantile(times, torch.tensor([0.1, 0.5, 0.9])) + min_ms = q[0].item() + mean_ms = q[1].item() + max_ms = q[2].item() + return mean_ms, min_ms, max_ms class Benchmark: @@ -79,22 +78,42 @@ class Mark: import matplotlib.pyplot as plt import pandas as pd import os - - df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines) + y_mean = bench.y_lines + y_min = [f'{x}-min' for x in bench.y_lines] + y_max = [f'{x}-max' for x in bench.y_lines] + df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) for x in bench.x_vals: x_args = {x_name: x for x_name in bench.x_names} - row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals] - df.loc[len(df)] = [x] + row + row_mean, row_min, row_max = [], [], [] + for y in bench.y_vals: + ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = [x] + row_mean + row_min + row_max if with_plot and bench.plot_name: + plt.figure() + ax = plt.subplot() xlabel = " = ".join(bench.x_names) - plot = df.plot(x=bench.x_names[0], y=bench.y_lines) - plot.set_xlabel(xlabel) - plot.set_ylabel(bench.ylabel) - plot.set_title(bench.plot_name) - plot.set_xscale("log" if bench.loglog else "linear") - plot.set_yscale("log" if bench.loglog else "linear") + x = bench.x_names[0] + for y in bench.y_lines: + y_min, y_max = df[y + '-min'], df[y + '-max'] + ax.plot(df[x], df[y], label=y) + if y_min is not None and y_max is not None: + ax.fill_between(df[x], y_min, y_max, alpha=0.5) + ax.legend() + ax.set_xlabel(xlabel) + ax.set_ylabel(bench.ylabel) + ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.loglog else "linear") + ax.set_yscale("log" if bench.loglog else "linear") plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png")) - df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv")) + df = df[[bench.x_names[0]] + bench.y_lines] + df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) def run(self, result_path, with_plot): with open(os.path.join(result_path, "results.html"), "w") as html: