CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
lhsExpressionInterface.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://www.scicomp.uni-kl.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://www.scicomp.uni-kl.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include "../config.h"
38#include "../misc/eventSystem.hpp"
39#include "../misc/macros.hpp"
40#include "../misc/toConst.hpp"
41#include "../tapes/interfaces/fullTapeInterface.hpp"
42#include "../traits/expressionTraits.hpp"
43#include "../traits/realTraits.hpp"
44#include "expressionInterface.hpp"
45
47namespace codi {
48
62 template<typename T_Real, typename T_Gradient, typename T_Tape, typename T_Impl>
63 struct LhsExpressionInterface : public ExpressionInterface<T_Real, T_Impl> {
64 public:
65
66 using Real = CODI_DD(T_Real, double);
67 using Gradient = CODI_DD(T_Gradient, Real);
68 using Tape = CODI_DD(T_Tape, CODI_DEFAULT_TAPE);
70
72
73 using Identifier = typename Tape::Identifier;
75
78
79 /*******************************************************************************/
82
83 Real const& value() const;
85
86 Identifier const& getIdentifier() const;
90
91 static Tape& getTape();
92
94 /*******************************************************************************/
97
100 return static_cast<Impl&>(*this);
101 }
102 using Base::cast;
103
106 return Impl::getTape().gradient(cast().getIdentifier());
107 }
108
111 return toConst(Impl::getTape()).gradient(cast().getIdentifier());
112 }
113
116 return cast().gradient();
117 }
118
121 cast().gradient() = g;
122 }
123
125 CODI_INLINE Real const& getValue() const {
126 return cast().value();
127 }
128
130 CODI_INLINE void setValue(Real const& v) {
131 cast().value() = v;
132 }
133
137 rhs, EventHints::Statement::Passive);
138 Impl::getTape().store(cast(), rhs);
139 return cast();
140 }
143 template<typename U = Real, typename = RealTraits::EnableIfNotPassiveReal<U>>
146 rhs, EventHints::Statement::Passive);
147 Impl::getTape().store(cast(), rhs);
148 return cast();
149 }
150
152 template<typename Rhs>
155 rhs.cast().getValue(), EventHints::Statement::Expression);
156 Impl::getTape().store(cast(), rhs.cast());
157 return cast();
158 }
159
161 template<typename Rhs, typename U = Real, typename = RealTraits::EnableIfNotPassiveReal<U>>
164 rhs.cast().getValue(), EventHints::Statement::Passive);
165 Impl::getTape().store(cast(), Real(rhs));
166 return cast();
167 }
168
172 rhs.cast().getValue(), EventHints::Statement::Copy);
173 Impl::getTape().store(cast(), rhs);
174 return cast();
175 }
176
178 template<typename Rhs>
181 rhs.cast().getValue(), EventHints::Statement::Copy);
182 Impl::getTape().store(cast(), rhs);
183 return cast();
184 }
185
187 /*******************************************************************************/
190
191 static bool constexpr EndPoint = true;
192
194 template<typename Logic, typename... Args>
195 CODI_INLINE void forEachLink(TraversalLogic<Logic>& logic, Args&&... args) const {
196 CODI_UNUSED(logic, args...);
197 }
198
200 template<typename Logic, typename... Args>
201 CODI_INLINE static typename Logic::ResultType constexpr forEachLinkConstExpr(Args&&... CODI_UNUSED_ARG(args)) {
202 return Logic::NeutralElement;
203 }
204
205 protected:
206
210 CODI_INLINE void init(Real const& newValue, EventHints::Statement statementType) {
211 Impl::getTape().initIdentifier(cast().value(), cast().getIdentifier());
213 statementType);
214 }
215
220 Impl::getTape().destroyIdentifier(cast().value(), cast().getIdentifier());
221 }
222
224 };
225
227 template<typename Expr>
229 typename Expr::Real temp;
230
231 stream >> temp;
232 v.setValue(temp);
233
234 return stream;
235 }
236
237#ifndef DOXYGEN_DISABLE
238
240 template<typename T_Type>
241 struct RealTraits::DataExtraction<T_Type, ExpressionTraits::EnableIfLhsExpression<T_Type>> {
242 public:
243 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
244
245 using Real = typename Type::Real;
246 using Identifier = typename Type::Identifier;
247
249 CODI_INLINE static Real getValue(Type const& v) {
250 return v.getValue();
251 }
252
254 CODI_INLINE static Identifier getIdentifier(Type const& v) {
255 return v.getIdentifier();
256 }
257
259 CODI_INLINE static void setValue(Type& v, Real const& value) {
260 v.setValue(value);
261 }
262 };
263
265 template<typename T_Type>
266 struct RealTraits::TapeRegistration<T_Type, ExpressionTraits::EnableIfLhsExpression<T_Type>> {
267 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
268
269 using Real = typename DataExtraction<Type>::Real;
270
272 CODI_INLINE static void registerInput(Type& v) {
273 Type::getTape().registerInput(v);
274 }
275
277 CODI_INLINE static void registerOutput(Type& v) {
278 Type::getTape().registerOutput(v);
279 }
280
283 return Type::getTape().registerExternalFunctionOutput(v);
284 }
285 };
286#endif
287}
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition: config.h:457
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition: macros.hpp:94
#define CODI_UNUSED_ARG(arg)
Used in a constexpr context, where using CODI_UNUSED spoils the constexpr.
Definition: macros.hpp:49
Statement
Classify statements.
Definition: eventSystem.hpp:65
typename std::enable_if< IsLhsExpression< Expr >::value, T >::type EnableIfLhsExpression
Enable if wrapper for IsLhsExpression.
Definition: expressionTraits.hpp:137
typename TraitsImplementation< Type >::PassiveReal PassiveReal
The original computation type, that was used in the application.
Definition: realTraits.hpp:117
CoDiPack - Code Differentiation Package.
Definition: codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition: macros.hpp:46
ExpressionTraits::EnableIfLhsExpression< Expr, std::istream > & operator>>(std::istream &stream, Expr &v)
Read the primal value from a stream.
Definition: lhsExpressionInterface.hpp:228
T const toConst(T &&v)
Constant cast function that works with CUDA.
Definition: toConst.hpp:47
Represents a concrete lvalue in the CoDiPack expression tree.
Definition: activeType.hpp:52
static void notifyStatementPrimalListeners(Tape &tape, Real const &lhsValue, Identifier const &lhsIdentifier, Real const &newValue, EventHints::Statement statement)
Invoke callbacks for StatementPrimal events.
Definition: eventSystem.hpp:264
Base class for all CoDiPack expressions.
Definition: expressionInterface.hpp:59
Impl const & cast() const
Cast to the implementation.
Definition: expressionInterface.hpp:75
Base class for all CoDiPack lvalue expression.
Definition: lhsExpressionInterface.hpp:63
static Logic::ResultType constexpr forEachLinkConstExpr(Args &&...)
Definition: lhsExpressionInterface.hpp:201
Real & value()
Get a reference to the lvalue represented by the expression.
LhsExpressionInterface(LhsExpressionInterface const &other)=default
Constructor.
Real const & value() const
Get a constant reference to the lvalue represented by the expression.
T_Gradient Gradient
See LhsExpressionInterface.
Definition: lhsExpressionInterface.hpp:67
Impl & operator=(LhsExpressionInterface< Real, Gradient, Tape, Rhs > const &rhs)
Assignment operator for lhs expressions. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:179
Impl & operator=(ExpressionInterface< Real, Rhs > const &rhs)
Assignment operator for expressions. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:153
RealTraits::PassiveReal< Real > PassiveReal
Basic computation type.
Definition: lhsExpressionInterface.hpp:74
T_Tape Tape
See LhsExpressionInterface.
Definition: lhsExpressionInterface.hpp:68
void setValue(Real const &v)
Set the primal value of this lvalue.
Definition: lhsExpressionInterface.hpp:130
typename Tape::Identifier Identifier
See GradientAccessTapeInterface.
Definition: lhsExpressionInterface.hpp:73
T_Real Real
See LhsExpressionInterface.
Definition: lhsExpressionInterface.hpp:66
void setGradient(Gradient const &g)
Set the gradient of this lvalue in the tape.
Definition: lhsExpressionInterface.hpp:120
Gradient const & gradient() const
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:110
Gradient & gradient()
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:105
Impl & operator=(PassiveReal const &rhs)
Assignment operator for passive values. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:144
static bool constexpr EndPoint
If this expression is handled as a leaf in the tree.
Definition: lhsExpressionInterface.hpp:191
LhsExpressionInterface()=default
Constructor.
Impl & operator=(LhsExpressionInterface const &rhs)
Assignment operator for lhs expressions. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:170
T_Impl Impl
See LhsExpressionInterface.
Definition: lhsExpressionInterface.hpp:69
Real const & getValue() const
Get the primal value of this lvalue.
Definition: lhsExpressionInterface.hpp:125
Gradient getGradient() const
Get the gradient of this lvalue from the tape.
Definition: lhsExpressionInterface.hpp:115
void destroy()
Definition: lhsExpressionInterface.hpp:219
Identifier const & getIdentifier() const
void forEachLink(TraversalLogic< Logic > &logic, Args &&... args) const
Definition: lhsExpressionInterface.hpp:195
static Tape & getTape()
Get a reference to the tape which manages this expression.
Impl & cast()
Cast to the implementation.
Definition: lhsExpressionInterface.hpp:99
Impl & operator=(Real const &rhs)
Assignment operator for passive values. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:135
void init(Real const &newValue, EventHints::Statement statementType)
Definition: lhsExpressionInterface.hpp:210
Impl & operator=(ExpressionInterface< typename U::Real, Rhs > const &rhs)
Assignment operator for expressions. Calls store on the InternalStatementRecordingTapeInterface.
Definition: lhsExpressionInterface.hpp:162
typename Type::Real Real
Type of primal values extracted from the type with AD values.
Definition: realTraits.hpp:165
static Real getValue(Type const &v)
Extract the primal values from a type of aggregated active types.
static void setValue(Type &v, Real const &value)
Set the primal values of a type of aggregated active types.
T_Type Type
See DataExtraction.
Definition: realTraits.hpp:163
static Identifier getIdentifier(Type const &v)
Extract the identifiers from a type of aggregated active types.
static Real registerExternalFunctionOutput(Type &v)
Register all active types of a aggregated type as external function outputs.
static void registerOutput(Type &v)
Register all active types of a aggregated type as tape output.
static void registerInput(Type &v)
Register all active types of a aggregated type as tape input.
T_Type Type
See TapeRegistration.
Definition: realTraits.hpp:194
Traversal of CoDiPack expressions.
Definition: traversalLogic.hpp:56