xo-jit: + primitive wrapper (accept+ignore envptr as 1st argument)

This commit is contained in:
Roland Conybeare 2024-07-07 16:57:05 -04:00
commit 4c8289336d
4 changed files with 138 additions and 20 deletions

View file

@ -20,6 +20,7 @@ namespace xo {
using xo::reflect::Reflect;
using xo::reflect::StructMember;
using xo::reflect::TypeDescr;
using xo::scope;
using llvm::orc::ExecutionSession;
using llvm::DataLayout;
using std::cerr;
@ -167,7 +168,6 @@ namespace xo {
MachPipeline::codegen_primitive(ref::brw<PrimitiveInterface> expr)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag));
@ -249,12 +249,123 @@ namespace xo {
return fn;
} /*codegen_primitive*/
llvm::Function *
MachPipeline::codegen_primitive_wrapper(ref::brw<PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
scope log(XO_DEBUG(c_debug_flag),
xtag("primitive-name", expr->name()));
constexpr const char * c_prefix = "w.";
/* unique name for wrapper. Note we don't allow period in schematica identifiers
* (though we could if we replace . with .. when lowering)
*/
std::string wrap_name = std::string(c_prefix) + expr->name();
/* original primitive */
auto * native_lvfn = codegen_primitive(expr);
/* wrapped primitive */
auto * wrap_lvfn = llvm_module_->getFunction(wrap_name);
if (wrap_lvfn) {
/* wrapper already defined */
return wrap_lvfn;
}
TypeDescr fn_td = expr->valuetype();
llvm::FunctionType * native_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), fn_td);
if (!native_lvtype)
return nullptr;
llvm::FunctionType * wrapper_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
fn_td,
true /*wrapper_flag (for closure)*/);
wrap_lvfn = llvm::Function::Create(wrapper_lvtype,
llvm::Function::ExternalLinkage,
wrap_name,
llvm_module_.get());
/* at least we know the name of the 1st argument :) */
auto ix = wrap_lvfn->args().begin();
ix->setName(".env");
auto block = llvm::BasicBlock::Create(llvm_cx_->llvm_cx_ref(),
"entry", wrap_lvfn);
ir_builder.SetInsertPoint(block);
std::vector<llvm::Value *> args;
/* call to native_lvfn,
* forwarding all args of wrap_lvfn, except the first
*/
{
args.reserve(expr->n_arg());
int i_wrap_arg = 0;
for (auto & arg : wrap_lvfn->args()) {
if (i_wrap_arg > 0)
args.push_back(&arg);
++i_wrap_arg;
}
}
/* {caller,callee} must agree on calling convention,
* so for primitives we need to assume c.
*/
llvm::CallInst * call = ir_builder.CreateCall(native_lvtype,
native_lvfn,
args,
"w.calltmp");
if (call) {
call->setTailCall(true);
/* does this work if call returns void? Is this needed with tail call? */
ir_builder.CreateRet(call);
llvm::verifyFunction(*wrap_lvfn);
if (log) {
std::string buf;
llvm::raw_string_ostream ss(buf);
wrap_lvfn->print(ss);
log(xtag("IR-before-opt", buf));
}
/* optimize! */
ir_pipeline_->run_pipeline(*wrap_lvfn);
if (log) {
std::string buf;
llvm::raw_string_ostream ss(buf);
wrap_lvfn->print(ss);
log(xtag("IR-after-opt", buf));
}
} else {
wrap_lvfn->eraseFromParent();
wrap_lvfn = nullptr;
}
return wrap_lvfn;
} /*codegen_primitive_wrapper*/
llvm::Value *
MachPipeline::codegen_apply(ref::brw<Apply> apply,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("apply", apply));
@ -418,7 +529,6 @@ namespace xo {
TypeDescr var_type)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("llvm_fn", (void*)llvm_fn),
@ -475,7 +585,6 @@ namespace xo {
MachPipeline::codegen_lambda_decl(ref::brw<Lambda> lambda)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("lambda-name", lambda->name()));
@ -491,18 +600,6 @@ namespace xo {
/* establish prototype for this function */
#ifdef OBSOLETE
llvm::Type * llvm_retval = td_to_llvm_type(llvm_cx_.borrow(),
lambda->fn_retval());
std::vector<llvm::Type *> arg_type_v(lambda->n_arg());
for (size_t i = 0, n = lambda->n_arg(); i < n; ++i) {
arg_type_v[i] = td_to_llvm_type(llvm_cx_.borrow(),
lambda->fn_arg(i));
}
#endif
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
lambda->valuetype());
@ -533,7 +630,6 @@ namespace xo {
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
using xo::scope;
scope log(XO_DEBUG(c_debug_flag),
xtag("lambda-name", lambda->name()));