CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
Example 20 - Aggregated active type handling

Goal: Learn how to generalize the data extraction for external functions.

Prequesties: Example 11 - External function user data

Function: Simple real valued function

template<typename Type>
Type func(const Type& x) {
return x * x;
}

Full code:

#include <codi.hpp>
#include <iostream>
using Tape = typename Real::Tape;
using Identifier = typename Real::Identifier;
using RealBase = typename Real::Real;
template<typename Type>
Type func(const Type& x) {
return x * x;
}
template<typename Type>
void extFunc_rev(Tape* t, void* d, codi::VectorAccessInterface<RealBase, Identifier>* va) {
// Step 3: Create a wrapped vector access interface.
using VectorWrapper = typename Factory::RType;
VectorWrapper* vaType = Factory::create(va);
using TypeIdentifier = typename VectorWrapper::Identifier;
using TypeReal = typename VectorWrapper::Real;
// Step 4: Get the external function data
TypeReal x_v = data->getData<TypeReal>();
TypeIdentifier x_i = data->getData<TypeIdentifier>();
TypeIdentifier w_i = data->getData<TypeIdentifier>();
// Step 5: Use the wrapped vector access interface and perform the adjoint operation
TypeReal w_b = vaType->getAdjoint(w_i, 0);
TypeReal t_b = 2.0 * codi::ComputationTraits::transpose(x_v) * w_b;
vaType->updateAdjoint(x_i, 0, t_b);
vaType->resetAdjoint(w_i, 0);
// Step 6: Delete the created wrapper.
Factory::destroy(vaType);
}
void extFunc_del(Tape* t, void* d) {
delete data;
std::cout << " Reset: data is deleted." << std::endl;
}
template<typename Type>
Type addExternalFunc(Type const& x) {
Tape& tape = Real::getTape();
// Step 1: Perform the passive function evaluation.
tape.setPassive();
Type w = func(x);
tape.setActive();
// Step 2: Use the general access routines on the values to extract the primal and identifier data.
tape.pushExternalFunction(codi::ExternalFunction<Tape>::create(extFunc_rev<Type>, data, extFunc_del));
return w;
}
int main(int nargs, char** args) {
Real x = 3.0;
Tape& tape = Real::getTape();
tape.setActive();
tape.registerInput(x);
Real t1 = addExternalFunc(x);
std::complex<Real> c(t1, -t1);
std::complex<Real> t2 = addExternalFunc(c);
Real y = std::abs(t2);
tape.registerOutput(y);
tape.setPassive();
y.setGradient(1.0);
tape.evaluate();
std::cout << "x = " << x << std::endl;
std::cout << "y = " << y << std::endl;
std::cout << "dy/dx = " << x.getGradient() << std::endl;
tape.reset();
return 0;
}
DataExtraction< Type >::Identifier registerExternalFunctionOutput(Type &v)
Register all active types of a aggregated type as external function outputs.
Definition: realTraits.hpp:240
DataExtraction< Type >::Identifier getIdentifier(Type const &v)
Extract the identifiers from a type of aggregated active types.
Definition: realTraits.hpp:216
DataExtraction< Type >::Real getValue(Type const &v)
Extract the primal values from a type of aggregated active types.
Definition: realTraits.hpp:210
RealReverseGen< double > RealReverse
Definition: codi.hpp:120
Represents a concrete lvalue in the CoDiPack expression tree.
Definition: activeType.hpp:52
typename Tape::Identifier Identifier
See LhsExpressionInterface.
Definition: activeTypeBase.hpp:78
static Tape & getTape()
Get a reference to the tape which manages this expression.
Definition: activeType.hpp:99
T_Tape Tape
See ActiveType.
Definition: activeType.hpp:55
typename Tape::Real Real
See LhsExpressionInterface.
Definition: activeTypeBase.hpp:76
Factory for the creation of AggregatedTypeVectorAccessWrapper instances.
Definition: aggregatedTypeVectorAccessWrapper.hpp:217
Ease of access structure for user-provided data on the tape for external functions....
Definition: externalFunctionUserData.hpp:59
void getData(Type &value)
Get a copy of the next data item.
Definition: externalFunctionUserData.hpp:171
size_t addData(Type const &value)
Definition: externalFunctionUserData.hpp:151
User-defined evaluation functions for the taping process.
Definition: externalFunction.hpp:102
void setGradient(Gradient const &g)
Set the gradient of this lvalue in the tape.
Definition: lhsExpressionInterface.hpp:120
Gradient getGradient() const
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:115
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition: vectorAccessInterface.hpp:91

The example shows how a function that can be called with double and std::complex<double> can be differentiated with external functions. The implementation for the differentiation is generalized for the template parameter of the function. In the recording process, the helper structure codi::RealTraits is used for the generalization. For the reverse handling in the external function, the codi::AggregatedTypeVectorAccessWrapperFactory is used to create a wrapped version of the codi::VectorAccessInterface. In addition, codi::ComputationTraits are used for a generalization of the transpose. The advantage of using these traits and the wrapper is that aggregated types can be used in a similar fashion to standard CoDiPack types. In this case, the same code covers codi::RealReverse and std::complex<codi::RealReverse>.