[CODEGEN] Fixes masked load exception (#342)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							bfacc191b3
						
					
				
				
					commit
					c2e6b90ff1
				
			@@ -27,7 +27,8 @@ private:
 | 
				
			|||||||
  void update_graph_trans(ir::instruction *i);
 | 
					  void update_graph_trans(ir::instruction *i);
 | 
				
			||||||
  void update_graph_broadcast(ir::instruction *i);
 | 
					  void update_graph_broadcast(ir::instruction *i);
 | 
				
			||||||
  void update_graph_dot(ir::instruction *i);
 | 
					  void update_graph_dot(ir::instruction *i);
 | 
				
			||||||
  void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
 | 
					  void update_graph_elementwise(ir::instruction *i,
 | 
				
			||||||
 | 
					                                bool is_masked_load_async=false);
 | 
				
			||||||
  void update_graph_no_edge(ir::instruction *i);
 | 
					  void update_graph_no_edge(ir::instruction *i);
 | 
				
			||||||
  void update_graph(ir::instruction *i);
 | 
					  void update_graph(ir::instruction *i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -79,21 +79,30 @@ void axes::update_graph_dot(ir::instruction *i) {
 | 
				
			|||||||
    graph_.add_edge({dot, d}, {D, d});
 | 
					    graph_.add_edge({dot, d}, {D, d});
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
 | 
					void axes::update_graph_elementwise(ir::instruction *i, 
 | 
				
			||||||
 | 
					                                    bool is_masked_load_async) {
 | 
				
			||||||
  if(i->get_num_operands() == 0)
 | 
					  if(i->get_num_operands() == 0)
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  ir::value *op = i->get_operand(0);
 | 
					  ir::value *op = i->get_operand(0);
 | 
				
			||||||
  if(!op->get_type()->is_block_ty())
 | 
					  if(!op->get_type()->is_block_ty())
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  auto rank = op->get_type()->get_tile_rank();
 | 
					  auto rank = op->get_type()->get_tile_rank();
 | 
				
			||||||
  for(unsigned d = 0; d < rank; d++)
 | 
					  for(unsigned d = 0; d < rank; d++) {
 | 
				
			||||||
 | 
					    // If we are dealing with a masked async load we need to attach the
 | 
				
			||||||
 | 
					    // dimensions so we match the behaviour of the copy_to_shared instruction
 | 
				
			||||||
 | 
					    // which async masked load replaces.
 | 
				
			||||||
 | 
					    if (is_masked_load_async) {
 | 
				
			||||||
 | 
					      graph_.add_edge({i, d}, {i, d});
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for(ir::value* opx: i->ops())
 | 
					    for(ir::value* opx: i->ops())
 | 
				
			||||||
    for(ir::value* opy: i->ops()) {
 | 
					    for(ir::value* opy: i->ops()) {
 | 
				
			||||||
    if(connect_ret && !i->get_type()->is_void_ty())
 | 
					      if(!is_masked_load_async && !i->get_type()->is_void_ty())
 | 
				
			||||||
        graph_.add_edge({i, d}, {opx, d});
 | 
					        graph_.add_edge({i, d}, {opx, d});
 | 
				
			||||||
      graph_.add_edge({opx, d}, {opy, d});
 | 
					      graph_.add_edge({opx, d}, {opy, d});
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void axes::update_graph_no_edge(ir::instruction *i) {
 | 
					void axes::update_graph_no_edge(ir::instruction *i) {
 | 
				
			||||||
  if(!i->get_type()->is_block_ty())
 | 
					  if(!i->get_type()->is_block_ty())
 | 
				
			||||||
@@ -112,7 +121,7 @@ void axes::update_graph(ir::instruction *i) {
 | 
				
			|||||||
    case ir::INST_BROADCAST:         return update_graph_broadcast(i);
 | 
					    case ir::INST_BROADCAST:         return update_graph_broadcast(i);
 | 
				
			||||||
    case ir::INST_DOT:               return update_graph_dot(i);
 | 
					    case ir::INST_DOT:               return update_graph_dot(i);
 | 
				
			||||||
    case ir::INST_COPY_TO_SHARED:    return update_graph_no_edge(i);
 | 
					    case ir::INST_COPY_TO_SHARED:    return update_graph_no_edge(i);
 | 
				
			||||||
    case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false);
 | 
					    case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
 | 
				
			||||||
    case ir::INST_COPY_FROM_SHARED:  return update_graph_no_edge(i);
 | 
					    case ir::INST_COPY_FROM_SHARED:  return update_graph_no_edge(i);
 | 
				
			||||||
    case ir::INST_CVT_LAYOUT:        return update_graph_no_edge(i);
 | 
					    case ir::INST_CVT_LAYOUT:        return update_graph_no_edge(i);
 | 
				
			||||||
    default:                         return update_graph_elementwise(i);
 | 
					    default:                         return update_graph_elementwise(i);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -549,6 +549,55 @@ def test_arange(start, device='cuda'):
 | 
				
			|||||||
# ---------------
 | 
					# ---------------
 | 
				
			||||||
# test load
 | 
					# test load
 | 
				
			||||||
# ---------------
 | 
					# ---------------
 | 
				
			||||||
 | 
					# 'bfloat16': torch.bfloat16,
 | 
				
			||||||
 | 
					# Testing masked loads with an intermate copy to shared memory run.
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
 | 
				
			||||||
 | 
					def test_masked_load_shared_memory(dtype, device='cuda'):
 | 
				
			||||||
 | 
					    M = 32
 | 
				
			||||||
 | 
					    N = 32
 | 
				
			||||||
 | 
					    K = 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    in1 = torch.rand((M, K), dtype=dtype, device=device)
 | 
				
			||||||
 | 
					    in2 = torch.rand((K, N), dtype=dtype, device=device)
 | 
				
			||||||
 | 
					    out = torch.zeros((M, N), dtype=dtype, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @triton.jit
 | 
				
			||||||
 | 
					    def _kernel(in1_ptr, in2_ptr, output_ptr,
 | 
				
			||||||
 | 
					                in_stride, in2_stride, out_stride,
 | 
				
			||||||
 | 
					                in_numel, in2_numel, out_numel, **meta):
 | 
				
			||||||
 | 
					        M = meta['M']
 | 
				
			||||||
 | 
					        N = meta['N']
 | 
				
			||||||
 | 
					        K = meta['K']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        M_offsets = tl.arange(0, M)
 | 
				
			||||||
 | 
					        N_offsets = tl.arange(0, N)
 | 
				
			||||||
 | 
					        K_offsets = tl.arange(0, K)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        in_offsets =  M_offsets[:, None] * in_stride + K_offsets[None,:]
 | 
				
			||||||
 | 
					        in2_offsets =  K_offsets[:, None] * in2_stride + N_offsets[None,:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Load inputs.
 | 
				
			||||||
 | 
					        x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
 | 
				
			||||||
 | 
					        w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Without a dot product the memory doesn't get promoted to shared.
 | 
				
			||||||
 | 
					        o = tl.dot(x, w)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Store output
 | 
				
			||||||
 | 
					        output_offsets =  M_offsets[:, None] * out_stride + N_offsets[None,:]
 | 
				
			||||||
 | 
					        tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pgm = _kernel[(1,)](in1, in2, out,
 | 
				
			||||||
 | 
					                  in1.stride()[0],
 | 
				
			||||||
 | 
					                  in2.stride()[0],
 | 
				
			||||||
 | 
					                  out.stride()[0],
 | 
				
			||||||
 | 
					                  in1.numel(),
 | 
				
			||||||
 | 
					                  in2.numel(),
 | 
				
			||||||
 | 
					                  out.numel(),
 | 
				
			||||||
 | 
					                  M=M, N=N, K=K)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    reference_out =torch.matmul(in1, in2)
 | 
				
			||||||
 | 
					    triton.testing.allclose(out, reference_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ---------------
 | 
					# ---------------
 | 
				
			||||||
# test store
 | 
					# test store
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user