diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h index 1806ff725..759ed0f8f 100644 --- a/include/triton/codegen/analysis/axes.h +++ b/include/triton/codegen/analysis/axes.h @@ -27,7 +27,8 @@ private: void update_graph_trans(ir::instruction *i); void update_graph_broadcast(ir::instruction *i); void update_graph_dot(ir::instruction *i); - void update_graph_elementwise(ir::instruction *i, bool connect_ret=true); + void update_graph_elementwise(ir::instruction *i, + bool is_masked_load_async=false); void update_graph_no_edge(ir::instruction *i); void update_graph(ir::instruction *i); diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 13b8f8d05..37b95eaa3 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -79,19 +79,28 @@ void axes::update_graph_dot(ir::instruction *i) { graph_.add_edge({dot, d}, {D, d}); } -void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) { +void axes::update_graph_elementwise(ir::instruction *i, + bool is_masked_load_async) { if(i->get_num_operands() == 0) return; ir::value *op = i->get_operand(0); if(!op->get_type()->is_block_ty()) return; auto rank = op->get_type()->get_tile_rank(); - for(unsigned d = 0; d < rank; d++) - for(ir::value* opx: i->ops()) - for(ir::value* opy: i->ops()){ - if(connect_ret && !i->get_type()->is_void_ty()) - graph_.add_edge({i, d}, {opx, d}); - graph_.add_edge({opx, d}, {opy, d}); + for(unsigned d = 0; d < rank; d++) { + // If we are dealing with a masked async load we need to attach the + // dimensions so we match the behaviour of the copy_to_shared instruction + // which async masked load replaces. + if (is_masked_load_async) { + graph_.add_edge({i, d}, {i, d}); + } + + for(ir::value* opx: i->ops()) + for(ir::value* opy: i->ops()) { + if(!is_masked_load_async && !i->get_type()->is_void_ty()) + graph_.add_edge({i, d}, {opx, d}); + graph_.add_edge({opx, d}, {opy, d}); + } } } @@ -112,7 +121,7 @@ void axes::update_graph(ir::instruction *i) { case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_DOT: return update_graph_dot(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); - case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false); + case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true); case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i); default: return update_graph_elementwise(i); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5c6072b3e..024807392 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -549,6 +549,55 @@ def test_arange(start, device='cuda'): # --------------- # test load # --------------- +# 'bfloat16': torch.bfloat16, +# Testing masked loads with an intermate copy to shared memory run. +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device='cuda'): + M = 32 + N = 32 + K = 8 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, + in_stride, in2_stride, out_stride, + in_numel, in2_numel, out_numel, **meta): + M = meta['M'] + N = meta['N'] + K = meta['K'] + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) + + pgm = _kernel[(1,)](in1, in2, out, + in1.stride()[0], + in2.stride()[0], + out.stride()[0], + in1.numel(), + in2.numel(), + out.numel(), + M=M, N=N, K=K) + + reference_out =torch.matmul(in1, in2) + triton.testing.allclose(out, reference_out) # --------------- # test store