[FRONTEND] Alignment fix-up (#428)

This commit is contained in:
Philippe Tillet
2022-01-11 23:11:58 -08:00
committed by GitHub
parent bbc78f6516
commit 4c94359199
4 changed files with 5386 additions and 871 deletions

View File

@@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
}
std::string pow2_divisor(long N){
if(N % 16 == 0) return "16";
if(N % 8 == 0) return "8";
if(N % 4 == 0) return "4";
if(N % 2 == 0) return "2";
return "1";
long pow2_divisor(long N){
if(N % 16 == 0) return 16;
if(N % 8 == 0) return 8;
if(N % 4 == 0) return 4;
if(N % 2 == 0) return 2;
return 1;
}
// Returns something like "int16", whether dtype is a torch.dtype or
@@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) {
}
}
size_t get_pointer_range_size(uint64_t addr){
if(addr == 0)
return 0;
size_t size;
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
return size;
}
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
@@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue;
// values divisible by small powers of 2 are specialized
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
cache_key += std::to_string(pow2_divisor(value));
cache_key += ")]";
continue;
}
@@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
size_t range_size = get_pointer_range_size(value);
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
cache_key += ")]";
continue;
}
@@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) {
}
);
// get range size for the given pointer
m.def("get_pointer_range_size", &get_pointer_range_size);
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,