[GENERAL] Various bugfixes

This commit is contained in:
Philippe Tillet
2020-11-11 14:44:56 -05:00
committed by Philippe Tillet
parent 50587bbf4b
commit 8f8d36c7a4
11 changed files with 103 additions and 59 deletions

View File

@@ -3,16 +3,16 @@ import triton
class _dot(torch.autograd.Function):
src = """
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
@@ -95,11 +95,12 @@ class _dot(torch.autograd.Function):
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [64, 128],
'TN' : [64, 128],
'TK' : [8, 16],
'TM' : [128],
'TN' : [128],
'TK' : [16],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
@@ -120,7 +121,7 @@ dot = _dot.apply
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
M, N, K = 4096, 4096, 4096
a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half()
@@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half()
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))