diff --git a/example/ex1/ex1.cpp b/example/ex1/ex1.cpp index bad239be..d86b77f4 100644 --- a/example/ex1/ex1.cpp +++ b/example/ex1/ex1.cpp @@ -4,6 +4,8 @@ #include "xo/expression/Constant.hpp" #include "xo/expression/Primitive.hpp" #include "xo/expression/Apply.hpp" +#include "xo/expression/Lambda.hpp" +#include "xo/expression/Variable.hpp" #include int @@ -12,6 +14,8 @@ main() { using xo::ast::make_constant; using xo::ast::make_primitive; using xo::ast::make_apply; + using xo::ast::make_var; + using xo::ast::make_lambda; using xo::xtag; using std::cerr; using std::endl; @@ -75,6 +79,35 @@ main() { << endl; } } + + { + /* (lambda (x) (sin (cos x))) */ + + auto sin = make_primitive("sin", ::sin); + auto cos = make_primitive("cos", ::cos); + + auto x_var = make_var("x"); + auto call1 = make_apply(cos, x_var); /* (cos x) */ + auto call2 = make_apply(sin, call1); /* (sin (cos x)) */ + + /* (define (lm_1 x) (sin (cos x))) */ + auto lambda = make_lambda("lm_1", + {"x"}, + call2); + + auto llvm_ircode = jit->codegen(lambda); + + if (llvm_ircode) { + /* note: llvm:errs() is 'raw stderr stream' */ + cerr << "ex1 llvm_ircode:" << endl; + llvm_ircode->print(llvm::errs()); + cerr << endl; + } else { + cerr << "ex1: code generation failed" + << xtag("expr", lambda) + << endl; + } + } } /** end ex1.cpp **/ diff --git a/include/xo/jit/Jit.hpp b/include/xo/jit/Jit.hpp index 7e32fb26..4c1f1534 100644 --- a/include/xo/jit/Jit.hpp +++ b/include/xo/jit/Jit.hpp @@ -12,6 +12,8 @@ #include "xo/expression/ConstantInterface.hpp" #include "xo/expression/PrimitiveInterface.hpp" #include "xo/expression/Apply.hpp" +#include "xo/expression/Lambda.hpp" +#include "xo/expression/Variable.hpp" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/BasicBlock.h" @@ -52,6 +54,8 @@ namespace xo { llvm::Value * codegen_constant(ref::brw expr); llvm::Function * codegen_primitive(ref::brw expr); llvm::Value * codegen_apply(ref::brw expr); + llvm::Function * codegen_lambda(ref::brw expr); + llvm::Value * codegen_variable(ref::brw var); llvm::Value * codegen(ref::brw expr); @@ -72,6 +76,13 @@ namespace xo { * - function names are unique within a module. **/ std::unique_ptr llvm_module_; + + /** map global names to functions/variables **/ + std::map> global_env_; + /** map variable names (formal parameters) to + * corresponding llvm interactor + **/ + std::map nested_env_; }; } /*namespace jit*/ } /*namespace xo*/ diff --git a/src/jit/Jit.cpp b/src/jit/Jit.cpp index 6bf39631..d9fc0f61 100644 --- a/src/jit/Jit.cpp +++ b/src/jit/Jit.cpp @@ -7,6 +7,8 @@ namespace xo { using xo::ast::Expression; using xo::ast::ConstantInterface; using xo::ast::PrimitiveInterface; + using xo::ast::Lambda; + using xo::ast::Variable; using xo::ast::Apply; using xo::reflect::TypeDescr; using std::cerr; @@ -133,6 +135,88 @@ namespace xo { } } /*codegen_apply*/ + llvm::Function * + Jit::codegen_lambda(ref::brw lambda) + { + /* reminder! this is the *expression*, not the *closure* */ + + global_env_[lambda->name()] = lambda.get(); + + /* do we already know a function with this name? */ + auto * fn = llvm_module_->getFunction(lambda->name()); + + if (fn) { + /** function with this name already defined?? **/ + return nullptr; + } + + /* establish prototype for this function */ + + // PLACEHOLDER + // just make prototype for function :: double -> double + + std::vector double_v(1, llvm::Type::getDoubleTy(*llvm_cx_)); + + auto * llvm_fn_type = llvm::FunctionType::get(llvm::Type::getDoubleTy(*llvm_cx_), + double_v, + false /*!varargs*/); + + /* create (initially empty) function */ + fn = llvm::Function::Create(llvm_fn_type, + llvm::Function::ExternalLinkage, + lambda->name(), + llvm_module_.get()); + /* also capture argument names */ + int i = 0; + for (auto & arg : fn->args()) + arg.setName(lambda->argv().at(i)); + + /* generate function body */ + + auto block = llvm::BasicBlock::Create(*llvm_cx_, "entry", fn); + + llvm_ir_builder_->SetInsertPoint(block); + + /* formal parameters need to appear in named_value_map_ */ + nested_env_.clear(); + for (auto & arg : fn->args()) + nested_env_[std::string(arg.getName())] = &arg; + + llvm::Value * retval = this->codegen(lambda->body()); + + if (retval) { + /* completes the function.. */ + llvm_ir_builder_->CreateRet(retval); + + /* validate! always validate! */ + llvm::verifyFunction(*fn); + + /* optimize! */ + // thefpm->run(*fn, *thefam); + + return fn; + } + + /* oops, something went wrong */ + fn->eraseFromParent(); + + return nullptr; + } /*codegen_lambda*/ + + llvm::Value * + Jit::codegen_variable(ref::brw var) + { + auto ix = nested_env_.find(var->name()); + + if (ix == nested_env_.end()) { + cerr << "Jit::codegen_variable: no binding for variable x" + << xtag("x", var->name()) + << endl; + } + + return ix->second; + } /*codegen_variable*/ + llvm::Value * Jit::codegen(ref::brw expr) { @@ -143,7 +227,10 @@ namespace xo { return this->codegen_primitive(PrimitiveInterface::from(expr)); case exprtype::apply: return this->codegen_apply(Apply::from(expr)); - break; + case exprtype::lambda: + return this->codegen_lambda(Lambda::from(expr)); + case exprtype::variable: + return this->codegen_variable(Variable::from(expr)); case exprtype::invalid: case exprtype::n_expr: return nullptr;