[DOCS] Transposition fix
This commit is contained in:
committed by
Philippe Tillet
parent
0920da6fae
commit
180ed26b61
@@ -14,10 +14,15 @@ class _transpose(torch.autograd.Function):
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:,newaxis]; //(6)
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
|
||||
*py = ^*px;
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -40,6 +45,7 @@ class _transpose(torch.autograd.Function):
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _transpose.kernel is None:
|
||||
@@ -53,7 +59,7 @@ transpose = _transpose.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(128,200).cuda()
|
||||
x = torch.randn(1024,128).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
|
Reference in New Issue
Block a user