[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

autodiff.hxx
1/************************************************************************/
2/* */
3/* Copyright 2012-2013 by Ullrich Koethe */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35
36
37#ifndef VIGRA_AUTODIFF_HXX
38#define VIGRA_AUTODIFF_HXX
39
40#include "tinyvector.hxx"
41#include "mathutil.hxx"
42#include <cmath>
43
44namespace vigra {
45
46namespace autodiff {
47
48/** Number type for automatic differentiation.
49
50 <a href="http://en.wikipedia.org/wiki/Automatic_differentiation">Automatic differentiation</a>
51 allows one to compute the value of a numeric expression and its gradient
52 with respect to the expression's arguments automatically and in one go.
53 To support this, one needs a special number type that holds a scalar value
54 and the corresponding gradient vector of appropriate length. This is the
55 purpose of hte template class <tt>DualVector<T, N></tt>, where <tt>T</tt> is the
56 underlying numerical type (usually 'double'), and <tt>N</tt> denotes the
57 length of the gradient vector.
58
59 The standard arithmetic and algebraic functions are overloaded for
60 <tt>DualVector</tt> in order to implement the required arithmetic of
61 dual numbers. When you replace all arguments in a numeric expression
62 with the appropriate <tt>DualVector</tt> instances, the result will be
63 a <tt>DualVector</tt> that contains the result value and gradient of
64 the expression, evaluated at the point defined by the input values.
65
66 <b> Usage:</b>
67
68 <b>\#include</b> <vigra/autodiff.hxx><br>
69 Namespace: vigra::autodiff
70
71 \code
72 typedef DualVector<double, 2> N; // for expressions with two arguments
73
74 N x(1.0, 0); // first argument of the expression
75 N s(2.0, 1); // second argument of the expression
76
77 N y = exp(-0.5 * sq(x / s));
78
79 std::cout << "Evaluated exp(- x^2 / (2 s^2)) at x=1 and s = 2:\n";
80 std::cout << "result = " << y.value() <<", gradient = " << y.gradient() << "\n";
81 \endcode
82 Note that the second argument of the <tt>DualVector</tt> constructors specifies that
83 the derivative w.r.t 'x' shall be the element 0 of the gradient vector, and the
84 derivative w.r.t. 's' shall be element 1.
85*/
86template <class T, int N>
87class DualVector
88{
89 public:
90 typedef T value_type; ///< type of function values and gradient elements
91 typedef TinyVector<T, N> Gradient; ///< type of the gradient vector
92
93 T v;
94 Gradient d;
95
96 /** Zero initialization.
97 */
99 : v(), d()
100 {}
101
102 /** Provide a value, but zero-initialize the gradient.
103 */
104 explicit DualVector(T const & val)
105 : v(val), d()
106 {}
107
108 /** Initialize with given value and gradient.
109 */
110 DualVector(T const & val, Gradient const & grad)
111 : v(val), d(grad)
112 {}
113
114 /** Shorthand for <tt>DualVector(val, Gradient(g0))</tt> when <tt>N == 1</tt>.
115
116 Not to be used when <tt>N != 1</tt>.
117 */
118 DualVector(T const & val, T const & g0)
119 : v(val), d(g0)
120 {}
121
122 /** Shorthand for <tt>DualVector(val, Gradient(g0, g1))</tt> when <tt>N == 2</tt>.
123
124 Not to be used when <tt>N != 2</tt>.
125 */
126 DualVector(T const & val, T const & g0, T const & g1)
127 : v(val), d(g0, g1)
128 {}
129
130 /** Initialize value to represent the argument number 'targetElement' in an
131 expression.
132
133 The derivative of the expression w.r.t. this variable will be element 'targetElement'
134 of the resulting gradient vector.
135 */
136 DualVector(T const & val, int targetElement)
137 : v(val), d()
138 {
139 d[targetElement] = T(1.0);
140 }
141
142 /** Get current value.
143 */
144 T value() const
145 {
146 return v;
147 }
148
149 /** Get current gradient.
150 */
151 Gradient const & gradient() const
152 {
153 return d;
154 }
155
156 DualVector operator+() const
157 {
158 return *this;
159 }
160
161 DualVector operator-() const
162 {
163 return DualVector(-v, -d);
164 }
165
166 DualVector & operator+=(DualVector const & o)
167 {
168 d += o.d;
169 v += o.v;
170 return *this;
171 }
172
173 DualVector & operator+=(T const & o)
174 {
175 v += o;
176 return *this;
177 }
178
179 DualVector & operator-=(DualVector const & o)
180 {
181 d -= o.d;
182 v -= o.v;
183 return *this;
184 }
185
186 DualVector & operator-=(T const & o)
187 {
188 v -= o;
189 return *this;
190 }
191
192 DualVector & operator*=(DualVector const & o)
193 {
194 d = o.v * d + v * o.d;
195 v *= o.v;
196 return *this;
197 }
198
199 DualVector & operator*=(T const & o)
200 {
201 d *= o;
202 v *= o;
203 return *this;
204 }
205
206 DualVector & operator/=(DualVector const & o)
207 {
208 d = (o.v * d - v * o.d) / sq(o.v);
209 v /= o.v;
210 return *this;
211 }
212
213 DualVector & operator/=(T const & o)
214 {
215 d /= o;
216 v /= o;
217 return *this;
218 }
219};
220
221 /** Given a vector 'v' of expression arguments, create the corresponding
222 vector of dual numbers for automatic differentiation.
223 */
224template <class T, int N>
225TinyVector<DualVector<T, N>, N>
226dualMatrix(TinyVector<T, N> const & v)
227{
228 TinyVector<DualVector<T, N>, N> res;
229 for(int k=0; k<N; ++k)
230 {
231 res[k].v = v[k];
232 res[k].d[k] = T(1.0);
233 }
234 return res;
235}
236
237template <class T, int N>
238inline DualVector<T, N> operator+(DualVector<T, N> v1, DualVector<T, N> const & v2)
239{
240 return v1 += v2;
241}
242
243template <class T, int N>
244inline DualVector<T, N> operator+(DualVector<T, N> v1, T v2)
245{
246 return v1 += v2;
247}
248
249template <class T, int N>
250inline DualVector<T, N> operator+(T v1, DualVector<T, N> v2)
251{
252 return v2 += v1;
253}
254
255template <class T, int N>
256inline DualVector<T, N> operator-(DualVector<T, N> v1, DualVector<T, N> const & v2)
257{
258 return v1 -= v2;
259}
260
261template <class T, int N>
262inline DualVector<T, N> operator-(DualVector<T, N> v1, T v2)
263{
264 return v1 -= v2;
265}
266
267template <class T, int N>
268inline DualVector<T, N> operator-(T v1, DualVector<T, N> const & v2)
269{
270 return DualVector<T, N>(v1 - v2.v, -v2.d);
271}
272
273template <class T, int N>
274inline DualVector<T, N> operator*(DualVector<T, N> v1, DualVector<T, N> const & v2)
275{
276 return v1 *= v2;
277}
278
279template <class T, int N>
280inline DualVector<T, N> operator*(DualVector<T, N> v1, T v2)
281{
282 return v1 *= v2;
283}
284
285template <class T, int N>
286inline DualVector<T, N> operator*(T v1, DualVector<T, N> v2)
287{
288 return v2 *= v1;
289}
290
291template <class T, int N>
292inline DualVector<T, N> operator/(DualVector<T, N> v1, DualVector<T, N> const & v2)
293{
294 return v1 /= v2;
295}
296
297template <class T, int N>
298inline DualVector<T, N> operator/(DualVector<T, N> v1, T v2)
299{
300 return v1 /= v2;
301}
302
303template <class T, int N>
304inline DualVector<T, N> operator/(T v1, DualVector<T, N> const & v2)
305{
306 return DualVector<T, N>(v1 / v2.v, -v1*v2.d / sq(v2.v));
307}
308
309using vigra::abs;
310// abs(x + h) => x + h or -(x + h)
311template <typename T, int N>
312inline DualVector<T, N> abs(DualVector<T, N> const & v)
313{
314 return v.v < T(0.0) ? -v : v;
315}
316
317using std::fabs;
318// abs(x + h) => x + h or -(x + h)
319template <typename T, int N>
320inline DualVector<T, N> fabs(DualVector<T, N> const & v)
321{
322 return v.v < T(0.0) ? -v : v;
323}
324
325using std::log;
326// log(a + h) => log(a) + h / a
327template <typename T, int N>
328inline DualVector<T, N> log(DualVector<T, N> v)
329{
330 v.d /= v.v;
331 v.v = log(v.v);
332 return v;
333}
334
335using std::exp;
336// exp(a + h) => exp(a) + exp(a) h
337template <class T, int N>
338inline DualVector<T, N> exp(DualVector<T, N> v)
339{
340 v.v = exp(v.v);
341 v.d *= v.v;
342 return v;
343}
344
345using vigra::sqrt;
346// sqrt(a + h) => sqrt(a) + h / (2 sqrt(a))
347template <typename T, int N>
348inline DualVector<T, N> sqrt(DualVector<T, N> v)
349{
350 v.v = sqrt(v.v);
351 v.d /= T(2.0) * v.v;
352 return v;
353}
354
355using std::sin;
356using std::cos;
357// sin(a + h) => sin(a) + cos(a) h
358template <typename T, int N>
359inline DualVector<T, N> sin(DualVector<T, N> v)
360{
361 v.d *= cos(v.v);
362 v.v = sin(v.v);
363 return v;
364}
365
366// cos(a + h) => cos(a) - sin(a) h
367template <typename T, int N>
368inline DualVector<T, N> cos(DualVector<T, N> v)
369{
370 v.d *= -sin(v.v);
371 v.v = cos(v.v);
372 return v;
373}
374
375using vigra::sin_pi;
376using vigra::cos_pi;
377// sin_pi(a + h) => sin_pi(a) + pi cos_pi(a) h
378template <typename T, int N>
379inline DualVector<T, N> sin_pi(DualVector<T, N> v)
380{
381 v.d *= M_PI*cos_pi(v.v);
382 v.v = sin_pi(v.v);
383 return v;
384}
385
386// cos_pi(a + h) => cos_pi(a) - pi sin_pi(a) h
387template <typename T, int N>
388inline DualVector<T, N> cos_pi(DualVector<T, N> v)
389{
390 v.d *= -M_PI*sin_pi(v.v);
391 v.v = cos_pi(v.v);
392 return v;
393}
394
395using std::asin;
396// asin(a + h) => asin(a) + 1 / sqrt(1 - a^2) h
397template <typename T, int N>
398inline DualVector<T, N> asin(DualVector<T, N> v)
399{
400 v.d /= sqrt(T(1.0) - sq(v.v));
401 v.v = asin(v.v);
402 return v;
403}
404
405using std::acos;
406// acos(a + h) => acos(a) - 1 / sqrt(1 - a^2) h
407template <typename T, int N>
408inline DualVector<T, N> acos(DualVector<T, N> v)
409{
410 v.d /= -sqrt(T(1.0) - sq(v.v));
411 v.v = acos(v.v);
412 return v;
413}
414
415using std::tan;
416// tan(a + h) => tan(a) + (1 + tan(a)^2) h
417template <typename T, int N>
418inline DualVector<T, N> tan(DualVector<T, N> v)
419{
420 v.v = tan(v.v);
421 v.d *= T(1.0) + sq(v.v);
422 return v;
423}
424
425using std::atan;
426// atan(a + h) => atan(a) + 1 / (1 + a^2) h
427template <typename T, int N>
428inline DualVector<T, N> atan(DualVector<T, N> v)
429{
430 v.d /= T(1.0) + sq(v.v);
431 v.v = atan(v.v);
432 return v;
433}
434
435using std::sinh;
436using std::cosh;
437// sinh(a + h) => sinh(a) + cosh(a) h
438template <typename T, int N>
439inline DualVector<T, N> sinh(DualVector<T, N> v)
440{
441 v.d *= cosh(v.v);
442 v.v = sinh(v.v);
443 return v;
444}
445
446// cosh(a + h) => cosh(a) + sinh(a) h
447template <typename T, int N>
448inline DualVector<T, N> cosh(DualVector<T, N> v)
449{
450 v.d *= sinh(v.v);
451 v.v = cosh(v.v);
452 return v;
453}
454
455using std::tanh;
456// tanh(a + h) => tanh(a) + (1 - tanh(a)^2) h
457template <typename T, int N>
458inline DualVector<T, N> tanh(DualVector<T, N> v)
459{
460 v.v = tanh(v.v);
461 v.d *= T(1.0) - sq(v.v);
462 return v;
463}
464
465using vigra::sq;
466// (a + h)^2 => a^2 + 2 a h
467template <class T, int N>
468inline DualVector<T, N> sq(DualVector<T, N> v)
469{
470 v.d *= T(2.0)*v.v;
471 v.v *= v.v;
472 return v;
473}
474
475using std::atan2;
476// atan2(b + db, a + da) => atan2(b, a) + (- b da + a db) / (a^2 + b^2)
477template <typename T, int N>
478inline DualVector<T, N> atan2(DualVector<T, N> v1, DualVector<T, N> const & v2)
479{
480 v1.d = (v2.v * v1.d - v1.v * v2.d) / (sq(v1.v) + sq(v2.v));
481 v1.v = atan2(v1.v, v2.v);
482 return v1;
483}
484
485
486using vigra::pow;
487// (a+da)^p => a^p + p*a^(p-1) da
488template <typename T, int N>
489inline DualVector<T, N> pow(DualVector<T, N> v, T p)
490{
491 T pow_p_1 = pow(v.v, p-T(1.0));
492 v.d *= p * pow_p_1;
493 v.v *= pow_p_1;
494 return v;
495}
496
497// (a)^(p+dp) => a^p + a^p log(a) dp
498template <typename T, int N>
499inline DualVector<T, N> pow(T v, DualVector<T, N> p)
500{
501 p.v = pow(v, p.v);
502 p.d *= p.v * log(v);
503 return p;
504}
505
506
507// (a+da)^(b+db) => a^b + b * a^(b-1) da + a^b log(a) * db
508template <typename T, int N>
509inline DualVector<T, N> pow(DualVector<T, N> v, DualVector<T, N> const & p)
510{
511 T pow_p_1 = pow(v.v, p.v-T(1.0)),
512 pow_p = v.v * pow_p_1;
513 v.d = p.v * pow_p_1 * v.d + pow_p * log(v.v) * p.d;
514 v.v = pow_p;
515 return v;
516}
517
518using vigra::min;
519template <class T, int N>
520inline DualVector<T, N> min(DualVector<T, N> const & v1, DualVector<T, N> const & v2)
521{
522 return v1.v < v2.v
523 ? v1
524 : v2;
525}
526
527template <class T, int N>
528inline DualVector<T, N> min(T v1, DualVector<T, N> const & v2)
529{
530 return v1 < v2.v
531 ? DualVector<T, N>(v1)
532 : v2;
533}
534
535template <class T, int N>
536inline DualVector<T, N> min(DualVector<T, N> const & v1, T v2)
537{
538 return v1.v < v2
539 ? v1
540 : DualVector<T, N>(v2);
541}
542
543using vigra::max;
544template <class T, int N>
545inline DualVector<T, N> max(DualVector<T, N> const & v1, DualVector<T, N> const & v2)
546{
547 return v1.v > v2.v
548 ? v1
549 : v2;
550}
551
552template <class T, int N>
553inline DualVector<T, N> max(T v1, DualVector<T, N> const & v2)
554{
555 return v1 > v2.v
556 ? DualVector<T, N>(v1)
557 : v2;
558}
559
560template <class T, int N>
561inline DualVector<T, N> max(DualVector<T, N> const & v1, T v2)
562{
563 return v1.v > v2
564 ? v1
565 : DualVector<T, N>(v2);
566}
567
568template <class T, int N>
569inline bool
570operator==(DualVector<T, N> const & v1, DualVector<T, N> const & v2)
571{
572 return v1.v == v2.v && v1.d == v2.d;
573}
574
575template <class T, int N>
576inline bool
577operator!=(DualVector<T, N> const & v1, DualVector<T, N> const & v2)
578{
579 return v1.v != v2.v || v1.d != v2.d;
580}
581
582#define VIGRA_DUALVECTOR_RELATIONAL_OPERATORS(op) \
583template <class T, int N> \
584inline bool \
585operator op(DualVector<T, N> const & v1, DualVector<T, N> const & v2) \
586{ \
587 return v1.v op v2.v; \
588} \
589 \
590template <class T, int N> \
591inline bool \
592operator op(T v1, DualVector<T, N> const & v2) \
593{ \
594 return v1 op v2.v; \
595} \
596 \
597template <class T, int N> \
598inline bool \
599operator op(DualVector<T, N> const & v1, T v2) \
600{ \
601 return v1.v op v2; \
602}
603
604VIGRA_DUALVECTOR_RELATIONAL_OPERATORS(<)
605VIGRA_DUALVECTOR_RELATIONAL_OPERATORS(<=)
606VIGRA_DUALVECTOR_RELATIONAL_OPERATORS(>)
607VIGRA_DUALVECTOR_RELATIONAL_OPERATORS(>=)
608
609#undef VIGRA_DUALVECTOR_RELATIONAL_OPERATORS
610
611template <class T, int N>
612inline bool
613closeAtTolerance(DualVector<T, N> const & v1, DualVector<T, N> const & v2,
614 T epsilon = NumericTraits<T>::epsilon())
615{
616 return vigra::closeAtTolerance(v1.v, v2.v, epsilon) && vigra::closeAtTolerance(v1.d, v2.d, epsilon);
617}
618
619} // namespace autodiff
620
621} // namespace vigra
622
623namespace std {
624
625 /// stream output
626template <class T, int N>
627ostream &
628operator<<(ostream & out, vigra::autodiff::DualVector<T, N> const & l)
629{
630 out << l.v << " " << l.d;
631 return out;
632}
633
634} // namespace std
635
636#endif // VIGRA_AUTODIFF_HXX
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition splines.hxx:50
TinyVector< T, N > Gradient
type of the gradient vector
Definition autodiff.hxx:91
DualVector(T const &val, T const &g0, T const &g1)
Definition autodiff.hxx:126
T value_type
type of function values and gradient elements
Definition autodiff.hxx:90
DualVector(T const &val, int targetElement)
Definition autodiff.hxx:136
Gradient const & gradient() const
Definition autodiff.hxx:151
DualVector(T const &val)
Definition autodiff.hxx:104
DualVector(T const &val, T const &g0)
Definition autodiff.hxx:118
DualVector(T const &val, Gradient const &grad)
Definition autodiff.hxx:110
DualVector()
Definition autodiff.hxx:98
T value() const
Definition autodiff.hxx:144
REAL cos_pi(REAL x)
cos(pi*x).
Definition mathutil.hxx:1242
REAL sin_pi(REAL x)
sin(pi*x).
Definition mathutil.hxx:1204
NumericTraits< T >::Promote sq(T t)
The square function.
Definition mathutil.hxx:382
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition mathutil.hxx:1638

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2