16 #include "CLI/CLI.hpp"
18 #include "spdlog/spdlog.h"
19 #include "nlohmann/json.hpp"
24 prediction::MacroMetricReporter*
add_macro_metrics(prediction::EvaluateMetrics& metrics,
int k) {
25 auto* macro = metrics.add_macro_at_k(k);
27 macro->add_confusion_matrix();
60 metrics.add_precision_at_k(1);
61 metrics.add_abandonment_at_k(1);
62 metrics.add_dcg_at_k(1,
false);
63 metrics.add_dcg_at_k(1,
true);
68 metrics.add_precision_at_k(3);
69 metrics.add_abandonment_at_k(3);
70 metrics.add_dcg_at_k(3,
false);
71 metrics.add_dcg_at_k(3,
true);
75 metrics.add_precision_at_k(5);
76 metrics.add_abandonment_at_k(5);
77 metrics.add_dcg_at_k(5,
false);
78 metrics.add_dcg_at_k(5,
true);
84 int main(
int argc,
const char** argv) {
85 CLI::App app{
"DiSMEC"};
87 std::string problem_file;
88 std::string model_file;
89 std::string result_file;
90 std::string labels_file;
91 std::filesystem::path save_metrics;
94 bool save_as_npy =
false;
99 app.add_option(
"model-file", model_file,
"The file from which the model will be read.")->required()->check(CLI::ExistingFile);;
100 app.add_option(
"result-file", result_file,
"The file to which the predictions will be written.")->required();
101 app.add_option(
"--threads", threads,
"Number of threads to use. -1 means auto-detect");
102 app.add_option(
"--save-metrics", save_metrics,
"Target file in which the metric values are saved");
103 app.add_option(
"--topk, --top-k", top_k,
"Only the top k predictions will be saved. "
104 "Set to -1 if you need all predictions. (Warning: This may result in very large files!)");
105 app.add_flag(
"--save-as-npy", save_as_npy,
"Save the predictions as a numpy file instead of plain text.");
107 app.add_flag(
"-v", Verbose);
110 app.parse(argc, argv);
111 }
catch (
const CLI::ParseError &e) {
115 auto test_set = DataProc.
load(Verbose);
131 spdlog::error(
"No weight files");
135 spdlog::info(
"Calculating top-{} predictions", top_k);
138 std::vector<std::vector<label_id_t>> examples_to_labels(test_set->num_examples());
140 for(
auto example : test_set->get_label_instances(label)) {
141 examples_to_labels[example].push_back(label);
145 auto initial_model = loader.
load_model(wf_it);
146 spdlog::info(
"Using {} representation for model weights", initial_model->has_sparse_weights() ?
"sparse" :
"dense");
148 prediction::TopKPredictionTaskGenerator task(test_set.get(), initial_model, top_k);
151 auto preload_weights = std::async(std::launch::async, [iter=wf_it, &loader]() {
153 return loader.load_model(iter);
155 return std::shared_ptr<dismec::model::Model>{};
158 auto result = runner.
run(task);
159 if(!result.IsFinished) {
160 spdlog::error(
"Something went wrong, prediction computation was not finished!");
163 spdlog::info(
"Finished prediction in {}s", result.Duration.count());
167 task.update_model(preload_weights.get());
170 spdlog::info(
"Saving to '{}'", result_file);
172 task.get_top_k_values(),
173 task.get_top_k_indices());
175 prediction::EvaluateMetrics metrics{&examples_to_labels, &task.get_top_k_indices(), test_set->num_labels()};
178 spdlog::info(
"Calculating metrics");
180 auto result_info = runner.
run(metrics);
181 spdlog::info(
"Calculated metrics in {}ms", std::chrono::duration_cast<std::chrono::milliseconds>(result_info.Duration).count());
184 std::vector<std::pair<std::string, double>> results =metrics.get_metrics();
185 std::sort(results.begin(), results.end());
187 for(
const auto& [name, value] : results ) {
188 std::cout << fmt::format(
"{:15} = {:.4}", name, value) <<
"\n";
191 if(!save_metrics.empty()) {
193 for(
const auto& [name, value] : results ) {
196 std::ofstream file(save_metrics);
197 file << std::setw(4) << data;
200 const auto& cm = task.get_confusion_matrix();
205 std::int64_t total = tp + fp + tn + fn;
207 std::cout << fmt::format(
"Confusion matrix is: \n"
208 "TP: {:15L} FP: {:15L}\n"
209 "FN: {:15L} TN: {:15L}\n", tp, fp, fn, tn);
214 auto percentage = [](std::int64_t enumerator, std::int64_t denominator) {
215 std::int64_t base_result = (std::int64_t{10'000} * enumerator) / denominator;
216 return double(base_result) / 100.0;
219 std::cout << fmt::format(
"Accuracy: {:.3}%\n", percentage(tp + tn, total));
220 std::cout << fmt::format(
"Precision: {:.3}%\n", percentage(tp, tp + fp));
221 std::cout << fmt::format(
"Recall: {:.3}%\n", percentage(tp, tp + fn));
222 std::cout << fmt::format(
"F1: {:.3}%\n", percentage(tp, tp + (fp + fn) / 2));
225 spdlog::info(
"Reading model file from '{}'", model_file);
228 spdlog::info(
"Calculating full predictions");
229 prediction::FullPredictionTaskGenerator task(test_set.get(), model);
230 auto result = runner.run(task);
231 if(!result.IsFinished) {
232 spdlog::error(
"Something went wrong, prediction computation was not finished!");
235 const auto& predictions = task.get_predictions();
std::shared_ptr< MultiLabelData > load(int verbose)
void setup_data_args(CLI::App &app)
This class allows loading only a subset of the weights of a large model.
std::shared_ptr< Model > load_model(label_id_t label_begin, label_id_t label_end) const
Loads part of the model.
long num_weight_files() const
Returns the number of availabel weight files.
bool validate() const
Validates that all weight files exist.
Strong typedef for an int to signify a label id.
constexpr T to_index() const
! Explicitly convert to an integer.
void set_logger(std::shared_ptr< spdlog::logger > logger)
sets the logger object that is used for reporting. Set to nullptr for quiet mode.
RunResult run(TaskGenerator &tasks, long start=0)
void set_chunk_size(long chunk_size)
void add_coverage(double threshold, std::string name={})
static constexpr const int TRUE_POSITIVES
static constexpr const int TRUE_NEGATIVES
static constexpr const int FALSE_POSITIVES
static constexpr const int FALSE_NEGATIVES
std::shared_ptr< Model > load_model(path source)
void save_dense_predictions_as_txt(const path &target, const PredictionMatrix &values)
Saves predictions as a dense txt matrix.
void save_dense_predictions_as_npy(const path &target, const PredictionMatrix &values)
Saves predictions as a dense npy file.
void save_sparse_predictions(const path &target, const PredictionMatrix &values, const IndexMatrix &indices)
Saves sparse predictions as a text file.
Main namespace in which all types, classes, and functions are defined.
constexpr const int PREDICTION_RUN_CHUNK_SIZE
Default chunk size for predicting scores.
constexpr const int PREDICTION_METRICS_CHUNK_SIZE
Default chunk size for calculating metrics.
int main(int argc, const char **argv)
prediction::MacroMetricReporter * add_macro_metrics(prediction::EvaluateMetrics &metrics, int k)
void setup_metrics(prediction::EvaluateMetrics &metrics, int top_k)