[IR] Added special-purpose dequantize instruction (#759)

It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
Yu Guo
2022-10-12 14:14:45 -07:00
committed by GitHub
parent 33e6f0df7f
commit 71b46acc42
16 changed files with 728 additions and 73 deletions

View File

@@ -83,8 +83,8 @@ void cu_enqueue(uint64_t stream, uint64_t kernel,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (CUstream)stream, nullptr, config);
}
@@ -97,8 +97,8 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
HIP_LAUNCH_PARAM_END
};
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
}
@@ -302,8 +302,8 @@ void init_triton_runtime(py::module &&m) {
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
py::dict extern_libs, py::function add_to_cache, py::object grid){
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
long _num_warps = PyLong_AsLong(num_warps.ptr());
@@ -351,8 +351,8 @@ void init_triton_runtime(py::module &&m) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
}
return bin;
@@ -372,7 +372,7 @@ void init_triton_runtime(py::module &&m) {
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
return 0;
if(backend == CUDA)
if(backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
if(backend == ROCM)
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
@@ -422,7 +422,7 @@ void init_triton_runtime(py::module &&m) {
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
});
}
/*****************************************************************************/
@@ -430,9 +430,9 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// ---------------------------------------
// Compile Triton-IR to assembly
// ---------------------------------------
// ---------------------------------------
void init_triton_codegen(py::module &&m) {
m.def("compile_ttir",
@@ -550,13 +550,13 @@ void init_triton_ir(py::module &&m) {
.value("CA", ir::load_inst::CA)
.value("CG", ir::load_inst::CG)
.export_values();
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
.value("NORMAL", ir::load_inst::NORMAL)
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
.export_values();
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
.value("ADD", ir::reduce_inst::ADD)
.value("FADD", ir::reduce_inst::FADD)
@@ -573,7 +573,7 @@ void init_triton_ir(py::module &&m) {
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
.value("ADD", ir::atomic_rmw_op_t::Add)
.value("FADD", ir::atomic_rmw_op_t::FAdd)
@@ -704,7 +704,7 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::function_type, ir::type>(m, "function_type")
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
.def_property_readonly("arg_tys", [](ir::function_type* self){
.def_property_readonly("arg_tys", [](ir::function_type* self){
return std::vector<ir::type*>(self->params_begin(), self->params_end());
});
@@ -713,7 +713,7 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::struct_type, ir::type>(m, "struct_type")
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
@@ -834,6 +834,8 @@ void init_triton_ir(py::module &&m) {
.def("create_br", &ir::builder::create_br, ret::reference)
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
// Dequantize instructions
.def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
// Cast instructions
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
.def("create_cast", &ir::builder::create_cast, ret::reference)
@@ -857,27 +859,27 @@ void init_triton_ir(py::module &&m) {
.def("create_frem", &ir::builder::create_frem, ret::reference)
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
.def("create_srem", &ir::builder::create_srem, ret::reference)
.def("create_urem", &ir::builder::create_urem, ret::reference)
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sub", &ir::builder::create_sub, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_shl", &ir::builder::create_shl, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
// GEP
.def("create_gep", &ir::builder::create_gep, ret::reference)

View File

@@ -0,0 +1,261 @@
# flake8: noqa: F821,F841
import random
import torch
import triton
import triton.language as tl
@triton.jit
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int8(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int4(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int2(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
def test_dequantize_int8() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 128, 4)
else:
size = random.randrange(132, 1024, 4)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input = torch.cat((scale_shift, input_int32))
expected = (input_int8 * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 128)
grid = (1,)
dequantize_kernel_int8[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int8[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int4() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size // 2,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input_int8_h1 = input_int8 >> 4
input_int8_h0 = input_int8 & 15
input_int4_val = torch.stack(
(input_int8_h0, input_int8_h1), dim=1
).flatten()
input = torch.cat((scale_shift, input_int32))
expected = (input_int4_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int4[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int4[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int2() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int16)
input_int8 = torch.randint(
0, 256, (size // 4,), dtype=torch.uint8, device=device
)
input_int16 = input_int8.view(torch.int16)
input_int8_q3 = input_int8 >> 6
input_int8_q2 = (input_int8 >> 4) & 3
input_int8_q1 = (input_int8 >> 2) & 3
input_int8_q0 = input_int8 & 3
input_int2_val = torch.stack(
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
).flatten()
input = torch.cat((scale_shift, input_int16))
expected = (input_int2_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int2[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int2[grid](
output,
input_int16,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)

View File

@@ -685,6 +685,20 @@ def zeros(shape, dtype, _builder=None):
return semantic.zeros(shape, dtype, _builder)
# -----------------------
# dequantize
# -----------------------
@builtin
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
"""
Tries to dequantize the input to given dtype
"""
nbit = _constexpr_to_value(nbit)
return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder)
# -----------------------
# Shape Manipulation
# -----------------------

View File

@@ -544,6 +544,31 @@ def broadcast_impl_value(lhs: tl.tensor,
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# dequantize
#######
def dequantize(input: tl.tensor,
scale: tl.tensor,
shift: tl.tensor,
nbit: int,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
input_ty = input.type
assert input_ty.is_block()
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
assert nbit in [2, 4, 8]
assert dst_ty == tl.float16
shape = input_ty.get_block_shapes()
factor = input_ty.element_ty.primitive_bitwidth // nbit
dst_shape = shape[:-1] + [factor * shape[-1]]
dst_ty = tl.block_type(dst_ty, dst_shape)
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
#######
# cast
#######