CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
Example 10 - External function helper

Goal: Add external functions to the tape via a helper structure.

Prerequisite: Tutorial 2 - Reverse mode AD

Function: Linear system solve

template<typename Number>
void solve2(const Number* A, const Number* b, Number* x) {
// A = a[0] a[1] A^-1 = 1/det * a[3] -a[1]
// a[2] a[3] -a[2] a[0]
Number det = A[0] * A[3] - A[1] * A[2];
x[0] = (A[3] * b[0] - A[1] * b[1]) / det;
x[1] = (-A[2] * b[0] + A[0] * b[1]) / det;
}

Full code:

#include <codi.hpp>
#include <iostream>
using Tape = typename Real::Tape;
using BaseReal = typename Real::Real;
template<typename Number>
void solve2(const Number* A, const Number* b, Number* x) {
// A = a[0] a[1] A^-1 = 1/det * a[3] -a[1]
// a[2] a[3] -a[2] a[0]
Number det = A[0] * A[3] - A[1] * A[2];
x[0] = (A[3] * b[0] - A[1] * b[1]) / det;
x[1] = (-A[2] * b[0] + A[0] * b[1]) / det;
}
void solve2_primal(const BaseReal* x, size_t m, BaseReal* y, size_t n, codi::ExternalFunctionUserData* d) {
solve2(&x[0], &x[4], y);
}
void solve2_rev(const BaseReal* x, BaseReal* x_b, size_t m, const BaseReal* y, const BaseReal* y_b, size_t n, codi::ExternalFunctionUserData* d) {
BaseReal ATrans[4] = {x[0], x[2], x[1], x[3]};
BaseReal s[2];
solve2(ATrans, y_b, s);
// Adjoint of A (\bar A = -s*x^T) (In local terms x[0-3] = -s*y^T)
x_b[0] = -s[0] * y[0];
x_b[1] = -s[0] * y[1];
x_b[2] = -s[1] * y[0];
x_b[3] = -s[1] * y[1];
// Adjoint of b (\bar b = s) (In local terms x[4-5] = s)
x_b[4] = s[0];
x_b[5] = s[1];
}
void runExample(int mode) {
Real u = 3.0;
Tape& tape = Real::getTape();
tape.setActive();
tape.registerInput(u);
Real A[4] = { u * 1.0, 0.5, 0.25, u * -1.0};
Real b[2] = {u * 10.0, u * 20.0};
Real x[2];
if(1 == mode) { // No special handling
std::cout << "Running regular differentiation without external functions." << std::endl;
solve2(A, b, x);
} else if(2 == mode) { // External function with primal function implementation
std::cout << "Running differentiation with external function, primal is called via a special function implementation." << std::endl;
codi::ExternalFunctionHelper<codi::RealReverse> eh; // Step 1: Create the helper
for(int i = 0; i < 4; ++i) {
eh.addInput(A[i]); // Step 2: Add inputs of the function
}
for(int i = 0; i < 2; ++i) {
eh.addInput(b[i]); // Step 2: Add inputs of the function
}
for(int i = 0; i < 2; ++i) {
eh.addOutput(x[i]); // Step 3: Add outputs of the function
}
eh.callPrimalFunc(solve2_primal); // Step 4: Call the primal with a special implementation.
eh.addToTape(solve2_rev); // Step 5: Added specialized reverse function to the tape.
} else if(3 == mode) { // External function with passive primal call
std::cout << "Running differentiation with external function, primal is called via a passive AD evaluation." << std::endl;
codi::ExternalFunctionHelper<codi::RealReverse> eh(true); // Step 1: Create the helper
for(int i = 0; i < 4; ++i) {
eh.addInput(A[i]); // Step 2: Add inputs of the function
}
for(int i = 0; i < 2; ++i) {
eh.addInput(b[i]); // Step 2: Add inputs of the function
}
for(int i = 0; i < 2; ++i) {
eh.addOutput(x[i]); // Step 3: Add outputs of the function
}
eh.callPrimalFuncWithADType(solve2<Real>, A, b, x); // Step 4: Call the primal with a regular function call that is not recorded
eh.addToTape(solve2_rev); // Step 5: Added specialized reverse function to the tape.
} else {
std::cerr << "Error: Unknown mode '" << mode << "'." << std::endl;
}
Real w = sqrt(x[0] * x[0] + x[1] * x[1]);
tape.registerOutput(w);
tape.setPassive();
w.setGradient(1.0);
tape.evaluate();
std::cout << "Solution w: " << w << std::endl;
std::cout << "Adjoint u: " << u.getGradient() << std::endl;
tape.reset();
}
int main(int nargs, char** args) {
runExample(1);
runExample(2);
runExample(3);
return 0;
}
RealReverseGen< double > RealReverse
Definition: codi.hpp:120
Represents a concrete lvalue in the CoDiPack expression tree.
Definition: activeType.hpp:52
static Tape & getTape()
Get a reference to the tape which manages this expression.
Definition: activeType.hpp:99
T_Tape Tape
See ActiveType.
Definition: activeType.hpp:55
typename Tape::Real Real
See LhsExpressionInterface.
Definition: activeTypeBase.hpp:76
Helper class for the implementation of an external function in CoDiPack.
Definition: externalFunctionHelper.hpp:102
void addInput(Type const &input)
Add an input value.
Definition: externalFunctionHelper.hpp:451
void callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition: externalFunctionHelper.hpp:513
void addOutput(Type &output)
Add an output value.
Definition: externalFunctionHelper.hpp:492
void callPrimalFunc(PrimalFunc func)
Definition: externalFunctionHelper.hpp:539
void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition: externalFunctionHelper.hpp:577
Ease of access structure for user-provided data on the tape for external functions....
Definition: externalFunctionUserData.hpp:59
void setGradient(Gradient const &g)
Set the gradient of this lvalue in the tape.
Definition: lhsExpressionInterface.hpp:120
Gradient getGradient() const
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:115

The forward and reverse equations for the linear system solve are defined in Linear system solve. The function solve2_rev implements the reverse mode equation according to this definition. The reverse mode function has to adhere to the header definition ReverseFunc. All arguments are derived from the usual source transformation notation where _b corresponds to the bar value of the primal. The sizes of x and y are given by m and n respectively. The parameter d holds user defined data. This data is provided by the helper and can be retrieved with a call to getExternalFunctionUserData. For details on how to use this structure please see Example 11 - External function user data.

The steps for using the codi::ExternalFUnctionHelper are then quite simple. The user has to:

  • Create the helper
  • Add the inputs
  • Add the outputs
  • Call the primal function
  • Add the reverse function to the tape

For the fourth step the user has two choices. Either a primal evaluation function with the header PrimalFunc can be implemented or the original function can be evaluated with the CoDiPack type. In the second option nothing is recorded during this evaluation.

After the function is registered on the tape, the helper is reset so that it can be used to push another external function.