Merge branch 'master' into rcom52_fixes

This commit is contained in:
Michael Melesse
2022-10-17 17:53:48 +00:00
151 changed files with 20150 additions and 19097 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,261 @@
# flake8: noqa: F821,F841
import random
import torch
import triton
import triton.language as tl
@triton.jit
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int8(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int4(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int2(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
def test_dequantize_int8() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 128, 4)
else:
size = random.randrange(132, 1024, 4)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input = torch.cat((scale_shift, input_int32))
expected = (input_int8 * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 128)
grid = (1,)
dequantize_kernel_int8[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int8[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int4() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size // 2,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input_int8_h1 = input_int8 >> 4
input_int8_h0 = input_int8 & 15
input_int4_val = torch.stack(
(input_int8_h0, input_int8_h1), dim=1
).flatten()
input = torch.cat((scale_shift, input_int32))
expected = (input_int4_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int4[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int4[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int2() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int16)
input_int8 = torch.randint(
0, 256, (size // 4,), dtype=torch.uint8, device=device
)
input_int16 = input_int8.view(torch.int16)
input_int8_q3 = input_int8 >> 6
input_int8_q2 = (input_int8 >> 4) & 3
input_int8_q1 = (input_int8 >> 2) & 3
input_int8_q0 = input_int8 & 3
input_int2_val = torch.stack(
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
).flatten()
input = torch.cat((scale_shift, input_int16))
expected = (input_int2_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int2[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int2[grid](
output,
input_int16,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)

View File

@@ -1,16 +1,16 @@
import torch
import triton
import triton.language as tl
import numpy as np
import pytest
import scipy.stats
import numpy as np
import torch
from numpy.random import Philox
import triton
import triton.language as tl
#####################################
## Reference Philox Implementation
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -74,9 +74,8 @@ class CustomPhilox4x:
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
ret0 = key[0] + self._config.PHILOX_KEY_A
ret1 = key[1] + self._config.PHILOX_KEY_B
return np.array([ret0, ret1], dtype=self._dtype)
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
return key + np.array(pk, dtype=self._dtype)
def random_raw(self):
counter = self._counter
@@ -104,18 +103,21 @@ class CustomPhilox(CustomPhilox4x):
#####################################
## Unit Tests
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -133,10 +135,12 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
@@ -148,13 +152,16 @@ def test_rand(size, seed, device='cuda'):
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
@@ -168,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2
# tl.rand() should never produce >=1.0
def test_rand_limits():
@triton.jit
def kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = tl.random.uint32_to_uniform_float(x)
tl.store(output + idx, y)
min_max_int32 = torch.tensor([
torch.iinfo(torch.int32).min,
torch.iinfo(torch.int32).max,
], dtype=torch.int32, device='cuda')
output = torch.empty(2, dtype=torch.float32, device='cuda')
kernel[(1,)](min_max_int32, output, 2)
assert output[0] == output[1]
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0