cast to amd device when as_nvidia shows up

This commit is contained in:
Michael Melesse
2022-10-26 18:12:18 +00:00
parent 4464dfcc18
commit abe0d3e1b1
2 changed files with 22 additions and 3 deletions

View File

@@ -36,6 +36,7 @@ namespace triton{
namespace codegen{
class nvidia_cu_target;
class amd_cl_target;
class target {
public:
@@ -49,7 +50,12 @@ public:
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
#ifdef USE_ROCM
amd_cl_target* as_nvidia();
amd_cl_target* as_amd();
#else
nvidia_cu_target* as_nvidia();
#endif
bool is_gpu() const;
private:
@@ -67,6 +73,7 @@ public:
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
int sm() { return 0; } // treat as if old CUDA device
};
class nvidia_cu_target: public target {

View File

@@ -15,10 +15,22 @@ namespace codegen{
// base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
#ifdef USE_ROCM
amd_cl_target *target::as_amd()
{
return dynamic_cast<amd_cl_target *>(this);
}
amd_cl_target *target::as_nvidia()
{
return this->as_amd();
}
#else
// causes segfault on ROCM
nvidia_cu_target *target::as_nvidia()
{
return dynamic_cast<nvidia_cu_target *>(this);
}
#endif
bool target::is_gpu() const {
return is_gpu_;