From 4c8289336d4a291432f2fa4fcb0f2c113e3bad91 Mon Sep 17 00:00:00 2001 From: Roland Conybeare Date: Sun, 7 Jul 2024 16:57:05 -0400 Subject: [PATCH] xo-jit: + primitive wrapper (accept+ignore envptr as 1st argument) --- include/xo/jit/MachPipeline.hpp | 10 +++ include/xo/jit/type2llvm.hpp | 10 ++- src/jit/MachPipeline.cpp | 130 +++++++++++++++++++++++++++----- src/jit/type2llvm.cpp | 8 +- 4 files changed, 138 insertions(+), 20 deletions(-) diff --git a/include/xo/jit/MachPipeline.hpp b/include/xo/jit/MachPipeline.hpp index 88c9f544..ed81c6c8 100644 --- a/include/xo/jit/MachPipeline.hpp +++ b/include/xo/jit/MachPipeline.hpp @@ -111,6 +111,16 @@ namespace xo { llvm::Type * codegen_type(TypeDescr td); llvm::Value * codegen_constant(ref::brw expr); llvm::Function * codegen_primitive(ref::brw expr); + + /** like @ref codegen_primitive , but create wrapper function that accepts (and discards) + * environment pointer as first argument. + * + * Implementation consists of tail call to natural primitive, that skips the unused + * environment pointer + **/ + llvm::Function * codegen_primitive_wrapper(ref::brw expr, + llvm::IRBuilder<> & ir_builder); + llvm::Value * codegen_apply(ref::brw expr, llvm::IRBuilder<> & ir_builder); /* NOTE: codegen_lambda() needs to be reentrant too. * for example can have a lambda in apply position. diff --git a/include/xo/jit/type2llvm.hpp b/include/xo/jit/type2llvm.hpp index 2d532893..dffe155b 100644 --- a/include/xo/jit/type2llvm.hpp +++ b/include/xo/jit/type2llvm.hpp @@ -29,9 +29,17 @@ namespace xo { /** establish llvm representation for a function type * described by @p fn_td + * + * @param wrapper_flag If true, create function type for a wrapper + * to be associated with a closure. + * The wrapper accepts (and ignores) an envapi pointer as first argument. + * Necessary to (for example) support function pointers that may refer + * to either {primitive functions, functions-requiring-closures}, + * with choice deferred until runtime **/ static llvm::FunctionType * function_td_to_llvm_type(xo::ref::brw llvm_cx, - TypeDescr fn_td); + TypeDescr fn_td, + bool wrapper_flag = false); /** establish llvm concrete representation for a particular lambda's * runtime local environment: diff --git a/src/jit/MachPipeline.cpp b/src/jit/MachPipeline.cpp index 3bd895f0..fc44ae88 100644 --- a/src/jit/MachPipeline.cpp +++ b/src/jit/MachPipeline.cpp @@ -20,6 +20,7 @@ namespace xo { using xo::reflect::Reflect; using xo::reflect::StructMember; using xo::reflect::TypeDescr; + using xo::scope; using llvm::orc::ExecutionSession; using llvm::DataLayout; using std::cerr; @@ -167,7 +168,6 @@ namespace xo { MachPipeline::codegen_primitive(ref::brw expr) { constexpr bool c_debug_flag = true; - using xo::scope; scope log(XO_DEBUG(c_debug_flag)); @@ -249,12 +249,123 @@ namespace xo { return fn; } /*codegen_primitive*/ + llvm::Function * + MachPipeline::codegen_primitive_wrapper(ref::brw expr, + llvm::IRBuilder<> & ir_builder) + { + constexpr bool c_debug_flag = true; + + scope log(XO_DEBUG(c_debug_flag), + xtag("primitive-name", expr->name())); + + constexpr const char * c_prefix = "w."; + + /* unique name for wrapper. Note we don't allow period in schematica identifiers + * (though we could if we replace . with .. when lowering) + */ + std::string wrap_name = std::string(c_prefix) + expr->name(); + + /* original primitive */ + auto * native_lvfn = codegen_primitive(expr); + + /* wrapped primitive */ + auto * wrap_lvfn = llvm_module_->getFunction(wrap_name); + + if (wrap_lvfn) { + /* wrapper already defined */ + return wrap_lvfn; + } + + TypeDescr fn_td = expr->valuetype(); + + llvm::FunctionType * native_lvtype + = type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), fn_td); + + if (!native_lvtype) + return nullptr; + + llvm::FunctionType * wrapper_lvtype + = type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), + fn_td, + true /*wrapper_flag (for closure)*/); + + wrap_lvfn = llvm::Function::Create(wrapper_lvtype, + llvm::Function::ExternalLinkage, + wrap_name, + llvm_module_.get()); + + /* at least we know the name of the 1st argument :) */ + auto ix = wrap_lvfn->args().begin(); + ix->setName(".env"); + + auto block = llvm::BasicBlock::Create(llvm_cx_->llvm_cx_ref(), + "entry", wrap_lvfn); + + ir_builder.SetInsertPoint(block); + + std::vector args; + + /* call to native_lvfn, + * forwarding all args of wrap_lvfn, except the first + */ + { + args.reserve(expr->n_arg()); + + int i_wrap_arg = 0; + for (auto & arg : wrap_lvfn->args()) { + if (i_wrap_arg > 0) + args.push_back(&arg); + + ++i_wrap_arg; + } + } + + /* {caller,callee} must agree on calling convention, + * so for primitives we need to assume c. + */ + llvm::CallInst * call = ir_builder.CreateCall(native_lvtype, + native_lvfn, + args, + "w.calltmp"); + if (call) { + call->setTailCall(true); + + /* does this work if call returns void? Is this needed with tail call? */ + ir_builder.CreateRet(call); + + llvm::verifyFunction(*wrap_lvfn); + + if (log) { + std::string buf; + llvm::raw_string_ostream ss(buf); + wrap_lvfn->print(ss); + + log(xtag("IR-before-opt", buf)); + } + + /* optimize! */ + ir_pipeline_->run_pipeline(*wrap_lvfn); + + if (log) { + std::string buf; + llvm::raw_string_ostream ss(buf); + wrap_lvfn->print(ss); + + log(xtag("IR-after-opt", buf)); + } + } else { + wrap_lvfn->eraseFromParent(); + wrap_lvfn = nullptr; + } + + return wrap_lvfn; + } /*codegen_primitive_wrapper*/ + llvm::Value * MachPipeline::codegen_apply(ref::brw apply, llvm::IRBuilder<> & ir_builder) { constexpr bool c_debug_flag = true; - using xo::scope; scope log(XO_DEBUG(c_debug_flag), xtag("apply", apply)); @@ -418,7 +529,6 @@ namespace xo { TypeDescr var_type) { constexpr bool c_debug_flag = true; - using xo::scope; scope log(XO_DEBUG(c_debug_flag), xtag("llvm_fn", (void*)llvm_fn), @@ -475,7 +585,6 @@ namespace xo { MachPipeline::codegen_lambda_decl(ref::brw lambda) { constexpr bool c_debug_flag = true; - using xo::scope; scope log(XO_DEBUG(c_debug_flag), xtag("lambda-name", lambda->name())); @@ -491,18 +600,6 @@ namespace xo { /* establish prototype for this function */ -#ifdef OBSOLETE - llvm::Type * llvm_retval = td_to_llvm_type(llvm_cx_.borrow(), - lambda->fn_retval()); - - std::vector arg_type_v(lambda->n_arg()); - - for (size_t i = 0, n = lambda->n_arg(); i < n; ++i) { - arg_type_v[i] = td_to_llvm_type(llvm_cx_.borrow(), - lambda->fn_arg(i)); - } -#endif - llvm::FunctionType * llvm_fn_type = type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), lambda->valuetype()); @@ -533,7 +630,6 @@ namespace xo { llvm::IRBuilder<> & ir_builder) { constexpr bool c_debug_flag = true; - using xo::scope; scope log(XO_DEBUG(c_debug_flag), xtag("lambda-name", lambda->name())); diff --git a/src/jit/type2llvm.cpp b/src/jit/type2llvm.cpp index 23ee48c4..fdfda7d5 100644 --- a/src/jit/type2llvm.cpp +++ b/src/jit/type2llvm.cpp @@ -58,12 +58,16 @@ namespace xo { **/ llvm::FunctionType * type2llvm::function_td_to_llvm_type(xo::ref::brw llvm_cx, - TypeDescr fn_td) + TypeDescr fn_td, + bool wrapper_flag) { int n_fn_arg = fn_td->n_fn_arg(); std::vector llvm_argtype_v; - llvm_argtype_v.reserve(n_fn_arg); + llvm_argtype_v.reserve(n_fn_arg + (wrapper_flag ? 1 : 0)); + + if (wrapper_flag) + llvm_argtype_v.push_back(env_api_llvm_ptr_type(llvm_cx)); /** check function args are all known **/ for (int i = 0; i < n_fn_arg; ++i) {