Here we adapt an automatic differentiation library to use Boost.YAP for specifying the equations it operates on.
Autodiff is a pretty small library, and doesn't cover every possible input
expression. What it covers is simple arithmetic, and the well-known functions
sin
, cos
,
sqrt
, and pow
.
Here is how you would form an input to the library using its API. This is taken from the test program that comes with the library.
Node* build_linear_fun1_manually(vector<Node*>& list) { //f(x1,x2,x3) = -5*x1+sin(10)*x1+10*x2-x3/6 PNode* v5 = create_param_node(-5); PNode* v10 = create_param_node(10); PNode* v6 = create_param_node(6); VNode* x1 = create_var_node(); VNode* x2 = create_var_node(); VNode* x3 = create_var_node(); OPNode* op1 = create_binary_op_node(OP_TIMES,v5,x1); //op1 = v5*x1 OPNode* op2 = create_uary_op_node(OP_SIN,v10); //op2 = sin(v10) OPNode* op3 = create_binary_op_node(OP_TIMES,op2,x1); //op3 = op2*x1 OPNode* op4 = create_binary_op_node(OP_PLUS,op1,op3); //op4 = op1 + op3 OPNode* op5 = create_binary_op_node(OP_TIMES,v10,x2); //op5 = v10*x2 OPNode* op6 = create_binary_op_node(OP_PLUS,op4,op5); //op6 = op4+op5 OPNode* op7 = create_binary_op_node(OP_DIVID,x3,v6); //op7 = x3/v6 OPNode* op8 = create_binary_op_node(OP_MINUS,op6,op7); //op8 = op6 - op7 x1->val = -1.9; x2->val = 2; x3->val = 5./6.; list.push_back(x1); list.push_back(x2); list.push_back(x3); return op8; }
I have a lot of trouble understanding what's going on here, and even more verifying that the expression written in the comment is actually what the code produces. Let's see if we can do better.
First, we start with a custom expression template, autodiff_expr
.
It supports simple arithmetic, but notice it has no call operator —
we don't want (a
+ b)()
to be a valid expression.
template <boost::yap::expr_kind Kind, typename Tuple> struct autodiff_expr { static boost::yap::expr_kind const kind = Kind; Tuple elements; }; BOOST_YAP_USER_UNARY_OPERATOR(negate, autodiff_expr, autodiff_expr) BOOST_YAP_USER_BINARY_OPERATOR(plus, autodiff_expr, autodiff_expr) BOOST_YAP_USER_BINARY_OPERATOR(minus, autodiff_expr, autodiff_expr) BOOST_YAP_USER_BINARY_OPERATOR(multiplies, autodiff_expr, autodiff_expr) BOOST_YAP_USER_BINARY_OPERATOR(divides, autodiff_expr, autodiff_expr)
We're going to be using a lot of placeholders in our Autodiff expressions,
and it sure would be nice if they were autodiff_expr
s
and not expression<>s
, so that only our
desired operators are in play. To do this, we define an operator that produces
placeholder literals, using the BOOST_YAP_USER_LITERAL_PLACEHOLDER_OPERATOR
macro:
namespace autodiff_placeholders { // This defines a placeholder literal operator that creates autodiff_expr // placeholders. BOOST_YAP_USER_LITERAL_PLACEHOLDER_OPERATOR(autodiff_expr) }
Now, how about the functions we need to support, and where do we put the call operator? In other examples we created terminal subclasses or templates to get special behavior on terminals. In this case, we want to create a function-terminal template:
template <OPCODE Opcode> struct autodiff_fn_expr : autodiff_expr<boost::yap::expr_kind::terminal, boost::hana::tuple<OPCODE>> { autodiff_fn_expr () : autodiff_expr {boost::hana::tuple<OPCODE>{Opcode}} {} BOOST_YAP_USER_CALL_OPERATOR_N(::autodiff_expr, 1); }; // Someone included <math.h>, so we have to add trailing underscores. autodiff_fn_expr<OP_SIN> const sin_; autodiff_fn_expr<OP_COS> const cos_; autodiff_fn_expr<OP_SQRT> const sqrt_;
OPCODE
is an enumeration
in Autodiff. We use it as a non-type template parameter for convenience
when declaring sin_
and
friends. All we really need is for the OPCODE
to be the value of the terminals we produce, and for these function-terminals
to have the call operator.
Note | |
---|---|
Using |
Now, some tranforms:
struct xform { // Create a var-node for each placeholder when we see it for the first // time. template <long long I> Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::terminal>, boost::yap::placeholder<I>) { if (list_.size() < I) list_.resize(I); auto & retval = list_[I - 1]; if (retval == nullptr) retval = create_var_node(); return retval; } // Create a param-node for every numeric terminal in the expression. Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::terminal>, double x) { return create_param_node(x); } // Create a "uary" node for each call expression, using its OPCODE. template <typename Expr> Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::call>, OPCODE opcode, Expr const & expr) { return create_uary_op_node( opcode, boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr), *this) ); } template <typename Expr> Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::negate>, Expr const & expr) { return create_uary_op_node( OP_NEG, boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr), *this) ); } // Define a mapping from binary arithmetic expr_kind to OPCODE... static OPCODE op_for_kind (boost::yap::expr_kind kind) { switch (kind) { case boost::yap::expr_kind::plus: return OP_PLUS; case boost::yap::expr_kind::minus: return OP_MINUS; case boost::yap::expr_kind::multiplies: return OP_TIMES; case boost::yap::expr_kind::divides: return OP_DIVID; default: assert(!"This should never execute"); return OPCODE{}; } assert(!"This should never execute"); return OPCODE{}; } // ... and use it to handle all the binary arithmetic operators. template <boost::yap::expr_kind Kind, typename Expr1, typename Expr2> Node * operator() (boost::yap::expr_tag<Kind>, Expr1 const & expr1, Expr2 const & expr2) { return create_binary_op_node( op_for_kind(Kind), boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr1), *this), boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr2), *this) ); } vector<Node *> & list_; };
We need a function to tie everything together, since the transforms cannot fill in the values for the placeholders.
template <typename Expr, typename ...T> Node * to_auto_diff_node (Expr const & expr, vector<Node *> & list, T ... args) { Node * retval = nullptr; // This fills in list as a side effect. retval = boost::yap::transform(expr, xform{list}); assert(list.size() == sizeof...(args)); // Fill in the values of the value-nodes in list with the "args" // parameter pack. auto it = list.begin(); boost::hana::for_each( boost::hana::make_tuple(args ...), [&it](auto x) { Node * n = *it; VNode * v = boost::polymorphic_downcast<VNode *>(n); v->val = x; ++it; } ); return retval; }
Finally, here is the Boost.YAP version of the function we started with:
Node* build_linear_fun1(vector<Node*>& list) { //f(x1,x2,x3) = -5*x1+sin(10)*x1+10*x2-x3/6 using namespace autodiff_placeholders; return to_auto_diff_node( -5 * 1_p + sin_(10) * 1_p + 10 * 2_p - 3_p / 6, list, -1.9, 2, 5./6. ); }