[GENERAL] Various bugfixes
This commit is contained in:
committed by
Philippe Tillet
parent
50587bbf4b
commit
8f8d36c7a4
@@ -3,16 +3,16 @@ import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
@@ -95,11 +95,12 @@ class _dot(torch.autograd.Function):
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
'TK' : [8, 16],
|
||||
'TM' : [128],
|
||||
'TN' : [128],
|
||||
'TK' : [16],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
@@ -120,7 +121,7 @@ dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 2048, 2048, 2048
|
||||
M, N, K = 4096, 4096, 4096
|
||||
a = torch.rand((M, K)).cuda().half()
|
||||
b = torch.rand((K, N)).cuda().half()
|
||||
|
||||
@@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half()
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
|
||||
print(torch.allclose(zc, zc_))
|
||||
|
Reference in New Issue
Block a user