[IR] Added special-purpose dequantize
instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
@@ -83,8 +83,8 @@ void cu_enqueue(uint64_t stream, uint64_t kernel,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
@@ -97,8 +97,8 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
|
||||
}
|
||||
@@ -302,8 +302,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
py::dict extern_libs, py::function add_to_cache, py::object grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
@@ -351,8 +351,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
}
|
||||
return bin;
|
||||
@@ -372,7 +372,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
return 0;
|
||||
if(backend == CUDA)
|
||||
if(backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
|
||||
if(backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
@@ -422,7 +422,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
@@ -430,9 +430,9 @@ void init_triton_runtime(py::module &&m) {
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
// ---------------------------------------
|
||||
// ---------------------------------------
|
||||
// Compile Triton-IR to assembly
|
||||
// ---------------------------------------
|
||||
// ---------------------------------------
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
@@ -550,13 +550,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("CA", ir::load_inst::CA)
|
||||
.value("CG", ir::load_inst::CG)
|
||||
.export_values();
|
||||
|
||||
|
||||
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", ir::load_inst::NORMAL)
|
||||
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
|
||||
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
|
||||
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
|
||||
.value("ADD", ir::reduce_inst::ADD)
|
||||
.value("FADD", ir::reduce_inst::FADD)
|
||||
@@ -573,7 +573,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
|
||||
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
.value("ADD", ir::atomic_rmw_op_t::Add)
|
||||
.value("FADD", ir::atomic_rmw_op_t::FAdd)
|
||||
@@ -704,7 +704,7 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type")
|
||||
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
|
||||
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||
return std::vector<ir::type*>(self->params_begin(), self->params_end());
|
||||
});
|
||||
|
||||
@@ -713,7 +713,7 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<ir::block_type, ir::type>(m, "block_type")
|
||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||
|
||||
|
||||
py::class_<ir::struct_type, ir::type>(m, "struct_type")
|
||||
.def("get", &ir::struct_type::get, ret::reference)
|
||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||
@@ -834,6 +834,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Dequantize instructions
|
||||
.def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
@@ -857,27 +859,27 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
|
||||
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
|
||||
.def("create_srem", &ir::builder::create_srem, ret::reference)
|
||||
.def("create_urem", &ir::builder::create_urem, ret::reference)
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sub", &ir::builder::create_sub, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_shl", &ir::builder::create_shl, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// GEP
|
||||
.def("create_gep", &ir::builder::create_gep, ret::reference)
|
||||
|
261
python/test/unit/language/test_dequantize.py
Normal file
261
python/test/unit/language/test_dequantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
@@ -685,6 +685,20 @@ def zeros(shape, dtype, _builder=None):
|
||||
return semantic.zeros(shape, dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# dequantize
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
|
||||
"""
|
||||
Tries to dequantize the input to given dtype
|
||||
"""
|
||||
nbit = _constexpr_to_value(nbit)
|
||||
return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Shape Manipulation
|
||||
# -----------------------
|
||||
|
@@ -544,6 +544,31 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
# (scalar, scalar) => returns original blocks
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
#######
|
||||
# dequantize
|
||||
#######
|
||||
|
||||
def dequantize(input: tl.tensor,
|
||||
scale: tl.tensor,
|
||||
shift: tl.tensor,
|
||||
nbit: int,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
input_ty = input.type
|
||||
assert input_ty.is_block()
|
||||
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
|
||||
assert nbit in [2, 4, 8]
|
||||
assert dst_ty == tl.float16
|
||||
|
||||
shape = input_ty.get_block_shapes()
|
||||
factor = input_ty.element_ty.primitive_bitwidth // nbit
|
||||
dst_shape = shape[:-1] + [factor * shape[-1]]
|
||||
|
||||
dst_ty = tl.block_type(dst_ty, dst_shape)
|
||||
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
|
||||
#######
|
||||
# cast
|
||||
#######
|
||||
|
Reference in New Issue
Block a user