DiSMEC++
labelstats.cpp
Go to the documentation of this file.
1 //
2 // Created by erik on 28.1.2022.
3 //
4 
5 #include "io/xmc.h"
6 #include "data/data.h"
7 #include "nlohmann/json.hpp"
8 #include "CLI/CLI.hpp"
9 #include "utils/conversion.h"
10 #include <numeric>
11 #include <random>
12 
13 using namespace dismec;
15 
16 double obesity(const std::vector<long>& values, int num_samples);
17 
18 int main(int argc, const char** argv) {
19  std::string DataSetFile;
20  std::string OutputFile;
21  bool OneBasedIndex = false;
22  CLI::App app{"labelstats"};
23  app.add_option("dataset", DataSetFile,
24  "The file from which the data will be loaded.")->required()->check(CLI::ExistingFile);
25  app.add_option("target", OutputFile,
26  "The file to which the result will be saved.")->required();
27 
28  app.add_flag("--one-based-index", OneBasedIndex,
29  "If this flag is given, then we assume that the input dataset in xmc format and"
30  " has one-based indexing, i.e. the first label and feature are at index 1 (as opposed to the usual 0)");
31 
32  try {
33  app.parse(argc, argv);
34  } catch (const CLI::ParseError &e) {
35  return app.exit(e);
36  }
37 
38  auto data = read_xmc_dataset(DataSetFile, OneBasedIndex ? io::IndexMode::ONE_BASED : io::IndexMode::ZERO_BASED);
39  std::vector<long> label_counts;
40  for(long id = 0; id < data.num_labels(); ++id) {
41  label_counts.push_back(static_cast<long>(data.num_positives(label_id_t{id})));
42  }
43 
44  std::sort(begin(label_counts), end(label_counts));
45 
46  json result;
47  result["num-labels"] = data.num_labels();
48  result["num-instances"] = data.num_examples();
49  result["most-frequent"] = label_counts.back();
50  result["least-frequent"] = label_counts.front();
51  result["intra-IR-min"] = double(data.num_examples()) / double(std::max(1l, label_counts.back()));
52  result["intra-IR-max"] = double(data.num_examples()) / double(std::max(1l, label_counts.front()));
53  result["inter-IR"] = double(label_counts.back()) / double(std::max(1l, label_counts.front()));
54 
55  // check where the 80-20 (and similar) rule would bring us
56  std::vector<long> cumulative;
57  std::partial_sum(label_counts.rbegin(), label_counts.rend(), std::back_inserter(cumulative));
58  int target = 10;
59  std::cout << cumulative[0] << " " << cumulative[1] << " " << cumulative[cumulative.size() - 1] << "\n";
60  for(int i = 0; i < ssize(cumulative); ++i) {
61  if(cumulative[i] / target >= cumulative.back() / 100) {
62  result["cumulative-" + std::to_string(target)] = i;
63  result["cumulative-rel-" + std::to_string(target)] = 100.0 * double(i) / double(data.num_labels());
64  target += 10;
65  }
66  }
67 
68  result["obesity"] = obesity(label_counts, 10000);
69 
70  std::fstream result_file(OutputFile, std::fstream::out);
71  result_file << std::setw(4) << result << "\n";
72 
73 
74 }
75 
76 
77 double obesity(const std::vector<long>& values, int num_samples) {
78  std::ranlux48 rng;
79  std::uniform_int_distribution<long> dist(0, ssize(values) - 1);
80  std::array<long, 4> sample{};
81  int larger = 0;
82  for(int i = 0; i < num_samples; ++i) {
83  for(auto& s : sample) s = dist(rng);
84  std::sort(begin(sample), end(sample));
85  if(values[sample[0]] + values[sample[3]] > values[sample[1]] + values[sample[2]]) {
86  ++larger;
87  }
88  }
89 
90  return double(larger) / double(num_samples / 100);
91 }
Strong typedef for an int to signify a label id.
Definition: types.h:20
double obesity(const std::vector< long > &values, int num_samples)
Definition: labelstats.cpp:77
int main(int argc, const char **argv)
Definition: labelstats.cpp:18
nlohmann::json json
Definition: labelstats.cpp:14
nlohmann::json json
Definition: model-io.cpp:22
const char * to_string(WeightFormat format)
Definition: model-io.cpp:89
MultiLabelData read_xmc_dataset(const std::filesystem::path &source, IndexMode mode=IndexMode::ZERO_BASED)
Reads a dataset given in the extreme multilabel classification format.
Definition: xmc.cpp:216
@ ONE_BASED
labels and feature indices are 1, 2, ..., num
@ ZERO_BASED
labels and feature indices are 0, 1, ..., num - 1
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
Definition: conversion.h:42