From 2f593d15d55cd00dc6c6ecee44374ecc38fa59d9 Mon Sep 17 00:00:00 2001 From: Roland Conybeare Date: Mon, 24 Jun 2024 14:03:19 -0400 Subject: [PATCH] xo-jit: + 2 examples --- example/CMakeLists.txt | 2 + example/ex2_jit/CMakeLists.txt | 12 +++ example/ex2_jit/ex2_jit.cpp | 170 ++++++++++++++++++++++++++++++++ example/ex3_fptr/CMakeLists.txt | 12 +++ example/ex3_fptr/ex3_fptr.cpp | 45 +++++++++ 5 files changed, 241 insertions(+) create mode 100644 example/ex2_jit/CMakeLists.txt create mode 100644 example/ex2_jit/ex2_jit.cpp create mode 100644 example/ex3_fptr/CMakeLists.txt create mode 100644 example/ex3_fptr/ex3_fptr.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 0cb5ee51..168ffa7c 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,2 +1,4 @@ add_subdirectory(ex1) +add_subdirectory(ex2_jit) +add_subdirectory(ex3_fptr) add_subdirectory(ex_kaleidoscope4) diff --git a/example/ex2_jit/CMakeLists.txt b/example/ex2_jit/CMakeLists.txt new file mode 100644 index 00000000..1a6d1e88 --- /dev/null +++ b/example/ex2_jit/CMakeLists.txt @@ -0,0 +1,12 @@ +# xo-jit/example/ex2_jit/CMakeLists.txt + +set(SELF_EXE xo_jit_ex2) +set(SELF_SRCS ex2_jit.cpp) + +if (XO_ENABLE_EXAMPLES) + xo_add_executable(${SELF_EXE} ${SELF_SRCS}) + xo_self_dependency(${SELF_EXE} xo_jit) + #xo_dependency(${SELF_EXE} xo_expression) +endif() + +# end CMakeLists.txt diff --git a/example/ex2_jit/ex2_jit.cpp b/example/ex2_jit/ex2_jit.cpp new file mode 100644 index 00000000..3131f823 --- /dev/null +++ b/example/ex2_jit/ex2_jit.cpp @@ -0,0 +1,170 @@ +/** @file ex2_jit.cpp **/ + +#include "xo/jit/MachPipeline.hpp" +#include "xo/jit/activation_record.hpp" +#include "xo/expression/Constant.hpp" +#include "xo/expression/Primitive.hpp" +#include "xo/expression/Apply.hpp" +#include "xo/expression/Lambda.hpp" +#include "xo/expression/Variable.hpp" +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/Reassociate.h" +#include "llvm/Transforms/Scalar/SimplifyCFG.h" + +//double foo(double x) { return x; } + +int +main() { + using xo::scope; + using xo::jit::MachPipeline; + using xo::jit::activation_record; + using xo::ast::make_constant; + using xo::ast::make_primitive; + using xo::ast::llvmintrinsic; + using xo::ast::make_apply; + using xo::ast::make_var; + using xo::ast::make_lambda; + using xo::reflect::Reflect; + using xo::xtag; + using std::cerr; + using std::endl; + + //using xo::ast::make_constant; + + static llvm::ExitOnError llvm_exit_on_err; + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + //auto jit = llvm_exit_on_err(Jit::make_aux()); + auto jit = MachPipeline::make(); + + //static_assert(std::is_function_v); + + scope log(XO_DEBUG(true)); + + /* try spelling everything out */ + + { + auto sqrt = make_primitive("sqrt", + ::sqrt, + false /*!explicit_symbol_def*/, + llvmintrinsic::fp_sqrt); + + { + auto llvm_ircode = jit->codegen_toplevel(sqrt); + + if (llvm_ircode) { + /* note: llvm:errs() is 'raw stderr stream' */ + cerr << "ex1 llvm_ircode:" << endl; + llvm_ircode->print(llvm::errs()); + cerr << endl; + } else { + cerr << "ex1: code generation failed" + << xtag("expr", sqrt) + << endl; + } + } + +#define CHOICE 2 + +#if CHOICE == 0 +#define FUNCTION_SYMBOL "callit" + /* def callit(f :: double->double, x :: double) { f(x); } */ + + auto f_var = make_var("f", Reflect::require()); + auto x_var = make_var("x", Reflect::require()); + auto call1 = make_apply(f_var, {x_var}); /* (f x) */ + //auto call2 = make_apply(f_var, {call1}); /* (f (f x)) */ + + auto lambda = make_lambda("callit", + {f_var, x_var}, + call1); +#elif CHOICE == 1 +#define FUNCTION_SYMBOL "root4" + /* def root4(x : double) { sqrt(sqrt(x)) } */ + + auto x_var = make_var("x", Reflect::require()); + auto call1 = make_apply(sqrt, {x_var}); + auto call2 = make_apply(sqrt, {call1}); + + auto lambda = make_lambda("root4", + {x_var}, + call2); +#elif CHOICE == 2 +#define FUNCTION_SYMBOL "twice" + auto root = make_primitive("sqrt", + ::sqrt, + false /*!explicit_symbol_def*/, + llvmintrinsic::fp_sqrt); + + /* def twice(f :: int->int, x :: int) { f(f(x)) } */ + auto f_var = make_var("f", Reflect::require()); + auto x_var = make_var("x", Reflect::require()); + auto call1 = make_apply(f_var, {x_var}); /* (f x) */ + auto call2 = make_apply(f_var, {call1}); /* (f (f x)) */ + + /* (define (twice f ::int->int x ::int) (f (f x))) */ + auto lambda = make_lambda("twice", + {f_var, x_var}, + call2); +#endif + + log && log(xtag("lambda", lambda)); + + auto llvm_ircode = jit->codegen_toplevel(lambda); + + if (llvm_ircode) { + /* note: llvm:errs() is 'raw stderr stream' */ + cerr << "ex1 llvm_ircode:" << endl; + llvm_ircode->print(llvm::errs()); + cerr << endl; + } else { + cerr << "ex1: code generation failed" + << xtag("expr", lambda) + << endl; + } + + jit->machgen_current_module(); + + /* is this in module? */ + cerr << "ex2: jit execution session" << endl; + jit->dump_execution_session(); + + auto fn = jit->lookup_symbol(FUNCTION_SYMBOL); + + if (!fn) { + cerr << "ex2: lookup: symbol not found" + << xtag("symbol", FUNCTION_SYMBOL) + << endl; + } else { + cerr << "ex2: lookup: symbol found" + << xtag("fn", fn.get().getValue()) + << xtag("symbol", FUNCTION_SYMBOL) + << endl; + } + } +} + +/** end ex2_jit.cpp **/ diff --git a/example/ex3_fptr/CMakeLists.txt b/example/ex3_fptr/CMakeLists.txt new file mode 100644 index 00000000..b66e7f86 --- /dev/null +++ b/example/ex3_fptr/CMakeLists.txt @@ -0,0 +1,12 @@ +# xo-jit/example/ex3_fptr/CMakeLists.txt + +set(SELF_EXE xo_fptr_ex3) +set(SELF_SRCS ex3_fptr.cpp) + +if (XO_ENABLE_EXAMPLES) + xo_add_executable(${SELF_EXE} ${SELF_SRCS}) + xo_self_dependency(${SELF_EXE} xo_jit) + #xo_dependency(${SELF_EXE} xo_expression) +endif() + +# end CMakeLists.txt diff --git a/example/ex3_fptr/ex3_fptr.cpp b/example/ex3_fptr/ex3_fptr.cpp new file mode 100644 index 00000000..ef6de58c --- /dev/null +++ b/example/ex3_fptr/ex3_fptr.cpp @@ -0,0 +1,45 @@ +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/IRBuilder.h" + +#include "llvm/Support/raw_ostream.h" + +int main() { + llvm::LLVMContext context; + llvm::IRBuilder<> builder(context); + llvm::Module *module = new llvm::Module("top", context); + + // Create main function and basic block + llvm::FunctionType *functionType = llvm::FunctionType::get(builder.getInt32Ty(), false); + llvm::Function *mainFunction = llvm::Function::Create(functionType, llvm::Function::ExternalLinkage, "main", module); + llvm::BasicBlock *entry = llvm::BasicBlock::Create(context, "entrypoint", mainFunction); + builder.SetInsertPoint(entry); + + // Create a global string pointer + llvm::Value *helloWorld = builder.CreateGlobalStringPtr("hello world\n"); + + // Create function pointer for puts + std::vector putArgs; + putArgs.push_back(builder.getInt8Ty()->getPointerTo()); + llvm::ArrayRef argsRef(putArgs); + llvm::FunctionType *putsType = llvm::FunctionType::get(builder.getInt32Ty(), argsRef, false); + /* = FunctionType + Callee-pointer */ + llvm::FunctionCallee putFunction_callee = module->getOrInsertFunction("puts", putsType); + +#ifdef NOT_YET + llvm::Constant * putFunction = llvm::Constant + + // Allocate memory for the function pointer + llvm::Value *p = builder.CreateAlloca(putFunction->getType(), nullptr, "p"); + builder.CreateStore(putFunction, p, false); + + // Load the function pointer and call it + llvm::Value *temp = builder.CreateLoad(p); + builder.CreateCall(temp, helloWorld); + + // Return 0 to complete the main function + builder.CreateRet(llvm::ConstantInt::get(builder.getInt32Ty(), 0)); + + // Print the module (IR code) + module->print(llvm::errs(), nullptr); +#endif +}