16 #include "CLI/CLI.hpp"
17 #include "spdlog/spdlog.h"
31 int run(
int argc,
const char** argv);
33 CLI::App app{
"DiSMEC"};
38 bool ReorderFeatures =
false;
44 std::filesystem::path ModelFile;
51 bool ContinueRun =
false;
57 CLI::Option* FirstLabelOpt;
58 CLI::Option* NumLabelsOpt;
70 void setup_regularization();
91 std::string StatsOutFile =
"stats.json";
92 std::string StatsLevelFile = {};
106 int main(
int argc,
const char** argv) {
109 program.
run(argc, argv);
116 app.add_option(
"output,--model-file", ModelFile,
117 "The file to which the model will be written. Note that models are saved in multiple different files, so this"
118 "just specifies the base name of the metadata file.")->required();
121 auto* dense_txt_flag = app.add_flag(
"--save-dense-txt", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::DENSE_TXT; },
122 "Save dense weights in a human-readable text format")->take_last();
123 auto* dense_npy_flag = app.add_flag(
"--save-dense-npy", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::DENSE_NPY; },
124 "Save dense weights in a npy file")->take_last();
125 auto* sparse_flag = app.add_flag(
"--save-sparse-txt", [&](std::size_t){ SaveOptions.Format = io::WeightFormat::SPARSE_TXT; },
126 "Save sparse weights in a human-readable text format. Sparsity can be adjusted using the --weight-culling option")->take_last();
128 dense_npy_flag->excludes(dense_txt_flag, sparse_flag);
129 dense_txt_flag->excludes(dense_npy_flag, sparse_flag);
130 sparse_flag->excludes(dense_txt_flag, dense_npy_flag);
132 app.add_option(
"--weight-culling", SaveOptions.Culling,
133 "When saving in a sparse format, any weight lower than this will be omitted.")->needs(sparse_flag)->check(CLI::NonNegativeNumber);
135 app.add_option(
"--save-precision", SaveOptions.Precision,
136 "The number of digits to write for real numbers in text file format.")->check(CLI::NonNegativeNumber)->excludes(dense_npy_flag);
141 DataProc.setup_data_args(app);
142 app.add_flag(
"--reorder-features", ReorderFeatures,
143 "If this flag is given, then the feature columns are sorted by the frequency before training. "
144 "This can lead to fast computations in case the number of features is very large and their frequencies imbalanced, "
145 "because it may improve data locality.");
149 FirstLabelOpt = app.add_option(
"--first-label", FirstLabel,
150 "If you want to train only a subset of labels, this is the id of the first label to be trained."
151 "The subset of labels trained is `[first_label, first_label + num-labels)`")->check(CLI::NonNegativeNumber);
152 NumLabelsOpt = app.add_option(
"--num-labels", NumLabels,
153 "If you want to train only a subset of labels, this is the total number of labels to be trained.")->check(CLI::NonNegativeNumber);
154 app.add_flag(
"--continue", ContinueRun,
155 "If this flag is given, the new weights will be appended to the model "
156 "file, instead of overwriting it. You can use the --first-label option to explicitly specify "
157 "at which label to start. If omitted, training starts at the first label for which no "
158 "weight vector is known.");
165 hps.set(
"epsilon", 0.01);
167 auto add_hyper_param_option = [&](
const char* option,
const char* name,
const char* desc) {
168 return app.add_option_function<
double>(
170 [
this, name](
double value) { hps.set(name, value); },
171 desc)->group(
"hyper-parameters");
174 add_hyper_param_option(
"--epsilon",
"epsilon",
175 "Tolerance for the minimizer. Will be adjusted by the number of positive/negative instances")
176 ->check(CLI::NonNegativeNumber);
178 add_hyper_param_option(
"--alpha-pcg",
"alpha-pcg",
179 "Interpolation parameter for preconditioning of CG optimization.")->check(CLI::Range(0.0, 1.0));
181 add_hyper_param_option(
"--line-search-step-size",
"search.step-size",
182 "Step size for the line search.")->check(CLI::NonNegativeNumber);
184 add_hyper_param_option(
"--line-search-alpha",
"search.alpha",
185 "Shrink factor for updating the line search step")->check(CLI::Range(0.0, 1.0));
187 add_hyper_param_option(
"--line-search-eta",
"search.eta",
188 "Acceptance criterion for the line search")->check(CLI::Range(0.0, 1.0));
189 add_hyper_param_option(
"--cg-epsilon",
"cg.epsilon",
190 "Stopping criterion for the CG solver")->check(CLI::PositiveNumber);
192 app.add_option_function<
long>(
194 [
this](
long value) { hps.set(
"max-steps", value); },
195 "Maximum number of newton steps.")->check(CLI::PositiveNumber)->group(
"hyper-parameters");
197 app.add_option_function<
long>(
198 "--line-search-max-steps",
199 [
this](
long value) { hps.set(
"search.max-steps", value); },
200 "Maximum number of line search steps.")->check(CLI::PositiveNumber)->group(
"hyper-parameters");
209 if(FirstLabelOpt->count() == 0)
211 auto missing = saver.get_missing_weights();
212 spdlog::info(
"Model is missing weight vectors {} to {}.", missing.first.to_index(), missing.second.to_index() - 1);
213 LabelsBegin = missing.first;
214 LabelsEnd = missing.second;
215 if (NumLabelsOpt->count() > 0) {
216 if (LabelsEnd - LabelsBegin >= NumLabels) {
217 LabelsEnd = LabelsBegin + NumLabels;
219 spdlog::warn(
"Number of labels to train was specified as {}, but only {} labels will be trained",
220 NumLabels, LabelsEnd - LabelsBegin);
227 if (NumLabelsOpt->count() > 0) {
228 LabelsBegin = LabelsBegin + NumLabels;
230 if(saver.any_weight_vector_for_interval(LabelsBegin, LabelsEnd)) {
231 spdlog::error(
"Specified continuation of training weight vectors for labels {}-{}, "
232 "which overlaps with existing weight vectors.", LabelsBegin.to_index(), LabelsEnd.to_index()-1);
244 if(FirstLabelOpt->count()) {
250 if (NumLabelsOpt->count() > 0) {
251 LabelsEnd = LabelsBegin + NumLabels;
259 app.add_option(
"--regularizer", Regularizer,
"The weight regularizer")->default_str(
"l2")
266 },CLI::ignore_case));
267 app.add_option(
"--reg-scale", RegScale,
"Scaling factor for the regularizer")->check(CLI::NonNegativeNumber);
268 app.add_flag(
"--reg-bias", RegBias,
"Include bias in regularization")->default_val(
false);
273 setup_source_cmdline();
274 setup_save_cmdline();
276 setup_hyper_params();
277 setup_regularization();
279 app.add_option(
"--threads", NumThreads,
"Number of threads to use. -1 means auto-detect");
280 app.add_option(
"--batch-size", BatchSize,
"If this is given, training is split into batches "
281 "and results are written to disk after each batch.");
282 app.add_option(
"--timeout", Timeout,
"No new training tasks will be started after this time. "
283 "This can be used e.g. on a cluster system to ensure that the training finishes properly "
284 "even if not all work could be done in the allotted time.")
285 ->transform(CLI::AsNumberWithUnit(std::map<std::string, float>{{
"ms", 1},
286 {
"s", 1'000}, {
"sec", 1'000},
287 {
"m", 60'000}, {
"min", 60'000},
289 CLI::AsNumberWithUnit::UNIT_REQUIRED,
"TIME"));
291 auto WMOpt = app.add_option(
"--weighting-mode", WeightingMode,
292 "Determines the re-weighting algorithm used to address missing labels.");
293 app.add_option(
"--propensity-a", PropA,
294 "Parameter a for the propensity model when using propensity based weighting")->needs(WMOpt);
295 app.add_option(
"--propensity-b", PropB,
296 "Parameter b for the propensity model when using propensity based weighting")->needs(WMOpt);
297 app.add_option(
"--weighting-pos-file", WeightingPosFile,
298 "File (npz or txt) that contains the weights for the positive instances for each label.")->needs(WMOpt);
299 app.add_option(
"--weighting-neg-file", WeightingNegFile,
300 "File (npz or txt) that contains the weights for the negative instances for each label.")->needs(WMOpt);
304 PreTrainedOpt = app.add_option(
"--pretrained", SourceModel,
"The model file which will be "
305 "used to initialize the weights.");
306 PreTrainedOpt->check(CLI::ExistingFile);
308 app.add_option(
"--loss", Loss,
"The loss function")->default_str(
"squared-hinge")
313 },CLI::ignore_case));
315 app.add_option(
"--sparsify", Sparsify,
"Feedback-driven sparsification. Specify the maximum amount (in %) up to which the binary loss "
316 "is allowed to increase.");
318 app.add_option(
"--init-mode", InitMode,
"How to initialize the weight vectors")
319 ->check(CLI::IsMember({
"zero",
"mean",
"bias",
"msi",
"multi-pos",
"ova-primal"}));
320 app.add_option(
"--bias-init-value", BiasInitValue,
"The value that is assigned to the bias weight for bias-init.");
321 app.add_option(
"--msi-pos", MSI_PFac,
"Positive target for msi init");
322 app.add_option(
"--msi-neg", MSI_NFac,
"Negative target for msi init");
323 app.add_option(
"--max-num-pos", InitMaxPos,
"Number of positives to consider for `multi-pos` initialization")->check(CLI::NonNegativeNumber);
325 app.add_option(
"--record-stats", StatsLevelFile,
326 "Record some statistics and save to file. The argument is a json file which describes which statistics are gathered.")
327 ->check(CLI::ExistingFile);
328 app.add_option(
"--stats-file", StatsOutFile,
"Target file for recorded statistics");
330 app.add_flag(
"-v,-q{-1}", Verbose);
337 if(WeightingMode ==
"2pm1") {
339 }
else if(WeightingMode ==
"p2mp") {
341 }
else if(WeightingMode ==
"from-file") {
342 auto load_vec = [&](std::string source){
344 if(!source.empty()) {
345 std::fstream file(source, std::fstream::in);
346 if(!file.is_open()) {
351 if(header.DataType != io::data_type_string<real_t>()) {
352 THROW_ERROR(
"Unsupported data type {}", header.DataType);
354 if(header.Rows != 1 && header.Cols != 1) {
355 THROW_ERROR(
"Expected a vector for weighting data");
357 io::binary_load(*file.rdbuf(), wgt.data(), wgt.data() + header.Rows * header.Cols);
364 config.
Weighting = std::make_shared<CustomWeighting>(load_vec(WeightingPosFile),
365 load_vec(WeightingNegFile));
366 }
else if (!WeightingMode.empty()) {
367 spdlog::error(
"Unknown weighting mode {}. Aborting.", WeightingMode);
370 config.
Weighting = std::make_shared<ConstantWeighting>(1.0, 1.0);
374 switch(Regularizer) {
394 spdlog::error(
"Unknown regularization mode {}. Aborting.", Regularizer);
399 std::shared_ptr<init::WeightInitializationStrategy> init_strategy;
400 if(InitMode ==
"mean") {
402 }
else if(InitMode ==
"msi") {
404 }
else if(InitMode ==
"multi-pos") {
406 }
else if(InitMode ==
"ova-primal") {
408 }
else if(InitMode ==
"bias" || (InitMode.empty() && BiasInitValue.has_value())) {
409 if(DataProc.augment_for_bias()) {
412 init_vec.coeffRef(init_vec.size()-1) = BiasInitValue.value_or(-1.0);
415 spdlog::error(
"--init-mode=bias requires --augment-for-bias");
420 config.
StatsGatherer = std::make_shared<TrainingStatsGatherer>(StatsLevelFile, StatsOutFile);
429 app.parse(argc, argv);
430 }
catch (
const CLI::ParseError &e) {
431 std::exit(app.exit(e));
435 auto parent = std::filesystem::absolute(ModelFile).parent_path();
436 if(!std::filesystem::exists(parent)) {
437 spdlog::warn(
"Save directory '{}' does not exist. Trying to create it.", parent.c_str());
438 std::filesystem::create_directories(parent);
439 if(!std::filesystem::exists(parent)) {
440 spdlog::error(
"Could not create directory -- exiting.");
449 auto start_time = std::chrono::steady_clock::now();
450 auto timeout_time = start_time + std::chrono::milliseconds(Timeout);
452 auto data = DataProc.load(Verbose);
454 std::shared_ptr<postproc::PostProcessFactory> permute_post_proc;
455 if(ReorderFeatures) {
464 runner.set_logger(spdlog::default_logger());
466 auto config = make_config(data);
468 std::shared_ptr<postproc::PostProcessFactory> post_proc;
469 bool use_sparse_model =
false;
470 switch (SaveOptions.Format) {
471 case io::WeightFormat::SPARSE_TXT:
473 use_sparse_model =
true;
475 case io::WeightFormat::DENSE_TXT:
484 use_sparse_model =
true;
485 SaveOptions.Culling = 1e-10;
489 if(permute_post_proc) {
493 post_proc = permute_post_proc;
499 BatchSize = data->num_labels();
503 spdlog::info(
"handled preprocessing in {} seconds",
504 std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
508 spdlog::info(
"Start training");
510 std::optional<io::PartialModelLoader> loader;
512 loader.emplace(SourceModel);
518 label_id_t next_label = std::min(LabelsEnd, first_label + BatchSize);
519 std::future<io::model::WeightFileEntry> saving;
521 config.PostProcessing = post_proc;
522 config.Sparse = use_sparse_model;
525 spdlog::info(
"Starting batch {} - {}", first_label.
to_index(), next_label.
to_index());
527 if(loader.has_value()) {
528 auto initial_weights = loader->load_model(first_label, next_label);
533 runner.set_time_limit(std::chrono::duration_cast<std::chrono::milliseconds>(timeout_time - std::chrono::steady_clock::now()));
537 train_spec->set_logger(spdlog::default_logger());
540 first_label, next_label);
555 saver.update_meta_file();
558 saving = saver.add_model(result.Model);
560 first_label = next_label;
561 if(first_label == LabelsEnd) {
564 saver.update_meta_file();
567 next_label = std::min(LabelsEnd, first_label + BatchSize);
570 if(next_label + BatchSize/2 > LabelsEnd) {
571 next_label = LabelsEnd;
575 spdlog::info(
"program finished after {} seconds", std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
void setup_source_cmdline()
std::string WeightingNegFile
std::filesystem::path SourceModel
void setup_hyper_params()
std::string WeightingMode
std::optional< real_t > BiasInitValue
void setup_regularization()
CLI::Option * PreTrainedOpt
void setup_save_cmdline()
CascadeTrainingConfig make_config(const std::shared_ptr< MultiLabelData > &data, std::shared_ptr< const GenericFeatureMatrix > dense)
int run(int argc, const char **argv)
std::string WeightingPosFile
This class represents a set of hyper-parameters.
Manage saving a model consisting of multiple partial models.
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
building blocks for io procedures that are used by multiple io subsystems
std::shared_ptr< WeightInitializationStrategy > create_pretrained_initializer(std::shared_ptr< model::Model > model)
Creates an initialization strategy that uses an already trained model to set the initial weights.
std::shared_ptr< WeightInitializationStrategy > create_feature_mean_initializer(std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
std::shared_ptr< WeightInitializationStrategy > create_multi_pos_mean_strategy(std::shared_ptr< DatasetBase > data, int max_pos, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
std::shared_ptr< WeightInitializationStrategy > create_constant_initializer(DenseRealVector vec)
std::shared_ptr< WeightInitializationStrategy > create_ova_primal_initializer(const std::shared_ptr< DatasetBase > &data, RegularizerSpec regularizer, LossType loss)
@ DENSE_TXT
Dense Text Format
std::istream & read_vector_from_text(std::istream &stream, Eigen::Ref< DenseRealVector > data)
Reads the given vector as space-separated human-readable numbers.
void binary_load(std::streambuf &target, T *begin, T *end)
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
NpyHeaderData parse_npy_header(std::streambuf &source)
Parses the header of the npy file given by source.
FactoryPtr create_combined(std::vector< FactoryPtr > processor)
FactoryPtr create_reordering(Eigen::PermutationMatrix< Eigen::Dynamic, Eigen::Dynamic, int > ordering)
FactoryPtr create_sparsify(real_t tolerance)
FactoryPtr create_culling(real_t eps)
Main namespace in which all types, classes, and functions are defined.
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
types::DenseVector< real_t > DenseRealVector
Any dense, real values vector.
std::shared_ptr< TrainingSpec > create_dismec_training(std::shared_ptr< const DatasetBase > data, HyperParameters params, DismecTrainingConfig config)
Eigen::PermutationMatrix< Eigen::Dynamic, Eigen::Dynamic, int > sort_features_by_frequency(DatasetBase &data)
DenseRealVector get_mean_feature(const GenericFeatureMatrix &features)
float real_t
The default type for floating point values.
RegularizerSpec Regularizer
std::shared_ptr< init::WeightInitializationStrategy > Init
std::shared_ptr< WeightingScheme > Weighting
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
int main(int argc, const char **argv)