CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
primalValueReuseTape.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://www.scicomp.uni-kl.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://www.scicomp.uni-kl.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <algorithm>
38#include <functional>
39#include <type_traits>
40
41#include "../config.h"
42#include "../expressions/lhsExpressionInterface.hpp"
43#include "../expressions/logic/compileTimeTraversalLogic.hpp"
44#include "../expressions/logic/constructStaticContext.hpp"
45#include "../expressions/logic/traversalLogic.hpp"
46#include "../misc/macros.hpp"
47#include "../misc/memberStore.hpp"
48#include "../traits/expressionTraits.hpp"
49#include "data/chunk.hpp"
50#include "indices/indexManagerInterface.hpp"
51#include "primalValueBaseTape.hpp"
52
54namespace codi {
55
63 template<typename T_TapeTypes>
64 struct PrimalValueReuseTape : public PrimalValueBaseTape<T_TapeTypes, PrimalValueReuseTape<T_TapeTypes>> {
65 public:
66
67 using TapeTypes =
68 CODI_DD(T_TapeTypes,
71
73 friend Base;
74
75 using Real = typename TapeTypes::Real;
76 using Gradient = typename TapeTypes::Gradient;
77 using Identifier = typename TapeTypes::Identifier;
79 using StatementEvaluator = typename TapeTypes::StatementEvaluator;
80 using EvalHandle = typename TapeTypes::EvalHandle;
81 using Position = typename Base::Position;
82
83 using StatementData = typename TapeTypes::StatementData;
84
87
89
93 void clearAdjoints(Position const& start, Position const& end,
95 CODI_UNUSED(adjointsManagement);
96
97 auto clearFunc = [this](Identifier* lhsIndex, Config::ArgumentSize* passiveArgs, Real* oldPrimal,
98 EvalHandle* evalHandle) {
99 CODI_UNUSED(passiveArgs, oldPrimal, evalHandle);
100
101 if (*lhsIndex < (Identifier)this->adjoints.size()) {
102 this->adjoints[*lhsIndex] = Gradient();
103 }
104 };
105
106 using StmtPosition = typename StatementData::Position;
107 StmtPosition startStmt = this->llfByteData.template extractPosition<StmtPosition>(start);
108 StmtPosition endStmt = this->llfByteData.template extractPosition<StmtPosition>(end);
109
110 this->statementData.forEachReverse(startStmt, endStmt, clearFunc);
111 }
112
113 protected:
114
117 /* data from call */
118 PrimalValueReuseTape& tape, Real* primalVector, ADJOINT_VECTOR_TYPE* adjointVector,
119 /* data from low level function byte data vector */
120 size_t& curLLFByteDataPos, size_t const& endLLFByteDataPos, char* dataPtr,
121 /* data from low level function info data vector */
122 size_t& curLLFInfoDataPos, size_t const& endLLFInfoDataPos, Config::LowLevelFunctionToken* const tokenPtr,
123 Config::LowLevelFunctionDataSize* const dataSizePtr,
124 /* data from constantValueData */
125 size_t& curConstantPos, size_t const& endConstantPos, PassiveReal const* const constantValues,
126 /* data from passiveValueData */
127 size_t& curPassivePos, size_t const& endPassivePos, Real const* const passiveValues,
128 /* data from rhsIdentifiersData */
129 size_t& curRhsIdentifiersPos, size_t const& endRhsIdentifiersPos, Identifier const* const rhsIdentifiers,
130 /* data from statementData */
131 size_t& curStatementPos, size_t const& endStatementPos, Identifier const* const lhsIdentifiers,
132 Config::ArgumentSize const* const numberOfPassiveArguments, Real* const oldPrimalValues,
133 EvalHandle const* const stmtEvalhandle) {
134 CODI_UNUSED(endLLFByteDataPos, endLLFInfoDataPos, endConstantPos, endPassivePos, endRhsIdentifiersPos);
135
136#if !CODI_VariableAdjointInterfaceInPrimalTapes
137 typename Base::template VectorAccess<Gradient> vectorAccess(adjointVector, primalVector);
138#endif
139
140 while (curStatementPos < endStatementPos) CODI_Likely {
141 Config::ArgumentSize nPassiveValues = numberOfPassiveArguments[curStatementPos];
142
144 Base::template callLowLevelFunction<LowLevelFunctionEntryCallKind::Forward>(
145 tape, true, curLLFByteDataPos, dataPtr, curLLFInfoDataPos, tokenPtr, dataSizePtr,
147 adjointVector
148#else
149 &vectorAccess
150#endif
151 );
152 } else CODI_Likely {
153 Identifier const lhsIdentifier = lhsIdentifiers[curStatementPos];
154
155 Gradient lhsTangent = Gradient();
156
157 oldPrimalValues[curStatementPos] = primalVector[lhsIdentifier];
158 primalVector[lhsIdentifier] = StatementEvaluator::template callForward<PrimalValueReuseTape>(
159 stmtEvalhandle[curStatementPos], primalVector, adjointVector, lhsTangent, nPassiveValues,
160 curConstantPos, constantValues, curPassivePos, passiveValues, curRhsIdentifiersPos, rhsIdentifiers);
161
162#if CODI_VariableAdjointInterfaceInPrimalTapes
163 adjointVector->setLhsTangent(lhsIdentifier);
165 tape, lhsIdentifier, adjointVector->getVectorSize(), adjointVector->getAdjointVec(lhsIdentifier));
166#else
167 adjointVector[lhsIdentifier] = lhsTangent;
169 tape, lhsIdentifier, GradientTraits::dim<Gradient>(), GradientTraits::toArray(lhsTangent).data());
170#endif
172 primalVector[lhsIdentifier]);
173 }
174
175 curStatementPos += 1;
176 }
177 }
178
181 /* data from call */
182 PrimalValueReuseTape& tape, Real* primalVector,
183 /* data from low level function byte data vector */
184 size_t& curLLFByteDataPos, size_t const& endLLFByteDataPos, char* dataPtr,
185 /* data from low level function info data vector */
186 size_t& curLLFInfoDataPos, size_t const& endLLFInfoDataPos, Config::LowLevelFunctionToken* const tokenPtr,
187 Config::LowLevelFunctionDataSize* const dataSizePtr,
188 /* data from constantValueData */
189 size_t& curConstantPos, size_t const& endConstantPos, PassiveReal const* const constantValues,
190 /* data from passiveValueData */
191 size_t& curPassivePos, size_t const& endPassivePos, Real const* const passiveValues,
192 /* data from rhsIdentifiersData */
193 size_t& curRhsIdentifiersPos, size_t const& endRhsIdentifiersPos, Identifier const* const rhsIdentifiers,
194 /* data from statementData */
195 size_t& curStatementPos, size_t const& endStatementPos, Identifier const* const lhsIdentifiers,
196 Config::ArgumentSize const* const numberOfPassiveArguments, Real* const oldPrimalValues,
197 EvalHandle const* const stmtEvalhandle) {
198 CODI_UNUSED(endLLFByteDataPos, endLLFInfoDataPos, endConstantPos, endPassivePos, endRhsIdentifiersPos);
199
200 typename Base::template VectorAccess<Gradient> vectorAccess(nullptr, primalVector);
201
202 while (curStatementPos < endStatementPos) CODI_Likely {
203 Config::ArgumentSize nPassiveValues = numberOfPassiveArguments[curStatementPos];
204
206 Base::template callLowLevelFunction<LowLevelFunctionEntryCallKind::Primal>(
207 tape, true, curLLFByteDataPos, dataPtr, curLLFInfoDataPos, tokenPtr, dataSizePtr, &vectorAccess);
208 } else CODI_Likely {
209 Identifier const lhsIdentifier = lhsIdentifiers[curStatementPos];
210
211 oldPrimalValues[curStatementPos] = primalVector[lhsIdentifier];
212 primalVector[lhsIdentifier] = StatementEvaluator::template callPrimal<PrimalValueReuseTape>(
213 stmtEvalhandle[curStatementPos], primalVector, numberOfPassiveArguments[curStatementPos],
214 curConstantPos, constantValues, curPassivePos, passiveValues, curRhsIdentifiersPos, rhsIdentifiers);
215
217 primalVector[lhsIdentifier]);
218 }
219
220 curStatementPos += 1;
221 }
222 }
223
226 /* data from call */
227 PrimalValueReuseTape& tape, Real* primalVector, ADJOINT_VECTOR_TYPE* adjointVector,
228 /* data from low level function byte data vector */
229 size_t& curLLFByteDataPos, size_t const& endLLFByteDataPos, char* dataPtr,
230 /* data from low level function info data vector */
231 size_t& curLLFInfoDataPos, size_t const& endLLFInfoDataPos, Config::LowLevelFunctionToken* const tokenPtr,
232 Config::LowLevelFunctionDataSize* const dataSizePtr,
233 /* data from constantValueData */
234 size_t& curConstantPos, size_t const& endConstantPos, PassiveReal const* const constantValues,
235 /* data from passiveValueData */
236 size_t& curPassivePos, size_t const& endPassivePos, Real const* const passiveValues,
237 /* data from rhsIdentifiersData */
238 size_t& curRhsIdentifiersPos, size_t const& endRhsIdentifiersPos, Identifier const* const rhsIdentifiers,
239 /* data from statementData */
240 size_t& curStatementPos, size_t const& endStatementPos, Identifier const* const lhsIdentifiers,
241 Config::ArgumentSize const* const numberOfPassiveArguments, Real const* const oldPrimalValues,
242 EvalHandle const* const stmtEvalhandle) {
243 CODI_UNUSED(endLLFByteDataPos, endLLFInfoDataPos, endConstantPos, endPassivePos, endRhsIdentifiersPos);
244
245#if !CODI_VariableAdjointInterfaceInPrimalTapes
246 typename Base::template VectorAccess<Gradient> vectorAccess(adjointVector, primalVector);
247#endif
248
249 while (curStatementPos > endStatementPos) CODI_Likely {
250 curStatementPos -= 1;
251
252 Config::ArgumentSize nPassiveValues = numberOfPassiveArguments[curStatementPos];
253
255 Base::template callLowLevelFunction<LowLevelFunctionEntryCallKind::Reverse>(
256 tape, false, curLLFByteDataPos, dataPtr, curLLFInfoDataPos, tokenPtr, dataSizePtr,
258 adjointVector
259#else
260 &vectorAccess
261#endif
262 );
263 } else CODI_Likely {
264 Identifier const lhsIdentifier = lhsIdentifiers[curStatementPos];
265
266#if CODI_VariableAdjointInterfaceInPrimalTapes
268 tape, lhsIdentifier, adjointVector->getVectorSize(), adjointVector->getAdjointVec(lhsIdentifier));
269 Gradient const lhsAdjoint{};
270 adjointVector->setLhsAdjoint(lhsIdentifier);
271#else
272 Gradient const lhsAdjoint = adjointVector[lhsIdentifier];
274 tape, lhsIdentifier, GradientTraits::dim<Gradient>(), GradientTraits::toArray(lhsAdjoint).data());
275 adjointVector[lhsIdentifier] = Gradient();
276#endif
278 primalVector[lhsIdentifier]);
279
280 primalVector[lhsIdentifier] = oldPrimalValues[curStatementPos];
281
282 StatementEvaluator::template callReverse<PrimalValueReuseTape>(
283 stmtEvalhandle[curStatementPos], primalVector, adjointVector, lhsAdjoint,
284 numberOfPassiveArguments[curStatementPos], curConstantPos, constantValues, curPassivePos, passiveValues,
285 curRhsIdentifiersPos, rhsIdentifiers);
286 }
287 }
288 }
289
292 // Reset primals.
293 auto clearFunc = [this](Identifier* lhsIndex, Config::ArgumentSize* passiveArgs, Real* oldPrimal,
294 EvalHandle* evalHandle) {
295 CODI_UNUSED(passiveArgs, evalHandle);
296
297 this->primals[*lhsIndex] = *oldPrimal;
298 };
299
300 using StmtPosition = typename StatementData::Position;
301 StmtPosition startStmt = this->llfByteData.template extractPosition<StmtPosition>(this->getPosition());
302 StmtPosition endStmt = this->llfByteData.template extractPosition<StmtPosition>(pos);
303
304 this->statementData.forEachReverse(startStmt, endStmt, clearFunc);
305 }
306
309 CODI_INLINE void pushStmtData(Identifier const& index, Config::ArgumentSize const& numberOfPassiveArguments,
310 Real const& oldPrimalValue, EvalHandle evalHandle) {
311 Base::statementData.pushData(index, numberOfPassiveArguments, oldPrimalValue, evalHandle);
312 }
313
314 public:
316 void revertPrimals(Position const& pos) {
318 }
319 };
320}
#define CODI_Unlikely
Declare unlikely evaluation of an execution path.
Definition: config.h:399
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition: config.h:457
#define CODI_VariableAdjointInterfaceInPrimalTapes
See codi::Config::VariableAdjointInterfaceInPrimalTapes.
Definition: config.h:269
#define ADJOINT_VECTOR_TYPE
See codi::Config::VariableAdjointInterfaceInPrimalTapes.
Definition: config.h:277
#define CODI_Likely
Declare likely evaluation of an execution path.
Definition: config.h:397
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition: macros.hpp:94
#define CODI_T(...)
Abbreviation for CODI_TEMPLATE.
Definition: macros.hpp:111
uint16_t LowLevelFunctionDataSize
Size store type for a low level function.
Definition: config.h:98
uint16_t LowLevelFunctionToken
Token type for low level functions in the tapes.
Definition: config.h:108
size_t constexpr StatementLowLevelFunctionTag
Statement tag for low level functions.
Definition: config.h:126
uint8_t ArgumentSize
Type for the number of arguments in statements.
Definition: config.h:117
std::array< AtomicTraits::RemoveAtomic< typename TraitsImplementation< Gradient >::Real >, TraitsImplementation< Gradient >::dim > toArray(Gradient const &gradient)
Converts the (possibly multi-component) gradient to an array of Reals.
Definition: gradientTraits.hpp:116
typename TraitsImplementation< Type >::PassiveReal PassiveReal
The original computation type, that was used in the application.
Definition: realTraits.hpp:117
CoDiPack - Code Differentiation Package.
Definition: codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition: macros.hpp:46
AdjointsManagement
Policies for management of the tape's interal adjoints.
Definition: tapeParameters.hpp:98
@ Automatic
Manage internal adjoints automatically, including locking, bounds checking, and resizing.
Data is stored chunk-wise in this DataInterface implementation. If a chunk runs out of space,...
Definition: chunkedData.hpp:64
Position getPosition() const
Current position of the tape.
Definition: commonTapeImplementation.hpp:629
LowLevelFunctionByteData llfByteData
Byte data for low level functions.
Definition: commonTapeImplementation.hpp:155
static void notifyStatementEvaluatePrimalListeners(Tape &tape, Identifier const &lhsIdentifier, Real const &lhsValue)
Invoke callbacks for StatementEvaluatePrimal events.
Definition: eventSystem.hpp:745
static void notifyStatementEvaluateListeners(Tape &tape, Identifier const &lhsIdentifier, size_t sizeLhsAdjoint, Real const *lhsAdjoint)
Invoke callbacks for StatementEvaluate events.
Definition: eventSystem.hpp:712
Indices enable the mapping of primal values to their adjoint counterparts.
Definition: indexManagerInterface.hpp:78
Implementation of VectorAccessInterface for adjoint and primal vectors.
Definition: primalAdjointVectorAccess.hpp:58
Base class for all standard Primal value tape implementations.
Definition: primalValueBaseTape.hpp:136
std::vector< Real > primals
Current state of primal values in the program.
Definition: primalValueBaseTape.hpp:189
std::vector< Gradient > adjoints
Evaluation vector for AD.
Definition: primalValueBaseTape.hpp:188
typename Base::Position Position
See TapeTypesInterface.
Definition: primalValueBaseTape.hpp:165
void clearAdjoints(AdjointsManagement adjointsManagement=AdjointsManagement::Automatic)
Clear all adjoint values, that is, set them to zero.
Definition: primalValueBaseTape.hpp:522
StatementData statementData
Data stream for statement specific data.
Definition: primalValueBaseTape.hpp:183
Final implementation for a primal value tape with a reuse index management.
Definition: primalValueReuseTape.hpp:64
void internalResetPrimalValues(Position const &pos)
Reset the primal values to the given position.
Definition: primalValueReuseTape.hpp:291
static void internalEvaluateReverse_EvalStatements(PrimalValueReuseTape &tape, Real *primalVector, Gradient *adjointVector, size_t &curLLFByteDataPos, size_t const &endLLFByteDataPos, char *dataPtr, size_t &curLLFInfoDataPos, size_t const &endLLFInfoDataPos, Config::LowLevelFunctionToken *const tokenPtr, Config::LowLevelFunctionDataSize *const dataSizePtr, size_t &curConstantPos, size_t const &endConstantPos, PassiveReal const *const constantValues, size_t &curPassivePos, size_t const &endPassivePos, Real const *const passiveValues, size_t &curRhsIdentifiersPos, size_t const &endRhsIdentifiersPos, Identifier const *const rhsIdentifiers, size_t &curStatementPos, size_t const &endStatementPos, Identifier const *const lhsIdentifiers, Config::ArgumentSize const *const numberOfPassiveArguments, Real const *const oldPrimalValues, EvalHandle const *const stmtEvalhandle)
Perform a reverse evaluation of the tape. Arguments are from the recursive eval methods of the DataIn...
Definition: primalValueReuseTape.hpp:225
friend Base
Allow the base class to call protected and private methods.
Definition: primalValueReuseTape.hpp:73
static void internalEvaluateForward_EvalStatements(PrimalValueReuseTape &tape, Real *primalVector, Gradient *adjointVector, size_t &curLLFByteDataPos, size_t const &endLLFByteDataPos, char *dataPtr, size_t &curLLFInfoDataPos, size_t const &endLLFInfoDataPos, Config::LowLevelFunctionToken *const tokenPtr, Config::LowLevelFunctionDataSize *const dataSizePtr, size_t &curConstantPos, size_t const &endConstantPos, PassiveReal const *const constantValues, size_t &curPassivePos, size_t const &endPassivePos, Real const *const passiveValues, size_t &curRhsIdentifiersPos, size_t const &endRhsIdentifiersPos, Identifier const *const rhsIdentifiers, size_t &curStatementPos, size_t const &endStatementPos, Identifier const *const lhsIdentifiers, Config::ArgumentSize const *const numberOfPassiveArguments, Real *const oldPrimalValues, EvalHandle const *const stmtEvalhandle)
Perform a forward evaluation of the tape. Arguments are from the recursive eval methods of the DataIn...
Definition: primalValueReuseTape.hpp:116
typename TapeTypes::Real Real
See TapeTypesInterface.
Definition: primalValueReuseTape.hpp:75
void pushStmtData(Identifier const &index, Config::ArgumentSize const &numberOfPassiveArguments, Real const &oldPrimalValue, EvalHandle evalHandle)
Add statement specific data to the data streams.
Definition: primalValueReuseTape.hpp:309
typename TapeTypes::Gradient Gradient
See TapeTypesInterface.
Definition: primalValueReuseTape.hpp:76
typename TapeTypes::StatementData StatementData
See PrimalValueTapeTypes.
Definition: primalValueReuseTape.hpp:83
static void internalEvaluatePrimal_EvalStatements(PrimalValueReuseTape &tape, Real *primalVector, size_t &curLLFByteDataPos, size_t const &endLLFByteDataPos, char *dataPtr, size_t &curLLFInfoDataPos, size_t const &endLLFInfoDataPos, Config::LowLevelFunctionToken *const tokenPtr, Config::LowLevelFunctionDataSize *const dataSizePtr, size_t &curConstantPos, size_t const &endConstantPos, PassiveReal const *const constantValues, size_t &curPassivePos, size_t const &endPassivePos, Real const *const passiveValues, size_t &curRhsIdentifiersPos, size_t const &endRhsIdentifiersPos, Identifier const *const rhsIdentifiers, size_t &curStatementPos, size_t const &endStatementPos, Identifier const *const lhsIdentifiers, Config::ArgumentSize const *const numberOfPassiveArguments, Real *const oldPrimalValues, EvalHandle const *const stmtEvalhandle)
Perform a primal evaluation of the tape. Arguments are from the recursive eval methods of the DataInt...
Definition: primalValueReuseTape.hpp:180
T_TapeTypes TapeTypes
See PrimalValueReuseTape.
Definition: primalValueReuseTape.hpp:70
void revertPrimals(Position const &pos)
Revert the primals to the state indicated by pos.
Definition: primalValueReuseTape.hpp:316
void clearAdjoints(Position const &start, Position const &end, AdjointsManagement adjointsManagement=AdjointsManagement::Automatic)
Clear all adjoints that would be set in a tape evaluation from start to end. It has to hold start >= ...
Definition: primalValueReuseTape.hpp:93
typename TapeTypes::Identifier Identifier
See TapeTypesInterface.
Definition: primalValueReuseTape.hpp:77
typename TapeTypes::StatementEvaluator StatementEvaluator
See PrimalValueTapeTypes.
Definition: primalValueReuseTape.hpp:79
typename Base::Position Position
See TapeTypesInterface.
Definition: primalValueReuseTape.hpp:81
RealTraits::PassiveReal< Real > PassiveReal
Basic computation type.
Definition: primalValueReuseTape.hpp:78
typename TapeTypes::EvalHandle EvalHandle
See PrimalValueTapeTypes.
Definition: primalValueReuseTape.hpp:80
PrimalValueReuseTape()
Constructor.
Definition: primalValueReuseTape.hpp:86
Type definitions for the primal value tapes.
Definition: primalValueBaseTape.hpp:77
Creation of handles for the evaluation of expressions in a context where the expression type is not a...
Definition: statementEvaluatorInterface.hpp:103