[JIT] re-added nvidia compatibility

This commit is contained in:
Philippe Tillet
2019-03-27 21:12:01 -04:00
parent fdf8559806
commit 2c3ae0675e
19 changed files with 106 additions and 53 deletions

View File

@@ -8,7 +8,7 @@ const char* src =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8, 16};
const tunable int32 TK = {8};
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K, int32 bound){
@@ -19,18 +19,35 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
fp32 C[TM, TN] = 0;
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
for(int32 k = K; k > 0;){
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
C = dot(a, b, C);
pa = pa + TK*M;
pb = pb + TK*K;
k = k - TK;
int1 checka[TM, TK] = k > bound;
int1 checkb[TN, TK] = k > bound;
@checka a = *pa;
@checkb b = *pb;
if(k > bound)
continue;
int1 checka0[TM] = rxa < M;
int1 checka1[TK] = rka < k;
int1 checkb0[TN] = ryb < N;
int1 checkb1[TK] = rkb < k;
checka = checka0[:, newaxis] && checka1[newaxis, :];
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
*pc = C;
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
}
)";
@@ -89,7 +106,7 @@ int main() {
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 256, N = 256, K = 256;
int32_t M = 512, N = 512, K = 512;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -151,13 +168,12 @@ int main() {
// just-in-time compile source-code
std::vector<unsigned> params = {
1, 4, 8,
1, 4, 8,
1, 1, 4, 4,
1, 8,
1,
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 8,
4
};
params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
// jit.autotune(src, benchmark);
jit.add_module(src, params);

View File

@@ -29,6 +29,11 @@
namespace triton
{
namespace codegen
{
class target;
}
namespace driver
{
@@ -40,6 +45,7 @@ public:
using polymorphic_resource::polymorphic_resource;
virtual size_t max_threads_per_block() const = 0;
virtual size_t max_shared_memory() const = 0;
virtual std::unique_ptr<codegen::target> make_target() const = 0;
};
// Host device
@@ -48,6 +54,7 @@ public:
host_device(): device(host_device_t(), true){ }
size_t max_threads_per_block() const { return 1; }
size_t max_shared_memory() const { return 0; }
std::unique_ptr<codegen::target> make_target() const;
};
// OpenCL device
@@ -56,6 +63,7 @@ public:
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
std::unique_ptr<codegen::target> make_target() const;
};
// CUDA device
@@ -87,26 +95,28 @@ private:
public:
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
//Accessors
// Accessors
Architecture architecture() const;
//Informations
// Informations
std::string infos() const;
size_t address_bits() const;
std::vector<size_t> max_block_dim() const;
size_t warp_size() const;
//Compute Capability
// Compute Capability
void interpret_as(std::pair<size_t, size_t> cc);
std::pair<size_t, size_t> compute_capability() const;
//Identifier
// Identifier
std::string name() const;
std::string pci_bus_id() const;
//Clocks
// Clocks
size_t current_sm_clock() const;
size_t current_mem_clock() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
size_t max_sm_clock() const;
size_t max_mem_clock() const;
// Target
std::unique_ptr<codegen::target> make_target() const;
private:
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;

View File

@@ -49,6 +49,11 @@ class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
protected:
void init_llvm();
enum file_type_t{
Object,
Assembly
};
public:
module(driver::context* ctx, CUmodule mod, bool has_ownership);
module(driver::context* ctx, cl_program mod, bool has_ownership);
@@ -57,7 +62,9 @@ public:
driver::context* context() const;
void compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer, std::vector<std::string> files = {});
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
file_type_t file_type);
protected:
driver::context* ctx_;

View File

@@ -32,7 +32,7 @@
#ifdef __APPLE__
#include <OpenCL/cl_platform.h>
#else
#include <CL/cl_platform.h>
#include "cl_platform.h"
#endif
#ifdef __cplusplus

View File

@@ -32,8 +32,8 @@
#define __OPENCL_CL_D3D10_H
#include <d3d10.h>
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {

View File

@@ -32,8 +32,8 @@
#define __OPENCL_CL_D3D11_H
#include <d3d11.h>
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {

View File

@@ -31,8 +31,8 @@
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H
#define __OPENCL_CL_DX9_MEDIA_SHARING_H
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {

View File

@@ -32,7 +32,7 @@
#ifdef __APPLE__
#else
#include <CL/cl.h>
#include "cl.h"
#endif
#ifdef __cplusplus

View File

@@ -42,7 +42,7 @@ extern "C" {
#include <OpenCL/cl.h>
#include <AvailabilityMacros.h>
#else
#include <CL/cl.h>
#include "cl.h"
#endif
/* cl_khr_fp64 extension - no extension #define since it has no functions */

View File

@@ -56,8 +56,8 @@ Notes:
#include <OpenCL/cl.h>
#include <OpenCL/cl_platform.h>
#else
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include "cl.h"
#include "cl_platform.h"
#endif
#ifdef __cplusplus

View File

@@ -32,7 +32,7 @@
#ifdef __APPLE__
#include <OpenCL/cl.h>
#else
#include <CL/cl.h>
#include "cl.h"
#endif
#ifdef __cplusplus

View File

@@ -41,7 +41,7 @@ extern "C" {
#ifdef __APPLE__
#include <OpenCL/cl_gl.h>
#else
#include <CL/cl_gl.h>
#include "cl_gl.h"
#endif
/*

View File

@@ -53,8 +53,8 @@ Notes:
#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include "cl.h"
#include "cl_platform.h"
#include <va/va.h>
#ifdef __cplusplus

View File

@@ -44,10 +44,10 @@ extern "C" {
#else
#include <CL/cl.h>
#include <CL/cl_gl.h>
#include <CL/cl_gl_ext.h>
#include <CL/cl_ext.h>
#include "cl.h"
#include "cl_gl.h"
#include "cl_gl_ext.h"
#include "cl_ext.h"
#endif

View File

@@ -61,7 +61,7 @@ public:
barriers.run(module);
}
vectorize.run(module);
// triton::ir::print(module, std::cout);
triton::ir::print(module, std::cout);
}
codegen::tune tune;

View File

@@ -145,7 +145,7 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}

View File

@@ -28,6 +28,7 @@
#include "triton/driver/helpers/CL/infos.hpp"
#include "triton/driver/device.h"
#include "triton/driver/context.h"
#include "triton/codegen/target.h"
namespace triton
{
@@ -35,6 +36,14 @@ namespace triton
namespace driver
{
/* ------------------------ */
// Host //
/* ------------------------ */
std::unique_ptr<codegen::target> host_device::make_target() const {
return std::unique_ptr<codegen::cpu_target>(new codegen::cpu_target());
}
/* ------------------------ */
// OpenCL //
@@ -49,6 +58,10 @@ size_t ocl_device::max_threads_per_block() const {
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
}
std::unique_ptr<codegen::target> ocl_device::make_target() const {
return std::unique_ptr<codegen::amd_cl_target>(new codegen::amd_cl_target());
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
@@ -216,6 +229,12 @@ std::string cu_device::infos() const{
return oss.str();
}
// target
std::unique_ptr<codegen::target> cu_device::make_target() const {
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
}
}
}

View File

@@ -53,10 +53,6 @@
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include "llvm/Transforms/Utils/Cloning.h"
#include "lld/Common/Driver.h"
#include "lld/Common/Args.h"
#include "lld/Common/ErrorHandler.h"
#include "lld/Common/LLVM.h"
namespace triton
{
@@ -107,12 +103,9 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
std::vector<std::string> paths) {
const std::string& features,
file_type_t ft) {
init_llvm();
// llvm::legacy::PassManager passes;
// passes.add(llvm::createPrintModulePass(llvm::outs()));
// passes.add(llvm::createVerifierPass());
// passes.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
@@ -122,7 +115,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
@@ -134,7 +127,14 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_ObjectFile);
// convert triton file type to llvm file type
auto ll_file_type = [&](module::file_type_t type){
if(type == Object)
return llvm::TargetMachine::CGFT_ObjectFile;
return llvm::TargetMachine::CGFT_AssemblyFile;
};
// emit
machine->addPassesToEmitFile(pass, stream, nullptr, ll_file_type(ft));
pass.run(*module);
}
@@ -149,7 +149,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
// std::string triple = llvm::sys::getDefaultTargetTriple();
// std::string cpu = llvm::sys::getHostCPUName();
// llvm::SmallVector<char, 0> buffer;
// module::compile_llvm_module(src, triple, cpu, "", buffer);
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
// create kernel wrapper
llvm::LLVMContext &ctx = src->getContext();
@@ -202,7 +202,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
init_llvm();
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer, "code-object-v3", Object);
std::ofstream output("/tmp/tmp.o", std::ios::binary);
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
@@ -243,7 +243,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer);
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer, "", Assembly);
return std::string(buffer.begin(), buffer.end());
}

View File

@@ -92,7 +92,8 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
}
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
jit::jit(driver::context *context): driver_context_(context),
target_(context->device()->make_target()) {
}