[FRONTEND] Expose end-to-end compile to python frontend (#58)

This commit is contained in:
Yan Chunwei
2022-08-18 01:42:48 +08:00
committed by GitHub
parent 95bbac41e7
commit b1673caaf6
15 changed files with 228 additions and 165 deletions

View File

@@ -1 +1,2 @@
add_subdirectory(LLVMIR)
add_subdirectory(LLVMIR)
add_subdirectory(PTX)

View File

@@ -4,6 +4,8 @@
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
@@ -77,7 +79,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
}
std::unique_ptr<llvm::Module>
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
registerLLVMDialectTranslation(registry);
@@ -114,5 +116,26 @@ TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return llvmModule;
}
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
pm.addPass(createConvertTritonGPUToLLVMPass());
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
return llvmir;
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,9 @@
add_mlir_translation_library(TritonPTX
PTXTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
TritonLLVMIR
)

View File

@@ -0,0 +1,41 @@
#include "triton/Target/PTX/PTXTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/llvm.h"
namespace triton {
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
*cc = major * 10 + minor;
*ptxasPath = driver::path_to_ptxas(*version); // assign version
}
std::tuple<std::string, size_t, int, std::string>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
int cc;
int version;
std::string ptxasPath;
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
llvm::LLVMContext ctx;
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
return std::make_tuple(ptxCode, cc, version, ptxasPath);
}
} // namespace triton