[IR] Added special-purpose dequantize
instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
@@ -83,8 +83,8 @@ void cu_enqueue(uint64_t stream, uint64_t kernel,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
@@ -97,8 +97,8 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END
|
||||
};
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
|
||||
}
|
||||
@@ -302,8 +302,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
py::dict extern_libs, py::function add_to_cache, py::object grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
@@ -351,8 +351,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
}
|
||||
return bin;
|
||||
@@ -372,7 +372,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
return 0;
|
||||
if(backend == CUDA)
|
||||
if(backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
|
||||
if(backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
@@ -422,7 +422,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
@@ -430,9 +430,9 @@ void init_triton_runtime(py::module &&m) {
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
// ---------------------------------------
|
||||
// ---------------------------------------
|
||||
// Compile Triton-IR to assembly
|
||||
// ---------------------------------------
|
||||
// ---------------------------------------
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
@@ -550,13 +550,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("CA", ir::load_inst::CA)
|
||||
.value("CG", ir::load_inst::CG)
|
||||
.export_values();
|
||||
|
||||
|
||||
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", ir::load_inst::NORMAL)
|
||||
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
|
||||
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
|
||||
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
|
||||
.value("ADD", ir::reduce_inst::ADD)
|
||||
.value("FADD", ir::reduce_inst::FADD)
|
||||
@@ -573,7 +573,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
|
||||
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
.value("ADD", ir::atomic_rmw_op_t::Add)
|
||||
.value("FADD", ir::atomic_rmw_op_t::FAdd)
|
||||
@@ -704,7 +704,7 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type")
|
||||
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
|
||||
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||
return std::vector<ir::type*>(self->params_begin(), self->params_end());
|
||||
});
|
||||
|
||||
@@ -713,7 +713,7 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<ir::block_type, ir::type>(m, "block_type")
|
||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||
|
||||
|
||||
py::class_<ir::struct_type, ir::type>(m, "struct_type")
|
||||
.def("get", &ir::struct_type::get, ret::reference)
|
||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||
@@ -834,6 +834,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Dequantize instructions
|
||||
.def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
@@ -857,27 +859,27 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
|
||||
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
|
||||
.def("create_srem", &ir::builder::create_srem, ret::reference)
|
||||
.def("create_urem", &ir::builder::create_urem, ret::reference)
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sub", &ir::builder::create_sub, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_shl", &ir::builder::create_shl, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// GEP
|
||||
.def("create_gep", &ir::builder::create_gep, ret::reference)
|
||||
|
Reference in New Issue
Block a user