xo-pyexpression: + Primitive + PrimitiveInterface + examples

This commit is contained in:
Roland Conybeare 2024-06-14 16:11:37 -04:00
commit 27a0dd591f

View file

@ -5,7 +5,10 @@
#include "xo/expression/Expression.hpp" #include "xo/expression/Expression.hpp"
#include "xo/expression/ConstantInterface.hpp" #include "xo/expression/ConstantInterface.hpp"
#include "xo/expression/Constant.hpp" #include "xo/expression/Constant.hpp"
#include "xo/expression/PrimitiveInterface.hpp"
#include "xo/expression/Primitive.hpp"
#include "xo/pyutil/pyutil.hpp" #include "xo/pyutil/pyutil.hpp"
#include <cmath>
namespace xo { namespace xo {
namespace ast { namespace ast {
@ -13,6 +16,9 @@ namespace xo {
using xo::ast::Expression; using xo::ast::Expression;
using xo::ast::ConstantInterface; using xo::ast::ConstantInterface;
using xo::ast::Constant; using xo::ast::Constant;
using xo::ast::PrimitiveInterface;
using xo::ast::Primitive;
using xo::ast::make_primitive;
using xo::reflect::TaggedPtr; using xo::reflect::TaggedPtr;
using xo::ref::rp; using xo::ref::rp;
namespace py = pybind11; namespace py = pybind11;
@ -38,6 +44,8 @@ namespace xo {
.def("__repr__", &Expression::display_string); .def("__repr__", &Expression::display_string);
; ;
// ----- Constants -----
py::class_<ConstantInterface, py::class_<ConstantInterface,
Expression, Expression,
rp<ConstantInterface>>(m, "ConstantInterface") rp<ConstantInterface>>(m, "ConstantInterface")
@ -69,6 +77,31 @@ namespace xo {
py::doc("make_constant(x) creates constant expression holding x [wip - only works for double")) py::doc("make_constant(x) creates constant expression holding x [wip - only works for double"))
; ;
// ----- Primitives -----
py::class_<PrimitiveInterface,
Expression,
rp<PrimitiveInterface>>(m, "PrimitiveInterface")
.def("name", &PrimitiveInterface::name,
py::doc("name of this primitive function; use this name to invoke the function"))
.def("n_arg", &PrimitiveInterface::n_arg,
py::doc("number of arguments to this function (not counting return value)"))
;
using Fn_dbl_dbl_type = double (*)(double);
m.def("make_sqrt_pm", []() { return make_primitive<Fn_dbl_dbl_type>("sqrt", sqrt); },
py::doc("create primitive representing the ::sqrt() function"));
m.def("make_sin_pm", []() { return make_primitive<Fn_dbl_dbl_type>("sin", ::sin); },
py::doc("create primitive representing the ::sin() function"));
m.def("make_cos_pm", []() { return make_primitive<Fn_dbl_dbl_type>("cos", ::cos); },
py::doc("create primitive representing the ::cos() function"));
py::class_<Primitive<double (*)(double)>,
PrimitiveInterface,
rp<Primitive<double (*)(double)>>>(m, "Primitive_double_double")
;
} /*pyexpresion*/ } /*pyexpresion*/
} /*namespace ast*/ } /*namespace ast*/
} /*namespace xo*/ } /*namespace xo*/