[OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
@@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) {
|
||||
return bin;
|
||||
});
|
||||
|
||||
m.def("cc", [](backend_t backend, uint64_t device) -> int {
|
||||
if (backend == CUDA) {
|
||||
CUdevice dev = (CUdevice)device;
|
||||
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
return major*10 + minor;
|
||||
}
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
@@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) {
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query DRAM & L2 cache
|
||||
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query clock rate (in kilohertz)
|
||||
m.def("clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("num_sm", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// enqueue
|
||||
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
|
Reference in New Issue
Block a user