xo-jit: refactor: + closures [wip: not tested]

This commit is contained in:
Roland Conybeare 2024-07-08 18:31:06 -04:00
commit 659c0c400b
5 changed files with 376 additions and 280 deletions

View file

@ -361,8 +361,29 @@ namespace xo {
return wrap_lvfn;
} /*codegen_primitive_wrapper*/
llvm::Value *
MachPipeline::codegen_primitive_closure(ref::brw<xo::ast::PrimitiveInterface> expr,
llvm::IRBuilder<> & ir_builder)
{
llvm::StructType * closure_lvtype
= type2llvm::create_closureapi_lvtype(llvm_cx_.borrow(), expr);
llvm::Function * pm_wrapper = codegen_primitive_wrapper(expr, ir_builder);
llvm::Value * env_0ptr = llvm::ConstantPointerNull::get(type2llvm::env_api_llvm_ptr_type(llvm_cx_));
llvm::Value * lv_closure = nullptr;
lv_closure = ir_builder.CreateInsertValue(llvm::UndefValue::get(closure_lvtype),
pm_wrapper, {0}, "wrapfnptr" /*name*/);
lv_closure = ir_builder.CreateInsertValue(lv_closure,
env_0ptr, {1}, "nullenvptr" /*name*/);
return lv_closure;
} /*codegen_primitive_closure*/
llvm::Value *
MachPipeline::codegen_apply(ref::brw<Apply> apply,
llvm::Value * envptr,
llvm::IRBuilder<> & ir_builder)
{
constexpr bool c_debug_flag = true;
@ -376,14 +397,14 @@ namespace xo {
using std::cerr;
using std::endl;
/* IR for value in function position.
* Although it will generate a function (or pointer-to-function),
* it need not have inherited type llvm::Function.
/* IR for closure in function position
* see:
* - MachPipeline::codegen_primitive_closure
* - MachPipeline::codegen_lambda_closure
* - type2llvm::create_closure_lvtype
*/
llvm::Value * llvm_fnval = nullptr;
llvm::Value * llvm_closure = nullptr;
llvmintrinsic intrinsic = llvmintrinsic::invalid;
/* function type in apply node's function position */
TypeDescr ast_fn_td = apply->fn()->valuetype();
{
/* special treatement for primitive in apply position:
* allows substituting LLVM intrinsic
@ -392,12 +413,12 @@ namespace xo {
auto pm = PrimitiveInterface::from(apply->fn());
if (pm) {
llvm_fnval = this->codegen_primitive(pm);
llvm_closure = this->codegen_primitive(pm);
/* hint, when available. use faster alternative to IRBuilder::CreateCall below */
intrinsic = pm->intrinsic();
}
} else {
llvm_fnval = this->codegen(apply->fn(), ir_builder);
llvm_closure = this->codegen(apply->fn(), envptr, ir_builder);
/* we don't need any special checking here.
* already know (from xo-level checking) that pointer has the right type.
@ -414,10 +435,13 @@ namespace xo {
}
}
if (!llvm_fnval) {
if (!llvm_closure) {
return nullptr;
}
/* function type in apply node's function position */
TypeDescr ast_fn_td = apply->fn()->valuetype();
#ifdef NOT_USING_DEBUG
cerr << "MachPipeline::codegen_apply: fn:" << endl;
fn->print(llvm::errs());
@ -452,12 +476,50 @@ namespace xo {
}
#endif
llvm::StructType * closure_lvtype
= type2llvm::function_td_to_closureapi_lvtype(llvm_cx_,
ast_fn_td,
"" /*name - not required*/);
llvm::Value * lv_fnptr = nullptr;
{
llvm::Value * fnptr_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 0 /*value*/));
std::array<llvm::Value*, 1> index_v = {{fnptr_slot /*fnptr slot = closure[0]*/}};
lv_fnptr = ir_builder.CreateInBoundsGEP(closure_lvtype,
llvm_closure,
index_v);
}
llvm::Value * lv_fnenvptr = nullptr;
{
llvm::Value * envptr_slot
= llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(),
llvm::APInt(32 /*bits*/, 1 /*value*/));
std::array<llvm::Value*, 1> index_v = {{envptr_slot /*envptr slot = closure[1]*/}};
lv_fnenvptr = ir_builder.CreateInBoundsGEP(closure_lvtype,
llvm_closure,
index_v);
}
std::vector<llvm::Value *> args;
args.reserve(apply->argv().size());
/* +1 for envptr */
args.reserve(1 + apply->argv().size());
/* we must take envptr from closure,
* and we need to do this using some version of getelementptr
*/
args.push_back(lv_fnenvptr);
int i = 0;
for (const auto & arg_expr : apply->argv()) {
auto * arg = this->codegen(arg_expr, ir_builder);
auto * arg = this->codegen(arg_expr, envptr, ir_builder);
if (log) {
/* TODO: print helper for llvm::Value* */
@ -476,26 +538,28 @@ namespace xo {
/* if we have an intrinsic hint,
* then instead of invoking a function,
* we use some native machine instruction instead.
*
* args[0] not used here, that holds envptr from faux closure
*/
switch(intrinsic) {
case llvmintrinsic::i_neg:
return ir_builder.CreateNeg(args[0]);
return ir_builder.CreateNeg(args[1]);
case llvmintrinsic::i_add:
return ir_builder.CreateAdd(args[0], args[1]);
return ir_builder.CreateAdd(args[1], args[2]);
case llvmintrinsic::i_sub:
return ir_builder.CreateSub(args[0], args[1]);
return ir_builder.CreateSub(args[1], args[2]);
case llvmintrinsic::i_mul:
return ir_builder.CreateMul(args[0], args[1]);
return ir_builder.CreateMul(args[1], args[2]);
case llvmintrinsic::i_sdiv:
return ir_builder.CreateSDiv(args[0], args[1]);
return ir_builder.CreateSDiv(args[1], args[2]);
case llvmintrinsic::i_udiv:
return ir_builder.CreateUDiv(args[0], args[1]);
return ir_builder.CreateUDiv(args[1], args[2]);
case llvmintrinsic::fp_add:
return ir_builder.CreateFAdd(args[0], args[1]);
return ir_builder.CreateFAdd(args[1], args[2]);
case llvmintrinsic::fp_mul:
return ir_builder.CreateFMul(args[0], args[1]);
return ir_builder.CreateFMul(args[1], args[2]);
case llvmintrinsic::fp_div:
return ir_builder.CreateFDiv(args[0], args[1]);
return ir_builder.CreateFDiv(args[1], args[2]);
case llvmintrinsic::invalid:
case llvmintrinsic::fp_sqrt:
case llvmintrinsic::fp_pow:
@ -506,65 +570,18 @@ namespace xo {
break;
}
/* At least as of 18.1.5, LLVM needs us to supply function type
* when making a function call. In particular it doesn't remember
* the function type with each function pointer
*/
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(this->llvm_cx_, ast_fn_td);
= type2llvm::function_td_to_llvm_type(this->llvm_cx_,
ast_fn_td,
true /*wrapper_flag*/);
return ir_builder.CreateCall(llvm_fn_type,
llvm_fnval,
lv_fnptr,
args,
"calltmp");
} /*codegen_apply*/
#ifdef OBSOLETE
/* in kaleidoscope7.cpp: CreateEntryBlockAlloca */
llvm::AllocaInst *
MachPipeline::create_entry_block_alloca(llvm::Function * llvm_fn,
const std::string & var_name,
TypeDescr var_type)
{
constexpr bool c_debug_flag = true;
scope log(XO_DEBUG(c_debug_flag),
xtag("llvm_fn", (void*)llvm_fn),
xtag("var_name", var_name),
xtag("var_type", var_type->short_name()));
llvm::IRBuilder<> tmp_ir_builder(&llvm_fn->getEntryBlock(),
llvm_fn->getEntryBlock().begin());
llvm::Type * llvm_var_type = type2llvm::td_to_llvm_type(llvm_cx_.borrow(),
var_type);
log && log(xtag("addr(llvm_var_type)", (void*)llvm_var_type));
if (log) {
std::string llvm_var_type_str;
llvm::raw_string_ostream ss(llvm_var_type_str);
llvm_var_type->print(ss);
log(xtag("llvm_var_type", llvm_var_type_str));
}
if (!llvm_var_type)
return nullptr;
llvm::AllocaInst * retval = tmp_ir_builder.CreateAlloca(llvm_var_type,
nullptr,
var_name);
log && log(xtag("alloca", (void*)retval),
xtag("align", retval->getAlign().value()),
xtag("size", retval->getAllocationSize(jit_->data_layout()).value()));
return retval;
} /*create_entry_block_alloca*/
#endif
std::vector<ref::brw<Lambda>>
MachPipeline::find_lambdas(ref::brw<Expression> expr) const
{
@ -600,24 +617,40 @@ namespace xo {
/* establish prototype for this function */
/* wrapper_flag: llvm function type takes extra first argument,
* supplying environment pointer from surrounding closure.
*
* Note that this argument is not present in lambda,
* so we need care. lambda->fn_arg(i) -> lvfn->arg [i+1]
*/
llvm::FunctionType * llvm_fn_type
= type2llvm::function_td_to_llvm_type(llvm_cx_.borrow(),
lambda->valuetype());
lambda->valuetype(),
true /*wrapper_flag*/);
/* create (initially empty) function */
fn = llvm::Function::Create(llvm_fn_type,
llvm::Function::ExternalLinkage,
lambda->name(),
llvm_module_.get());
/* also capture argument names */
/* also adopt lambda's formal argument names */
{
int i = 0;
for (auto & arg : fn->args()) {
log && log("llvm formal param names",
xtag("i", i),
xtag("param", lambda->argv().at(i)));
if (i == 0) {
log && log("llvm inserted env param",
xtag("i", i));
arg.setName(".env");
} else {
log && log("llvm formal param names",
xtag("i", i),
xtag("param", lambda->argv().at(i-1)));
arg.setName(lambda->argv().at(i-1)->name());
}
arg.setName(lambda->argv().at(i)->name());
++i;
}
}
@ -648,6 +681,10 @@ namespace xo {
return nullptr;
}
/* environment for this lambda's clsoure
* passed as extra 1st argument
*/
llvm::Value * envptr = llvm_fn->args().begin();
/* generate function body */
@ -667,45 +704,9 @@ namespace xo {
return nullptr;
}
#ifdef OBSOLETE
{
log && log("lambda: stack size Z", xtag("Z", env_stack_.size()));
int i = 0;
for (auto & arg : llvm_fn->args()) {
log && log("nested environment",
xtag("i", i),
xtag("param", std::string(arg.getName())));
std::string arg_name = std::string(arg.getName());
/* stack location for arg[i] */
llvm::AllocaInst * alloca
= create_entry_block_alloca(llvm_fn,
arg_name,
lambda->fn_arg(i));
if (!alloca) {
this->env_stack_.pop();
return nullptr;
}
/* store on function entry
* see codegen_variable() for corresponding load
*/
ir_builder.CreateStore(&arg, alloca);
/* remember stack location for reference + assignment
* in lambda body.
*
*/
env_stack_.top().alloc_var(i, arg_name, alloca);
++i;
}
}
#endif
llvm::Value * retval = this->codegen(lambda->body(), ir_builder);
llvm::Value * retval = this->codegen(lambda->body(),
envptr,
ir_builder);
if (retval) {
/* completes the function.. */
@ -746,10 +747,33 @@ namespace xo {
return llvm_fn;
} /*codegen_lambda_defn*/
llvm::Value *
MachPipeline::codegen_lambda_closure(ref::brw<Lambda> lambda,
llvm::Value * envptr,
llvm::IRBuilder<> & ir_builder)
{
llvm::StructType * closure_lvtype
= type2llvm::create_closureapi_lvtype(llvm_cx_.borrow(), lambda);
llvm::Function * lvfn = codegen_lambda_decl(lambda);
llvm::Value * lv_closure = nullptr;
lv_closure = ir_builder.CreateInsertValue(llvm::UndefValue::get(closure_lvtype),
lvfn, {0}, "lmfnptr" /*name*/);
lv_closure = ir_builder.CreateInsertValue(lv_closure,
envptr, {1}, "envptr" /*name*/);
return lv_closure;
} /*codegen_lambda_closure*/
llvm::Value *
MachPipeline::codegen_variable(ref::brw<Variable> var,
llvm::Value * /*envptr*/,
llvm::IRBuilder<> & ir_builder)
{
/* TODO: navigate envptr to handle non-local variables */
if (env_stack_.empty()) {
cerr << "MachPipeline::codegen_variable: expected non-empty environment stack"
<< xtag("x", var->name())
@ -772,9 +796,11 @@ namespace xo {
} /*codegen_variable*/
llvm::Value *
MachPipeline::codegen_ifexpr(ref::brw<IfExpr> expr, llvm::IRBuilder<> & ir_builder)
MachPipeline::codegen_ifexpr(ref::brw<IfExpr> expr,
llvm::Value * envptr,
llvm::IRBuilder<> & ir_builder)
{
llvm::Value * test_ir = this->codegen(expr->test(), ir_builder);
llvm::Value * test_ir = this->codegen(expr->test(), envptr, ir_builder);
/** need test result in a variable **/
llvm::Value * test_with_cmp_ir
@ -813,6 +839,7 @@ namespace xo {
ir_builder.SetInsertPoint(when_true_bb);
llvm::Value * when_true_ir = this->codegen(expr->when_true(),
envptr,
ir_builder);
if (!when_true_ir)
@ -827,7 +854,9 @@ namespace xo {
parent_fn->insert(parent_fn->end(), when_false_bb);
ir_builder.SetInsertPoint(when_false_bb);
llvm::Value * when_false_ir = this->codegen(expr->when_false(), ir_builder);
llvm::Value * when_false_ir = this->codegen(expr->when_false(),
envptr,
ir_builder);
if (!when_false_ir)
return nullptr;
@ -852,21 +881,24 @@ namespace xo {
} /*codegen_ifexpr*/
llvm::Value *
MachPipeline::codegen(ref::brw<Expression> expr, llvm::IRBuilder<> & ir_builder)
MachPipeline::codegen(ref::brw<Expression> expr,
llvm::Value * envptr,
llvm::IRBuilder<> & ir_builder)
{
switch(expr->extype()) {
case exprtype::constant:
return this->codegen_constant(ConstantInterface::from(expr));
case exprtype::primitive:
return this->codegen_primitive(PrimitiveInterface::from(expr));
return this->codegen_primitive_closure(PrimitiveInterface::from(expr), ir_builder);
case exprtype::apply:
return this->codegen_apply(Apply::from(expr), ir_builder);
return this->codegen_apply(Apply::from(expr), envptr, ir_builder);
case exprtype::lambda:
return this->codegen_lambda_decl(Lambda::from(expr));
return this->codegen_lambda_closure(Lambda::from(expr), envptr, ir_builder);
//return this->codegen_lambda_decl(Lambda::from(expr));
case exprtype::variable:
return this->codegen_variable(Variable::from(expr), ir_builder);
return this->codegen_variable(Variable::from(expr), envptr, ir_builder);
case exprtype::ifexpr:
return this->codegen_ifexpr(IfExpr::from(expr), ir_builder);
return this->codegen_ifexpr(IfExpr::from(expr), envptr, ir_builder);
case exprtype::invalid:
case exprtype::n_expr:
return nullptr;
@ -910,6 +942,8 @@ namespace xo {
this->codegen_lambda_decl(lambda);
}
#ifdef OBSOLETE /* don't do this anymore, obscures lexical context */
/* Pass 2 */
for (auto lambda : fn_v) {
this->codegen_lambda_defn(lambda,
@ -931,6 +965,19 @@ namespace xo {
return this->codegen(expr,
*(this->llvm_toplevel_ir_builder_.get()));
}
#endif
/* 1. using nullptr as runtime representation for global environment
* 2. may have to elaborate this later? not clear to me
*/
llvm::Value * env_0ptr
= (llvm::ConstantPointerNull::get
(type2llvm::env_api_llvm_ptr_type(llvm_cx_)));
return this->codegen(expr,
env_0ptr,
*(this->llvm_toplevel_ir_builder_.get()));
} /*codegen_toplevel*/
void