xo-jit: + primitive wrapper (accept+ignore envptr as 1st argument)

This commit is contained in:
Roland Conybeare 2024-07-07 16:57:05 -04:00
commit 4c8289336d
4 changed files with 138 additions and 20 deletions

View file

@ -111,6 +111,16 @@ namespace xo {
llvm::Type * codegen_type(TypeDescr td);
llvm::Value * codegen_constant(ref::brw<xo::ast::ConstantInterface> expr);
llvm::Function * codegen_primitive(ref::brw<xo::ast::PrimitiveInterface> expr);
/** like @ref codegen_primitive , but create wrapper function that accepts (and discards)
* environment pointer as first argument.
*
* Implementation consists of tail call to natural primitive, that skips the unused
* environment pointer
**/
llvm::Function * codegen_primitive_wrapper(ref::brw<xo::ast::PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder);
llvm::Value * codegen_apply(ref::brw<xo::ast::Apply> expr, llvm::IRBuilder<> & ir_builder);
/* NOTE: codegen_lambda() needs to be reentrant too.
* for example can have a lambda in apply position.

View file

@ -29,9 +29,17 @@ namespace xo {
/** establish llvm representation for a function type
* described by @p fn_td
*
* @param wrapper_flag If true, create function type for a wrapper
* to be associated with a closure.
* The wrapper accepts (and ignores) an envapi pointer as first argument.
* Necessary to (for example) support function pointers that may refer
* to either {primitive functions, functions-requiring-closures},
* with choice deferred until runtime
**/
static llvm::FunctionType * function_td_to_llvm_type(xo::ref::brw<LlvmContext> llvm_cx,
TypeDescr fn_td);
TypeDescr fn_td,
bool wrapper_flag = false);
/** establish llvm concrete representation for a particular lambda's
* runtime local environment:

View file

@ -20,6 +20,7 @@ namespace xo {
using xo::reflect::Reflect;
using xo::reflect::StructMember;
using xo::reflect::TypeDescr;
using xo::scope;
using llvm::orc::ExecutionSession;
using llvm::DataLayout;
using std::cerr;
@ -167,7 +168,6 @@ namespace xo {
MachPipeline::codegen_primitive(ref::brw<PrimitiveInterface> expr)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag));
@ -249,12 +249,123 @@ namespace xo {
return fn;
} /*codegen_primitive*/
llvm::Function *
MachPipeline::codegen_primitive_wrapper(ref::brw<PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
scope log(XO_DEBUG(c_debug_flag),
xtag("primitive-name", expr->name()));
constexpr const char * c_prefix = "w.";
/* unique name for wrapper. Note we don't allow period in schematica identifiers
* (though we could if we replace . with .. when lowering)
*/
std::string wrap_name = std::string(c_prefix) + expr->name();
/* original primitive */
auto * native_lvfn = codegen_primitive(expr);
/* wrapped primitive */
auto * wrap_lvfn = llvm_module_->getFunction(wrap_name);
if (wrap_lvfn) {
/* wrapper already defined */
return wrap_lvfn;
}
TypeDescr fn_td = expr->valuetype();
llvm::FunctionType * native_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), fn_td);
if (!native_lvtype)
return nullptr;
llvm::FunctionType * wrapper_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
fn_td,
true /*wrapper_flag (for closure)*/);
wrap_lvfn = llvm::Function::Create(wrapper_lvtype,
llvm::Function::ExternalLinkage,
wrap_name,
llvm_module_.get());
/* at least we know the name of the 1st argument :) */
auto ix = wrap_lvfn->args().begin();
ix->setName(".env");
auto block = llvm::BasicBlock::Create(llvm_cx_->llvm_cx_ref(),
"entry", wrap_lvfn);
ir_builder.SetInsertPoint(block);
std::vector<llvm::Value *> args;
/* call to native_lvfn,
* forwarding all args of wrap_lvfn, except the first
*/
{
args.reserve(expr->n_arg());
int i_wrap_arg = 0;
for (auto & arg : wrap_lvfn->args()) {
if (i_wrap_arg > 0)
args.push_back(&arg);
++i_wrap_arg;
}
}
/* {caller,callee} must agree on calling convention,
* so for primitives we need to assume c.
*/
llvm::CallInst * call = ir_builder.CreateCall(native_lvtype,
native_lvfn,
args,
"w.calltmp");
if (call) {
call->setTailCall(true);
/* does this work if call returns void? Is this needed with tail call? */
ir_builder.CreateRet(call);
llvm::verifyFunction(*wrap_lvfn);
if (log) {
std::string buf;
llvm::raw_string_ostream ss(buf);
wrap_lvfn->print(ss);
log(xtag("IR-before-opt", buf));
}
/* optimize! */
ir_pipeline_->run_pipeline(*wrap_lvfn);
if (log) {
std::string buf;
llvm::raw_string_ostream ss(buf);
wrap_lvfn->print(ss);
log(xtag("IR-after-opt", buf));
}
} else {
wrap_lvfn->eraseFromParent();
wrap_lvfn = nullptr;
}
return wrap_lvfn;
} /*codegen_primitive_wrapper*/
llvm::Value *
MachPipeline::codegen_apply(ref::brw<Apply> apply,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("apply", apply));
@ -418,7 +529,6 @@ namespace xo {
TypeDescr var_type)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("llvm_fn", (void*)llvm_fn),
@ -475,7 +585,6 @@ namespace xo {
MachPipeline::codegen_lambda_decl(ref::brw<Lambda> lambda)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("lambda-name", lambda->name()));
@ -491,18 +600,6 @@ namespace xo {
/* establish prototype for this function */
#ifdef OBSOLETE
llvm::Type * llvm_retval = td_to_llvm_type(llvm_cx_.borrow(),
lambda->fn_retval());
std::vector<llvm::Type *> arg_type_v(lambda->n_arg());
for (size_t i = 0, n = lambda->n_arg(); i < n; ++i) {
arg_type_v[i] = td_to_llvm_type(llvm_cx_.borrow(),
lambda->fn_arg(i));
}
#endif
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
lambda->valuetype());
@ -533,7 +630,6 @@ namespace xo {
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("lambda-name", lambda->name()));

View file

@ -58,12 +58,16 @@ namespace xo {
**/
llvm::FunctionType *
type2llvm::function_td_to_llvm_type(xo::ref::brw<LlvmContext> llvm_cx,
TypeDescr fn_td)
TypeDescr fn_td,
bool wrapper_flag)
{
int n_fn_arg = fn_td->n_fn_arg();
std::vector<llvm::Type *> llvm_argtype_v;
llvm_argtype_v.reserve(n_fn_arg);
llvm_argtype_v.reserve(n_fn_arg + (wrapper_flag ? 1 : 0));
if (wrapper_flag)
llvm_argtype_v.push_back(env_api_llvm_ptr_type(llvm_cx));
/** check function args are all known **/
for (int i = 0; i < n_fn_arg; ++i) {