Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -176,3 +176,36 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for
|
||||
func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
||||
// CHECK-NEXT: Membar 3
|
||||
%cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A>
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Although a_shared and b_shared are synced before entering the loop,
|
||||
// they are reassociated with aliases (c_shared) and thus require a barrier.
|
||||
// CHECK-LABEL: for_alias
|
||||
func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A>
|
||||
// CHECK-NEXT: Membar 6
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A>
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
|
||||
}
|
||||
// CHECK-NEXT: Membar 9
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A>, tensor<256x32xf16, #A>) -> tensor<512x32xf16, #A>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user