DiSMEC++
statistics.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_STATISTICS_H
7 #define DISMEC_STATISTICS_H
8 
9 #include <unordered_map>
10 #include <memory>
11 #include <mutex>
12 #include <vector>
13 #include "fwd.h"
14 #include "matrix_types.h"
15 #include "stats/tracked.h"
16 #include <nlohmann/json_fwd.hpp>
17 
18 namespace dismec
19 {
21  public:
24  virtual void record_result(const DenseRealVector& weights, const solvers::MinimizationResult& result) = 0;
25  virtual void start_label(label_id_t label) = 0;
26  virtual void start_training(const DenseRealVector& init_weights) = 0;
27  };
28 
31  public:
32  TrainingStatsGatherer(std::string source, std::string target_file);
34 
36  void setup_minimizer(thread_id_t thread, stats::Tracked& minimizer);
37  void setup_initializer(thread_id_t thread, stats::Tracked& initializer);
40  std::unique_ptr<ResultStatsGatherer> create_results_gatherer(thread_id_t thread, const std::shared_ptr<const TrainingSpec>& spec);
41 
42  void finalize();
43 
44  nlohmann::json to_json() const;
45  private:
46  struct StatData {
48  std::unique_ptr<stats::Statistics> Stat;
49  };
50  std::unordered_map<std::string, StatData> m_Merged;
51  using collection_ptr_t = std::shared_ptr<stats::StatisticsCollection>;
52  // we need to have this data per thread to 1) correctly associate per-thread tags and 2) ensure consistent order when merging.
53  std::vector<std::unordered_map<std::string, collection_ptr_t>> m_PerThreadCollections;
54 
55  std::mutex m_Lock;
56 
57  std::string m_TargetFile;
58 
59  void add_accu(const std::string& key, thread_id_t thread, const std::shared_ptr<stats::StatisticsCollection>& accumulator);
60 
61  std::unique_ptr<nlohmann::json> m_Config;
62  };
63 }
64 
65 #endif //DISMEC_STATISTICS_H
virtual void start_label(label_id_t label)=0
virtual void record_result(const DenseRealVector &weights, const solvers::MinimizationResult &result)=0
virtual void start_training(const DenseRealVector &init_weights)=0
void add_accu(const std::string &key, thread_id_t thread, const std::shared_ptr< stats::StatisticsCollection > &accumulator)
Definition: statistics.cpp:168
void setup_postproc(thread_id_t thread, stats::Tracked &objective)
Definition: statistics.cpp:42
std::unique_ptr< nlohmann::json > m_Config
Definition: statistics.h:61
void setup_minimizer(thread_id_t thread, stats::Tracked &minimizer)
NOTE: these functions will be called concurrently.
Definition: statistics.cpp:31
std::shared_ptr< stats::StatisticsCollection > collection_ptr_t
Definition: statistics.h:51
void setup_initializer(thread_id_t thread, stats::Tracked &initializer)
Definition: statistics.cpp:35
nlohmann::json to_json() const
Definition: statistics.cpp:72
std::vector< std::unordered_map< std::string, collection_ptr_t > > m_PerThreadCollections
Definition: statistics.h:53
void setup_objective(thread_id_t thread, stats::Tracked &objective)
Definition: statistics.cpp:39
std::unordered_map< std::string, StatData > m_Merged
Definition: statistics.h:50
TrainingStatsGatherer(std::string source, std::string target_file)
Definition: statistics.cpp:21
std::unique_ptr< ResultStatsGatherer > create_results_gatherer(thread_id_t thread, const std::shared_ptr< const TrainingSpec > &spec)
Definition: statistics.cpp:162
Strong typedef for an int to signify a label id.
Definition: types.h:20
Strong typedef for an int to signify a thread id.
Definition: thread_id.h:20
A base class to be used for all types that implement some for of statistics tracking.
Definition: tracked.h:42
Forward-declares types.
nlohmann::json json
Definition: model-io.cpp:22
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
stats::StatisticMetaData Meta
Definition: statistics.h:47
std::unique_ptr< stats::Statistics > Stat
Definition: statistics.h:48
Data that is associated with each declared statistics.
Definition: stat_id.h:33