[JIT] re-added nvidia compatibility
This commit is contained in:
@@ -8,7 +8,7 @@ const char* src =
|
|||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64};
|
const tunable int32 TM = {16, 32, 64};
|
||||||
const tunable int32 TN = {16, 32, 64};
|
const tunable int32 TN = {16, 32, 64};
|
||||||
const tunable int32 TK = {8, 16};
|
const tunable int32 TK = {8};
|
||||||
|
|
||||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||||
int32 M, int32 N, int32 K, int32 bound){
|
int32 M, int32 N, int32 K, int32 bound){
|
||||||
@@ -19,18 +19,35 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
|||||||
fp32 C[TM, TN] = 0;
|
fp32 C[TM, TN] = 0;
|
||||||
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];
|
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];
|
||||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
|
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
|
||||||
|
fp32 a[TM, TK] = *pa;
|
||||||
|
fp32 b[TN, TK] = *pb;
|
||||||
for(int32 k = K; k > 0;){
|
for(int32 k = K; k > 0;){
|
||||||
fp32 a[TM, TK] = *pa;
|
|
||||||
fp32 b[TN, TK] = *pb;
|
|
||||||
C = dot(a, b, C);
|
C = dot(a, b, C);
|
||||||
pa = pa + TK*M;
|
pa = pa + TK*M;
|
||||||
pb = pb + TK*K;
|
pb = pb + TK*K;
|
||||||
k = k - TK;
|
k = k - TK;
|
||||||
|
int1 checka[TM, TK] = k > bound;
|
||||||
|
int1 checkb[TN, TK] = k > bound;
|
||||||
|
@checka a = *pa;
|
||||||
|
@checkb b = *pb;
|
||||||
|
if(k > bound)
|
||||||
|
continue;
|
||||||
|
int1 checka0[TM] = rxa < M;
|
||||||
|
int1 checka1[TK] = rka < k;
|
||||||
|
int1 checkb0[TN] = ryb < N;
|
||||||
|
int1 checkb1[TK] = rkb < k;
|
||||||
|
checka = checka0[:, newaxis] && checka1[newaxis, :];
|
||||||
|
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = get_global_range[TM](0);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);
|
int32 ryc[TN] = get_global_range[TN](1);
|
||||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||||
*pc = C;
|
int1 checkc0[TM] = rxc < M;
|
||||||
|
int1 checkc1[TN] = ryc < N;
|
||||||
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
|
@checkc *pc = C;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
@@ -89,7 +106,7 @@ int main() {
|
|||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 256, N = 256, K = 256;
|
int32_t M = 512, N = 512, K = 512;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
@@ -151,13 +168,12 @@ int main() {
|
|||||||
|
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
std::vector<unsigned> params = {
|
std::vector<unsigned> params = {
|
||||||
1, 4, 8,
|
16, 2, 64,
|
||||||
1, 4, 8,
|
32, 2, 64,
|
||||||
1, 1, 4, 4,
|
16, 8, 2, 2,
|
||||||
1, 8,
|
8, 8,
|
||||||
1,
|
4
|
||||||
};
|
};
|
||||||
params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
|
||||||
|
|
||||||
// jit.autotune(src, benchmark);
|
// jit.autotune(src, benchmark);
|
||||||
jit.add_module(src, params);
|
jit.add_module(src, params);
|
||||||
|
@@ -29,6 +29,11 @@
|
|||||||
namespace triton
|
namespace triton
|
||||||
{
|
{
|
||||||
|
|
||||||
|
namespace codegen
|
||||||
|
{
|
||||||
|
class target;
|
||||||
|
}
|
||||||
|
|
||||||
namespace driver
|
namespace driver
|
||||||
{
|
{
|
||||||
|
|
||||||
@@ -40,6 +45,7 @@ public:
|
|||||||
using polymorphic_resource::polymorphic_resource;
|
using polymorphic_resource::polymorphic_resource;
|
||||||
virtual size_t max_threads_per_block() const = 0;
|
virtual size_t max_threads_per_block() const = 0;
|
||||||
virtual size_t max_shared_memory() const = 0;
|
virtual size_t max_shared_memory() const = 0;
|
||||||
|
virtual std::unique_ptr<codegen::target> make_target() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Host device
|
// Host device
|
||||||
@@ -48,6 +54,7 @@ public:
|
|||||||
host_device(): device(host_device_t(), true){ }
|
host_device(): device(host_device_t(), true){ }
|
||||||
size_t max_threads_per_block() const { return 1; }
|
size_t max_threads_per_block() const { return 1; }
|
||||||
size_t max_shared_memory() const { return 0; }
|
size_t max_shared_memory() const { return 0; }
|
||||||
|
std::unique_ptr<codegen::target> make_target() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// OpenCL device
|
// OpenCL device
|
||||||
@@ -56,6 +63,7 @@ public:
|
|||||||
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
||||||
size_t max_threads_per_block() const;
|
size_t max_threads_per_block() const;
|
||||||
size_t max_shared_memory() const;
|
size_t max_shared_memory() const;
|
||||||
|
std::unique_ptr<codegen::target> make_target() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// CUDA device
|
// CUDA device
|
||||||
@@ -87,26 +95,28 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
||||||
//Accessors
|
// Accessors
|
||||||
Architecture architecture() const;
|
Architecture architecture() const;
|
||||||
//Informations
|
// Informations
|
||||||
std::string infos() const;
|
std::string infos() const;
|
||||||
size_t address_bits() const;
|
size_t address_bits() const;
|
||||||
std::vector<size_t> max_block_dim() const;
|
std::vector<size_t> max_block_dim() const;
|
||||||
size_t warp_size() const;
|
size_t warp_size() const;
|
||||||
//Compute Capability
|
// Compute Capability
|
||||||
void interpret_as(std::pair<size_t, size_t> cc);
|
void interpret_as(std::pair<size_t, size_t> cc);
|
||||||
std::pair<size_t, size_t> compute_capability() const;
|
std::pair<size_t, size_t> compute_capability() const;
|
||||||
//Identifier
|
// Identifier
|
||||||
std::string name() const;
|
std::string name() const;
|
||||||
std::string pci_bus_id() const;
|
std::string pci_bus_id() const;
|
||||||
//Clocks
|
// Clocks
|
||||||
size_t current_sm_clock() const;
|
size_t current_sm_clock() const;
|
||||||
size_t current_mem_clock() const;
|
size_t current_mem_clock() const;
|
||||||
size_t max_threads_per_block() const;
|
size_t max_threads_per_block() const;
|
||||||
size_t max_shared_memory() const;
|
size_t max_shared_memory() const;
|
||||||
size_t max_sm_clock() const;
|
size_t max_sm_clock() const;
|
||||||
size_t max_mem_clock() const;
|
size_t max_mem_clock() const;
|
||||||
|
// Target
|
||||||
|
std::unique_ptr<codegen::target> make_target() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
|
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
|
||||||
|
@@ -49,6 +49,11 @@ class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
|
|||||||
protected:
|
protected:
|
||||||
void init_llvm();
|
void init_llvm();
|
||||||
|
|
||||||
|
enum file_type_t{
|
||||||
|
Object,
|
||||||
|
Assembly
|
||||||
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
module(driver::context* ctx, CUmodule mod, bool has_ownership);
|
module(driver::context* ctx, CUmodule mod, bool has_ownership);
|
||||||
module(driver::context* ctx, cl_program mod, bool has_ownership);
|
module(driver::context* ctx, cl_program mod, bool has_ownership);
|
||||||
@@ -57,7 +62,9 @@ public:
|
|||||||
driver::context* context() const;
|
driver::context* context() const;
|
||||||
void compile_llvm_module(llvm::Module* module, const std::string& triple,
|
void compile_llvm_module(llvm::Module* module, const std::string& triple,
|
||||||
const std::string &proc, std::string layout,
|
const std::string &proc, std::string layout,
|
||||||
llvm::SmallVectorImpl<char> &buffer, std::vector<std::string> files = {});
|
llvm::SmallVectorImpl<char> &buffer,
|
||||||
|
const std::string &features,
|
||||||
|
file_type_t file_type);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
driver::context* ctx_;
|
driver::context* ctx_;
|
||||||
|
2
include/triton/external/CL/cl.h
vendored
2
include/triton/external/CL/cl.h
vendored
@@ -32,7 +32,7 @@
|
|||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
#include <OpenCL/cl_platform.h>
|
#include <OpenCL/cl_platform.h>
|
||||||
#else
|
#else
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
4
include/triton/external/CL/cl_d3d10.h
vendored
4
include/triton/external/CL/cl_d3d10.h
vendored
@@ -32,8 +32,8 @@
|
|||||||
#define __OPENCL_CL_D3D10_H
|
#define __OPENCL_CL_D3D10_H
|
||||||
|
|
||||||
#include <d3d10.h>
|
#include <d3d10.h>
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
4
include/triton/external/CL/cl_d3d11.h
vendored
4
include/triton/external/CL/cl_d3d11.h
vendored
@@ -32,8 +32,8 @@
|
|||||||
#define __OPENCL_CL_D3D11_H
|
#define __OPENCL_CL_D3D11_H
|
||||||
|
|
||||||
#include <d3d11.h>
|
#include <d3d11.h>
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@@ -31,8 +31,8 @@
|
|||||||
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H
|
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H
|
||||||
#define __OPENCL_CL_DX9_MEDIA_SHARING_H
|
#define __OPENCL_CL_DX9_MEDIA_SHARING_H
|
||||||
|
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
2
include/triton/external/CL/cl_egl.h
vendored
2
include/triton/external/CL/cl_egl.h
vendored
@@ -32,7 +32,7 @@
|
|||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
|
|
||||||
#else
|
#else
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
2
include/triton/external/CL/cl_ext.h
vendored
2
include/triton/external/CL/cl_ext.h
vendored
@@ -42,7 +42,7 @@ extern "C" {
|
|||||||
#include <OpenCL/cl.h>
|
#include <OpenCL/cl.h>
|
||||||
#include <AvailabilityMacros.h>
|
#include <AvailabilityMacros.h>
|
||||||
#else
|
#else
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* cl_khr_fp64 extension - no extension #define since it has no functions */
|
/* cl_khr_fp64 extension - no extension #define since it has no functions */
|
||||||
|
4
include/triton/external/CL/cl_ext_intel.h
vendored
4
include/triton/external/CL/cl_ext_intel.h
vendored
@@ -56,8 +56,8 @@ Notes:
|
|||||||
#include <OpenCL/cl.h>
|
#include <OpenCL/cl.h>
|
||||||
#include <OpenCL/cl_platform.h>
|
#include <OpenCL/cl_platform.h>
|
||||||
#else
|
#else
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
2
include/triton/external/CL/cl_gl.h
vendored
2
include/triton/external/CL/cl_gl.h
vendored
@@ -32,7 +32,7 @@
|
|||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
#include <OpenCL/cl.h>
|
#include <OpenCL/cl.h>
|
||||||
#else
|
#else
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
2
include/triton/external/CL/cl_gl_ext.h
vendored
2
include/triton/external/CL/cl_gl_ext.h
vendored
@@ -41,7 +41,7 @@ extern "C" {
|
|||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
#include <OpenCL/cl_gl.h>
|
#include <OpenCL/cl_gl.h>
|
||||||
#else
|
#else
|
||||||
#include <CL/cl_gl.h>
|
#include "cl_gl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@@ -53,8 +53,8 @@ Notes:
|
|||||||
#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
|
#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
|
||||||
#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
|
#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
|
||||||
|
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_platform.h>
|
#include "cl_platform.h"
|
||||||
#include <va/va.h>
|
#include <va/va.h>
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
8
include/triton/external/CL/opencl.h
vendored
8
include/triton/external/CL/opencl.h
vendored
@@ -44,10 +44,10 @@ extern "C" {
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#include <CL/cl.h>
|
#include "cl.h"
|
||||||
#include <CL/cl_gl.h>
|
#include "cl_gl.h"
|
||||||
#include <CL/cl_gl_ext.h>
|
#include "cl_gl_ext.h"
|
||||||
#include <CL/cl_ext.h>
|
#include "cl_ext.h"
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@@ -61,7 +61,7 @@ public:
|
|||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
}
|
}
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
// triton::ir::print(module, std::cout);
|
triton::ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
codegen::tune tune;
|
codegen::tune tune;
|
||||||
|
@@ -145,7 +145,7 @@ void tune::run(ir::module &mod) {
|
|||||||
// Layout parameters
|
// Layout parameters
|
||||||
while(!nodes_.empty()){
|
while(!nodes_.empty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||||
}
|
}
|
||||||
|
@@ -28,6 +28,7 @@
|
|||||||
#include "triton/driver/helpers/CL/infos.hpp"
|
#include "triton/driver/helpers/CL/infos.hpp"
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
#include "triton/driver/context.h"
|
#include "triton/driver/context.h"
|
||||||
|
#include "triton/codegen/target.h"
|
||||||
|
|
||||||
namespace triton
|
namespace triton
|
||||||
{
|
{
|
||||||
@@ -35,6 +36,14 @@ namespace triton
|
|||||||
namespace driver
|
namespace driver
|
||||||
{
|
{
|
||||||
|
|
||||||
|
/* ------------------------ */
|
||||||
|
// Host //
|
||||||
|
/* ------------------------ */
|
||||||
|
|
||||||
|
std::unique_ptr<codegen::target> host_device::make_target() const {
|
||||||
|
return std::unique_ptr<codegen::cpu_target>(new codegen::cpu_target());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// OpenCL //
|
// OpenCL //
|
||||||
@@ -49,6 +58,10 @@ size_t ocl_device::max_threads_per_block() const {
|
|||||||
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
|
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<codegen::target> ocl_device::make_target() const {
|
||||||
|
return std::unique_ptr<codegen::amd_cl_target>(new codegen::amd_cl_target());
|
||||||
|
}
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
@@ -216,6 +229,12 @@ std::string cu_device::infos() const{
|
|||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// target
|
||||||
|
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
||||||
|
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -53,10 +53,6 @@
|
|||||||
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
|
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
|
||||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
#include "lld/Common/Driver.h"
|
|
||||||
#include "lld/Common/Args.h"
|
|
||||||
#include "lld/Common/ErrorHandler.h"
|
|
||||||
#include "lld/Common/LLVM.h"
|
|
||||||
|
|
||||||
namespace triton
|
namespace triton
|
||||||
{
|
{
|
||||||
@@ -107,12 +103,9 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
|
|||||||
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
|
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
|
||||||
const std::string &proc, std::string layout,
|
const std::string &proc, std::string layout,
|
||||||
llvm::SmallVectorImpl<char> &buffer,
|
llvm::SmallVectorImpl<char> &buffer,
|
||||||
std::vector<std::string> paths) {
|
const std::string& features,
|
||||||
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// llvm::legacy::PassManager passes;
|
|
||||||
// passes.add(llvm::createPrintModulePass(llvm::outs()));
|
|
||||||
// passes.add(llvm::createVerifierPass());
|
|
||||||
// passes.run(*module);
|
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
@@ -122,7 +115,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
|||||||
opt.UnsafeFPMath = false;
|
opt.UnsafeFPMath = false;
|
||||||
opt.NoInfsFPMath = false;
|
opt.NoInfsFPMath = false;
|
||||||
opt.NoNaNsFPMath = true;
|
opt.NoNaNsFPMath = true;
|
||||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
|
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||||
// set data layout
|
// set data layout
|
||||||
if(layout.empty())
|
if(layout.empty())
|
||||||
@@ -134,7 +127,14 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
|||||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||||
llvm::legacy::PassManager pass;
|
llvm::legacy::PassManager pass;
|
||||||
llvm::raw_svector_ostream stream(buffer);
|
llvm::raw_svector_ostream stream(buffer);
|
||||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_ObjectFile);
|
// convert triton file type to llvm file type
|
||||||
|
auto ll_file_type = [&](module::file_type_t type){
|
||||||
|
if(type == Object)
|
||||||
|
return llvm::TargetMachine::CGFT_ObjectFile;
|
||||||
|
return llvm::TargetMachine::CGFT_AssemblyFile;
|
||||||
|
};
|
||||||
|
// emit
|
||||||
|
machine->addPassesToEmitFile(pass, stream, nullptr, ll_file_type(ft));
|
||||||
pass.run(*module);
|
pass.run(*module);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
|||||||
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
||||||
// std::string cpu = llvm::sys::getHostCPUName();
|
// std::string cpu = llvm::sys::getHostCPUName();
|
||||||
// llvm::SmallVector<char, 0> buffer;
|
// llvm::SmallVector<char, 0> buffer;
|
||||||
// module::compile_llvm_module(src, triple, cpu, "", buffer);
|
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
|
||||||
|
|
||||||
// create kernel wrapper
|
// create kernel wrapper
|
||||||
llvm::LLVMContext &ctx = src->getContext();
|
llvm::LLVMContext &ctx = src->getContext();
|
||||||
@@ -202,7 +202,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
|||||||
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
|
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
|
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer, "code-object-v3", Object);
|
||||||
std::ofstream output("/tmp/tmp.o", std::ios::binary);
|
std::ofstream output("/tmp/tmp.o", std::ios::binary);
|
||||||
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
||||||
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
|
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
|
||||||
@@ -243,7 +243,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
|||||||
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer);
|
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer, "", Assembly);
|
||||||
return std::string(buffer.begin(), buffer.end());
|
return std::string(buffer.begin(), buffer.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -92,7 +92,8 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
|
jit::jit(driver::context *context): driver_context_(context),
|
||||||
|
target_(context->device()->make_target()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user