diff --git a/include/xo/jit/MachPipeline.hpp b/include/xo/jit/MachPipeline.hpp index 84b01a9d..3f217db3 100644 --- a/include/xo/jit/MachPipeline.hpp +++ b/include/xo/jit/MachPipeline.hpp @@ -96,7 +96,7 @@ namespace xo { void dump_current_module(); /** lookup symbol in jit-associated output library **/ - llvm::orc::ExecutorAddr lookup_symbol(const std::string & x); + llvm::Expected lookup_symbol(const std::string & x); virtual void display(std::ostream & os) const; virtual std::string display_string() const; diff --git a/src/jit/MachPipeline.cpp b/src/jit/MachPipeline.cpp index 3192e531..e1851474 100644 --- a/src/jit/MachPipeline.cpp +++ b/src/jit/MachPipeline.cpp @@ -6,11 +6,13 @@ namespace xo { using xo::ast::exprtype; using xo::ast::Expression; using xo::ast::ConstantInterface; + using xo::ast::FunctionInterface; using xo::ast::PrimitiveInterface; using xo::ast::Lambda; using xo::ast::Variable; using xo::ast::Apply; using xo::ast::IfExpr; + using xo::reflect::Reflect; using xo::reflect::TypeDescr; using std::cerr; using std::endl; @@ -120,17 +122,51 @@ namespace xo { { TypeDescr td = expr->value_td(); - if (td->is_native()) { + if (Reflect::is_native(td)) { return llvm::ConstantFP::get(llvm_cx_->llvm_cx_ref(), llvm::APFloat(*(expr->value_tp().recover_native()))); - } else if (td->is_native()) { + } else if (Reflect::is_native(td)) { return llvm::ConstantFP::get(llvm_cx_->llvm_cx_ref(), llvm::APFloat(*(expr->value_tp().recover_native()))); + } else if (Reflect::is_native(td)) { + return llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(), + llvm::APSInt(*(expr->value_tp().recover_native()))); + } else if (Reflect::is_native(td)) { + return llvm::ConstantInt::get(llvm_cx_->llvm_cx_ref(), + llvm::APSInt(*(expr->value_tp().recover_native()))); } return nullptr; } + namespace { + llvm::Type * + td_to_llvm_type(xo::ref::brw llvm_cx, TypeDescr td) { + auto & llvm_cx_ref = llvm_cx->llvm_cx_ref(); + + if (Reflect::is_native(td)) { + return llvm::Type::getInt1Ty(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getInt8Ty(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getInt16Ty(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getInt32Ty(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getInt64Ty(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getFloatTy(llvm_cx_ref); + } else if (Reflect::is_native(td)) { + return llvm::Type::getDoubleTy(llvm_cx_ref); + } else { + cerr << "td_to_llvm_type: no llvm type available for T" + << xtag("T", td->short_name()) + << endl; + return nullptr; + } + } + } + llvm::Function * MachPipeline::codegen_primitive(ref::brw expr) { @@ -158,7 +194,7 @@ namespace xo { // PLACEHOLDER // just make prototype for function :: double^n -> double - TypeDescr fn_td = expr->value_td(); + TypeDescr fn_td = expr->valuetype(); int n_fn_arg = fn_td->n_fn_arg(); scope log(XO_DEBUG(c_debug_flag), @@ -174,18 +210,12 @@ namespace xo { log && log(xtag("i_arg", i), xtag("arg_td", arg_td->short_name())); - if (arg_td->is_native()) { - llvm_argtype_v.push_back(llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref())); + llvm::Type * llvm_argtype = td_to_llvm_type(llvm_cx_.borrow(), arg_td); - // TODO: extend with other native types here... - } else { - cerr << "MachPipeline::codegen_primitive: error: primitive f with arg i of type T where double expected" - << xtag("f", expr->name()) - << xtag("i", i) - << xtag("T", arg_td->short_name()) - << endl; + if (!llvm_argtype) return nullptr; - } + + llvm_argtype_v.push_back(llvm_argtype); } //std::vector double_v(n_fn_arg, llvm::Type::getDoubleTy(*llvm_cx_)); @@ -194,17 +224,10 @@ namespace xo { log && log(xtag("retval_td", retval_td->short_name())); - llvm::Type * llvm_retval = nullptr; + llvm::Type * llvm_retval = td_to_llvm_type(llvm_cx_.borrow(), retval_td); - if (retval_td->is_native()) { - llvm_retval = llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref()); - } else { - cerr << "MachPipeline::codegen_primitive: error: primitive f returning T where double expected" - << xtag("f", expr->name()) - << xtag("T", retval_td->short_name()) - << endl; + if (!llvm_retval) return nullptr; - } auto * llvm_fn_type = llvm::FunctionType::get(llvm_retval, llvm_argtype_v, @@ -247,9 +270,30 @@ namespace xo { * * For now, finesse by only handling PrimitiveInterface in function-callee position */ - if (apply->fn()->extype() == exprtype::primitive) { - auto pm = PrimitiveInterface::from(apply->fn()); - auto * fn = this->codegen_primitive(pm); + if (apply->fn()->extype() == exprtype::primitive + || apply->fn()->extype() == exprtype::lambda) + { + llvm::Function * llvm_fn = nullptr; + FunctionInterface * fn = nullptr; + { + // TODO: codgen_function() + + auto pm = PrimitiveInterface::from(apply->fn()); + if (pm) { + fn = pm.get(); + llvm_fn = this->codegen_primitive(pm); + } + + auto lm = Lambda::from(apply->fn()); + if (lm) { + fn = lm.get(); + llvm_fn = this->codegen_lambda(lm); + } + } + + if (!llvm_fn) { + return nullptr; + } #ifdef NOT_USING_DEBUG cerr << "MachPipeline::codegen_apply: fn:" << endl; @@ -257,15 +301,30 @@ namespace xo { cerr << endl; #endif - if (fn->arg_size() != apply->argv().size()) { + if (llvm_fn->arg_size() != apply->argv().size()) { cerr << "MachPipeline::codegen_apply: error: callee f expecting n1 args where n2 supplied" - << xtag("f", pm->name()) - << xtag("n1", pm->n_arg()) + << xtag("f", fn->name()) + << xtag("n1", fn->n_arg()) << xtag("n2", apply->argv().size()) << endl; + return nullptr; } + /** also check argument types **/ + for (size_t i = 0, n = fn->n_arg(); i < n; ++i) { + if (apply->argv()[i]->valuetype() != fn->fn_arg(i)) { + cerr << "MachPipeline::codegen_apply: error: callee f for arg i seeeing U instead of expected T" + << xtag("f", fn->name()) + << xtag("i", i) + << xtag("U", apply->argv()[i]->valuetype()->short_name()) + << xtag("T", fn->fn_arg(i)->short_name()) + << endl; + + return nullptr; + } + } + std::vector args; args.reserve(apply->argv().size()); @@ -281,7 +340,7 @@ namespace xo { args.push_back(arg); } - return llvm_ir_builder_->CreateCall(fn, args, "calltmp"); + return llvm_ir_builder_->CreateCall(llvm_fn, args, "calltmp"); } else { cerr << "MachPipeline::codegen_apply: error: only allowing call to known primitives at present" << endl; return nullptr; @@ -318,11 +377,16 @@ namespace xo { // PLACEHOLDER // just handle double arguments + return type for now - std::vector double_v(lambda->n_arg(), - llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref())); + llvm::Type * llvm_retval = td_to_llvm_type(llvm_cx_.borrow(), lambda->fn_retval()); - auto * llvm_fn_type = llvm::FunctionType::get(llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref()), - double_v, + std::vector arg_type_v(lambda->n_arg()); + + for (size_t i = 0, n = lambda->n_arg(); i < n; ++i) { + arg_type_v[i] = td_to_llvm_type(llvm_cx_.borrow(), lambda->fn_arg(i)); + } + + auto * llvm_fn_type = llvm::FunctionType::get(llvm_retval, + arg_type_v, false /*!varargs*/); /* create (initially empty) function */ @@ -336,7 +400,7 @@ namespace xo { for (auto & arg : fn->args()) { log && log("llvm format param names", xtag("i", i), xtag("param", lambda->argv().at(i))); - arg.setName(lambda->argv().at(i)); + arg.setName(lambda->argv().at(i)->name()); ++i; } } @@ -464,6 +528,7 @@ namespace xo { parent_fn->insert(parent_fn->end(), merge_bb); llvm_ir_builder_->SetInsertPoint(merge_bb); + /** TODO: switch to getInt1Ty here **/ llvm::PHINode * phi_node = llvm_ir_builder_->CreatePHI(llvm::Type::getDoubleTy(llvm_cx_->llvm_cx_ref()), 2 /*#of branches being merged (?)*/, @@ -533,18 +598,21 @@ namespace xo { this->recreate_llvm_ir_pipeline(); } - llvm::orc::ExecutorAddr + llvm::Expected MachPipeline::lookup_symbol(const std::string & sym) { static llvm::ExitOnError llvm_exit_on_err; /* llvm_sym: ExecutorSymbolDef */ - auto llvm_sym = llvm_exit_on_err(this->jit_->lookup(sym)); + auto llvm_sym_expected = this->jit_->lookup(sym); - /* llvm_addr: ExecutorAddr */ - auto llvm_addr = llvm_sym.getAddress(); + if (llvm_sym_expected) { + auto llvm_addr = llvm_sym_expected.get().getAddress(); - return llvm_addr; + return llvm_addr; + } else { + return llvm_sym_expected.takeError(); + } } /*lookup_symbol*/ void