DiSMEC++
train.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "parallel/runner.h"
7 #include "io/model-io.h"
8 #include "io/slice.h"
9 #include "data/data.h"
10 #include "data/transform.h"
11 #include "training/training.h"
12 #include "training/weighting.h"
13 #include "training/postproc.h"
14 #include "training/initializer.h"
15 #include "training/statistics.h"
16 #include "CLI/CLI.hpp"
17 #include "spdlog/spdlog.h"
18 #include "io/numpy.h"
19 #include "io/common.h"
20 #include "app.h"
21 #include <future>
22 
23 using namespace dismec;
24 
25 // extern "C" void openblas_set_num_threads(int num_threads);
26 
27 
28 class TrainingProgram {
29 public:
31  int run(int argc, const char** argv);
32 private:
33  CLI::App app{"DiSMEC"};
34 
35  // command line parameters
36  // source data
38  bool ReorderFeatures = false;
39 
41 
42  // target model
44  std::filesystem::path ModelFile;
45  io::model::SaveOption SaveOptions;
46 
47  // run range
49  int FirstLabel = 0;
50  int NumLabels = -1;
51  bool ContinueRun = false;
52 
54  label_id_t LabelsBegin{0};
55  label_id_t LabelsEnd{-1};
56 
57  CLI::Option* FirstLabelOpt;
58  CLI::Option* NumLabelsOpt;
59 
60  // hyper params
62  HyperParameters hps;
63  std::string WeightingMode;
64  double PropA = 0.55;
65  double PropB = 1.5;
66  std::string WeightingPosFile;
67  std::string WeightingNegFile;
68 
69  // regularization
70  void setup_regularization();
71 
72  // Pre-trained model
73  std::filesystem::path SourceModel;
74  CLI::Option* PreTrainedOpt;
75 
76  std::string InitMode;
77  std::optional<real_t> BiasInitValue;
78  real_t MSI_PFac = 1;
79  real_t MSI_NFac = -2;
80  int InitMaxPos = 1;
81 
83  real_t RegScale = 1.0;
84  bool RegBias = false;
85 
86  real_t Sparsify = -1;
87 
89 
90  // statistics
91  std::string StatsOutFile = "stats.json";
92  std::string StatsLevelFile = {};
93 
94 
95  // others
96  long NumThreads = -1;
97  long Timeout = -1;
98  long BatchSize = -1;
99 
100  int Verbose = 0;
101 
102  // config setup helpers
103  DismecTrainingConfig make_config(const std::shared_ptr<MultiLabelData>& data);
104 };
105 
106 int main(int argc, const char** argv) {
107  //openblas_set_num_threads(1);
108  TrainingProgram program;
109  program.run(argc, argv);
110 }
111 
113 {
114  SaveOptions.Format = io::model::WeightFormat::DENSE_TXT;
115 
116  app.add_option("output,--model-file", ModelFile,
117  "The file to which the model will be written. Note that models are saved in multiple different files, so this"
118  "just specifies the base name of the metadata file.")->required();
119 
120  // save format flags
121  auto* dense_txt_flag = app.add_flag("--save-dense-txt", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::DENSE_TXT; },
122  "Save dense weights in a human-readable text format")->take_last();
123  auto* dense_npy_flag = app.add_flag("--save-dense-npy", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::DENSE_NPY; },
124  "Save dense weights in a npy file")->take_last();
125  auto* sparse_flag = app.add_flag("--save-sparse-txt", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::SPARSE_TXT; },
126  "Save sparse weights in a human-readable text format. Sparsity can be adjusted using the --weight-culling option")->take_last();
127 
128  dense_npy_flag->excludes(dense_txt_flag, sparse_flag);
129  dense_txt_flag->excludes(dense_npy_flag, sparse_flag);
130  sparse_flag->excludes(dense_txt_flag, dense_npy_flag);
131 
132  app.add_option("--weight-culling", SaveOptions.Culling,
133  "When saving in a sparse format, any weight lower than this will be omitted.")->needs(sparse_flag)->check(CLI::NonNegativeNumber);
134 
135  app.add_option("--save-precision", SaveOptions.Precision,
136  "The number of digits to write for real numbers in text file format.")->check(CLI::NonNegativeNumber)->excludes(dense_npy_flag);
137 
138 }
139 
141  DataProc.setup_data_args(app);
142  app.add_flag("--reorder-features", ReorderFeatures,
143  "If this flag is given, then the feature columns are sorted by the frequency before training. "
144  "This can lead to fast computations in case the number of features is very large and their frequencies imbalanced, "
145  "because it may improve data locality.");
146 }
147 
149  FirstLabelOpt = app.add_option("--first-label", FirstLabel,
150  "If you want to train only a subset of labels, this is the id of the first label to be trained."
151  "The subset of labels trained is `[first_label, first_label + num-labels)`")->check(CLI::NonNegativeNumber);
152  NumLabelsOpt = app.add_option("--num-labels", NumLabels,
153  "If you want to train only a subset of labels, this is the total number of labels to be trained.")->check(CLI::NonNegativeNumber);
154  app.add_flag("--continue", ContinueRun,
155  "If this flag is given, the new weights will be appended to the model "
156  "file, instead of overwriting it. You can use the --first-label option to explicitly specify "
157  "at which label to start. If omitted, training starts at the first label for which no "
158  "weight vector is known.");
159 }
160 
162 {
163  // this needs to be set in all cases, because we need to adapt it dynamically and thus cannot rely on
164  // default values
165  hps.set("epsilon", 0.01);
166 
167  auto add_hyper_param_option = [&](const char* option, const char* name, const char* desc) {
168  return app.add_option_function<double>(
169  option,
170  [this, name](double value) { hps.set(name, value); },
171  desc)->group("hyper-parameters");
172  };
173 
174  add_hyper_param_option("--epsilon", "epsilon",
175  "Tolerance for the minimizer. Will be adjusted by the number of positive/negative instances")
176  ->check(CLI::NonNegativeNumber);
177 
178  add_hyper_param_option("--alpha-pcg", "alpha-pcg",
179  "Interpolation parameter for preconditioning of CG optimization.")->check(CLI::Range(0.0, 1.0));
180 
181  add_hyper_param_option("--line-search-step-size", "search.step-size",
182  "Step size for the line search.")->check(CLI::NonNegativeNumber);
183 
184  add_hyper_param_option("--line-search-alpha", "search.alpha",
185  "Shrink factor for updating the line search step")->check(CLI::Range(0.0, 1.0));
186 
187  add_hyper_param_option("--line-search-eta", "search.eta",
188  "Acceptance criterion for the line search")->check(CLI::Range(0.0, 1.0));
189  add_hyper_param_option("--cg-epsilon", "cg.epsilon",
190  "Stopping criterion for the CG solver")->check(CLI::PositiveNumber);
191 
192  app.add_option_function<long>(
193  "--max-steps",
194  [this](long value) { hps.set("max-steps", value); },
195  "Maximum number of newton steps.")->check(CLI::PositiveNumber)->group("hyper-parameters");
196 
197  app.add_option_function<long>(
198  "--line-search-max-steps",
199  [this](long value) { hps.set("search.max-steps", value); },
200  "Maximum number of line search steps.")->check(CLI::PositiveNumber)->group("hyper-parameters");
201 }
202 
204 {
205  // continue with automatic first label selection
206  if(ContinueRun)
207  {
208  io::PartialModelSaver saver(ModelFile, SaveOptions, true);
209  if(FirstLabelOpt->count() == 0)
210  {
211  auto missing = saver.get_missing_weights();
212  spdlog::info("Model is missing weight vectors {} to {}.", missing.first.to_index(), missing.second.to_index() - 1);
213  LabelsBegin = missing.first;
214  LabelsEnd = missing.second;
215  if (NumLabelsOpt->count() > 0) {
216  if (LabelsEnd - LabelsBegin >= NumLabels) {
217  LabelsEnd = LabelsBegin + NumLabels;
218  } else {
219  spdlog::warn("Number of labels to train was specified as {}, but only {} labels will be trained",
220  NumLabels, LabelsEnd - LabelsBegin);
221  }
222  }
223  return;
224  } else {
225  // user has given us a label from which to start.
226  LabelsBegin = label_id_t{FirstLabel};
227  if (NumLabelsOpt->count() > 0) {
228  LabelsBegin = LabelsBegin + NumLabels;
229  // and a label count. Then we need to check is this is valid
230  if(saver.any_weight_vector_for_interval(LabelsBegin, LabelsEnd)) {
231  spdlog::error("Specified continuation of training weight vectors for labels {}-{}, "
232  "which overlaps with existing weight vectors.", LabelsBegin.to_index(), LabelsEnd.to_index()-1);
233  exit(EXIT_FAILURE);
234  }
235  return;
236  }
237  LabelsEnd = label_id_t{saver.num_labels()};
238  return;
239  }
240  }
241 
242  // OK, we are not continuing a run.
243 
244  if(FirstLabelOpt->count()) {
245  LabelsBegin = label_id_t{FirstLabel};
246  } else {
247  LabelsBegin = label_id_t{0};
248  }
249 
250  if (NumLabelsOpt->count() > 0) {
251  LabelsEnd = LabelsBegin + NumLabels;
252  } else {
253  LabelsEnd = label_id_t{-1};
254  }
255 }
256 
257 
259  app.add_option("--regularizer", Regularizer, "The weight regularizer")->default_str("l2")
260  ->transform(CLI::Transformer(std::map<std::string, RegularizerType>{{"l2", RegularizerType::REG_L2},
261  {"l1", RegularizerType::REG_L1},
262  {"l1-relaxed", RegularizerType::REG_L1_RELAXED},
263  {"huber", RegularizerType::REG_HUBER},
264  {"elastic-50-50", RegularizerType::REG_ELASTIC_50_50},
265  {"elastic-90-10", RegularizerType::REG_ELASTIC_90_10}
266  },CLI::ignore_case));
267  app.add_option("--reg-scale", RegScale, "Scaling factor for the regularizer")->check(CLI::NonNegativeNumber);
268  app.add_flag("--reg-bias", RegBias, "Include bias in regularization")->default_val(false);
269 }
270 
271 
273  setup_source_cmdline();
274  setup_save_cmdline();
275  setup_label_range();
276  setup_hyper_params();
277  setup_regularization();
278 
279  app.add_option("--threads", NumThreads, "Number of threads to use. -1 means auto-detect");
280  app.add_option("--batch-size", BatchSize, "If this is given, training is split into batches "
281  "and results are written to disk after each batch.");
282  app.add_option("--timeout", Timeout, "No new training tasks will be started after this time. "
283  "This can be used e.g. on a cluster system to ensure that the training finishes properly "
284  "even if not all work could be done in the allotted time.")
285  ->transform(CLI::AsNumberWithUnit(std::map<std::string, float>{{"ms", 1},
286  {"s", 1'000}, {"sec", 1'000},
287  {"m", 60'000}, {"min", 60'000},
288  {"h", 60*60'000}},
289  CLI::AsNumberWithUnit::UNIT_REQUIRED, "TIME"));
290 
291  auto WMOpt = app.add_option("--weighting-mode", WeightingMode,
292  "Determines the re-weighting algorithm used to address missing labels.");
293  app.add_option("--propensity-a", PropA,
294  "Parameter a for the propensity model when using propensity based weighting")->needs(WMOpt);
295  app.add_option("--propensity-b", PropB,
296  "Parameter b for the propensity model when using propensity based weighting")->needs(WMOpt);
297  app.add_option("--weighting-pos-file", WeightingPosFile,
298  "File (npz or txt) that contains the weights for the positive instances for each label.")->needs(WMOpt);
299  app.add_option("--weighting-neg-file", WeightingNegFile,
300  "File (npz or txt) that contains the weights for the negative instances for each label.")->needs(WMOpt);
301 
302 
303 
304  PreTrainedOpt = app.add_option("--pretrained", SourceModel, "The model file which will be "
305  "used to initialize the weights.");
306  PreTrainedOpt->check(CLI::ExistingFile);
307 
308  app.add_option("--loss", Loss, "The loss function")->default_str("squared-hinge")
309  ->transform(CLI::Transformer(std::map<std::string, LossType>{{"squared-hinge", LossType::SQUARED_HINGE},
310  {"logistic", LossType::LOGISTIC},
311  {"huber-hinge", LossType::HUBER_HINGE},
312  {"hinge", LossType::HINGE},
313  },CLI::ignore_case));
314 
315  app.add_option("--sparsify", Sparsify, "Feedback-driven sparsification. Specify the maximum amount (in %) up to which the binary loss "
316  "is allowed to increase.");
317 
318  app.add_option("--init-mode", InitMode, "How to initialize the weight vectors")
319  ->check(CLI::IsMember({"zero", "mean", "bias", "msi", "multi-pos", "ova-primal"}));
320  app.add_option("--bias-init-value", BiasInitValue, "The value that is assigned to the bias weight for bias-init.");
321  app.add_option("--msi-pos", MSI_PFac, "Positive target for msi init");
322  app.add_option("--msi-neg", MSI_NFac, "Negative target for msi init");
323  app.add_option("--max-num-pos", InitMaxPos, "Number of positives to consider for `multi-pos` initialization")->check(CLI::NonNegativeNumber);
324 
325  app.add_option("--record-stats", StatsLevelFile,
326  "Record some statistics and save to file. The argument is a json file which describes which statistics are gathered.")
327  ->check(CLI::ExistingFile);
328  app.add_option("--stats-file", StatsOutFile, "Target file for recorded statistics");
329 
330  app.add_flag("-v,-q{-1}", Verbose);
331 }
332 
333 DismecTrainingConfig TrainingProgram::make_config(const std::shared_ptr<MultiLabelData>& data) {
334  DismecTrainingConfig config;
335 
336  // Positive / Negative weighting
337  if(WeightingMode == "2pm1") {
338  config.Weighting = std::make_shared<PropensityWeighting>(PropensityModel(data.get(), PropA, PropB));
339  } else if(WeightingMode == "p2mp") {
340  config.Weighting = std::make_shared<PropensityDownWeighting>(PropensityModel(data.get(), PropA, PropB));
341  }else if(WeightingMode == "from-file") {
342  auto load_vec = [&](std::string source){
343  DenseRealVector wgt = DenseRealVector::Ones(data->num_labels());
344  if(!source.empty()) {
345  std::fstream file(source, std::fstream::in);
346  if(!file.is_open()) {
347  THROW_ERROR("Could not open file {}", source);
348  }
349  if(io::is_npy(file)) {
350  auto header = io::parse_npy_header(*file.rdbuf());
351  if(header.DataType != io::data_type_string<real_t>()) {
352  THROW_ERROR("Unsupported data type {}", header.DataType);
353  }
354  if(header.Rows != 1 && header.Cols != 1) {
355  THROW_ERROR("Expected a vector for weighting data");
356  }
357  io::binary_load(*file.rdbuf(), wgt.data(), wgt.data() + header.Rows * header.Cols);
358  } else {
359  io::read_vector_from_text(file, wgt);
360  }
361  }
362  return wgt;
363  };
364  config.Weighting = std::make_shared<CustomWeighting>(load_vec(WeightingPosFile),
365  load_vec(WeightingNegFile));
366  } else if (!WeightingMode.empty()) {
367  spdlog::error("Unknown weighting mode {}. Aborting.", WeightingMode);
368  exit(EXIT_FAILURE);
369  } else {
370  config.Weighting = std::make_shared<ConstantWeighting>(1.0, 1.0);
371  }
372 
373  // Regularizer
374  switch(Regularizer) {
376  config.Regularizer = objective::SquaredNormConfig{RegScale, !RegBias};
377  break;
379  config.Regularizer = objective::HuberConfig{RegScale, 1e-2, !RegBias};
380  break;
382  config.Regularizer = objective::HuberConfig{RegScale, 1e-1, !RegBias};
383  break;
385  config.Regularizer = objective::HuberConfig{RegScale, 1.0, !RegBias};
386  break;
388  config.Regularizer = objective::ElasticConfig{RegScale, 1e-1, 0.5, !RegBias};
389  break;
391  config.Regularizer = objective::ElasticConfig{RegScale, 1e-1, 0.9, !RegBias};
392  break;
393  default:
394  spdlog::error("Unknown regularization mode {}. Aborting.", Regularizer);
395  exit(EXIT_FAILURE);
396  }
397 
398 
399  std::shared_ptr<init::WeightInitializationStrategy> init_strategy;
400  if(InitMode == "mean") {
401  config.Init = init::create_constant_initializer(-get_mean_feature(*data->get_features()));
402  } else if(InitMode == "msi") {
403  config.Init = init::create_feature_mean_initializer(data, MSI_PFac, MSI_NFac);
404  } else if(InitMode == "multi-pos") {
405  config.Init = init::create_multi_pos_mean_strategy(data, InitMaxPos, MSI_PFac, MSI_NFac);
406  } else if(InitMode == "ova-primal") {
407  config.Init = init::create_ova_primal_initializer(data, config.Regularizer, Loss);
408  } else if(InitMode == "bias" || (InitMode.empty() && BiasInitValue.has_value())) {
409  if(DataProc.augment_for_bias()) {
410  DenseRealVector init_vec(data->num_features());
411  init_vec.setZero();
412  init_vec.coeffRef(init_vec.size()-1) = BiasInitValue.value_or(-1.0);
413  config.Init = init::create_constant_initializer(std::move(init_vec));
414  } else {
415  spdlog::error("--init-mode=bias requires --augment-for-bias");
416  exit(EXIT_FAILURE);
417  }
418  }
419 
420  config.StatsGatherer = std::make_shared<TrainingStatsGatherer>(StatsLevelFile, StatsOutFile);
421  config.Loss = Loss;
422 
423  return config;
424 }
425 
426 int TrainingProgram::run(int argc, const char** argv)
427 {
428  try {
429  app.parse(argc, argv);
430  } catch (const CLI::ParseError &e) {
431  std::exit(app.exit(e));
432  }
433 
434  // check validity of save location
435  auto parent = std::filesystem::absolute(ModelFile).parent_path();
436  if(!std::filesystem::exists(parent)) {
437  spdlog::warn("Save directory '{}' does not exist. Trying to create it.", parent.c_str());
438  std::filesystem::create_directories(parent);
439  if(!std::filesystem::exists(parent)) {
440  spdlog::error("Could not create directory -- exiting.");
441  return EXIT_FAILURE;
442  }
443  }
444 
445  // TODO At this point, we know that the target directory exists, but not whether it is writeable.
446  // still, it's a start.
447 
448 
449  auto start_time = std::chrono::steady_clock::now();
450  auto timeout_time = start_time + std::chrono::milliseconds(Timeout);
451 
452  auto data = DataProc.load(Verbose);
453 
454  std::shared_ptr<postproc::PostProcessFactory> permute_post_proc;
455  if(ReorderFeatures) {
456  auto permute = sort_features_by_frequency(*data);
457  permute_post_proc = postproc::create_reordering(permute);
458  }
459 
460  parse_label_range();
461 
462  auto runner = parallel::ParallelRunner(NumThreads);
463  if(Verbose > 0)
464  runner.set_logger(spdlog::default_logger());
465 
466  auto config = make_config(data);
467 
468  std::shared_ptr<postproc::PostProcessFactory> post_proc;
469  bool use_sparse_model = false;
470  switch (SaveOptions.Format) {
471  case io::WeightFormat::SPARSE_TXT:
472  post_proc = postproc::create_culling(SaveOptions.Culling);
473  use_sparse_model = true;
474  break;
475  case io::WeightFormat::DENSE_TXT:
476  default:
477  break;
478  }
479 
480  // if we explicitly enable sparsification, we override the culling post-proc that
481  // may implicitly be generated due to the WeightFormat::SPARSE_TXT
482  if(Sparsify > 0) {
483  post_proc = postproc::create_sparsify(Sparsify / real_t{100});
484  use_sparse_model = true;
485  SaveOptions.Culling = 1e-10;
486  }
487 
488  // make sure to combine the post-processing operations
489  if(permute_post_proc) {
490  if (post_proc) {
491  post_proc = postproc::create_combined({post_proc, permute_post_proc});
492  } else {
493  post_proc = permute_post_proc;
494  }
495  }
496 
497 
498  if(BatchSize <= 0) {
499  BatchSize = data->num_labels();
500  }
501 
502  if(Verbose >= 0) {
503  spdlog::info("handled preprocessing in {} seconds",
504  std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
505  }
506 
507  // batched training
508  spdlog::info("Start training");
509  io::PartialModelSaver saver(ModelFile, SaveOptions, ContinueRun);
510  std::optional<io::PartialModelLoader> loader;
511  if(*PreTrainedOpt) {
512  loader.emplace(SourceModel);
513  }
514  label_id_t first_label = LabelsBegin;
515  if(LabelsEnd == label_id_t{-1}) {
516  LabelsEnd = label_id_t{data->num_labels()};
517  }
518  label_id_t next_label = std::min(LabelsEnd, first_label + BatchSize);
519  std::future<io::model::WeightFileEntry> saving;
520 
521  config.PostProcessing = post_proc;
522  config.Sparse = use_sparse_model;
523 
524  while(true) {
525  spdlog::info("Starting batch {} - {}", first_label.to_index(), next_label.to_index());
526 
527  if(loader.has_value()) {
528  auto initial_weights = loader->load_model(first_label, next_label);
529  config.Init = init::create_pretrained_initializer(initial_weights);
530  }
531 
532  // update time limit to respect remaining time
533  runner.set_time_limit(std::chrono::duration_cast<std::chrono::milliseconds>(timeout_time - std::chrono::steady_clock::now()));
534 
535  std::shared_ptr<TrainingSpec> train_spec = create_dismec_training(data, hps, config);
536  if(Verbose >= 0) {
537  train_spec->set_logger(spdlog::default_logger());
538  }
539  auto result = run_training(runner, train_spec,
540  first_label, next_label);
541 
542  /* do async saving. This has some advantages and some drawbacks:
543  + all the i/o latency will be interleaved with actual new computation and we don't waste much time
544  in this essentially non-parallel code
545  - we may overcommit the processor. If run_training uses all cores, then we will spawn an additional thread
546  here
547  - increased memory consumption. Instead of 1 model, we need to keep 2 in memory at the same time: The one
548  that is currently worked on and the one that is still being saved.
549  */
550  // make sure we don't interleave saving, as we don't do any locking in `saver`. Also, throw any exception
551  // that happened during the saving
552  if(saving.valid()) {
553  saving.get();
554  // saving weights has finished, we can update the meta data
555  saver.update_meta_file();
556  }
557 
558  saving = saver.add_model(result.Model);
559 
560  first_label = next_label;
561  if(first_label == LabelsEnd) {
562  // wait for the last saving process to finish
563  saving.get();
564  saver.update_meta_file();
565  break;
566  }
567  next_label = std::min(LabelsEnd, first_label + BatchSize);
568  // special case -- if the remaining labels are less than half a batch, we add them to this
569  // batch
570  if(next_label + BatchSize/2 > LabelsEnd) {
571  next_label = LabelsEnd;
572  }
573  }
574 
575  spdlog::info("program finished after {} seconds", std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
576 
577  return EXIT_SUCCESS;
578 }
void setup_source_cmdline()
DataProcessing DataProc
Definition: train.cpp:40
std::string WeightingNegFile
Definition: train.cpp:67
std::filesystem::path SourceModel
Definition: train.cpp:73
void setup_label_range()
void parse_label_range()
std::string InitMode
Definition: train.cpp:76
void setup_hyper_params()
std::string WeightingMode
Definition: train.cpp:63
std::optional< real_t > BiasInitValue
Definition: train.cpp:77
void setup_regularization()
Definition: train.cpp:258
CLI::Option * PreTrainedOpt
Definition: train.cpp:74
void setup_save_cmdline()
CascadeTrainingConfig make_config(const std::shared_ptr< MultiLabelData > &data, std::shared_ptr< const GenericFeatureMatrix > dense)
Definition: cascade.cpp:308
int run(int argc, const char **argv)
std::string WeightingPosFile
Definition: train.cpp:66
This class represents a set of hyper-parameters.
Definition: hyperparams.h:241
Manage saving a model consisting of multiple partial models.
Definition: model-io.h:236
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
building blocks for io procedures that are used by multiple io subsystems
#define THROW_ERROR(...)
Definition: common.h:23
std::shared_ptr< WeightInitializationStrategy > create_pretrained_initializer(std::shared_ptr< model::Model > model)
Creates an initialization strategy that uses an already trained model to set the initial weights.
Definition: pretrained.cpp:48
std::shared_ptr< WeightInitializationStrategy > create_feature_mean_initializer(std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
Definition: msi.cpp:90
std::shared_ptr< WeightInitializationStrategy > create_multi_pos_mean_strategy(std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
Definition: multi_pos.cpp:212
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
Definition: constant.cpp:56
std::shared_ptr< WeightInitializationStrategy > create_ova_primal_initializer(const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)
Definition: ova-primal.cpp:15
@ DENSE_TXT
Dense Text Format
std::istream & read_vector_from_text(std::istream &stream, Eigen::Ref< DenseRealVector > data)
Reads the given vector as space-separated human-readable numbers.
Definition: common.cpp:37
void binary_load(std::streambuf &target, T *begin, T *end)
Definition: common.h:120
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
Definition: numpy.cpp:22
NpyHeaderData parse_npy_header(std::streambuf &source)
Parses the header of the npy file given by source.
Definition: numpy.cpp:280
FactoryPtr create_combined(std::vector< FactoryPtr > processor)
FactoryPtr create_reordering(Eigen::PermutationMatrix< Eigen::Dynamic, Eigen::Dynamic, int > ordering)
Definition: reorder.cpp:25
FactoryPtr create_sparsify(real_t tolerance)
Definition: sparsify.cpp:167
FactoryPtr create_culling(real_t eps)
Definition: postproc.cpp:54
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
Definition: training.cpp:122
RegularizerType
Definition: spec.h:120
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
Definition: matrix_types.h:40
LossType
Definition: spec.h:129
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
Definition: dismec.cpp:157
Eigen::PermutationMatrix< Eigen::Dynamic, Eigen::Dynamic, int > sort_features_by_frequency(DatasetBase &data)
Definition: transform.cpp:110
DenseRealVector get_mean_feature(const GenericFeatureMatrix &features)
Definition: transform.cpp:52
float real_t
The default type for floating point values.
Definition: config.h:17
RegularizerSpec Regularizer
Definition: spec.h:151
std::shared_ptr< init::WeightInitializationStrategy > Init
Definition: spec.h:147
std::shared_ptr< WeightingScheme > Weighting
Definition: spec.h:146
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
Definition: spec.h:149
int main(int argc, const char **argv)
Definition: train.cpp:106