DiSMEC++
stats_base.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_STATS_BASE_H
7 #define DISMEC_STATS_BASE_H
8 
9 #include "matrix_types.h"
10 #include <memory>
11 #include <string>
12 #include <nlohmann/json_fwd.hpp>
13 
14 namespace dismec::stats {
15  class StatisticsCollection;
16 
30  class TagContainer {
31  public:
33  [[nodiscard]] const std::string& get_name() const { return m_Name; }
34 
36  [[nodiscard]] int get_value() const {
37  assert(!is_empty());
38  return *m_Value;
39  }
40 
41  // Returns whether the container is currently empty.
42  [[nodiscard]] bool is_empty() const { return m_Value == nullptr; }
43 
45  void set_value(int value) {
46  assert(!is_empty());
47  *m_Value = value;
48  }
49 
50  static TagContainer create_empty_container(std::string name) {
51  return TagContainer(std::move(name), nullptr);
52  }
53  static TagContainer create_full_container(std::string name) {
54  return TagContainer(std::move(name), std::make_shared<int>());
55  }
56 
57  private:
58  explicit TagContainer(std::string name, std::shared_ptr<int> val) : m_Name(std::move(name)), m_Value( std::move(val) ) {}
59 
60  std::string m_Name;
61  std::shared_ptr<int> m_Value;
62  };
63 
65  class Statistics {
66  public:
67  virtual ~Statistics() = default;
68 
69  // this overload is provided to prevent the int -> long / float ambiguity
70  void record(int integer) { record(long(integer)); }
71 
72  void record(long integer) { record_int(integer); }
73  void record(real_t real) { record_real(real); }
74  void record(const DenseRealVector& vector) { record_vec(vector); }
75 
76  [[nodiscard]] virtual std::unique_ptr<Statistics> clone() const = 0;
77 
84  virtual void setup(const StatisticsCollection& source) { };
85 
95  virtual void merge(const Statistics& other) = 0;
96 
98  [[nodiscard]] virtual nlohmann::json to_json() const = 0;
99 
100  private:
101  // Virtual functions to actually implement.
102  virtual void record_int(long integer) { throw std::logic_error("Not implemented"); }
103  virtual void record_real(real_t real) { throw std::logic_error("Not implemented"); }
104  virtual void record_vec(const DenseRealVector& vector) { throw std::logic_error("Not implemented"); }
105  };
106 
112  std::unique_ptr<stats::Statistics> make_stat_from_json(const nlohmann::json& source);
113 
126  template<class Derived>
127  class StatImplBase : public Statistics {
128  public:
129  void merge(const Statistics& other) override {
130  static_assert(std::is_final_v<Derived>, "Derived needs to be declared final, because further derived classes would break the merge code.");
131  static_cast<Derived*>(this)->merge_imp(dynamic_cast<const Derived&>(other));
132  }
133 
134  void record_vec(const DenseRealVector& vector) override {
135  for(int i = 0; i < vector.size(); ++i) {
136  static_cast<Derived*>(this)->record(vector.coeff(i));
137  }
138  }
139  };
140 
141 }
142 
143 #endif //DISMEC_STATS_BASE_H
Helper class for implementing Statistics classes.
Definition: stats_base.h:127
void merge(const Statistics &other) override
Merges this statistics of another one of the same type and settings.
Definition: stats_base.h:129
void record_vec(const DenseRealVector &vector) override
Definition: stats_base.h:134
This class manages a collection of named Statistics objects.
Definition: collection.h:47
TODO maybe we should solve this with a variant which does the dispatch of expected type and tag.
Definition: stats_base.h:65
void record(real_t real)
Definition: stats_base.h:73
virtual void record_vec(const DenseRealVector &vector)
Definition: stats_base.h:104
virtual void merge(const Statistics &other)=0
Merges this statistics of another one of the same type and settings.
void record(long integer)
Definition: stats_base.h:72
virtual void setup(const StatisticsCollection &source)
This function has to be called before the Statistics is used to collect data for the first time.
Definition: stats_base.h:84
virtual ~Statistics()=default
virtual void record_real(real_t real)
Definition: stats_base.h:103
virtual nlohmann::json to_json() const =0
Converts the statistics current value into a json object.
virtual void record_int(long integer)
Definition: stats_base.h:102
void record(const DenseRealVector &vector)
Definition: stats_base.h:74
void record(int integer)
Definition: stats_base.h:70
virtual std::unique_ptr< Statistics > clone() const =0
A tag container combines a name with a shared pointer, which points to the tag value.
Definition: stats_base.h:30
static TagContainer create_full_container(std::string name)
Definition: stats_base.h:53
int get_value() const
Returns the current value of the tag. Requires the container to not be empty.
Definition: stats_base.h:36
std::shared_ptr< int > m_Value
Definition: stats_base.h:61
const std::string & get_name() const
returns the name of the associated tag
Definition: stats_base.h:33
static TagContainer create_empty_container(std::string name)
Definition: stats_base.h:50
void set_value(int value)
Updates the value of the tag. Requires the container to not be empty.
Definition: stats_base.h:45
TagContainer(std::string name, std::shared_ptr< int > val)
Definition: stats_base.h:58
nlohmann::json json
Definition: model-io.cpp:22
std::unique_ptr< stats::Statistics > make_stat_from_json(const nlohmann::json &source)
Generates a stats::Statistics object based on a json configuration.
Definition: statistics.cpp:218
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17