DiSMEC++
eigen_generic.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_EIGEN_GENERIC_H
7 #define DISMEC_EIGEN_GENERIC_H
8 
9 #include <variant>
10 #include "utils/type_helpers.h"
11 
12 namespace Eigen {
13  template<typename Derived>
14  struct EigenBase;
15 }
16 
17 // NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
18 #define EIGEN_VISITORS_IMPLEMENT_VISITOR(VISITOR, BASE, CALL) \
19 struct VISITOR { \
20  template<class Derived, class... Args> \
21  auto operator()(const Eigen::BASE<Derived>& source, Args&&... args) const { \
22  return source.CALL(std::forward<Args>(args)...); \
23  } \
24 }
25 
27  EIGEN_VISITORS_IMPLEMENT_VISITOR(ColsVisitor, EigenBase, cols);
28  EIGEN_VISITORS_IMPLEMENT_VISITOR(RowsVisitor, EigenBase, rows);
29  EIGEN_VISITORS_IMPLEMENT_VISITOR(SizeVisitor, EigenBase, size);
30 }
31 
32 #undef EIGEN_VISITORS_IMPLEMENT_VISITOR
33 
34 namespace dismec::types {
35  class VarWrapBase {};
36 
37  template<class... Types>
39  public:
40  using variant_t = std::variant<Types...>;
41 
42  template<class T>
43  explicit EigenVariantWrapper(T&& source) : m_Variant(std::forward<T>(source)) {}
44 
45  [[nodiscard]] auto size() const {
46  return std::visit(eigen_visitors::SizeVisitor{}, m_Variant);
47  }
48 
49  [[nodiscard]] auto rows() const {
50  return std::visit(eigen_visitors::RowsVisitor{}, m_Variant);
51  }
52 
53  [[nodiscard]] auto cols() const {
54  return std::visit(eigen_visitors::ColsVisitor{}, m_Variant);
55  }
56 
58  return m_Variant;
59  }
60 
61  const variant_t& unpack_variant() const {
62  return m_Variant;
63  }
64 
65  template<class T>
66  T& get() {
67  return std::get<T>(m_Variant);
68  }
69 
70  template<class T>
71  const T& get() const {
72  return std::get<T>(m_Variant);
73  }
74 
75  protected:
77  };
78 
79  template<class T>
80  constexpr bool is_variant_wrapper = std::is_base_of_v<VarWrapBase, std::decay_t<T>>;
81 
82  // TODO figure out the correct return type specification here
83  template<class T>
84  decltype(auto) unpack_variant_wrapper(T&& source, std::enable_if_t<!is_variant_wrapper<T>, void*> dispatch = nullptr) {
85  return source;
86  }
87 
88  template<class T>
89  decltype(auto) unpack_variant_wrapper(T&& source, std::enable_if_t<is_variant_wrapper<T>, void*> dispatch = nullptr) {
90  return source.unpack_variant();
91  }
92 
93 
94  template<class F, class... Variants>
95  auto visit(F&& f, Variants&& ... variants) {
96  return std::visit(std::forward<F>(f), unpack_variant_wrapper(std::forward<Variants>(variants))...);
97  }
98 
99  template<class Dense, class Sparse>
100  class GenericMatrix : public EigenVariantWrapper<Dense, Sparse> {
101  public:
103  using base_t::base_t;
104 
105  [[nodiscard]] const Dense& dense() const {
106  return std::get<Dense>(this->m_Variant);
107  }
108 
109  [[nodiscard]] Dense& dense() {
110  return std::get<Dense>(this->m_Variant);
111  }
112 
113  [[nodiscard]] const Sparse& sparse() const {
114  return std::get<Sparse>(this->m_Variant);
115  }
116 
117  [[nodiscard]] Sparse& sparse() {
118  return std::get<Sparse>(this->m_Variant);
119  }
120 
121  [[nodiscard]] bool is_sparse() const {
122  return this->m_Variant.index() == 1;
123  }
124  };
125 
126 
127  template<class... Types>
128  class RefVariant : public EigenVariantWrapper<Eigen::Ref<Types>...> {
129  public:
131  using base_t::base_t;
132  };
133 
134 
135  template<class T>
136  class GenericVectorRef : public RefVariant<DenseVector<T>, SparseVector<T>> {
137  public:
139  using DenseRef = Eigen::Ref<DenseVector<T>>;
140  using SparseRef = Eigen::Ref<SparseVector<T>>;
141 
142  explicit GenericVectorRef(const DenseVector<T>& m) : base_t(DenseRef(m)) {}
143  explicit GenericVectorRef(const SparseVector<T>& m) : base_t(SparseRef(m)) {}
144 
145  [[nodiscard]] const DenseRef& dense() const {
146  return this->template get<DenseRef>();
147  }
148 
149  [[nodiscard]] DenseRef& dense() {
150  return this->template get<DenseRef>();
151  }
152 
153  [[nodiscard]] const SparseRef& sparse() const {
154  return this->template get<SparseRef>();
155  }
156 
157  [[nodiscard]] SparseRef& sparse() {
158  return this->template get<SparseRef>();
159  }
160  };
161 
162  template<class T>
164  public:
166 
167  using DenseRowMajorRef = Eigen::Ref<DenseRowMajor<T>>;
168  using DenseColMajorRef = Eigen::Ref<DenseColMajor<T>>;
169  using SparseRowMajorRef = Eigen::Ref<SparseRowMajor<T>>;
170  using SparseColMajorRef = Eigen::Ref<SparseColMajor<T>>;
171 
176 
181  };
182 }
183 
184 #endif //DISMEC_EIGEN_GENERIC_H
const variant_t & unpack_variant() const
Definition: eigen_generic.h:61
std::variant< Types... > variant_t
Definition: eigen_generic.h:40
Eigen::Ref< SparseColMajor< T > > SparseColMajorRef
GenericMatrixRef(SparseRowMajorRef m)
Eigen::Ref< SparseRowMajor< T > > SparseRowMajorRef
GenericMatrixRef(const DenseRowMajor< T > &m)
Eigen::Ref< DenseColMajor< T > > DenseColMajorRef
GenericMatrixRef(const SparseRowMajor< T > &m)
GenericMatrixRef(const SparseColMajor< T > &m)
Eigen::Ref< DenseRowMajor< T > > DenseRowMajorRef
GenericMatrixRef(SparseColMajorRef m)
GenericMatrixRef(DenseColMajorRef m)
GenericMatrixRef(DenseRowMajorRef m)
GenericMatrixRef(const DenseColMajor< T > &m)
const Dense & dense() const
const Sparse & sparse() const
Eigen::Ref< DenseVector< T > > DenseRef
GenericVectorRef(const SparseVector< T > &m)
const SparseRef & sparse() const
const DenseRef & dense() const
Eigen::Ref< SparseVector< T > > SparseRef
GenericVectorRef(const DenseVector< T > &m)
EIGEN_VISITORS_IMPLEMENT_VISITOR(ColsVisitor, EigenBase, cols)
outer_const< T, sparse_col_major_h > SparseColMajor
Definition: type_helpers.h:52
outer_const< T, sparse_row_major_h > SparseRowMajor
Definition: type_helpers.h:49
outer_const< T, dense_row_major_h > DenseRowMajor
Definition: type_helpers.h:43
outer_const< T, dense_vector_h > DenseVector
Definition: type_helpers.h:37
outer_const< T, sparse_vector_h > SparseVector
Definition: type_helpers.h:40
outer_const< T, dense_col_major_h > DenseColMajor
Definition: type_helpers.h:46
auto visit(F &&f, Variants &&... variants)
Definition: eigen_generic.h:95
decltype(auto) unpack_variant_wrapper(T &&source, std::enable_if_t<!is_variant_wrapper< T >, void * > dispatch=nullptr)
Definition: eigen_generic.h:84
constexpr bool is_variant_wrapper
Definition: eigen_generic.h:80