diff --git a/src/pyjit/pyjit.cpp b/src/pyjit/pyjit.cpp index 624f8882..2c76f219 100644 --- a/src/pyjit/pyjit.cpp +++ b/src/pyjit/pyjit.cpp @@ -7,15 +7,25 @@ #include namespace xo { - struct XferFn : public ref::Refcount { + struct XferDbl2DblFn : public ref::Refcount { using fptr_type = double (*) (double); - explicit XferFn(fptr_type fptr) : fptr_{fptr} {} + explicit XferDbl2DblFn(fptr_type fptr) : fptr_{fptr} {} double operator() (double x) { return (*fptr_)(x); } fptr_type fptr_; - }; + }; /*XferDbl2DblFn*/ + + struct XferDblDbl2DblFn : public ref::Refcount { + using fptr_type = double (*) (double, double); + + explicit XferDblDbl2DblFn(fptr_type fptr) : fptr_{fptr} {} + + double operator() (double x, double y) { return (*fptr_)(x, y); } + + fptr_type fptr_; + }; /*XferDblDbl2DblFn*/ namespace jit { using xo::ast::Expression; @@ -56,19 +66,36 @@ namespace xo { .def("machgen_current_module", &MachPipeline::machgen_current_module, py::doc("Make current module available for execution via the jit.\n" "Adds all functions generated since last call to this method.")) - .def("lookup_dbl_dbl_fn", + /* double -> double */ + .def("lookup_dbl2dbl_fn", [](MachPipeline & jit, const std::string & symbol) { auto llvm_addr = jit.lookup_symbol(symbol); auto fn_addr = llvm_addr.toPtr(); - return new XferFn(fn_addr); + return new XferDbl2DblFn(fn_addr); + }) + + /* (double x double) -> double */ + .def("lookup_dbldbl2dbl_fn", + [](MachPipeline & jit, const std::string & symbol) { + auto llvm_addr = jit.lookup_symbol(symbol); + + auto fn_addr = llvm_addr.toPtr(); + + return new XferDblDbl2DblFn(fn_addr); }) ; - py::class_>(m, "XferFn") + + py::class_>(m, "XferDbl2DblFn") .def("__call__", - [](XferFn & self, double x) { return self(x); } + [](XferDbl2DblFn & self, double x) { return self(x); } + ) + ; + py::class_>(m, "XferDblDbl2DblFn") + .def("__call__", + [](XferDblDbl2DblFn & self, double x, double y) { return self(x, y); } ) ;