Presets: Now checking device type when grabbing a preset

This commit is contained in:
Philippe Tillet
2015-08-13 13:43:26 -07:00
parent 0bb73602f9
commit 29e752c765
9 changed files with 22 additions and 20 deletions

View File

@@ -27,12 +27,7 @@ enum ISAACAPI backend_type
#endif
};
enum ISAACAPI device_type
{
DEVICE_TYPE_GPU = CL_DEVICE_TYPE_GPU,
DEVICE_TYPE_CPU = CL_DEVICE_TYPE_CPU,
DEVICE_TYPE_ACCELERATOR = CL_DEVICE_TYPE_ACCELERATOR
};
#ifdef ISAAC_WITH_CUDA

View File

@@ -20,6 +20,12 @@ private:
friend class CommandQueue;
public:
enum ISAACAPI Type
{
GPU = CL_DEVICE_TYPE_GPU,
CPU = CL_DEVICE_TYPE_CPU,
ACCELERATOR = CL_DEVICE_TYPE_ACCELERATOR
};
enum class Vendor
{
AMD,
@@ -59,7 +65,7 @@ public:
std::string name() const;
std::string vendor_str() const;
std::vector<size_t> max_work_item_sizes() const;
device_type type() const;
Type type() const;
std::string extensions() const;
size_t max_work_group_size() const;
size_t local_mem_size() const;

View File

@@ -17,7 +17,7 @@ namespace isaac
struct profiles
{
typedef std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_type;
typedef std::map<std::tuple<driver::Device::Type, driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_type;
public:
class value_type
{

View File

@@ -147,14 +147,14 @@ std::vector<size_t> Device::max_work_item_sizes() const
}
}
device_type Device::type() const
Device::Type Device::type() const
{
switch(backend_)
{
#ifdef ISAAC_WITH_CUDA
case CUDA: return DEVICE_TYPE_GPU;
#endif
case OPENCL: return static_cast<device_type>(ocl::info<CL_DEVICE_TYPE>(h_.cl()));
case OPENCL: return static_cast<Type>(ocl::info<CL_DEVICE_TYPE>(h_.cl()));
default: throw;
}
}

View File

@@ -7,11 +7,11 @@ namespace isaac
{
#define DATABASE_ENTRY(VENDOR, ARCHITECTURE, STRING) \
{std::make_tuple(driver::Device::Vendor::VENDOR, driver::Device::Architecture::ARCHITECTURE), STRING}
#define DATABASE_ENTRY(TYPE, VENDOR, ARCHITECTURE, STRING) \
{std::make_tuple(driver::Device::Type::TYPE, driver::Device::Vendor::VENDOR, driver::Device::Architecture::ARCHITECTURE), STRING}
const std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> profiles::presets_ =
{ DATABASE_ENTRY(INTEL, BROADWELL, presets::broadwell) };
const profiles::presets_type profiles::presets_ =
{ DATABASE_ENTRY(GPU, INTEL, BROADWELL, presets::broadwell) };
#undef DATABASE_ENTRY

View File

@@ -209,9 +209,10 @@ profiles::map_type& profiles::init(driver::CommandQueue const & queue)
result[std::make_pair(etype, dtype)] = std::shared_ptr<value_type>(new value_type(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
driver::Device const & device = queue.device();
presets_type::const_iterator it = presets_.find(std::make_tuple(device.vendor(), device.architecture()));
if(it==presets_.end())
import(presets_.at(std::make_tuple(device.vendor(), driver::Device::Architecture::UNKNOWN)), queue);
presets_type::const_iterator it = presets_.find(std::make_tuple(device.type(), device.vendor(), device.architecture()));
if(it==presets_.end()){
// import(presets_.at(std::make_tuple(device.type(), device.vendor(), driver::Device::Architecture::UNKNOWN)), queue);
}
else
import(it->second, queue);
std::string homepath = tools::getenv("HOME");

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include src/lib/external'.split() + ['external/boost/', 'external/boost/boost/', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files
src = 'src/lib/profiles/profiles.cpp src/lib/profiles/presets.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/stream.cpp src/lib/kernels/parse.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/wrap/clBLAS.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/command_queue.cpp src/lib/driver/backend.cpp src/lib/driver/kernel.cpp src/lib/driver/program_cache.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/buffer.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/event.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
src = 'src/lib/profiles/profiles.cpp src/lib/profiles/presets.cpp src/lib/profiles/predictors/random_forest.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/stream.cpp src/lib/kernels/parse.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/wrap/clBLAS.cpp src/lib/array.cpp src/lib/value_scalar.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/io.cpp src/lib/symbolic/expression.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/command_queue.cpp src/lib/driver/backend.cpp src/lib/driver/kernel.cpp src/lib/driver/program_cache.cpp src/lib/driver/platform.cpp src/lib/driver/ndrange.cpp src/lib/driver/handle.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/buffer.cpp src/lib/driver/check.cpp src/lib/driver/program.cpp src/lib/driver/event.cpp '.split() + [os.path.join('src', 'bind', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'kernels.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]

View File

@@ -144,7 +144,7 @@ int main()
for(isaac::driver::Context const * context : data)
{
sc::driver::Device device = sc::driver::backend::queues::get(*context,0).device();
if(device.type() != sc::driver::DEVICE_TYPE_GPU)
if(device.type() != sc::driver::Device::Type::GPU)
continue;
std::cout << "Device: " << device.name() << " on " << device.platform().name() << " " << device.platform().version() << std::endl;
std::cout << "---" << std::endl;

View File

@@ -130,7 +130,7 @@ int main()
for(isaac::driver::Context const * context : data)
{
sc::driver::Device device = sc::driver::backend::queues::get(*context,0).device();
if(device.type() != sc::driver::DEVICE_TYPE_GPU)
if(device.type() != sc::driver::Device::Type::GPU)
continue;
std::cout << "Device: " << device.name() << " on " << device.platform().name() << " " << device.platform().version() << std::endl;
std::cout << "---" << std::endl;