From a60374a5979f4a68025a5f1fb17d3d1c79332317 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 3 Jun 2022 11:36:06 -0700 Subject: [PATCH] Revert "[BACKEND] Various bug fixes; making reductions faster (#533)". This is a more stable commit that produce bitwise identical code to earlier versions. Using commits after this one may lead to slightly different numerics --- include/triton/codegen/analysis/layout.h | 2 +- lib/codegen/analysis/align.cc | 4 +- lib/codegen/analysis/layout.cc | 8 +- lib/codegen/pass.cc | 1 - lib/codegen/selection/generator.cc | 170 ++++++------------- lib/codegen/transform/coalesce.cc | 19 --- python/setup.py | 2 +- python/test/unit/language/test_core.py | 23 --- python/triton/language/core.py | 2 - python/triton/language/semantic.py | 2 +- python/tutorials/03-matrix-multiplication.py | 5 +- 11 files changed, 65 insertions(+), 173 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 050ac6956..28dfad18d 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,7 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } - int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} + int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 1c48a4c05..37b609228 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -319,8 +319,8 @@ std::vector align::populate_max_contiguous_binop(ir::binary_operator* } if(x->is_int_add_sub()){ unsigned lvalue = 1, rvalue = 1; - lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); - rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); + lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); + rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); value = std::max(lvalue, rvalue); } result.push_back(value); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 86473dc54..cec512fec 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -209,15 +209,14 @@ mma_layout::mma_layout(size_t num_warps, rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; contig_per_thread_ = {1, 1}; - order_ = {0, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; - order_ = {1, 0}; // rep_ = {2, 2, 1}; } + order_ = {0, 1}; /* warps per tile */ wpt_ = {1, 1, 1}; @@ -617,9 +616,8 @@ void layouts::run(ir::module &mod) { unsigned axis = red->get_axis(); // shape auto shapes = arg->get_type()->get_block_shapes(); - distributed_layout* layout = dynamic_cast(get(arg)); - shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); - + scanline_layout *layout = get(arg)->to_scanline(); + shapes[axis] = layout->mts(axis); // create layout layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 4ba423d20..e2cd6d228 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -88,7 +88,6 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); - // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 5397ceefe..cf51a3b4c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -88,7 +88,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define f16_ty builder_->getHalfTy() #define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() -#define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() @@ -737,9 +736,6 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { * \brief Code Generation for a (synchronous) `load` */ void generator::visit_load_inst(ir::load_inst* x){ - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); - Value *tid = tgt_->get_local_id(module, *builder_, 0); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); @@ -779,9 +775,6 @@ void generator::visit_load_inst(ir::load_inst* x){ in_off = 0; } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); - // if(!op->get_type()->is_block_ty()){ - // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); - // } Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words @@ -885,18 +878,6 @@ void generator::visit_load_inst(ir::load_inst* x){ Value *_ret = call(inlineAsm, args); - // if(!op->get_type()->is_block_ty()){ - // Value* cond = icmp_eq(tid, i32(0)); - // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); - // Instruction* bar = add_barrier(); - // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); - // builder_->SetInsertPoint(term); - // store(_ret, shptr); - // builder_->SetInsertPoint(bar->getParent()); - // _ret = load(shptr); - // add_barrier(); - // } - // --- // extract and store return values // --- @@ -2052,12 +2033,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { - unsigned cols_per_thread = num_rep_n * 2; + unsigned cols_per_thread = num_rep_m * 2; std::vector idx = { - (m + 0)*cols_per_thread + (n*2 + 0), - (m + 0)*cols_per_thread + (n*2 + 1), - (m + 1)*cols_per_thread + (n*2 + 0), - (m + 1)*cols_per_thread + (n*2 + 1) + (m + 0) + (n*2 + 0)*cols_per_thread, + (m + 0) + (n*2 + 1)*cols_per_thread, + (m + 1) + (n*2 + 0)*cols_per_thread, + (m + 1) + (n*2 + 1)*cols_per_thread }; Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], @@ -2335,93 +2316,62 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ // ir::value *arg = x->get_operand(0); - analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); + analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); std::vector shapes = layout->get_shape(); - - Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); - size_t n_bits = sca_ty->getPrimitiveSizeInBits(); - - std::string n_bits_str = std::to_string(n_bits); - std::string cst = (n_bits == 64) ? "l" : "r"; - - FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); - InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true); - FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); - InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true); - - - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value* warp = udiv(thread, i32(32)); - Value* lane = urem(thread, i32(32)); - - unsigned shuffle_width = 0; - unsigned warps_per_inner = 0; - auto arg_vals = vals_.at(arg); - std::vector arg_idxs = idxs_.at(arg); - size_t n_elts = arg_idxs.size(); - unsigned col_per_thread; - Value* warp_i; - Value* warp_j; - if(analysis::scanline_layout* scanline = layout->to_scanline()){ - std::vector order = layout->get_order(); - unsigned mts = scanline->mts(order[0]); - shuffle_width = std::min(mts, 32); - warps_per_inner = std::max(mts/32, 1); - col_per_thread = shapes[order[0]] / mts; - warp_i = udiv(warp, i32(warps_per_inner)); - warp_j = urem(warp, i32(warps_per_inner)); - } - else if(layout->to_mma()){ - shuffle_width = 4; - warps_per_inner = layout->to_mma()->wpt(1); - col_per_thread = 16; - warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; - warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; - } - - // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); + std::vector order = layout->get_order(); + unsigned mts = layout->mts(order[0]); + unsigned nts = layout->nts(order[0]); + unsigned col_per_thread = shapes[order[0]] / mts; + auto idxs = idxs_.at(arg); + size_t n_elts = idxs.size(); // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); - // preds - Value* is_lane0 = icmp_eq(lane, i32(0)); - Value* is_warp0 = icmp_eq(warp, i32(0)); - Value* is_thread0 = icmp_eq(thread, i32(0)); - Value* lane_j = urem(lane, i32(shuffle_width)); - Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); - add_barrier(); - // compute partial sum for each warp, and store to shared memory + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value* warp = udiv(thread, i32(32)); + Value* lane = urem(thread, i32(32)); + size_t warps_per_inner = std::max(mts/32, 1); + Value* warp_i = udiv(warp, i32(warps_per_inner)); + unsigned row_per_thread = std::max(32/mts, 1); + for(size_t i = 0; i < n_elts/col_per_thread; i++){ Value* acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; - // acc = (j == 0) ? val : do_acc(acc, val); + Value* val = vals_[arg][idxs[i*col_per_thread + j]]; acc = (j == 0) ? val : do_acc(acc, val); } // reduce within warp - for(int k = shuffle_width/2 ; k > 0; k >>= 1) + for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) acc = do_acc(acc, shfl_sync(acc, k)); - // store partial result to shared memory - auto x_idxs = idxs_[x][i]; - Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); + // store warp result in shared memory + Value* ret = acc; + if(mts >= 32){ + add_barrier(); + store(neutral, gep(base, lane)); + add_barrier(); + store(acc, gep(base, warp)); + add_barrier(); + // reduce across warps + Value *cond = icmp_eq(warp, i32(0)); + Instruction *barrier = add_barrier(); + builder_->SetInsertPoint(barrier->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); + Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + dummy->removeFromParent(); + builder_->SetInsertPoint(term); + ret = load(gep(base, thread)); + for(int k = (mts/32)/2; k > 0; k >>= 1){ + Value *current = shfl_sync(ret, k); + ret = do_acc(ret, current); + } + store(ret, gep(base, thread)); + builder_->SetInsertPoint(barrier->getParent()); + ret = load(gep(base, warp)); + } + vals_[x][idxs_[x][i]] = ret; } - add_barrier(); - // at this point, partial accumulator synchronized in shared memory - // Just need to reduce `warp_per_inner` numbers in shared memory - for(size_t i = 0; i < n_elts/col_per_thread; i++){ - auto x_idxs = idxs_[x][i]; - Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); - Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); - for(int k = warps_per_inner/2; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); - vals_[x][idxs_[x][i]] = acc; - } - // add_barrier(); } void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { @@ -2521,12 +2471,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - int cc = tgt_->as_nvidia()->sm(); analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); - bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); - bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); - if(is_coalesced_scanline || is_a100_mma) + if(scanline && scanline->get_order()[0] == x->get_axis()) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2719,12 +2665,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); - analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(arg)); + analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - in_vec = in_layout->contig_per_thread(in_order[0]); + in_vec = in_layout->nts(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); @@ -2732,11 +2678,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // - int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); - int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); - - int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; - int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); + int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); + int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); int n_shared_0 = std::max(in_vec / out_vec, 1); BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -2757,8 +2700,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; - int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; + int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); + int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ @@ -3083,7 +3026,8 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { else{ /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); + Value *warp_12 = udiv(warp, i32(layout->wpt(0))); + Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); Value *off_lane_m = urem(lane, _16); @@ -3208,9 +3152,7 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - // i->print(std::cout); visit_value(i); - // std::cout << "done" << std::endl; } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index d969139f1..ae8ce034d 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -52,7 +52,6 @@ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) //} void coalesce::run(ir::module &mod) { - std::set invalidated; ir::builder& builder = mod.get_builder(); // add layout conversion instructions for(ir::function *fn: mod.get_function_list()) @@ -62,29 +61,12 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) - if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } - // coalesce before copy_to_shared - // It's dirty, but the backend is being rewritten from scratch. :) - if(dynamic_cast(i)) - if(ir::value* op = i->get_operand(0)) - if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) - if(invalidated.find(layout_->get(op)) == invalidated.end()) - if(layout_->get(op)->to_mma()){ - ir::instruction* new_op = ir::cvt_layout_inst::create(op); - builder.set_insert_point(i); - builder.insert(new_op); - op->replace_all_uses_with(new_op); - new_op->replace_uses_of_with(new_op, op); - invalidated.insert(layout_->get(op)); - } // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) @@ -138,7 +120,6 @@ void coalesce::run(ir::module &mod) { } if(in_contig.size() <= 1 || out_contig==in_contig) continue; - std::cout << "3!!" << std::endl; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); x->replace_uses_of_with(val_inst, new_val); diff --git a/python/setup.py b/python/setup.py index 6a04a4e42..9179baa5b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - self.debug = True + # self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 71df6d73b..50bfb9d1c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -698,7 +698,6 @@ def test_reduce1d(dtype_str, shape, device='cuda'): rs = RandomState(17) x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - x[:] = 1 # numpy result z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result @@ -1133,25 +1132,3 @@ def test_constexpr_shape(): x_tri = to_triton(np.empty((256, ), dtype=np.int32)) kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) - -# ------------- -# test if -# ------------- - - -def test_if(): - - @triton.jit - def kernel(Cond, XTrue, XFalse, Ret): - pid = tl.program_id(0) - cond = tl.load(Cond) - if pid % 2: - tl.store(Ret, tl.load(XTrue)) - else: - tl.store(Ret, tl.load(XFalse)) - - cond = torch.ones(1, dtype=torch.int32, device='cuda') - x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') - x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') - ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f0cc02e66..f81645a36 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -32,8 +32,6 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x - elif x is None: - return None assert False, f'cannot convert {x} to tensor' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e1c8e6028..2af25cbb2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -559,7 +559,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type - if src_ty.is_block() and not dst_ty.is_block(): + if src_ty.is_block(): dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: return input diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 912833c52..f773a3787 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -252,7 +252,6 @@ def matmul_kernel( # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @triton.jit def leaky_relu(x): - x = x + 1 return tl.where(x >= 0, x, 0.01 * x) @@ -297,7 +296,7 @@ def matmul(a, b, activation=None): torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b, activation=leaky_relu) +triton_output = matmul(a, b, activation=None) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") @@ -306,8 +305,6 @@ if triton.testing.allclose(triton_output, torch_output): else: print("❌ Triton and Torch differ") -print(matmul_kernel.cache_key) -exit() # %% # Benchmark # --------------