Merge branch 'master' into rcom52_fixes
This commit is contained in:
File diff suppressed because it is too large
Load Diff
261
python/test/unit/language/test_dequantize.py
Normal file
261
python/test/unit/language/test_dequantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
@@ -1,16 +1,16 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from numpy.random import Philox
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
@@ -74,9 +74,8 @@ class CustomPhilox4x:
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||
return np.array([ret0, ret1], dtype=self._dtype)
|
||||
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||
return key + np.array(pk, dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
@@ -104,18 +103,21 @@ class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
@@ -133,10 +135,12 @@ def test_randint(size, seed, device='cuda'):
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -148,13 +152,16 @@ def test_rand(size, seed, device='cuda'):
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -168,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
||||
|
||||
|
||||
# tl.rand() should never produce >=1.0
|
||||
|
||||
def test_rand_limits():
|
||||
@triton.jit
|
||||
def kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = tl.random.uint32_to_uniform_float(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
min_max_int32 = torch.tensor([
|
||||
torch.iinfo(torch.int32).min,
|
||||
torch.iinfo(torch.int32).max,
|
||||
], dtype=torch.int32, device='cuda')
|
||||
output = torch.empty(2, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](min_max_int32, output, 2)
|
||||
|
||||
assert output[0] == output[1]
|
||||
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
||||
|
Reference in New Issue
Block a user