From abe0d3e1b10162ee95ed12718188cf0191796fbb Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 26 Oct 2022 18:12:18 +0000 Subject: [PATCH] cast to amd device when as_nvidia shows up --- include/triton/codegen/target.h | 7 +++++++ lib/codegen/target.cc | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/include/triton/codegen/target.h b/include/triton/codegen/target.h index 96e4d5c31..d134befa7 100644 --- a/include/triton/codegen/target.h +++ b/include/triton/codegen/target.h @@ -36,6 +36,7 @@ namespace triton{ namespace codegen{ class nvidia_cu_target; +class amd_cl_target; class target { public: @@ -49,7 +50,12 @@ public: virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; virtual unsigned guaranteed_alignment() = 0; +#ifdef USE_ROCM + amd_cl_target* as_nvidia(); + amd_cl_target* as_amd(); +#else nvidia_cu_target* as_nvidia(); +#endif bool is_gpu() const; private: @@ -67,6 +73,7 @@ public: Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); unsigned guaranteed_alignment() { return 16; } + int sm() { return 0; } // treat as if old CUDA device }; class nvidia_cu_target: public target { diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc index c775938ba..d2663080c 100644 --- a/lib/codegen/target.cc +++ b/lib/codegen/target.cc @@ -15,10 +15,22 @@ namespace codegen{ // base - -nvidia_cu_target* target::as_nvidia() { - return dynamic_cast(this); +#ifdef USE_ROCM +amd_cl_target *target::as_amd() +{ + return dynamic_cast(this); } +amd_cl_target *target::as_nvidia() +{ + return this->as_amd(); +} +#else +// causes segfault on ROCM +nvidia_cu_target *target::as_nvidia() +{ + return dynamic_cast(this); +} +#endif bool target::is_gpu() const { return is_gpu_;