DiSMEC++
training.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_TRAINING_H
7 #define DISMEC_TRAINING_H
8 
9 #include "spec.h"
10 #include "parallel/task.h"
11 #include "solver/minimizer.h"
12 #include <memory>
13 #include <functional>
14 #include "fwd.h"
15 #include "data/types.h"
16 
17 namespace dismec
18 {
27  public:
28  explicit TrainingTaskGenerator(std::shared_ptr<TrainingSpec> spec, label_id_t begin_label=label_id_t{0},
29  label_id_t end_label=label_id_t{-1});
31 
32  void run_tasks(long begin, long end, thread_id_t thread_id) override;
33  void prepare(long num_threads, long chunk_size) override;
34  void init_thread(thread_id_t thread_id) override;
35  void finalize() override;
36  [[nodiscard]] long num_tasks() const override;
37 
38  [[nodiscard]] const std::shared_ptr<model::Model>& get_model() const { return m_Model; }
39  [[nodiscard]] const std::vector<solvers::MinimizationResult>& get_results() const { return m_Results; }
40 
41  private:
42  void run_task(long task_id, thread_id_t thread_id);
43 
51 
52  //
53  std::shared_ptr<TrainingSpec> m_TaskSpec;
54 
55  // training only a partial model?
58 
59  // result variables
60  std::shared_ptr<model::Model> m_Model;
61  std::vector<solvers::MinimizationResult> m_Results;
62 
63  // thread-local caches
64  std::vector<DenseRealVector> m_ThreadLocalWorkingVector;
65  std::vector<std::unique_ptr<solvers::Minimizer>> m_ThreadLocalMinimizer;
66  std::vector<std::shared_ptr<objective::Objective>> m_ThreadLocalObjective;
67  std::vector<std::unique_ptr<init::WeightsInitializer>> m_ThreadLocalWeightInit;
68  std::vector<std::unique_ptr<postproc::PostProcessor>> m_ThreadLocalPostProc;
69  std::vector<std::unique_ptr<ResultStatsGatherer>> m_ResultGatherers;
70  };
71 
72  struct TrainingResult {
73  bool IsFinished = false;
74  std::shared_ptr<model::Model> Model;
77  };
78 
79  TrainingResult run_training(parallel::ParallelRunner& runner, std::shared_ptr<TrainingSpec> spec,
80  label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1});
81 
82 }
83 
84 #endif //DISMEC_TRAINING_H
Generates tasks for training weights for the i'th label.
Definition: training.h:26
long num_tasks() const override
Definition: training.cpp:118
std::vector< DenseRealVector > m_ThreadLocalWorkingVector
Definition: training.h:64
std::vector< std::unique_ptr< postproc::PostProcessor > > m_ThreadLocalPostProc
Definition: training.h:68
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
Definition: training.cpp:84
void run_task(long task_id, thread_id_t thread_id)
Definition: training.cpp:43
const std::shared_ptr< model::Model > & get_model() const
Definition: training.h:38
std::shared_ptr< TrainingSpec > m_TaskSpec
Definition: training.h:53
const std::vector< solvers::MinimizationResult > & get_results() const
Definition: training.h:39
void finalize() override
Called after all threads have finished their tasks.
Definition: training.cpp:108
std::vector< std::unique_ptr< ResultStatsGatherer > > m_ResultGatherers
Definition: training.h:69
std::vector< std::unique_ptr< init::WeightsInitializer > > m_ThreadLocalWeightInit
Definition: training.h:67
std::vector< std::unique_ptr< solvers::Minimizer > > m_ThreadLocalMinimizer
Definition: training.h:65
TrainingTaskGenerator(std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:20
solvers::MinimizationResult train_label(label_id_t label_id, thread_id_t thread_id)
Runs the training of a single label.
Definition: training.cpp:50
std::shared_ptr< model::Model > m_Model
Definition: training.h:60
void run_tasks(long begin, long end, thread_id_t thread_id) override
Definition: training.cpp:37
std::vector< solvers::MinimizationResult > m_Results
Definition: training.h:61
std::vector< std::shared_ptr< objective::Objective > > m_ThreadLocalObjective
Definition: training.h:66
void init_thread(thread_id_t thread_id) override
Called once a thread has spun up, but before it runs its first task.
Definition: training.cpp:93
Strong typedef for an int to signify a label id.
Definition: types.h:20
Base class for all parallelized operations.
Definition: task.h:21
dismec::parallel::thread_id_t thread_id_t
Definition: task.h:23
Strong typedef for an int to signify a thread id.
Definition: thread_id.h:20
Forward-declares types.
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:122
float real_t
The default type for floating point values.
Definition: config.h:17
std::shared_ptr< model::Model > Model
Definition: training.h:74