/* @file MachPipeline.test.cpp */ #include "xo/jit/MachPipeline.hpp" #include "xo/expression/Primitive.hpp" #include "xo/ratio/ratio.hpp" #include "xo/ratio/ratio_reflect.hpp" #include "xo/reflect/reflect_struct.hpp" #include "xo/indentlog/scope.hpp" #include namespace xo { using xo::jit::MachPipeline; using xo::ast::make_apply; using xo::ast::make_var; using xo::ast::make_primitive; using xo::ast::llvmintrinsic; using xo::ast::Expression; using xo::ast::Lambda; using xo::ast::exprtype; using xo::reflect::Reflect; using xo::reflect::reflect_struct; using xo::ref::rp; using xo::ref::brw; using std::cerr; using std::endl; namespace ut { /* abstract syntax tree for a function: * def root4(x :: double) { sqrt(sqrt(x)); } */ rp root4_ast() { auto sqrt = make_primitive("sqrt", ::sqrt, false /*!explicit_symbol_def*/, llvmintrinsic::fp_sqrt); auto x_var = make_var("x", Reflect::require()); auto call1 = make_apply(sqrt, {x_var}); auto call2 = make_apply(sqrt, {call1}); auto fn_ast = make_lambda("root4", {x_var}, call2); return fn_ast; } /* abstract syntax tree for a function: * def twice(f :: double->double, x :: double) { f(f(x)); } */ rp root_2x_ast() { auto root = make_primitive("sqrt", ::sqrt, false /*!explicit_symbol_def*/, llvmintrinsic::fp_sqrt); /* def twice(f :: double->double, x :: double) { f(f(x)) } */ auto f_var = make_var("f", Reflect::require()); auto x_var = make_var("x", Reflect::require()); auto call1 = make_apply(f_var, {x_var}); /* (f x) */ auto call2 = make_apply(f_var, {call1}); /* (f (f x)) */ /* def twice(f :: double->double, x :: double) { f(f(x)); } */ auto twice = make_lambda("twice", {f_var, x_var}, call2); auto x2_var = make_var("x2", Reflect::require()); auto call3 = make_apply(twice, {root, x2_var}); /* def root4(x2 :: double) { * def twice(f :: double->double, x :: double) { f(f(x)); }}; * twice(sqrt, x2) * } */ auto fn_ast = make_lambda("root_2x", {x2_var}, call3); return fn_ast; } struct TestCase { rp (*make_ast_)(); /* each pair is (input, output) for function double->double */ std::vector> call_v_; }; static std::vector s_testcase_v = { {&root4_ast, {std::make_pair(1.0, 1.0), std::make_pair(16.0, 2.0), std::make_pair(81.0, 3.0)}}, {&root_2x_ast, {std::make_pair(1.0, 1.0), std::make_pair(16.0, 2.0), std::make_pair(81.0, 3.0)}} }; /** testcase root_ast tests: * - nested function call * * testcase root_2x_ast relies on: * - lambda in function position * - argument with function type **/ TEST_CASE("machpipeline.fptr", "[llvm][llvm_fnptr]") { constexpr bool c_debug_flag = true; // can get bits from /dev/random by uncommenting the 2nd line below //uint64_t seed = xxx; //rng::Seed seed; //auto rng = xo::rng::xoshiro256ss(seed); scope log(XO_DEBUG2(c_debug_flag, "TEST_CASE.machpipeline")); //log && log("(A)", xtag("foo", foo)); auto jit = MachPipeline::make(); for (std::size_t i_tc = 0, n_tc = s_testcase_v.size(); i_tc < n_tc; ++i_tc) { TestCase const & testcase = s_testcase_v[i_tc]; INFO(tostr(xtag("i_tc", i_tc))); auto ast = (*testcase.make_ast_)(); REQUIRE(ast.get()); log && log(xtag("ast", ast)); REQUIRE(ast->extype() == exprtype::lambda); brw fn_ast = Lambda::from(ast); llvm::Value * llvm_ircode = jit->codegen_toplevel(fn_ast); /* TODO: printer for llvm::Value* */ if (llvm_ircode) { /* note: llvm:errs() is 'raw stderr stream' */ cerr << "llvm_ircode:" << endl; llvm_ircode->print(llvm::errs()); cerr << endl; } else { cerr << "code generation failed" << xtag("fn_ast", fn_ast) << endl; } REQUIRE(llvm_ircode); jit->machgen_current_module(); /** lookup compiled function pointer in jit **/ auto llvm_addr = jit->lookup_symbol(fn_ast->name()); if (!llvm_addr) { cerr << "ex2: lookup: symbol not found" << xtag("symbol", fn_ast->name()) << endl; } else { cerr << "ex2: lookup: symbol found" << xtag("llvm_addr", llvm_addr.get().getValue()) << xtag("symbol", fn_ast->name()) << endl; } auto fn_ptr = llvm_addr.get().toPtr(); REQUIRE(fn_ptr); for (std::size_t j_call = 0, n_call = testcase.call_v_.size(); j_call < n_call; ++j_call) { double input = testcase.call_v_[j_call].first; double expected = testcase.call_v_[j_call].second; INFO(tostr(xtag("j_call", j_call), xtag("input", input), xtag("expected", expected))); auto actual = (*fn_ptr)(input); REQUIRE(actual == expected); } } } /*TEST_CASE(machpipeline.fptr)*/ TEST_CASE("machpipeline.wrap", "[llvm][llvm_closure]") { constexpr bool c_debug_flag = true; scope log(XO_DEBUG2(c_debug_flag, "TEST_CASE.machpipelin.wrap")); auto jit = MachPipeline::make(); auto root = make_primitive("sqrt", ::sqrt, false /*!explicit_symbol_def*/, llvmintrinsic::fp_sqrt); llvm::Value * llvm_ircode = jit->codegen_primitive_wrapper(root, *(jit->llvm_current_ir_builder())); /* TODO: printer for llvm::Value* */ if (llvm_ircode) { /* note: llvm:errs() is 'raw stderr stream' */ cerr << "llvm_ircode for primitive wrapper:" << endl; llvm_ircode->print(llvm::errs()); cerr << endl; } else { cerr << "code generation failed" << xtag("root", root) << endl; } REQUIRE(llvm_ircode); std::string wrapper_name = std::string("w.") + root->name(); jit->machgen_current_module(); auto llvm_addr = jit->lookup_symbol(wrapper_name); bool llvm_addr_flag = static_cast(llvm_addr); if (!llvm_addr_flag) { cerr << "ex2: lookup: symbol not found" << xtag("symbol", wrapper_name) << endl; } else { cerr << "ex2: lookup: symbol found" << xtag("llvm_addr", llvm_addr.get().getValue()) << xtag("symbol", wrapper_name) << endl; } REQUIRE(llvm_addr_flag); auto fn_ptr = llvm_addr.get().toPtr(); REQUIRE(fn_ptr); auto actual = (*fn_ptr)(nullptr, 4.0); REQUIRE(actual == 2.0); } rp make_ratio() { auto make_ratio_impl = make_primitive("make_ratio_impl", xo::ratio::make_ratio, true /*explicit_symbol_def*/, llvmintrinsic::invalid); REQUIRE(make_ratio_impl.get()); REQUIRE(make_ratio_impl->explicit_symbol_def()); /* jit-prepared library: * 1. *uses* make_ratio_impl * 2. *provides* make_ratio (can do jit->lookup_symbol("make_ratio")) */ auto n_var = make_var("n", Reflect::require()); auto d_var = make_var("d", Reflect::require()); auto call1 = make_apply(make_ratio_impl, {n_var, d_var}); /*make_ratio(n,d)*/ auto make_ratio = make_lambda("make_ratio", {n_var, d_var}, call1); return make_ratio; } TEST_CASE("machpipeline.struct", "[llvm][llvm_struct]") { constexpr bool c_debug_flag = true; // can get bits from /dev/random by uncommenting the 2nd line below //uint64_t seed = xxx; //rng::Seed seed; //auto rng = xo::rng::xoshiro256ss(seed); scope log(XO_DEBUG2(c_debug_flag, "TEST_CASE.machpipeline.struct")); //log && log("(A)", xtag("foo", foo)); auto jit = MachPipeline::make(); /* let's reflect xo::ratio::ratio */ using ratio_type = xo::ratio::ratio; auto struct_td = reflect_struct(); REQUIRE(struct_td); // ----- build AST ----- auto fn_ast = make_ratio(); // ----- convert AST -> llvm IR datastructure ----- llvm::Value * llvm_ircode = jit->codegen_toplevel(fn_ast); /* TODO: printer for llvm::Value* */ if (llvm_ircode) { /* note: llvm:errs() is 'raw stderr stream' */ cerr << "llvm_ircode:" << endl; llvm_ircode->print(llvm::errs()); cerr << endl; } else { cerr << "code generation failed" << xtag("fn_ast", fn_ast) << endl; } REQUIRE(llvm_ircode); // ----- inspect alignment ----- llvm::StructType * struct_llvm_type = static_cast(jit->codegen_type(struct_td)); auto struct_layout = jit->data_layout().getStructLayout(struct_llvm_type); log && log(xtag("struct-size", struct_layout->getSizeInBytes()), xtag("struct-alignment", struct_layout->getAlignment().value())); for (int i = 0, n = struct_llvm_type->getNumElements(); i < n; ++i) { llvm::TypeSize llvm_tz = struct_layout->getElementOffset(i); auto offset = reinterpret_cast(struct_td->struct_member(i).get_member_tp(nullptr).address()); log && log(xtag("i", i), xtag("name(c++)", struct_td->struct_member(i).member_name()), xtag("type(c++)", struct_td->struct_member(i).get_member_td()->short_name()), xtag("offset(c++)", offset), xtag("offset(llvm)", llvm_tz.getKnownMinValue())); REQUIRE(offset == llvm_tz.getKnownMinValue()); } // ----- generate JIT machine code ----- jit->machgen_current_module(); log && log("execution session after codegen:"); //log && log(jit->xsession()); // segfaults jit->dump_execution_session(); // ----- verify: lookup symbol /** lookup compiled function pointer in jit **/ auto llvm_addr = jit->lookup_symbol(fn_ast->name()); log && log("execution session after lookup attempt:"); jit->dump_execution_session(); if (!llvm_addr) { cerr << "ex2: lookup: symbol not found" << xtag("symbol", fn_ast->name()) << endl; } else { cerr << "ex2: lookup: symbol found" << xtag("llvm_addr", llvm_addr.get().getValue()) << xtag("symbol", fn_ast->name()) << endl; } auto fn_ptr = llvm_addr.get().toPtr(); REQUIRE(fn_ptr); // ---- invoke compiled function ----- auto value = (*fn_ptr)(2, 3); log && log(xtag("value.num", value.num()), xtag("value.den", value.den())); } /*TEST_CASE(machpipeline.struct)*/ } /*namespace ut*/ } /*namespace xo*/ /* end MachPipeline.test.cpp */