[CODEGEN] Fixes masked load exception (#342)
This commit is contained in:
committed by
GitHub
parent
bfacc191b3
commit
c2e6b90ff1
@@ -27,7 +27,8 @@ private:
|
|||||||
void update_graph_trans(ir::instruction *i);
|
void update_graph_trans(ir::instruction *i);
|
||||||
void update_graph_broadcast(ir::instruction *i);
|
void update_graph_broadcast(ir::instruction *i);
|
||||||
void update_graph_dot(ir::instruction *i);
|
void update_graph_dot(ir::instruction *i);
|
||||||
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
void update_graph_elementwise(ir::instruction *i,
|
||||||
|
bool is_masked_load_async=false);
|
||||||
void update_graph_no_edge(ir::instruction *i);
|
void update_graph_no_edge(ir::instruction *i);
|
||||||
void update_graph(ir::instruction *i);
|
void update_graph(ir::instruction *i);
|
||||||
|
|
||||||
|
@@ -79,21 +79,30 @@ void axes::update_graph_dot(ir::instruction *i) {
|
|||||||
graph_.add_edge({dot, d}, {D, d});
|
graph_.add_edge({dot, d}, {D, d});
|
||||||
}
|
}
|
||||||
|
|
||||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
void axes::update_graph_elementwise(ir::instruction *i,
|
||||||
|
bool is_masked_load_async) {
|
||||||
if(i->get_num_operands() == 0)
|
if(i->get_num_operands() == 0)
|
||||||
return;
|
return;
|
||||||
ir::value *op = i->get_operand(0);
|
ir::value *op = i->get_operand(0);
|
||||||
if(!op->get_type()->is_block_ty())
|
if(!op->get_type()->is_block_ty())
|
||||||
return;
|
return;
|
||||||
auto rank = op->get_type()->get_tile_rank();
|
auto rank = op->get_type()->get_tile_rank();
|
||||||
for(unsigned d = 0; d < rank; d++)
|
for(unsigned d = 0; d < rank; d++) {
|
||||||
|
// If we are dealing with a masked async load we need to attach the
|
||||||
|
// dimensions so we match the behaviour of the copy_to_shared instruction
|
||||||
|
// which async masked load replaces.
|
||||||
|
if (is_masked_load_async) {
|
||||||
|
graph_.add_edge({i, d}, {i, d});
|
||||||
|
}
|
||||||
|
|
||||||
for(ir::value* opx: i->ops())
|
for(ir::value* opx: i->ops())
|
||||||
for(ir::value* opy: i->ops()) {
|
for(ir::value* opy: i->ops()) {
|
||||||
if(connect_ret && !i->get_type()->is_void_ty())
|
if(!is_masked_load_async && !i->get_type()->is_void_ty())
|
||||||
graph_.add_edge({i, d}, {opx, d});
|
graph_.add_edge({i, d}, {opx, d});
|
||||||
graph_.add_edge({opx, d}, {opy, d});
|
graph_.add_edge({opx, d}, {opy, d});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||||
if(!i->get_type()->is_block_ty())
|
if(!i->get_type()->is_block_ty())
|
||||||
@@ -112,7 +121,7 @@ void axes::update_graph(ir::instruction *i) {
|
|||||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||||
case ir::INST_DOT: return update_graph_dot(i);
|
case ir::INST_DOT: return update_graph_dot(i);
|
||||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false);
|
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
|
||||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||||
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
||||||
default: return update_graph_elementwise(i);
|
default: return update_graph_elementwise(i);
|
||||||
|
@@ -549,6 +549,55 @@ def test_arange(start, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test load
|
# test load
|
||||||
# ---------------
|
# ---------------
|
||||||
|
# 'bfloat16': torch.bfloat16,
|
||||||
|
# Testing masked loads with an intermate copy to shared memory run.
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||||
|
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||||
|
M = 32
|
||||||
|
N = 32
|
||||||
|
K = 8
|
||||||
|
|
||||||
|
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
||||||
|
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
||||||
|
out = torch.zeros((M, N), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||||
|
in_stride, in2_stride, out_stride,
|
||||||
|
in_numel, in2_numel, out_numel, **meta):
|
||||||
|
M = meta['M']
|
||||||
|
N = meta['N']
|
||||||
|
K = meta['K']
|
||||||
|
|
||||||
|
M_offsets = tl.arange(0, M)
|
||||||
|
N_offsets = tl.arange(0, N)
|
||||||
|
K_offsets = tl.arange(0, K)
|
||||||
|
|
||||||
|
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
||||||
|
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
||||||
|
|
||||||
|
# Load inputs.
|
||||||
|
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||||
|
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||||
|
|
||||||
|
# Without a dot product the memory doesn't get promoted to shared.
|
||||||
|
o = tl.dot(x, w)
|
||||||
|
|
||||||
|
# Store output
|
||||||
|
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
||||||
|
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||||
|
|
||||||
|
pgm = _kernel[(1,)](in1, in2, out,
|
||||||
|
in1.stride()[0],
|
||||||
|
in2.stride()[0],
|
||||||
|
out.stride()[0],
|
||||||
|
in1.numel(),
|
||||||
|
in2.numel(),
|
||||||
|
out.numel(),
|
||||||
|
M=M, N=N, K=K)
|
||||||
|
|
||||||
|
reference_out =torch.matmul(in1, in2)
|
||||||
|
triton.testing.allclose(out, reference_out)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test store
|
# test store
|
||||||
|
Reference in New Issue
Block a user