xo-pyjit: + pycaller_store, streamlining c++ function signatures

This commit is contained in:
Roland Conybeare 2024-06-17 19:35:02 -04:00
commit c6da0fb58b
2 changed files with 99 additions and 10 deletions

View file

@ -39,14 +39,85 @@ namespace xo {
using xo::ref::unowned_ptr;
namespace py = pybind11;
/** storage for pycaller glue functions for different function signatures.
* each pycaller instance embodies captures a canonical (architecture-dependent)
* calling sequence for a C/C++ function with that signature.
**/
struct pycaller_store {
public:
/** singleton instance **/
static pycaller_store * instance() { return &s_instance; }
/** establish caller for signature @p prototype_str.
* This needs to be called at most once for each distinct signature.
*
* Although it takes module as argument, the module being used
* doesn't (shoudn't ??) matter
**/
template <typename Retval, typename... Args>
pycaller_base::factory_function_type
require_prototype(py::module & m,
const std::string & prototype_str)
{
using caller_type = pycaller<Retval, Args...>;
caller_type::declare_once(m);
/* factory function takes function pointer of type
* Retval(*)(Args...)
* and returns new instance of caller_type for that function
*/
auto ix = pycaller_map_.find(prototype_str);
auto retval = &caller_type::make;
if(ix == pycaller_map_.end())
pycaller_map_[prototype_str] = retval;
return retval;
}
/** lookup caller for signature @p prototype_str **/
pycaller_base::factory_function_type
lookup_prototype(const std::string & prototype_str) const
{
auto ix = pycaller_map_.find(prototype_str);
if (ix == pycaller_map_.end())
return nullptr;
else
return ix->second;
}
private:
static pycaller_store s_instance;
/** map prototype string to pycaller factory for that prototype.
* For example
* "double(double)" -> pycaller<double,double>()
**/
std::unordered_map<std::string,
pycaller_base::factory_function_type> pycaller_map_;
}; /*pycaller_store*/
pycaller_store
pycaller_store::s_instance;
PYBIND11_MODULE(XO_PYJIT_MODULE_NAME(), m) {
// e.g. for xo::ast::Expression
XO_PYEXPRESSION_IMPORT_MODULE(); // py::module_::import("pyexpression");
m.doc() = "pybind11 plugin for xo-jit";
pycaller<double, double>::declare_once(m);
pycaller<double, double, double>::declare_once(m);
pycaller_store::instance()
->require_prototype<double, double>(m, "double(double)");
pycaller_store::instance()
->require_prototype<double, double, double>(m, "double(double,double)");
//pycaller<double, double>::declare_once(m);
//pycaller<double, double, double>::declare_once(m);
py::class_<MachPipeline, rp<MachPipeline>>(m, "MachPipeline")
.def_static("make", &MachPipeline::make,
@ -104,25 +175,43 @@ namespace xo {
[](MachPipeline & jit, const std::string & prototype, const std::string & symbol) -> pycaller_base* {
auto llvm_addr = jit.lookup_symbol(symbol);
/* llvm doesn't know the actual function signature,
* so any function type will appear to succeed here.
* We cast to particular function type within the pycaller<..> template
*/
auto fn_addr = llvm_addr.toPtr<void(*)()>();
/* note: llvm_addr.toPtr<..> always succeeds,
* event if pointer refers to an object of incompatible type
*
* note: return value policy is for python to own the wrapper
*
* note: pycaller signatures need to have been introduced in advance
* (in practice determined at compile time,
* since they encode a function-signature-specific calling sequence)
* by calling pycaller_store::instance()->require_prototype<Retval, Args...>(prototype);
*/
auto factory = pycaller_store::instance()->lookup_prototype(prototype);
if (!factory) {
throw std::runtime_error(tostr("MachPipeline.lookup_fn: unknown function prototype",
xtag("p", prototype)));
}
return (*factory)(fn_addr);
#ifdef OBSOLETE
if((prototype == "double(double,double)") || (prototype == "double(*)(double,double)")) {
auto fn_addr = llvm_addr.toPtr<double(*)(double,double)>();
return new pycaller<double, double, double>(fn_addr);
//return new XferDblDbl2DblFn(fn_addr);
} else if ((prototype == "double(double)") || (prototype == "double(*)(double)")) {
auto fn_addr = llvm_addr.toPtr<double(*)(double)>();
return new pycaller<double, double>(fn_addr);
} else {
throw std::runtime_error(tostr("MachPipeline.lookup_fn: unknown function prototype",
xtag("p", prototype)));
}})
}
#endif
})
;