diff --git a/CMakeLists.txt b/CMakeLists.txt index 094529c0..3e1f3fb4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ xo_export_cmake_config(${PROJECT_NAME} ${PROJECT_VERSION} ${PROJECT_NAME}Targets # ---------------------------------------------------------------- add_subdirectory(example) -#add_subdirectory(utest) +add_subdirectory(utest) # ---------------------------------------------------------------- # docs targets depend on all the other library/utest targets diff --git a/utest/CMakeLists.txt b/utest/CMakeLists.txt new file mode 100644 index 00000000..dbb64642 --- /dev/null +++ b/utest/CMakeLists.txt @@ -0,0 +1,15 @@ +# xo-jit/utest/CMakeLists.txt + +set(SELF_EXE utest.jit) +set(SELF_SRCS + jit_utest_main.cpp + MachPipeline.test.cpp +) + +if (ENABLE_TESTING) + xo_add_utest_executable(${SELF_EXE} ${SELF_SRCS}) + xo_self_dependency(${SELF_EXE} xo_jit) + xo_external_target_dependency(${SELF_EXE} Catch2 Catch2::Catch2) +endif() + +# end CMakeLists.txt diff --git a/utest/MachPipeline.test.cpp b/utest/MachPipeline.test.cpp new file mode 100644 index 00000000..072cde2a --- /dev/null +++ b/utest/MachPipeline.test.cpp @@ -0,0 +1,188 @@ +/* @file MachPipeline.test.cpp */ + +#include "xo/jit/MachPipeline.hpp" +#include "xo/expression/Primitive.hpp" +#include "xo/indentlog/scope.hpp" +#include + +namespace xo { + using xo::jit::MachPipeline; + using xo::ast::make_apply; + using xo::ast::make_var; + using xo::ast::make_primitive; + using xo::ast::llvmintrinsic; + using xo::ast::Expression; + using xo::ast::Lambda; + using xo::ast::exprtype; + using xo::reflect::Reflect; + using xo::ref::rp; + using xo::ref::brw; + using std::cerr; + using std::endl; + + namespace ut { + + /* abstract syntax tree for a function: + * def root4(x :: double) { sqrt(sqrt(x)); } + */ + rp + root4_ast() { + auto sqrt = make_primitive("sqrt", + ::sqrt, + false /*!explicit_symbol_def*/, + llvmintrinsic::fp_sqrt); + auto x_var = make_var("x", Reflect::require()); + auto call1 = make_apply(sqrt, {x_var}); + auto call2 = make_apply(sqrt, {call1}); + + auto fn_ast = make_lambda("root4", + {x_var}, + call2); + + return fn_ast; + } + + /* abstract syntax tree for a function: + * def twice(f :: double->double, x :: double) { f(f(x)); } + */ + rp + root_2x_ast() { + auto root = make_primitive("sqrt", + ::sqrt, + false /*!explicit_symbol_def*/, + llvmintrinsic::fp_sqrt); + + /* def twice(f :: double->double, x :: double) { f(f(x)) } */ + auto f_var = make_var("f", Reflect::require()); + auto x_var = make_var("x", Reflect::require()); + auto call1 = make_apply(f_var, {x_var}); /* (f x) */ + auto call2 = make_apply(f_var, {call1}); /* (f (f x)) */ + + /* def twice(f :: double->double, x :: double) { f(f(x)); } */ + auto twice = make_lambda("twice", + {f_var, x_var}, + call2); + + auto x2_var = make_var("x2", Reflect::require()); + auto call3 = make_apply(twice, {root, x2_var}); + + /* def root4(x2 :: double) { + * def twice(f :: double->double, x :: double) { f(f(x)); }}; + * twice(sqrt, x2) + * } + */ + auto fn_ast = make_lambda("root_2x", + {x2_var}, + call3); + + return fn_ast; + } + + struct TestCase { + rp (*make_ast_)(); + /* each pair is (input, output) for function double->double */ + std::vector> call_v_; + }; + + static std::vector + s_testcase_v = { + {&root4_ast, + {std::make_pair(1.0, 1.0), + std::make_pair(16.0, 2.0), + std::make_pair(81.0, 3.0)}}, + {&root_2x_ast, + {std::make_pair(1.0, 1.0), + std::make_pair(16.0, 2.0), + std::make_pair(81.0, 3.0)}} + }; + + /** testcase root_ast tests: + * - nested function call + * + * testcase root_2x_ast relies on: + * - lambda in function position + * - argument with function type + **/ + TEST_CASE("machpipeline.fptr", "[llvm][llvm_fnptr]") { + constexpr bool c_debug_flag = true; + + // can get bits from /dev/random by uncommenting the 2nd line below + //uint64_t seed = xxx; + //rng::Seed seed; + + //auto rng = xo::rng::xoshiro256ss(seed); + + scope log(XO_DEBUG2(c_debug_flag, "TEST_CASE.machpipeline")); + //log && log("(A)", xtag("foo", foo)); + + auto jit = MachPipeline::make(); + + for (std::size_t i_tc = 0, n_tc = s_testcase_v.size(); i_tc < n_tc; ++i_tc) { + TestCase const & testcase = s_testcase_v[i_tc]; + + INFO(tostr(xtag("i_tc", i_tc))); + + auto ast = (*testcase.make_ast_)(); + + REQUIRE(ast.get()); + + log && log(xtag("ast", ast)); + + REQUIRE(ast->extype() == exprtype::lambda); + + brw fn_ast = Lambda::from(ast); + + llvm::Value * llvm_ircode = jit->codegen_toplevel(fn_ast); + + /* TODO: printer for llvm::Value* */ + if (llvm_ircode) { + /* note: llvm:errs() is 'raw stderr stream' */ + cerr << "llvm_ircode:" << endl; + llvm_ircode->print(llvm::errs()); + cerr << endl; + } else { + cerr << "code generation failed" + << xtag("fn_ast", fn_ast) + << endl; + } + + REQUIRE(llvm_ircode); + + jit->machgen_current_module(); + + /** lookup compiled function pointer in jit **/ + auto llvm_addr = jit->lookup_symbol(fn_ast->name()); + + if (!llvm_addr) { + cerr << "ex2: lookup: symbol not found" + << xtag("symbol", fn_ast->name()) + << endl; + } else { + cerr << "ex2: lookup: symbol found" + << xtag("llvm_addr", llvm_addr.get().getValue()) + << xtag("symbol", fn_ast->name()) + << endl; + } + + auto fn_ptr = llvm_addr.get().toPtr(); + + REQUIRE(fn_ptr); + + for (std::size_t j_call = 0, n_call = testcase.call_v_.size(); j_call < n_call; ++j_call) { + double input = testcase.call_v_[j_call].first; + double expected = testcase.call_v_[j_call].second; + + INFO(tostr(xtag("j_call", j_call), xtag("input", input), xtag("expected", expected))); + + auto actual = (*fn_ptr)(input); + + REQUIRE(actual == expected); + } + } + + } /*TEST_CASE(machpipeline)*/ + } /*namespace ut*/ +} /*namespace xo*/ + + +/* end MachPipeline.test.cpp */ diff --git a/utest/jit_utest_main.cpp b/utest/jit_utest_main.cpp new file mode 100644 index 00000000..1bea47b1 --- /dev/null +++ b/utest/jit_utest_main.cpp @@ -0,0 +1,6 @@ +/* @file jit_utest_main.cpp */ + +#define CATCH_CONFIG_MAIN +#include + +/* end jit_utest_main.cpp */