fix bfloat failure
This commit is contained in:
@@ -607,11 +607,11 @@ void init_triton_codegen(py::module &&m) {
|
||||
// set dynamic shared memory if necessary
|
||||
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
|
||||
int shared_optin;
|
||||
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||
int shared_total, shared_static;
|
||||
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||
// drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
shared_total = attr.sharedSizeBytes;
|
||||
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||
|
@@ -104,9 +104,13 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if torch.version.hip is not None:
|
||||
if dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16:
|
||||
pytest.skip("bfloat16 is not supported on AMDGPU")
|
||||
else:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
|
@@ -1024,7 +1024,7 @@ static inline hipDeviceptr_t getPointer(PyObject *obj, int idx) {{
|
||||
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
printf("launch(PyObject* self, PyObject* args)");
|
||||
// printf("launch(PyObject* self, PyObject* args)");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
|
Reference in New Issue
Block a user