[FRONTEND] Alignment fix-up (#428)
This commit is contained in:
@@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
|
||||
}
|
||||
|
||||
std::string pow2_divisor(long N){
|
||||
if(N % 16 == 0) return "16";
|
||||
if(N % 8 == 0) return "8";
|
||||
if(N % 4 == 0) return "4";
|
||||
if(N % 2 == 0) return "2";
|
||||
return "1";
|
||||
long pow2_divisor(long N){
|
||||
if(N % 16 == 0) return 16;
|
||||
if(N % 8 == 0) return 8;
|
||||
if(N % 4 == 0) return 4;
|
||||
if(N % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
@@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_pointer_range_size(uint64_t addr){
|
||||
if(addr == 0)
|
||||
return 0;
|
||||
size_t size;
|
||||
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
|
||||
return size;
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
@@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
cache_key += std::to_string(pow2_divisor(value));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
@@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// udpate cache key
|
||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
size_t range_size = get_pointer_range_size(value);
|
||||
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
@@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) {
|
||||
}
|
||||
);
|
||||
|
||||
// get range size for the given pointer
|
||||
m.def("get_pointer_range_size", &get_pointer_range_size);
|
||||
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
|
Reference in New Issue
Block a user