17 #include "CLI/CLI.hpp"
18 #include "spdlog/spdlog.h"
21 #include "spdlog/stopwatch.h"
29 int run(
int argc,
const char** argv);
31 CLI::App app{
"DiSMEC-Cascade"};
35 void setup_source_cmdline();
41 void setup_save_cmdline();
46 void setup_label_range();
49 bool ContinueRun =
false;
51 void parse_label_range();
59 void setup_hyper_params();
63 std::string StatsOutFile =
"stats.json";
64 std::string StatsLevelFile = {};
67 bool NormalizeSparse =
false;
68 bool NormalizeDense =
false;
74 bool InitSparseMSI =
false;
75 bool InitDenseMSI =
false;
82 bool AugmentDenseWithBias =
false;
83 bool AugmentSparseWithBias =
false;
95 CascadeTrainingConfig make_config(
const std::shared_ptr<MultiLabelData>& data, std::shared_ptr<const GenericFeatureMatrix> dense);
98 int main(
int argc,
const char** argv) {
101 program.
run(argc, argv);
107 SaveOptions.Culling = 0.01;
109 app.add_option(
"output,--model-file", ModelFile,
110 "The file to which the model will be written. Note that models are saved in multiple different files, so this"
111 "just specifies the base name of the metadata file.")->required();
113 app.add_option(
"--weight-culling", SaveOptions.Culling,
114 "When saving in a sparse format, any weight lower than this will be omitted.")->check(CLI::NonNegativeNumber);
116 app.add_option(
"--save-precision", SaveOptions.Precision,
117 "The number of digits to write for real numbers in text file format.")->check(CLI::NonNegativeNumber);
122 app.add_option(
"tfidf-file", TfIdfFile,
123 "The file from which the tfidf data will be loaded.")->required()->check(CLI::ExistingFile);
124 app.add_option(
"dense-file", DenseFile,
125 "The file from which the dense data will be loaded.")->required()->check(CLI::ExistingFile);
126 app.add_option(
"--shortlist", ShortlistFile,
127 "A file containing the shortlist of hard-negative instances for each label.")->check(CLI::ExistingFile);
132 FirstLabelOpt = app.add_option(
"--first-label", FirstLabel,
133 "If you want to train only a subset of labels, this is the id of the first label to be trained."
134 "The subset of labels trained is `[first_label, first_label + num-labels)`")->check(CLI::NonNegativeNumber);
135 NumLabelsOpt = app.add_option(
"--num-labels", NumLabels,
136 "If you want to train only a subset of labels, this is the total number of labels to be trained.")->check(CLI::NonNegativeNumber);
137 app.add_flag(
"--continue", ContinueRun,
138 "If this flag is given, the new weights will be appended to the model "
139 "file, instead of overwriting it. You can use the --first-label option to explicitly specify "
140 "at which label to start. If omitted, training starts at the first label for which no "
141 "weight vector is known.");
148 hps.set(
"epsilon", 0.01);
150 auto add_hyper_param_option = [&](
const char* option,
const char* name,
const char* desc) {
151 return app.add_option_function<
double>(
153 [
this, name](
double value) { hps.set(name, value); },
154 desc)->group(
"hyper-parameters");
157 add_hyper_param_option(
"--epsilon",
"epsilon",
158 "Tolerance for the minimizer. Will be adjusted by the number of positive/negative instances")
159 ->check(CLI::NonNegativeNumber);
161 add_hyper_param_option(
"--alpha-pcg",
"alpha-pcg",
162 "Interpolation parameter for preconditioning of CG optimization.")->check(CLI::Range(0.0, 1.0));
164 add_hyper_param_option(
"--line-search-step-size",
"search.step-size",
165 "Step size for the line search.")->check(CLI::NonNegativeNumber);
167 add_hyper_param_option(
"--line-search-alpha",
"search.alpha",
168 "Shrink factor for updating the line search step")->check(CLI::Range(0.0, 1.0));
170 add_hyper_param_option(
"--line-search-eta",
"search.eta",
171 "Acceptance criterion for the line search")->check(CLI::Range(0.0, 1.0));
172 add_hyper_param_option(
"--cg-epsilon",
"cg.epsilon",
173 "Stopping criterion for the CG solver")->check(CLI::PositiveNumber);
175 app.add_option_function<
long>(
177 [
this](
long value) { hps.set(
"max-steps", value); },
178 "Maximum number of newton steps.")->check(CLI::PositiveNumber)->group(
"hyper-parameters");
180 app.add_option_function<
long>(
181 "--line-search-max-steps",
182 [
this](
long value) { hps.set(
"search.max-steps", value); },
183 "Maximum number of line search steps.")->check(CLI::PositiveNumber)->group(
"hyper-parameters");
192 if(FirstLabelOpt->count() == 0)
195 spdlog::info(
"Model is missing weight vectors {} to {}.", missing.first.to_index(), missing.second.to_index() - 1);
196 LabelsBegin = missing.first;
197 LabelsEnd = missing.second;
198 if (NumLabelsOpt->count() > 0) {
199 if (LabelsEnd - LabelsBegin >= NumLabels) {
200 LabelsEnd = LabelsBegin + NumLabels;
202 spdlog::warn(
"Number of labels to train was specified as {}, but only {} labels will be trained",
203 NumLabels, LabelsEnd - LabelsBegin);
210 if (NumLabelsOpt->count() > 0) {
211 LabelsBegin = LabelsBegin + NumLabels;
214 spdlog::error(
"Specified continuation of training weight vectors for labels {}-{}, "
215 "which overlaps with existing weight vectors.", LabelsBegin.to_index(), LabelsEnd.to_index()-1);
227 if(FirstLabelOpt->count()) {
233 if (NumLabelsOpt->count() > 0) {
234 LabelsEnd = LabelsBegin + NumLabels;
241 setup_source_cmdline();
242 setup_save_cmdline();
244 setup_hyper_params();
246 app.add_option(
"--threads", NumThreads,
"Number of threads to use. -1 means auto-detect");
247 app.add_option(
"--batch-size", BatchSize,
"If this is given, training is split into batches "
248 "and results are written to disk after each batch.");
249 app.add_option(
"--timeout", Timeout,
"No new training tasks will be started after this time. "
250 "This can be used e.g. on a cluster system to ensure that the training finishes properly "
251 "even if not all work could be done in the allotted time.")
252 ->transform(CLI::AsNumberWithUnit(std::map<std::string, float>{{
"ms", 1},
253 {
"s", 1'000}, {
"sec", 1'000},
254 {
"m", 60'000}, {
"min", 60'000},
256 CLI::AsNumberWithUnit::UNIT_REQUIRED,
"TIME"));
258 app.add_option(
"--record-stats", StatsLevelFile,
259 "Record some statistics and save to file. The argument is a json file which describes which statistics are gathered.")
260 ->check(CLI::ExistingFile);
261 app.add_option(
"--stats-file", StatsOutFile,
"Target file for recorded statistics");
262 app.add_option(
"--init-dense-weights", DenseWeightsFile,
"File from which the initial weights for the dense part will be loaded.")->check(CLI::ExistingFile);
263 app.add_option(
"--init-dense-biases", DenseBiasesFile,
"File from which the initial biases for the dense part will be loaded.")->check(CLI::ExistingFile);
264 app.add_flag(
"--init-sparse-msi", InitSparseMSI,
"If this flag is given, then the sparse part will use mean-separating initialization.");
265 app.add_flag(
"--init-dense-msi", InitDenseMSI,
"If this flag is given, then the dense part will use mean-separating initialization.");
267 app.add_option(
"--sparse-reg-scale", RegScaleSparse,
"Scaling factor for the sparse-part regularizer")->check(CLI::NonNegativeNumber);
268 app.add_option(
"--dense-reg-scale", RegScaleDense,
"Scaling factor for the dense-part regularizer")->check(CLI::NonNegativeNumber);
269 app.add_flag(
"--normalize-dense", NormalizeDense,
"Normalize the dense part of the feature matrix");
270 app.add_flag(
"--normalize-sparse", NormalizeSparse,
"Normalize the sparse part of the feature matrix");
271 app.add_option(
"--transform-sparse", TransformSparse,
"Apply a transformation to the sparse features.")->default_str(
"identity")
272 ->transform(CLI::Transformer(std::map<std::string, DatasetTransform>{
277 },CLI::ignore_case));
279 app.add_flag(
"--augment-dense-bias", AugmentDenseWithBias,
"Add an additional feature column to the dense matrix with values one.");
280 app.add_flag(
"--augment-sparse-bias", AugmentSparseWithBias,
"Add an additional feature column to the sparse matrix with values one.");
282 app.add_option(
"--export-dataset", ExportProcessedData,
283 "Exports the preprocessed dataset to the given file.");
284 app.add_flag(
"-v,-q{-1}", Verbose);
288 std::shared_ptr<const GenericFeatureMatrix> dense_data) {
293 SparseFeatures new_sparse(data->num_examples(), data->num_features() + dense.cols());
294 new_sparse.reserve(sparse.nonZeros() + dense.size());
295 for (
int k=0; k < data->num_examples(); ++k) {
296 new_sparse.startVec(k);
297 for (DenseFeatures::InnerIterator it(dense, k); it; ++it) {
298 new_sparse.insertBack(it.row(), it.col()) = it.value();
300 for (SparseFeatures::InnerIterator it(sparse, k); it; ++it) {
301 new_sparse.insertBack(it.row(), it.col() + dense.cols()) = it.value();
304 new_sparse.finalize();
305 return {new_sparse, data->all_labels()};
309 std::shared_ptr<const GenericFeatureMatrix> dense) {
315 if(!DenseWeightsFile.empty()) {
317 spdlog::error(
"Cannot use MSI and pretrained weights at the same time!");
320 if(DenseBiasesFile.empty()) {
325 }
else if(InitDenseMSI) {
326 auto dense_ds = std::make_shared<MultiLabelData>(dense->dense(), data->all_labels());
330 config.
StatsGatherer = std::make_shared<TrainingStatsGatherer>(StatsLevelFile, StatsOutFile);
338 app.parse(argc, argv);
339 }
catch (
const CLI::ParseError &e) {
344 auto parent = std::filesystem::absolute(ModelFile).parent_path();
345 if(!std::filesystem::exists(parent)) {
346 spdlog::warn(
"Save directory '{}' does not exist. Trying to create it.", parent.c_str());
347 std::filesystem::create_directories(parent);
348 if(!std::filesystem::exists(parent)) {
349 spdlog::error(
"Could not create directory -- exiting.");
358 auto start_time = std::chrono::steady_clock::now();
359 auto timeout_time = start_time + std::chrono::milliseconds(Timeout);
361 spdlog::info(
"Loading training data from file '{}'", TfIdfFile);
362 auto data = std::make_shared<MultiLabelData>([&]() {
367 spdlog::info(
"Applying data transformation");
371 if(NormalizeSparse) {
372 spdlog::stopwatch timer;
374 spdlog::info(
"Normalized sparse features in {:.3} seconds.", timer);
376 if(AugmentSparseWithBias) {
377 spdlog::stopwatch timer;
379 spdlog::info(
"Added bias column to sparse features in {:.3} seconds.", timer);
385 spdlog::stopwatch timer;
387 spdlog::info(
"Normalized dense features in {:.3} seconds.", timer);
390 if(AugmentDenseWithBias) {
391 spdlog::stopwatch timer;
393 spdlog::info(
"Added bias column to dense features in {:.3} seconds.", timer);
397 if(!ExportProcessedData.empty()) {
398 spdlog::stopwatch timer;
399 auto exported =
join_data(data, dense_data);
401 spdlog::info(
"Saved preprocessed data to {} in {:.3} seconds", ExportProcessedData.string(), timer);
405 std::shared_ptr<const std::vector<std::vector<long>>> shortlist;
406 if(!ShortlistFile.empty()) {
407 auto stream = std::fstream(ShortlistFile, std::fstream::in);
409 if(result.NumCols != data->num_labels()) {
410 spdlog::error(
"Mismatch between number of labels in shortlist {} and in dataset {}",
411 result.NumCols, data->num_labels());
414 if(result.NumRows != data->num_examples()) {
415 spdlog::error(
"Mismatch between number of examples in shortlist {} and in dataset {}",
416 result.NumRows, data->num_examples());
420 shortlist = std::make_shared<std::vector<std::vector<long>>>(std::move(result.NonZeros));
427 runner.set_logger(spdlog::default_logger());
429 auto config = make_config(data, dense_data);
432 SaveOptions.Culling = 1e-10;
435 BatchSize = data->num_labels();
439 spdlog::info(
"handled preprocessing in {} seconds",
440 std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
444 spdlog::info(
"Start training");
450 label_id_t next_label = std::min(LabelsEnd, first_label + BatchSize);
451 std::future<io::model::WeightFileEntry> saving;
453 config.PostProcessing = post_proc;
454 config.DenseReg = RegScaleDense;
455 config.SparseReg = RegScaleSparse;
458 spdlog::info(
"Starting batch {} - {}", first_label.
to_index(), next_label.
to_index());
461 runner.set_time_limit(std::chrono::duration_cast<std::chrono::milliseconds>(timeout_time - std::chrono::steady_clock::now()));
465 train_spec->set_logger(spdlog::default_logger());
468 first_label, next_label);
488 first_label = next_label;
489 if(first_label == LabelsEnd) {
495 next_label = std::min(LabelsEnd, first_label + BatchSize);
498 if(next_label + BatchSize/2 > LabelsEnd) {
499 next_label = LabelsEnd;
503 spdlog::info(
"program finished after {} seconds", std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start_time).count() );
int main(int argc, const char **argv)
MultiLabelData join_data(const std::shared_ptr< MultiLabelData > &data, std::shared_ptr< const GenericFeatureMatrix > dense_data)
void setup_source_cmdline()
io::model::SaveOption SaveOptions
std::string ShortlistFile
CLI::Option * FirstLabelOpt
std::filesystem::path DenseBiasesFile
std::filesystem::path ModelFile
void setup_hyper_params()
std::filesystem::path DenseWeightsFile
std::filesystem::path ExportProcessedData
void setup_save_cmdline()
CLI::Option * NumLabelsOpt
CascadeTrainingConfig make_config(const std::shared_ptr< MultiLabelData > &data, std::shared_ptr< const GenericFeatureMatrix > dense)
int run(int argc, const char **argv)
std::shared_ptr< const GenericFeatureMatrix > get_features() const
get a shared pointer to the (immutable) feature data
This class represents a set of hyper-parameters.
long num_labels() const noexcept
Gets the total number of labels.
Manage saving a model consisting of multiple partial models.
bool any_weight_vector_for_interval(label_id_t begin, label_id_t end) const
Checks if there are any weight vectors for the given interval.
std::pair< label_id_t, label_id_t > get_missing_weights() const
Get an interval labels for which weights are missing.
std::future< WeightFileEntry > add_model(const std::shared_ptr< const Model > &model, const std::optional< std::string > &file_path={})
Adds the weights of a partial model asynchronously.
void update_meta_file()
Updates the metadata file.
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
building blocks for io procedures that are used by multiple io subsystems
std::shared_ptr< WeightInitializationStrategy > create_feature_mean_initializer(std::shared_ptr< DatasetBase > data, real_t pos=1, real_t neg=-2)
Creates an initialization strategy based on the mean of positive and negative features.
std::shared_ptr< WeightInitializationStrategy > create_numpy_initializer(const std::filesystem::path &weights, std::optional< std::filesystem::path > biases)
Creates an initialization strategy that uses weights loaded from a npy file.
@ SPARSE_TXT
Sparse Text Format
MultiLabelData read_xmc_dataset(const std::filesystem::path &source, IndexMode mode=IndexMode::ZERO_BASED)
Reads a dataset given in the extreme multilabel classification format.
void save_xmc_dataset(std::ostream &target, const MultiLabelData &data)
Saves the given dataset in XMC format.
LoLBinarySparse read_binary_matrix_as_lol(std::istream &source)
@ ZERO_BASED
labels and feature indices are 0, 1, ..., num - 1
types::DenseRowMajor< real_t > load_matrix_from_npy(std::istream &source)
Loads a matrix from a numpy array.
FactoryPtr create_culling(real_t eps)
Main namespace in which all types, classes, and functions are defined.
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
TrainingResult run_training(parallel::ParallelRunner &runner, std::shared_ptr< TrainingSpec > spec, label_id_t begin_label=label_id_t{0}, label_id_t end_label=label_id_t{-1})
void normalize_instances(DatasetBase &data)
void augment_features_with_bias(DatasetBase &data, real_t bias=1)
void transform_features(DatasetBase &data, DatasetTransform transform)
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
std::shared_ptr< TrainingSpec > create_cascade_training(std::shared_ptr< const DatasetBase > data, std::shared_ptr< const GenericFeatureMatrix > dense, std::shared_ptr< const std::vector< std::vector< long >>> shortlist, HyperParameters params, CascadeTrainingConfig config)
float real_t
The default type for floating point values.
std::shared_ptr< TrainingStatsGatherer > StatsGatherer
std::shared_ptr< init::WeightInitializationStrategy > DenseInit
std::shared_ptr< init::WeightInitializationStrategy > SparseInit