DiSMEC++
training.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "training.h"
7 
8 #include <utility>
9 #include "data/data.h"
10 #include "spdlog/fmt/chrono.h"
11 #include "model/submodel.h"
12 #include "parallel/runner.h"
13 #include "initializer.h"
14 #include "postproc.h"
15 #include "statistics.h"
16 #include "utils/eigen_generic.h"
17 
18 using namespace dismec;
19 
20 TrainingTaskGenerator::TrainingTaskGenerator(std::shared_ptr<TrainingSpec> spec,
21  label_id_t begin_label, label_id_t end_label) :
22  m_TaskSpec(std::move(spec)),
23  m_LabelRangeBegin(begin_label),
24  m_LabelRangeEnd(end_label.to_index() > 0 ? end_label : label_id_t{m_TaskSpec->get_data().num_labels()})
25 {
27 
30  static_cast<long>(m_TaskSpec->get_data().num_labels())};
31 
32  m_Model = m_TaskSpec->make_model(m_TaskSpec->num_features(), model_spec);
33 }
34 
36 
37 void TrainingTaskGenerator::run_tasks(long begin, long end, thread_id_t thread_id) {
38  for(long t = begin; t < end; ++t) {
39  run_task(t, thread_id);
40  }
41 }
42 
43 void TrainingTaskGenerator::run_task(long task_id, thread_id_t thread_id) {
44  label_id_t label_id = m_LabelRangeBegin + task_id;
45  assert(0 <= label_id.to_index());
46  assert(label_id.to_index() < m_TaskSpec->get_data().num_labels());
47  m_Results.at(task_id) = train_label(label_id, thread_id);
48 }
49 
51  m_ResultGatherers.at(thread_id.to_index())->start_label(label_id);
52 
53  // first, update the thread local objective and minimizer
54  auto& objective = m_ThreadLocalObjective.at(thread_id.to_index());
55  m_TaskSpec->update_objective(*objective, label_id);
56  auto& minimizer = m_ThreadLocalMinimizer.at(thread_id.to_index());
57  m_TaskSpec->update_minimizer(*minimizer, label_id);
58 
59  // get a reference to the thread-local weight buffer and initialize the weight.
60  DenseRealVector& target = m_ThreadLocalWorkingVector.at(thread_id.to_index());
61  m_ThreadLocalWeightInit.at(thread_id.to_index())->get_initial_weight(label_id, target, *objective);
62  m_ResultGatherers.at(thread_id.to_index())->start_training(target);
63 
64  // run the minimizer and update the weights in the model
65  auto result = minimizer->minimize(*objective, target);
66  m_ResultGatherers.at(thread_id.to_index())->record_result(target, result);
67  m_ThreadLocalPostProc.at(thread_id.to_index())->process(label_id, target, result);
68  m_Model->set_weights_for_label(label_id, model::Model::WeightVectorIn{target});
69 
70  // some logging
71  if(result.Outcome != solvers::MinimizerStatus::SUCCESS) {
72  spdlog::warn("Minimization for label {:5} failed after {:4} iterations", label_id.to_index(), result.NumIters);
73  }
74 
75  if(m_TaskSpec->get_logger()) {
76  m_TaskSpec->get_logger()->info(
77  "Thread {} finished minimization for label {:5} in {:4} iterations ({}) with loss {:6.3} -> {:6.3} and gradient {:6.3} -> {:6.3}.",
78  thread_id.to_index(), label_id.to_index(), result.NumIters, result.Duration, result.InitialValue,
79  result.FinalValue, result.InitialGrad, result.FinalGrad);
80  }
81  return result;
82 }
83 
84 void TrainingTaskGenerator::prepare(long num_threads, long chunk_size) {
85  m_ThreadLocalWorkingVector.resize(num_threads);
86  m_ThreadLocalMinimizer.resize(num_threads);
87  m_ThreadLocalObjective.resize(num_threads);
88  m_ThreadLocalWeightInit.resize(num_threads);
89  m_ThreadLocalPostProc.resize(num_threads);
90  m_ResultGatherers.resize(num_threads);
91 }
92 
94 {
95  m_ThreadLocalWorkingVector.at(thread_id.to_index()) = DenseRealVector::Zero(m_TaskSpec->num_features());
96  m_ThreadLocalMinimizer.at(thread_id.to_index()) = m_TaskSpec->make_minimizer();
97  m_ThreadLocalObjective.at(thread_id.to_index()) = m_TaskSpec->make_objective();
98  m_ThreadLocalWeightInit.at(thread_id.to_index()) = m_TaskSpec->make_initializer();
99  m_ThreadLocalPostProc.at(thread_id.to_index()) = m_TaskSpec->make_post_processor(m_ThreadLocalObjective.at(thread_id.to_index()));
100  m_ResultGatherers.at(thread_id.to_index()) = m_TaskSpec->get_statistics_gatherer().create_results_gatherer(thread_id, m_TaskSpec);
101 
102  m_TaskSpec->get_statistics_gatherer().setup_minimizer(thread_id, *m_ThreadLocalMinimizer.at(thread_id.to_index()));
103  m_TaskSpec->get_statistics_gatherer().setup_initializer(thread_id, *m_ThreadLocalWeightInit.at(thread_id.to_index()));
104  m_TaskSpec->get_statistics_gatherer().setup_objective(thread_id, *m_ThreadLocalObjective.at(thread_id.to_index()));
105  m_TaskSpec->get_statistics_gatherer().setup_postproc(thread_id, *m_ThreadLocalPostProc.at(thread_id.to_index()));
106 }
107 
110  m_ThreadLocalMinimizer.clear();
111  m_ThreadLocalObjective.clear();
112  m_ThreadLocalWeightInit.clear();
113  m_ThreadLocalPostProc.clear();
114 
115  m_TaskSpec->get_statistics_gatherer().finalize();
116 }
117 
119  return m_Results.size();
120 }
121 
122 TrainingResult dismec::run_training(parallel::ParallelRunner& runner, std::shared_ptr<TrainingSpec> spec,
123  label_id_t begin_label, label_id_t end_label)
124 {
125  auto task = TrainingTaskGenerator(std::move(spec), begin_label, end_label);
126  auto result = runner.run(task);
127 
128  real_t total_loss = 0.0;
129  real_t total_grad = 0.0;
130  for(const auto& r : task.get_results()) {
131  total_loss += r.FinalValue;
132  total_grad += r.FinalGrad;
133  }
134 
135  auto model = task.get_model();
136  // if training did time out, we need to adapt the resulting model to only declare the weight vectors
137  // which have actually been calculated
138  if(!result.IsFinished)
139  {
141  model = std::make_shared<SubWrapperType>(model, model->labels_begin(),
142  label_id_t{result.NextTask});
143  }
144 
145  return {result.IsFinished, std::move(model), total_loss, total_grad};
146 }
Generates tasks for training weights for the i'th label.
Definition: training.h:26
long num_tasks() const override
Definition: training.cpp:118
std::vector< DenseRealVector > m_ThreadLocalWorkingVector
Definition: training.h:64
std::vector< std::unique_ptr< postproc::PostProcessor > > m_ThreadLocalPostProc
Definition: training.h:68
void prepare(long num_threads, long chunk_size) override
Called to notify the TaskGenerator about the number of threads.
Definition: training.cpp:84
void run_task(long task_id, thread_id_t thread_id)
Definition: training.cpp:43
std::shared_ptr< TrainingSpec > m_TaskSpec
Definition: training.h:53
void finalize() override
Called after all threads have finished their tasks.
Definition: training.cpp:108
std::vector< std::unique_ptr< ResultStatsGatherer > > m_ResultGatherers
Definition: training.h:69
std::vector< std::unique_ptr< init::WeightsInitializer > > m_ThreadLocalWeightInit
Definition: training.h:67
std::vector< std::unique_ptr< solvers::Minimizer > > m_ThreadLocalMinimizer
Definition: training.h:65
TrainingTaskGenerator(std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:20
solvers::MinimizationResult train_label(label_id_t label_id, thread_id_t thread_id)
Runs the training of a single label.
Definition: training.cpp:50
std::shared_ptr< model::Model > m_Model
Definition: training.h:60
void run_tasks(long begin, long end, thread_id_t thread_id) override
Definition: training.cpp:37
std::vector< solvers::MinimizationResult > m_Results
Definition: training.h:61
std::vector< std::shared_ptr< objective::Objective > > m_ThreadLocalObjective
Definition: training.h:66
void init_thread(thread_id_t thread_id) override
Called once a thread has spun up, but before it runs its first task.
Definition: training.cpp:93
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
RunResult run(TaskGenerator &tasks, long start=0)
Definition: runner.cpp:39
Strong typedef for an int to signify a thread id.
Definition: thread_id.h:20
@ SUCCESS
The returned result is a minimum according to the stopping criterion of the algorithm.
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:122
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
float real_t
The default type for floating point values.
Definition: config.h:17
Specifies how to interpret a weight matrix for a partial model.
Definition: model.h:22