Goal: Learn how side effects can effect the tangent computation.
Prerequisite: Tutorial 1 - Forward mode AD
Function:
Real func(
const Real& x,
bool updateGlobal) {
if(updateGlobal) {
global = x * x;
}
return x * global;
}
Represents a concrete lvalue in the CoDiPack expression tree.
Definition: activeType.hpp:52
Full code:
#include <codi.hpp>
#include <iostream>
Real func(
const Real& x,
bool updateGlobal) {
if(updateGlobal) {
global = x * x;
}
return x * global;
}
int main(int nargs, char** args) {
std::cout << "Update global:" << std::endl;
std::cout << "f(4.0, true) = " << y << std::endl;
std::cout <<
"df/dx(4.0, true) = " << y.
getGradient() << std::endl << std::endl;
y = func(x, false);
std::cout << "No update global:" << std::endl;
std::cout << "f(4.0, false) = " << y << std::endl;
std::cout <<
"df/dx(4.0, false) = " << y.
getGradient() << std::endl << std::endl;
y = func(x, false);
std::cout << "No update global with reset:" << std::endl;
std::cout << "f(4.0, false) = " << y << std::endl;
std::cout <<
"df/dx(4.0, false) = " << y.
getGradient() << std::endl << std::endl;
return 0;
}
RealForwardGen< double, double > RealForward
Definition: codi.hpp:104
void setGradient(Gradient const &g)
Set the gradient of this lvalue in the tape.
Definition: lhsExpressionInterface.hpp:120
Gradient getGradient() const
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:115
The computational path of the function is changed via the parameter updateGlobal
. If this parameter is true
then x
also enters the computation of the global variable. If the parameter is false
then the global value is seen as a constant with respect to AD.
The three different calls demonstrate the error in the tangent computation. The first and third call are correct, the second one is wrong. Here, we fix the issue by directly resetting the tangent of the global
variable. An other option would be to call func(x, true)
again, with a tangent value of zero for x
.
Notes
This problem is not specific to CoDiPack. Nearly all tapeless forward AD tools suffer from this problem. One possible final fix is to implement a Gradient type, that tags tangents with an epoch. The user would then need to mange the global valid epoch and the Gradient type can ignore tangents from older epochs.