[CODEGEN] Fixes masked load exception (#342)
This commit is contained in:
committed by
GitHub
parent
bfacc191b3
commit
c2e6b90ff1
@@ -27,7 +27,8 @@ private:
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async=false);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
|
@@ -79,19 +79,28 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(connect_ret && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
for(unsigned d = 0; d < rank; d++) {
|
||||
// If we are dealing with a masked async load we need to attach the
|
||||
// dimensions so we match the behaviour of the copy_to_shared instruction
|
||||
// which async masked load replaces.
|
||||
if (is_masked_load_async) {
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()) {
|
||||
if(!is_masked_load_async && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +121,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
|
@@ -549,6 +549,55 @@ def test_arange(start, device='cuda'):
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M = 32
|
||||
N = 32
|
||||
K = 8
|
||||
|
||||
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
||||
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
||||
out = torch.zeros((M, N), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||
in_stride, in2_stride, out_stride,
|
||||
in_numel, in2_numel, out_numel, **meta):
|
||||
M = meta['M']
|
||||
N = meta['N']
|
||||
K = meta['K']
|
||||
|
||||
M_offsets = tl.arange(0, M)
|
||||
N_offsets = tl.arange(0, N)
|
||||
K_offsets = tl.arange(0, K)
|
||||
|
||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
in2.stride()[0],
|
||||
out.stride()[0],
|
||||
in1.numel(),
|
||||
in2.numel(),
|
||||
out.numel(),
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
|
Reference in New Issue
Block a user