[CODEGEN] Fixes masked load exception (#342)

This commit is contained in:
Stephen McGroarty
2021-10-13 21:31:52 +01:00
committed by GitHub
parent bfacc191b3
commit c2e6b90ff1
3 changed files with 68 additions and 9 deletions

View File

@@ -27,7 +27,8 @@ private:
void update_graph_trans(ir::instruction *i); void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i); void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i); void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true); void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i); void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i); void update_graph(ir::instruction *i);

View File

@@ -79,21 +79,30 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d}); graph_.add_edge({dot, d}, {D, d});
} }
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) { void axes::update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async) {
if(i->get_num_operands() == 0) if(i->get_num_operands() == 0)
return; return;
ir::value *op = i->get_operand(0); ir::value *op = i->get_operand(0);
if(!op->get_type()->is_block_ty()) if(!op->get_type()->is_block_ty())
return; return;
auto rank = op->get_type()->get_tile_rank(); auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++) for(unsigned d = 0; d < rank; d++) {
// If we are dealing with a masked async load we need to attach the
// dimensions so we match the behaviour of the copy_to_shared instruction
// which async masked load replaces.
if (is_masked_load_async) {
graph_.add_edge({i, d}, {i, d});
}
for(ir::value* opx: i->ops()) for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()) { for(ir::value* opy: i->ops()) {
if(connect_ret && !i->get_type()->is_void_ty()) if(!is_masked_load_async && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d}); graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d}); graph_.add_edge({opx, d}, {opy, d});
} }
} }
}
void axes::update_graph_no_edge(ir::instruction *i) { void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_block_ty()) if(!i->get_type()->is_block_ty())
@@ -112,7 +121,7 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i); case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false); case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i); case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
default: return update_graph_elementwise(i); default: return update_graph_elementwise(i);

View File

@@ -549,6 +549,55 @@ def test_arange(start, device='cuda'):
# --------------- # ---------------
# test load # test load
# --------------- # ---------------
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32
N = 32
K = 8
in1 = torch.rand((M, K), dtype=dtype, device=device)
in2 = torch.rand((K, N), dtype=dtype, device=device)
out = torch.zeros((M, N), dtype=dtype, device=device)
@triton.jit
def _kernel(in1_ptr, in2_ptr, output_ptr,
in_stride, in2_stride, out_stride,
in_numel, in2_numel, out_numel, **meta):
M = meta['M']
N = meta['N']
K = meta['K']
M_offsets = tl.arange(0, M)
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
reference_out =torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
# --------------- # ---------------
# test store # test store