Hack to make OpenCL for AMD work

This commit is contained in:
Philippe Tillet
2019-03-23 18:58:25 -07:00
parent be55b3a081
commit deb7a1cc5c
3 changed files with 28 additions and 45 deletions

View File

@@ -111,7 +111,7 @@ int main() {
triton::jit jit(context); triton::jit jit(context);
// matrix multiplication parameters // matrix multiplication parameters
size_t M = 512, N = 512, K = 512; int32_t M = 128, N = 128, K = 128;
std::vector<float> hc(M*N); std::vector<float> hc(M*N);
std::vector<float> rc(M*N); std::vector<float> rc(M*N);
std::vector<float> ha(M*K); std::vector<float> ha(M*K);
@@ -163,8 +163,9 @@ int main() {
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize(); stream->synchronize();
// benchmark // benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, // double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }); // [&](){ stream->synchronize(); });
double ts = 1;
ts = ts * 1e-9; ts = ts * 1e-9;
double tflops = 2*M*N*K / ts * 1e-12; double tflops = 2*M*N*K / ts * 1e-12;
return tflops; return tflops;

View File

@@ -748,6 +748,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
indices_t b_idx = {idx[1], builder.getInt32(K)}; indices_t b_idx = {idx[1], builder.getInt32(K)};
Value *a = TA->get_value(a_idx); Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx); Value *b = TB->get_value(b_idx);
// a = ConstantFP::get(builder.getFloatTy(), 1);
// b = ConstantFP::get(builder.getFloatTy(), 1);
res = builder.CreateCall(f_mul_add, {a, b, res}); res = builder.CreateCall(f_mul_add, {a, b, res});
} }
result->set_value(idx, res); result->set_value(idx, res);

View File

@@ -185,37 +185,17 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) { ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
init_llvm(); init_llvm();
// std::vector<std::string> files = {
// "opencl.amdgcn.bc",
// "ocml.amdgcn.bc",
// "ockl.amdgcn.bc",
// "oclc_correctly_rounded_sqrt_off.amdgcn.bc",
// "oclc_daz_opt_on.amdgcn.bc",
// "oclc_finite_only_off.amdgcn.bc",
// "oclc_isa_version_902.amdgcn.bc",
// "oclc_unsafe_math_off.amdgcn.bc"
// };
// for(auto&x : files)
// x = "/opt/rocm/lib/" + x;
llvm::LLVMContext ctx;
// llvm::IRBuilder<> builder(ctx);
// auto dummy = new llvm::Module("matmul", ctx);
// llvm::Function *fn = llvm::Function::Create(llvm::FunctionType::get(builder.getVoidTy(), {}, false), llvm::Function::ExternalLinkage, "matmul", dummy);
// llvm::BasicBlock *entry = llvm::BasicBlock::Create(ctx, "entry", fn);
// builder.SetInsertPoint(entry);
// builder.CreateRetVoid();
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
llvm::SMDiagnostic error; module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
auto dummy = llvm::parseIRFile("test.bc", error, ctx);
module::compile_llvm_module(dummy.get(), "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
std::ofstream output("tmp.o", std::ios::binary);
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
system("ld.lld tmp.o -shared -o test.o");
// std::ifstream fin("test.o", std::ios::in | std::ios::binary ); std::ifstream input("test.o", std::ios::in | std::ios::binary );
// std::vector<char> buffer(9296); std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
// fin.read(buffer.data(), buffer.size()); size_t sizes[] = {in_buffer.size()};
size_t sizes[] = {buffer.size()}; const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
const unsigned char* data[] = {(unsigned char*)buffer.data()};
cl_int status; cl_int status;
cl_int err; cl_int err;
*cl_ = dispatch::clCreateProgramWithBinary(*context->cl(), 1, &*context->device()->cl(), sizes, data, &status, &err); *cl_ = dispatch::clCreateProgramWithBinary(*context->cl(), 1, &*context->device()->cl(), sizes, data, &status, &err);