DiSMEC++
predict.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "prediction/metrics.h"
7 #include "parallel/runner.h"
9 #include "prediction/evaluate.h"
10 #include "model/model.h"
11 #include "data/data.h"
12 #include "data/transform.h"
13 #include "io/model-io.h"
14 #include "io/prediction.h"
15 #include "io/xmc.h"
16 #include "CLI/CLI.hpp"
17 #include "app.h"
18 #include "spdlog/spdlog.h"
19 #include "nlohmann/json.hpp"
20 
21 using namespace dismec;
22 
23 
24 prediction::MacroMetricReporter* add_macro_metrics(prediction::EvaluateMetrics& metrics, int k) {
25  auto* macro = metrics.add_macro_at_k(k);
26  macro->add_coverage(0.0);
27  macro->add_confusion_matrix();
28  macro->add_precision(prediction::MacroMetricReporter::MACRO);
29  macro->add_precision(prediction::MacroMetricReporter::MICRO);
30  macro->add_recall(prediction::MacroMetricReporter::MACRO);
31  macro->add_recall(prediction::MacroMetricReporter::MICRO);
32  macro->add_f_measure(prediction::MacroMetricReporter::MACRO);
33  macro->add_f_measure(prediction::MacroMetricReporter::MICRO);
34  macro->add_accuracy(prediction::MacroMetricReporter::MICRO);
35  macro->add_accuracy(prediction::MacroMetricReporter::MACRO);
36  macro->add_balanced_accuracy(prediction::MacroMetricReporter::MICRO);
37  macro->add_balanced_accuracy(prediction::MacroMetricReporter::MACRO);
38  macro->add_specificity(prediction::MacroMetricReporter::MICRO);
39  macro->add_specificity(prediction::MacroMetricReporter::MACRO);
40  macro->add_informedness(prediction::MacroMetricReporter::MICRO);
41  macro->add_informedness(prediction::MacroMetricReporter::MACRO);
42  macro->add_markedness(prediction::MacroMetricReporter::MICRO);
43  macro->add_markedness(prediction::MacroMetricReporter::MACRO);
44  macro->add_fowlkes_mallows(prediction::MacroMetricReporter::MICRO);
45  macro->add_fowlkes_mallows(prediction::MacroMetricReporter::MACRO);
46  macro->add_negative_predictive_value(prediction::MacroMetricReporter::MICRO);
47  macro->add_negative_predictive_value(prediction::MacroMetricReporter::MACRO);
48  macro->add_matthews(prediction::MacroMetricReporter::MICRO);
49  macro->add_matthews(prediction::MacroMetricReporter::MACRO);
50  macro->add_positive_likelihood_ratio(prediction::MacroMetricReporter::MICRO);
51  macro->add_positive_likelihood_ratio(prediction::MacroMetricReporter::MACRO);
52  macro->add_negative_likelihood_ratio(prediction::MacroMetricReporter::MICRO);
53  macro->add_negative_likelihood_ratio(prediction::MacroMetricReporter::MACRO);
54  macro->add_diagnostic_odds_ratio(prediction::MacroMetricReporter::MICRO);
55  macro->add_diagnostic_odds_ratio(prediction::MacroMetricReporter::MACRO);
56  return macro;
57 };
58 
59 void setup_metrics(prediction::EvaluateMetrics& metrics, int top_k) {
60  metrics.add_precision_at_k(1);
61  metrics.add_abandonment_at_k(1);
62  metrics.add_dcg_at_k(1, false);
63  metrics.add_dcg_at_k(1, true);
64 
65  add_macro_metrics(metrics, 1);
66 
67  if(top_k >= 3) {
68  metrics.add_precision_at_k(3);
69  metrics.add_abandonment_at_k(3);
70  metrics.add_dcg_at_k(3, false);
71  metrics.add_dcg_at_k(3, true);
72  add_macro_metrics(metrics, 3);
73  }
74  if(top_k >= 5) {
75  metrics.add_precision_at_k(5);
76  metrics.add_abandonment_at_k(5);
77  metrics.add_dcg_at_k(5, false);
78  metrics.add_dcg_at_k(5, true);
79  add_macro_metrics(metrics, 5);
80  }
81 }
82 
83 
84 int main(int argc, const char** argv) {
85  CLI::App app{"DiSMEC"};
86 
87  std::string problem_file;
88  std::string model_file;
89  std::string result_file;
90  std::string labels_file;
91  std::filesystem::path save_metrics;
92  int threads = -1;
93  int top_k = 5;
94  bool save_as_npy = false;
95 
96  DataProcessing DataProc;
97  DataProc.setup_data_args(app);
98 
99  app.add_option("model-file", model_file, "The file from which the model will be read.")->required()->check(CLI::ExistingFile);;
100  app.add_option("result-file", result_file, "The file to which the predictions will be written.")->required();
101  app.add_option("--threads", threads, "Number of threads to use. -1 means auto-detect");
102  app.add_option("--save-metrics", save_metrics, "Target file in which the metric values are saved");
103  app.add_option("--topk, --top-k", top_k, "Only the top k predictions will be saved. "
104  "Set to -1 if you need all predictions. (Warning: This may result in very large files!)");
105  app.add_flag("--save-as-npy", save_as_npy, "Save the predictions as a numpy file instead of plain text.");
106  int Verbose = 0;
107  app.add_flag("-v", Verbose);
108 
109  try {
110  app.parse(argc, argv);
111  } catch (const CLI::ParseError &e) {
112  return app.exit(e);
113  }
114 
115  auto test_set = DataProc.load(Verbose);
116 
117  parallel::ParallelRunner runner(threads);
118  if(Verbose > 0)
119  runner.set_logger(spdlog::default_logger());
120 
122 
123  if(top_k > 0) {
125  if(!loader.validate()) {
126  return EXIT_FAILURE;
127  }
128 
129  int wf_it = 0;
130  if(loader.num_weight_files() == 0) {
131  spdlog::error("No weight files");
132  return EXIT_FAILURE;
133  }
134 
135  spdlog::info("Calculating top-{} predictions", top_k);
136 
137  // generate a transpose of the label matrix
138  std::vector<std::vector<label_id_t>> examples_to_labels(test_set->num_examples());
139  for(label_id_t label{0}; label.to_index() < test_set->num_labels(); ++label) {
140  for(auto example : test_set->get_label_instances(label)) {
141  examples_to_labels[example].push_back(label);
142  }
143  }
144 
145  auto initial_model = loader.load_model(wf_it);
146  spdlog::info("Using {} representation for model weights", initial_model->has_sparse_weights() ? "sparse" : "dense");
147 
148  prediction::TopKPredictionTaskGenerator task(test_set.get(), initial_model, top_k);
149  while(true) {
150  ++wf_it;
151  auto preload_weights = std::async(std::launch::async, [iter=wf_it, &loader]() {
152  if(iter != loader.num_weight_files()) {
153  return loader.load_model(iter);
154  } else {
155  return std::shared_ptr<dismec::model::Model>{};
156  }
157  });
158  auto result = runner.run(task);
159  if(!result.IsFinished) {
160  spdlog::error("Something went wrong, prediction computation was not finished!");
161  std::exit(1);
162  }
163  spdlog::info("Finished prediction in {}s", result.Duration.count());
164  if(wf_it == loader.num_weight_files()) {
165  break;
166  }
167  task.update_model(preload_weights.get());
168  }
169 
170  spdlog::info("Saving to '{}'", result_file);
172  task.get_top_k_values(),
173  task.get_top_k_indices());
174 
175  prediction::EvaluateMetrics metrics{&examples_to_labels, &task.get_top_k_indices(), test_set->num_labels()};
176  setup_metrics(metrics, top_k);
177 
178  spdlog::info("Calculating metrics");
180  auto result_info = runner.run(metrics);
181  spdlog::info("Calculated metrics in {}ms", std::chrono::duration_cast<std::chrono::milliseconds>(result_info.Duration).count());
182 
183  // sort thew results and present them
184  std::vector<std::pair<std::string, double>> results =metrics.get_metrics();
185  std::sort(results.begin(), results.end());
186 
187  for(const auto& [name, value] : results ) {
188  std::cout << fmt::format("{:15} = {:.4}", name, value) << "\n";
189  }
190 
191  if(!save_metrics.empty()) {
192  nlohmann::json data;
193  for(const auto& [name, value] : results ) {
194  data[name] = value;
195  }
196  std::ofstream file(save_metrics);
197  file << std::setw(4) << data;
198  }
199 
200  const auto& cm = task.get_confusion_matrix();
205  std::int64_t total = tp + fp + tn + fn;
206 
207  std::cout << fmt::format("Confusion matrix is: \n"
208  "TP: {:15L} FP: {:15L}\n"
209  "FN: {:15L} TN: {:15L}\n", tp, fp, fn, tn);
210 
211  // calculates a percentage with decimals for extremely large integers.
212  // we do the division still as integers, with two additional digits,
213  // and only then convert to floating point.
214  auto percentage = [](std::int64_t enumerator, std::int64_t denominator) {
215  std::int64_t base_result = (std::int64_t{10'000} * enumerator) / denominator;
216  return double(base_result) / 100.0;
217  };
218 
219  std::cout << fmt::format("Accuracy: {:.3}%\n", percentage(tp + tn, total));
220  std::cout << fmt::format("Precision: {:.3}%\n", percentage(tp, tp + fp));
221  std::cout << fmt::format("Recall: {:.3}%\n", percentage(tp, tp + fn));
222  std::cout << fmt::format("F1: {:.3}%\n", percentage(tp, tp + (fp + fn) / 2));
223 
224  } else {
225  spdlog::info("Reading model file from '{}'", model_file);
226  auto model = io::load_model(model_file);
227 
228  spdlog::info("Calculating full predictions");
229  prediction::FullPredictionTaskGenerator task(test_set.get(), model);
230  auto result = runner.run(task);
231  if(!result.IsFinished) {
232  spdlog::error("Something went wrong, prediction computation was not finished!");
233  std::exit(1);
234  }
235  const auto& predictions = task.get_predictions();
236 
237  if(save_as_npy) {
238  io::prediction::save_dense_predictions_as_npy(result_file, predictions);
239  } else {
240  io::prediction::save_dense_predictions_as_txt(result_file, predictions);
241  }
242  }
243 }
std::shared_ptr< MultiLabelData > load(int verbose)
Definition: app.cpp:47
void setup_data_args(CLI::App &app)
Definition: app.cpp:14
This class allows loading only a subset of the weights of a large model.
Definition: model-io.h:317
std::shared_ptr< Model > load_model(label_id_t label_begin, label_id_t label_end) const
Loads part of the model.
Definition: model-io.cpp:373
long num_weight_files() const
Returns the number of availabel weight files.
Definition: model-io.cpp:395
bool validate() const
Validates that all weight files exist.
Definition: model-io.cpp:441
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
void set_logger(std::shared_ptr< spdlog::logger > logger)
sets the logger object that is used for reporting. Set to nullptr for quiet mode.
Definition: runner.cpp:28
RunResult run(TaskGenerator &tasks, long start=0)
Definition: runner.cpp:39
void set_chunk_size(long chunk_size)
Definition: runner.cpp:24
void add_coverage(double threshold, std::string name={})
Definition: metrics.cpp:179
static constexpr const int TRUE_POSITIVES
Definition: prediction.h:101
static constexpr const int TRUE_NEGATIVES
Definition: prediction.h:103
static constexpr const int FALSE_POSITIVES
Definition: prediction.h:102
static constexpr const int FALSE_NEGATIVES
Definition: prediction.h:104
nlohmann::json json
Definition: model-io.cpp:22
std::shared_ptr< Model > load_model(path source)
Definition: model-io.cpp:334
void save_dense_predictions_as_txt(const path &target, const PredictionMatrix &values)
Saves predictions as a dense txt matrix.
void save_dense_predictions_as_npy(const path &target, const PredictionMatrix &values)
Saves predictions as a dense npy file.
void save_sparse_predictions(const path &target, const PredictionMatrix &values, const IndexMatrix &indices)
Saves sparse predictions as a text file.
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
constexpr const int PREDICTION_RUN_CHUNK_SIZE
Default chunk size for predicting scores.
Definition: config.h:39
constexpr const int PREDICTION_METRICS_CHUNK_SIZE
Default chunk size for calculating metrics.
Definition: config.h:42
int main(int argc, const char **argv)
Definition: predict.cpp:84
prediction::MacroMetricReporter * add_macro_metrics(prediction::EvaluateMetrics &metrics, int k)
Definition: predict.cpp:24
void setup_metrics(prediction::EvaluateMetrics &metrics, int top_k)
Definition: predict.cpp:59