DiSMEC++
binding.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_SRC_UTILS_BINDING_H
7 #define DISMEC_SRC_UTILS_BINDING_H
8 
9 #include "pybind11/pybind11.h"
10 #include "pybind11/eigen.h"
11 #include "pybind11/stl.h"
12 #include "pybind11/stl/filesystem.h"
13 #include "fwd.h"
14 
15 #include <memory>
16 
17 // Utilities for pybind11 binding
31 template<class T>
32 class PyWrapper {
33 public:
34  template<class U>
35  PyWrapper(U&& source, std::enable_if_t<std::is_convertible_v<U&, T&>>* a = nullptr) :
36  m_Data(std::make_shared<U>(std::forward<U>(source))) {
37  }
38  PyWrapper(std::shared_ptr<T> d) : m_Data(std::move(d)) {};
39 
40  T* operator->() {
41  return &access();
42  }
43 
44  const T* operator->() const {
45  return &access();
46  }
47 
48  T& access() {
49  if(m_Data) {
50  return *m_Data;
51  }
52  throw std::runtime_error("Trying to access empty object");
53  }
54 
55  [[nodiscard]] const T& access() const {
56  if(m_Data) {
57  return *m_Data;
58  }
59  throw std::runtime_error("Trying to access empty object");
60  }
61 
62  [[nodiscard]] const std::shared_ptr<T>& ptr() const {
63  return m_Data;
64  }
65 
66  [[nodiscard]] std::shared_ptr<T>& ptr() {
67  return m_Data;
68  }
69 private:
70  std::shared_ptr<T> m_Data;
71 };
72 
73 // Move value into shared ptr
74 template<class T, class = typename std::enable_if<!std::is_lvalue_reference<T>::value>::type>
75 std::shared_ptr<T> wrap_shared(T&& source) {
76  return std::make_shared<T>(source);
77 }
78 
79 namespace py = pybind11;
80 
81 using PyDataSet = std::shared_ptr<dismec::DatasetBase>;
82 using PyWeighting = std::shared_ptr<dismec::WeightingScheme>;
84 
85 #endif //DISMEC_SRC_UTILS_BINDING_H
std::shared_ptr< T > wrap_shared(T &&source)
Definition: binding.h:75
std::shared_ptr< dismec::DatasetBase > PyDataSet
Definition: binding.h:81
std::shared_ptr< dismec::WeightingScheme > PyWeighting
Definition: binding.h:82
Utility class used to wrap all objects we provide to python.
Definition: binding.h:32
T & access()
Definition: binding.h:48
const T & access() const
Definition: binding.h:55
const std::shared_ptr< T > & ptr() const
Definition: binding.h:62
std::shared_ptr< T > & ptr()
Definition: binding.h:66
std::shared_ptr< T > m_Data
Definition: binding.h:70
const T * operator->() const
Definition: binding.h:44
PyWrapper(U &&source, std::enable_if_t< std::is_convertible_v< U &, T & >> *a=nullptr)
Definition: binding.h:35
T * operator->()
Definition: binding.h:40
PyWrapper(std::shared_ptr< T > d)
Definition: binding.h:38
Forward-declares types.