xo-jit: fnptr -> closures for primitives+lambdas throughout

This commit is contained in:
Roland Conybeare 2024-07-10 16:05:00 -04:00
commit 09f5c141df
6 changed files with 295 additions and 124 deletions

View file

@ -192,7 +192,7 @@ namespace xo {
TypeDescr fn_td = expr->valuetype();
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), fn_td);
= type2llvm::function_td_to_lvtype(llvm_cx_.borrow(), fn_td);
if (!llvm_fn_type)
return nullptr;
@ -251,7 +251,7 @@ namespace xo {
llvm::Function *
MachPipeline::codegen_primitive_wrapper(ref::brw<PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder)
llvm::IRBuilder<> & /*ir_builder*/)
{
constexpr bool c_debug_flag = true;
@ -266,7 +266,7 @@ namespace xo {
std::string wrap_name = std::string(c_prefix) + expr->name();
/* original primitive */
auto * native_lvfn = codegen_primitive(expr);
auto * native_lvfn = this->codegen_primitive(expr);
/* wrapped primitive */
auto * wrap_lvfn = llvm_module_->getFunction(wrap_name);
@ -279,13 +279,15 @@ namespace xo {
TypeDescr fn_td = expr->valuetype();
llvm::FunctionType * native_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(), fn_td);
= type2llvm::function_td_to_lvtype(llvm_cx_.borrow(),
fn_td,
false /*!wrapper_flag*/);
if (!native_lvtype)
return nullptr;
llvm::FunctionType * wrapper_lvtype
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
= type2llvm::function_td_to_lvtype(llvm_cx_.borrow(),
fn_td,
true /*wrapper_flag (for closure)*/);
@ -301,7 +303,11 @@ namespace xo {
auto block = llvm::BasicBlock::Create(llvm_cx_->llvm_cx_ref(),
"entry", wrap_lvfn);
ir_builder.SetInsertPoint(block);
/* don't call SetInsertPoint() on incoming ir_builder argument.
* Want to avoid disturbing top-to-bottom flow
*/
llvm::IRBuilder<> tmp_ir_builder(llvm_cx_->llvm_cx_ref());
tmp_ir_builder.SetInsertPoint(block);
std::vector<llvm::Value *> args;
@ -323,15 +329,15 @@ namespace xo {
/* {caller,callee} must agree on calling convention,
* so for primitives we need to assume c.
*/
llvm::CallInst * call = ir_builder.CreateCall(native_lvtype,
native_lvfn,
args,
"w.calltmp");
llvm::CallInst * call = tmp_ir_builder.CreateCall(native_lvtype,
native_lvfn,
args,
"w.calltmp");
if (call) {
call->setTailCall(true);
/* does this work if call returns void? Is this needed with tail call? */
ir_builder.CreateRet(call);
tmp_ir_builder.CreateRet(call);
llvm::verifyFunction(*wrap_lvfn);
@ -365,6 +371,9 @@ namespace xo {
MachPipeline::codegen_primitive_closure(ref::brw<xo::ast::PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
scope log(XO_DEBUG(c_debug_flag));
llvm::StructType * closure_lvtype
= type2llvm::create_closureapi_lvtype(llvm_cx_.borrow(), expr);
@ -402,6 +411,8 @@ namespace xo {
* - MachPipeline::codegen_primitive_closure
* - MachPipeline::codegen_lambda_closure
* - type2llvm::create_closure_lvtype
*
* although this refers to a closure, llvm doesn't know that
*/
llvm::Value * llvm_closure = nullptr;
llvmintrinsic intrinsic = llvmintrinsic::invalid;
@ -413,7 +424,7 @@ namespace xo {
auto pm = PrimitiveInterface::from(apply->fn());
if (pm) {
llvm_closure = this->codegen_primitive(pm);
llvm_closure = this->codegen_primitive_closure(pm, ir_builder);
/* hint, when available. use faster alternative to IRBuilder::CreateCall below */
intrinsic = pm->intrinsic();
}
@ -442,69 +453,93 @@ namespace xo {
/* function type in apply node's function position */
TypeDescr ast_fn_td = apply->fn()->valuetype();
#ifdef NOT_USING_DEBUG
cerr << "MachPipeline::codegen_apply: fn:" << endl;
fn->print(llvm::errs());
cerr << endl;
#endif
if (log) {
log("MachPipeline::codegen_apply: fn in apply pos...");
llvm_closure->print(llvm::errs());
log("...done");
log("llvm type...");
llvm_closure->getType()->dump();
log("...done");
}
/* checks here will be redundant */
#ifdef REDUNDANT_TYPECHECK
if (apply->argv().size() != ast_fn_td->n_fn_arg()) {
cerr << "MachPipeline::codegen_apply: error: callee f expecting n1 args where n2 supplied"
//<< xtag("f", ast_fn->name())
<< xtag("n1", ast_fn_td->n_fn_arg())
<< xtag("n2", apply->argv().size())
<< endl;
return nullptr;
}
/** also check argument types **/
for (size_t i = 0, n = ast_fn_td->n_fn_arg(); i < n; ++i) {
if (apply->argv()[i]->valuetype() != ast_fn_td->fn_arg(i)) {
cerr << "MachPipeline::codegen_apply: error: callee F for arg# I seeeing U instead of expected T"
<< xtag("F", apply->fn())
<< xtag("I", i)
<< xtag("U", apply->argv()[i]->valuetype()->short_name())
<< xtag("T", ast_fn_td->fn_arg(i)->short_name())
<< endl;
return nullptr;
}
}
#endif
#ifdef OBSOLETE
llvm::StructType * closure_lvtype
= type2llvm::function_td_to_closureapi_lvtype(llvm_cx_,
ast_fn_td,
"" /*name - not required*/);
#endif
llvm::Value * lv_fnptr = nullptr;
{
#ifdef MAYBE_VERBOSE
llvm::Value * i0_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 0 /*value*/));
llvm::Value * fnptr_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 0 /*value*/));
std::array<llvm::Value*, 1> index_v = {{fnptr_slot /*fnptr slot = closure[0]*/}};
std::array<llvm::Value*, 2> index_v
= {{i0_slot,
fnptr_slot /*fnptr slot = closure[0]*/}};
lv_fnptr = ir_builder.CreateInBoundsGEP(closure_lvtype,
llvm_closure,
index_v);
llvm::Value * lv_fnptr_addr
= ir_builder.CreateInBoundsGEP(llvm_closure->getType(), //closure_lvtype,
llvm_closure,
index_v);
llvm::Type * fnptr_lvtype
= type2llvm::function_td_to_llvm_fnptr_type(llvm_cx_,
apply->fn()->valuetype(),
true /*wrapper_flag*/);
/* the thing we're going to call */
lv_fnptr = ir_builder.CreateLoad(fnptr_lvtype, lv_fnptr_addr);
#endif
std::array<unsigned int, 1> index_v = {{ 0 }};
//ir_builder.CreateExtractValue(Value *Agg, ArrayRef<unsigned int> Idxs)
lv_fnptr = ir_builder.CreateExtractValue(llvm_closure,
index_v,
"fnptr");
}
llvm::Value * lv_fnenvptr = nullptr;
{
#ifdef MAYBE_VERBOSE
llvm::Value * i0_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 0 /*value*/));
llvm::Value * envptr_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 1 /*value*/));
std::array<llvm::Value*, 1> index_v = {{envptr_slot /*envptr slot = closure[1]*/}};
std::array<llvm::Value*, 2> index_v
= {{i0_slot,
envptr_slot /*envptr slot = closure[1]*/}};
lv_fnenvptr = ir_builder.CreateInBoundsGEP(closure_lvtype,
llvm_closure,
index_v);
llvm::Value * lv_fnenvptr_addr
= ir_builder.CreateInBoundsGEP(llvm_closure->getType(), //closure_lvtype,
llvm_closure,
index_v);
llvm::Type * fnenvptr_lvtype
= type2llvm::env_api_llvm_ptr_type(llvm_cx_);
lv_fnenvptr = ir_builder.CreateLoad(fnenvptr_lvtype, lv_fnenvptr_addr);
#endif
std::array<unsigned int, 1> index_v = {{ 1 }};
lv_fnenvptr = ir_builder.CreateExtractValue(llvm_closure,
index_v,
"envptr");
}
std::vector<llvm::Value *> args;
@ -524,8 +559,13 @@ namespace xo {
if (log) {
/* TODO: print helper for llvm::Value* */
std::string llvm_value_str;
llvm::raw_string_ostream ss(llvm_value_str);
arg->print(ss);
if (arg) {
llvm::raw_string_ostream ss(llvm_value_str);
arg->print(ss);
} else {
llvm_value_str = "<null llvm::Value>";
}
log(xtag("i_arg", i),
xtag("arg", llvm_value_str));
@ -533,6 +573,14 @@ namespace xo {
args.push_back(arg);
++i;
if (!arg) {
cerr << "MachPipeline::codegen_apply: failed for i'th argument"
<< xtag("i", i)
<< endl;
return nullptr;
}
}
/* if we have an intrinsic hint,
@ -571,9 +619,9 @@ namespace xo {
}
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(this->llvm_cx_,
ast_fn_td,
true /*wrapper_flag*/);
= type2llvm::function_td_to_lvtype(this->llvm_cx_,
ast_fn_td,
true /*wrapper_flag*/);
return ir_builder.CreateCall(llvm_fn_type,
lv_fnptr,
@ -623,13 +671,13 @@ namespace xo {
* Note that this argument is not present in lambda,
* so we need care. lambda->fn_arg(i) -> lvfn->arg [i+1]
*/
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
lambda->valuetype(),
true /*wrapper_flag*/);
llvm::FunctionType * fn_lvtype
= type2llvm::function_td_to_lvtype(llvm_cx_.borrow(),
lambda->valuetype(),
true /*wrapper_flag*/);
/* create (initially empty) function */
fn = llvm::Function::Create(llvm_fn_type,
fn = llvm::Function::Create(fn_lvtype,
llvm::Function::ExternalLinkage,
lambda->name(),
llvm_module_.get());
@ -660,7 +708,7 @@ namespace xo {
llvm::Function *
MachPipeline::codegen_lambda_defn(ref::brw<Lambda> lambda,
llvm::IRBuilder<> & ir_builder)
llvm::IRBuilder<> & /*ir_builder*/)
{
constexpr bool c_debug_flag = true;
@ -690,14 +738,19 @@ namespace xo {
auto block = llvm::BasicBlock::Create(llvm_cx_->llvm_cx_ref(), "entry", llvm_fn);
ir_builder.SetInsertPoint(block);
/* since we need to explictly set builder's insert point,
* make a new builder instead of disturbing the top-to-bottom flow of the
* called ir_builder
*/
llvm::IRBuilder<> tmp_ir_builder(llvm_cx_->llvm_cx_ref());
tmp_ir_builder.SetInsertPoint(block);
/** Actual parameters will need their own activation record.
* Track its shape + setup/teardown here.
**/
this->env_stack_.push(activation_record(lambda.get()));
bool ok_flag = this->env_stack_.top().bind_locals(llvm_cx_, llvm_fn, ir_builder);
bool ok_flag = this->env_stack_.top().bind_locals(llvm_cx_, llvm_fn, tmp_ir_builder);
if (!ok_flag) {
this->env_stack_.pop();
@ -706,11 +759,11 @@ namespace xo {
llvm::Value * retval = this->codegen(lambda->body(),
envptr,
ir_builder);
tmp_ir_builder);
if (retval) {
/* completes the function.. */
ir_builder.CreateRet(retval);
tmp_ir_builder.CreateRet(retval);
/* validate! always validate! */
llvm::verifyFunction(*llvm_fn);
@ -742,7 +795,9 @@ namespace xo {
this->env_stack_.pop();
log && log("after pop, env stack size Z", xtag("Z", env_stack_.size()));
log && log("after pop, env stack size Z",
xtag("Z", env_stack_.size()),
xtag("llvm_fn", (void*)llvm_fn));
return llvm_fn;
} /*codegen_lambda_defn*/
@ -755,14 +810,21 @@ namespace xo {
llvm::StructType * closure_lvtype
= type2llvm::create_closureapi_lvtype(llvm_cx_.borrow(), lambda);
llvm::Function * lvfn = codegen_lambda_decl(lambda);
llvm::Function * lvfn = codegen_lambda_defn(lambda, ir_builder);
if (!lvfn) {
cerr << "MachPipeline::codegen_lambda_closure: codegen lambda failed"
<< endl;
return nullptr;
}
llvm::Value * lv_closure = nullptr;
lv_closure = ir_builder.CreateInsertValue(llvm::UndefValue::get(closure_lvtype),
lvfn, {0}, "lmfnptr" /*name*/);
lv_closure = ir_builder.CreateInsertValue(lv_closure,
envptr, {1}, "envptr" /*name*/);
{
lv_closure = ir_builder.CreateInsertValue(llvm::UndefValue::get(closure_lvtype),
lvfn, {0}); //, "lmfnptr" /*name*/);
lv_closure = ir_builder.CreateInsertValue(lv_closure,
envptr, {1}, "closure" /*name*/);
}
return lv_closure;
} /*codegen_lambda_closure*/
@ -836,44 +898,45 @@ namespace xo {
when_false_bb);
/* populate when_true_bb */
ir_builder.SetInsertPoint(when_true_bb);
llvm::IRBuilder<> tmp_ir_builder(llvm_cx_->llvm_cx_ref());
tmp_ir_builder.SetInsertPoint(when_true_bb);
llvm::Value * when_true_ir = this->codegen(expr->when_true(),
envptr,
ir_builder);
tmp_ir_builder);
if (!when_true_ir)
return nullptr;
/* at end of when-true sequence, jump to merge suffix */
ir_builder.CreateBr(merge_bb);
tmp_ir_builder.CreateBr(merge_bb);
/* note: codegen for expr->when_true() may have altered builder's "current block" */
when_true_bb = ir_builder.GetInsertBlock();
when_true_bb = tmp_ir_builder.GetInsertBlock();
/* populate when_false_bb */
parent_fn->insert(parent_fn->end(), when_false_bb);
ir_builder.SetInsertPoint(when_false_bb);
tmp_ir_builder.SetInsertPoint(when_false_bb);
llvm::Value * when_false_ir = this->codegen(expr->when_false(),
envptr,
ir_builder);
tmp_ir_builder);
if (!when_false_ir)
return nullptr;
/* at end of when-false sequence, jump to merge suffix */
ir_builder.CreateBr(merge_bb);
tmp_ir_builder.CreateBr(merge_bb);
/* note: codegen for expr->when_false() may have altered builder's "current block" */
when_false_bb = ir_builder.GetInsertBlock();
when_false_bb = tmp_ir_builder.GetInsertBlock();
/* merged suffix sequence */
parent_fn->insert(parent_fn->end(), merge_bb);
ir_builder.SetInsertPoint(merge_bb);
tmp_ir_builder.SetInsertPoint(merge_bb);
/** TODO: switch to getInt1Ty here **/
llvm::PHINode * phi_node
= ir_builder.CreatePHI(llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref()),
2 /*#of branches being merged (?)*/,
"iftmp");
= tmp_ir_builder.CreatePHI(llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref()),
2 /*#of branches being merged (?)*/,
"iftmp");
phi_node->addIncoming(when_true_ir, when_true_bb);
phi_node->addIncoming(when_false_ir, when_false_bb);