[IR] Added special-purpose dequantize
instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
261
python/test/unit/language/test_dequantize.py
Normal file
261
python/test/unit/language/test_dequantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
Reference in New Issue
Block a user