DiSMEC++
xmc.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "io/xmc.h"
7 #include "io/common.h"
8 #include "data/data.h"
9 #include <fstream>
10 #include "spdlog/spdlog.h"
11 #include "spdlog/fmt/fmt.h"
12 #include "spdlog/stopwatch.h"
13 
14 using namespace dismec;
15 
16 namespace {
18  struct XMCHeader {
21  long NumLabels;
22  };
23 
31  XMCHeader parse_xmc_header(const std::string& content) {
32  std::stringstream parse_header{content};
33  long NumExamples = -1;
34  long NumFeatures = -1;
35  long NumLabels = -1;
36 
37  parse_header >> NumExamples >> NumFeatures >> NumLabels;
38  if (parse_header.fail()) {
39  THROW_ERROR("Error parsing dataset header: '{}'", content);
40  }
41 
42  // check validity of numbers
43  if(NumExamples <= 0) {
44  THROW_ERROR("Invalid number of examples {} in specified in header '{}'", NumExamples, content);
45  }
46  if(NumFeatures <= 0) {
47  THROW_ERROR("Invalid number of features {} in specified in header '{}'", NumFeatures, content);
48  }
49  if(NumLabels <= 0) {
50  THROW_ERROR("Invalid number of labels {} in specified in header '{}'", NumLabels, content);
51  }
52 
53  std::string rest;
54  parse_header >> rest;
55  if(!rest.empty()) {
56  THROW_ERROR("Found additional text '{}' in header '{}'", rest, content);
57  }
58 
59  return {NumExamples, NumFeatures, NumLabels};
60  }
61 
75  std::vector<long> count_features_per_example(std::istream& source, std::size_t num_examples = 100'000)
76  {
77  std::string line_buffer;
78  std::vector<long> features_per_example;
79  features_per_example.reserve(num_examples);
80 
81  // next, we iterate over the entire dataset once to gather
82  // more statistics so we can pre-allocate the corresponding buffers
83  while (std::getline(source, line_buffer))
84  {
85  // we don't parse empty lines or comment lines
86  if (line_buffer.empty())
87  continue;
88  if(line_buffer.front() == '#')
89  continue;
90 
91  // features are denoted by index:value, so the number of features is equal to the
92  // number of colons in the string
93  long num_ftr = std::count(begin(line_buffer), end(line_buffer), ':');
94  features_per_example.push_back(num_ftr);
95  }
96 
97  return features_per_example;
98  }
99 
113  template<class F>
114  const char* parse_labels(const char* line, F&& callback) {
115  const char *last = line;
116  if (!std::isspace(*line)) {
117  // then read as many integers as we can, always skipping exactly one character between. If an integer
118  // is followed by a colon, it was a feature id
119  while (true) {
120  const char *result = nullptr;
121  errno = 0;
122  long read = dismec::io::parse_long(last, &result);
123  // was there a number to read?
124  if(result == last) {
125  THROW_ERROR("Error parsing label. Expected a number.");
126  } else if(errno != 0) {
127  THROW_ERROR("Error parsing label. Errno={}: '{}'", errno, strerror(errno));
128  }
129  if (*result == ',') {
130  // fine, more labels to come
131  } else if (std::isspace(*result) != 0 || *result == '\0') {
132  // fine, this was the last label
133  callback(read);
134  return result;
135  } else {
136  // everything else is not accepted
137  THROW_ERROR("Error parsing label. Expected ',', got '{}', '{}'", errno, *result ? *result : '0', line);
138  }
139  callback(read);
140  last = result + 1;
141  }
142  }
143  return last;
144  }
145 
163  template<long IndexOffset>
164  void read_into_buffers(std::istream& source,
165  SparseFeatures& feature_buffer,
166  std::vector<std::vector<long>>& label_buffer)
167  {
168  std::string line_buffer;
169  auto num_labels = ssize(label_buffer);
170  auto num_features = feature_buffer.cols();
171  auto num_examples = feature_buffer.rows();
172  long example = 0;
173 
174  while (std::getline(source, line_buffer)) {
175  if (line_buffer.empty())
176  continue;
177  if (line_buffer.front() == '#')
178  continue;
179 
180  if(example >= num_examples) {
181  THROW_ERROR("Encountered example number index {:5} but buffers only expect {:5} examples.", example, num_examples);
182  }
183 
184  try {
185  auto label_end = parse_labels(line_buffer.data(), [&](long lbl) {
186  long adjusted_label = lbl - IndexOffset;
187  if (adjusted_label >= num_labels || adjusted_label < 0) {
188  THROW_ERROR("Encountered label {:5}, but number of labels "
189  "was specified as {}.", lbl, num_labels);
190  }
191  label_buffer[adjusted_label].push_back(example);
192  });
193 
194  dismec::io::parse_sparse_vector_from_text(label_end, [&](long index, double value) {
195  long adjusted_index = index - IndexOffset;
196  if (adjusted_index >= num_features || adjusted_index < 0) {
197  THROW_ERROR("Encountered feature index {:5} with value {}. Number of features "
198  "was specified as {}.", index, value, num_features);
199  }
200  // filter out explicit zeros
201  if (value != 0) {
202  if(std::isnan(value)) {
203  THROW_ERROR("Encountered feature index {:5} with value {}.", index, value);
204  }
205  feature_buffer.insert(example, adjusted_index) = static_cast<real_t>(value);
206  }
207  });
208  } catch (std::runtime_error& e) {
209  THROW_ERROR("Error reading example {}: {}.", example + 1, e.what());
210  }
211  ++example;
212  }
213  }
214 }
215 
216 dismec::MultiLabelData dismec::io::read_xmc_dataset(const std::filesystem::path& source_path, IndexMode mode) {
217  std::fstream source(source_path, std::fstream::in);
218  if (!source.is_open()) {
219  throw std::runtime_error(fmt::format("Cannot open input file {}", source_path.c_str()));
220  }
221 
222  return read_xmc_dataset(source, source_path.c_str(), mode);
223 }
224 
225 dismec::MultiLabelData dismec::io::read_xmc_dataset(std::istream& source, std::string_view name, IndexMode mode) {
226  // for now, do what the old code does: iterate twice, once to count and once to read
227  std::string line_buffer;
228  spdlog::stopwatch timer;
229 
230  std::getline(source, line_buffer);
231  XMCHeader header = parse_xmc_header(line_buffer);
232 
233  spdlog::info("Loading dataset '{}' with {} examples, {} features and {} labels.",
234  name, header.NumExamples, header.NumFeatures, header.NumLabels);
235 
236  std::vector<long> features_per_example = count_features_per_example(source, header.NumExamples);
237  if (ssize(features_per_example) != header.NumExamples) {
238  THROW_EXCEPTION(std::runtime_error, "Dataset '{}' declared {} examples, but {} where found!",
239  name, header.NumExamples, features_per_example.size());
240  }
241 
242  // reset to beginning
243  source.clear();
244  source.seekg(0);
245 
246  // reserve space for all the features
247  SparseFeatures x(header.NumExamples, header.NumFeatures);
248  x.reserve(features_per_example);
249 
250  std::vector<std::vector<long>> label_data;
251  label_data.resize(header.NumLabels);
252  // TODO reserve space for correct number of labels
253 
254  // skip header this time
255  std::getline(source, line_buffer);
256 
257  if(mode == IndexMode::ZERO_BASED) {
258  read_into_buffers<0>(source, x, label_data);
259  } else {
260  read_into_buffers<1>(source, x, label_data);
261  }
262 
263  x.makeCompressed();
264 
265  // remove excess memory
266  for (auto& instance_list : label_data) {
267  instance_list.shrink_to_fit();
268  }
269 
270  spdlog::info("Finished loading dataset '{}' in {:.3}s.", name, timer);
271 
272  return {x.markAsRValue(), std::move(label_data)};
273 }
274 
275 namespace {
276  std::ostream& write_label_list(std::ostream& stream, const std::vector<int>& labels)
277  {
278  if(labels.empty()) {
279  return stream;
280  }
281 
282  // size is > 0, so this code is safe
283  auto all_but_one = ssize(labels) - 1;
284  for(int i = 0; i < all_but_one; ++i) {
285  stream << labels[i] << ',';
286  }
287  // no trailing space
288  stream << labels.back();
289 
290  return stream;
291  }
292 }
293 
294 void dismec::io::save_xmc_dataset(std::ostream& target, const MultiLabelData& data) {
296  target << data.num_examples() << " " << data.num_features() << " " << data.num_labels() << "\n";
297  // for efficient saving, we need the labels in sparse row format, but for training we have them
298  // in sparse column format.
300  std::vector<std::vector<int>> all_labels(data.num_examples());
301  for(label_id_t label{0}; label.to_index() < data.num_labels(); ++label) {
302  for(const auto& instance : data.get_label_instances(label)) {
303  all_labels[instance].push_back(label.to_index());
304  }
305  }
306 
307  if(!data.get_features()->is_sparse()) {
308  throw std::runtime_error(fmt::format("XMC format requires sparse labels"));
309  }
310  const auto& feature_ptr = data.get_features()->sparse();
311 
312  for(int example = 0; example < data.num_examples(); ++example) {
313  // first, write the label list
314  write_label_list(target, all_labels[example]);
315  // then, write the sparse features
316  for (SparseFeatures::InnerIterator it(feature_ptr, example); it; ++it) {
317  target << ' ' << it.col() << ':' << it.value();
318  }
319  target << '\n';
320  }
321 }
322 
323 void dismec::io::save_xmc_dataset(const std::filesystem::path& target_path, const MultiLabelData& data, int precision) {
324  std::fstream target(target_path, std::fstream::out);
325  if (!target.is_open()) {
326  throw std::runtime_error(fmt::format("Cannot open output file {}", target_path.c_str()));
327  }
328 
329  target.setf(std::fstream::fmtflags::_S_fixed, std::fstream::floatfield);
330  target.precision(precision);
331  save_xmc_dataset(target, data);
332 }
333 
334 #include "doctest.h"
335 
337 TEST_CASE("parse valid header") {
338  std::string input;
339  SUBCASE("minimal") {
340  input = "12 54 43";
341  }
342  SUBCASE("trailing space") {
343  input = "12 54 43 ";
344  }
345  SUBCASE("tab separated") {
346  input = "12\t54 \t 43 ";
347  }
348  auto valid = parse_xmc_header(input);
349  CHECK(valid.NumExamples == 12);
350  CHECK(valid.NumFeatures == 54);
351  CHECK(valid.NumLabels == 43);
352 }
353 
354 
357 TEST_CASE("parse invalid header") {
358  // check number of arguments
359  CHECK_THROWS(parse_xmc_header("6 1"));
360  CHECK_THROWS(parse_xmc_header("6 1 5 1"));
361 
362  // we also know that something is wrong if any of the counts are <= 0
363  CHECK_THROWS(parse_xmc_header("0 5 5"));
364  CHECK_THROWS(parse_xmc_header("5 0 5"));
365  CHECK_THROWS(parse_xmc_header("5 5 0"));
366  CHECK_THROWS(parse_xmc_header("-1 5 5"));
367  CHECK_THROWS(parse_xmc_header("5 -1 5"));
368  CHECK_THROWS(parse_xmc_header("5 5 -1"));
369 }
370 
376 TEST_CASE("count features") {
377  auto do_test = [](const std::string& source) {
378  std::stringstream sstr(source);
379  auto count = count_features_per_example(sstr, 10);
380  REQUIRE(count.size() == 3);
381  CHECK(count[0] == 2);
382  CHECK(count[1] == 1);
383  CHECK(count[2] == 4);
384  };
385 
386  SUBCASE("minimal") {
387  std::string source = R"(12 5:5.3 6:34
388  4 6:4
389  1 3:4 5:1 10:43 5:3)";
390  do_test(source);
391  }
392 
393  SUBCASE("comment") {
394  std::string source = R"(12 5:5.3 6:34
395  4 6:4
396 # 65:4
397  1 3:4 5:1 10:43 5:3)";
398  do_test(source);
399  }
400 
401  SUBCASE("empty line") {
402  std::string source = R"(12 5:5.3 6:34
403  4 6:4
404 
405  1 3:4 5:1 10:43 5:3)";
406  do_test(source);
407  }
408 
409 }
410 
421 TEST_CASE("parse labels errors") {
422  // trailing comma
423  CHECK_THROWS(parse_labels("5,1, 5:2.0", [&](long v) {}));
424  // not a number
425  CHECK_THROWS(parse_labels("5, x", [&](long v) {}));
426  // floating point
427  CHECK_THROWS(parse_labels("5.5,1 10:3.0", [&](long v) {}));
428  // wrong separator
429  CHECK_THROWS(parse_labels("5;1 10:3.0", [&](long v) {}));
430  // wrong spacing
431  // an error like this will only be problematic for the subsequent feature parsing
432  // the label parsing will already stop after the 5
433  // CHECK_THROWS(parse_labels("5 ,1 10:3.0", [&](long v) {}));
434 }
435 
436 
443 TEST_CASE("parse labels") {
444  auto run_test = [&](std::string source, const std::vector<long>& expect){
445  int pos = 0;
446  CAPTURE(source);
447  try {
448  parse_labels(source.data(), [&](long v) {
449  CHECK(expect.at(pos) == v);
450  ++pos;
451  });
452  } catch (std::runtime_error& err) {
453  FAIL("parsing failed");
454  }
455  CHECK(expect.size() == pos);
456  };
457 
458  SUBCASE("simple valid line") {
459  run_test("1,3,4 12:4", {1, 3, 4});
460  }
461  SUBCASE("with space") {
462  run_test("1, 3,\t4 12:4", {1, 3, 4});
463  }
464  SUBCASE("leading +") {
465  run_test("+1, 3,\t4 12:4", {1, 3, 4});
466  }
467  SUBCASE("separated by space") {
468  run_test("1,3,4\t12:4", {1, 3, 4});
469  }
470  SUBCASE("empty labels space") {
471  run_test(" 12:4", {});
472  }
473  SUBCASE("empty labels tab") {
474  run_test("\t12:4", {});
475  }
476  SUBCASE("missing features") {
477  run_test("5, 1", {5, 1});
478  }
479 }
480 
481 
488 TEST_CASE("read into buffers bounds checks") {
489  auto x = std::make_shared<SparseFeatures>(2, 3);
490 
491  std::vector<std::vector<long>> labels;
492  labels.resize(2);
493  std::stringstream source;
494 
495  SUBCASE("invalid feature") {
496  source.str("1 2:0.5 3:0.5");
497  SUBCASE("zero-base") {
498  CHECK_THROWS(read_into_buffers<0>(source, *x, labels));
499  }
500  SUBCASE("one-base") {
501  CHECK_NOTHROW(read_into_buffers<1>(source, *x, labels));
502  }
503  }
504 
505  SUBCASE("negative feature") {
506  source.str("1 -1:0.5 1:0.5");
507  SUBCASE("zero-base") {
508  CHECK_THROWS(read_into_buffers<0>(source, *x, labels));
509  }
510  SUBCASE("one-base") {
511  CHECK_THROWS(read_into_buffers<1>(source, *x, labels));
512  }
513  }
514 
515  SUBCASE("invalid label") {
516  source.str("2 2:0.5");
517  SUBCASE("zero-base") {
518  CHECK_THROWS(read_into_buffers<0>(source, *x, labels));
519  }
520  SUBCASE("one-base") {
521  CHECK_NOTHROW(read_into_buffers<1>(source, *x, labels));
522  }
523  }
524 
525  SUBCASE("negative label") {
526  source.str("-1 2:0.5");
527  SUBCASE("zero-base") {
528  CHECK_THROWS(read_into_buffers<0>(source, *x, labels));
529  }
530  SUBCASE("one-base") {
531  CHECK_THROWS(read_into_buffers<1>(source, *x, labels));
532  }
533  }
534 
535  SUBCASE("invalid example") {
536  source.str("0 0:0.5\n0 0:0.5\n0 0:0.5");
537  SUBCASE("zero-base") {
538  CHECK_THROWS(read_into_buffers<0>(source, *x, labels));
539  }
540  SUBCASE("one-base") {
541  CHECK_THROWS(read_into_buffers<1>(source, *x, labels));
542  }
543  }
544 
545  SUBCASE("invalid zero label in one-based indexing") {
546  source.str("0 2:0.5 2:0.5");
547  SUBCASE("zero-base") {
548  CHECK_NOTHROW(read_into_buffers<0>(source, *x, labels));
549  }
550  SUBCASE("one-base") {
551  CHECK_THROWS(read_into_buffers<1>(source, *x, labels));
552  }
553  }
554 
555  SUBCASE("invalid zero feature in one-based indexing") {
556  source.str("1 0:0.5 2:0.5");
557  SUBCASE("zero-base") {
558  CHECK_NOTHROW(read_into_buffers<0>(source, *x, labels));
559  }
560  SUBCASE("one-base") {
561  CHECK_THROWS(read_into_buffers<1>(source, *x, labels));
562  }
563  }
564 }
long num_examples() const noexcept
Get the total number of instances, i.e. the number of rows in the feature matrix.
Definition: data.cpp:52
std::shared_ptr< const GenericFeatureMatrix > get_features() const
get a shared pointer to the (immutable) feature data
Definition: data.cpp:39
long num_features() const noexcept
Get the total number of features, i.e. the number of columns in the feature matrix.
Definition: data.cpp:48
long num_labels() const noexcept override
Definition: data.cpp:59
const std::vector< long > & get_label_instances(label_id_t label) const
Definition: data.cpp:72
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
building blocks for io procedures that are used by multiple io subsystems
#define THROW_ERROR(...)
Definition: common.h:23
XMCHeader parse_xmc_header(const std::string &content)
Parses the header (number of examples, features, labels) of an XMC dataset file.
Definition: xmc.cpp:31
std::ostream & write_label_list(std::ostream &stream, const std::vector< int > &labels)
Definition: xmc.cpp:276
std::vector< long > count_features_per_example(std::istream &source, std::size_t num_examples=100 '000)
Extracts number of nonzero features for each instance.
Definition: xmc.cpp:75
const char * parse_labels(const char *line, F &&callback)
parses the labels part of a xmc dataset line.
Definition: xmc.cpp:114
void read_into_buffers(std::istream &source, SparseFeatures &feature_buffer, std::vector< std::vector< long >> &label_buffer)
iterates over the lines in source and puts the corresponding features and labels into the given buffe...
Definition: xmc.cpp:164
constexpr double precision(const ConfusionMatrixBase< T > &matrix)
MultiLabelData read_xmc_dataset(const std::filesystem::path &source, IndexMode mode=IndexMode::ZERO_BASED)
Reads a dataset given in the extreme multilabel classification format.
Definition: xmc.cpp:216
MatrixHeader parse_header(const std::string &content)
Definition: common.cpp:49
void save_xmc_dataset(std::ostream &target, const MultiLabelData &data)
Saves the given dataset in XMC format.
Definition: xmc.cpp:294
IndexMode
Enum to decide whether indices in an xmc file are starting from 0 or from 1.
Definition: xmc.h:67
long parse_long(const char *string, const char **out)
Definition: common.h:34
void parse_sparse_vector_from_text(const char *feature_part, F &&callback)
parses sparse features given in index:value text format.
Definition: common.h:52
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
Definition: conversion.h:42
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50
float real_t
The default type for floating point values.
Definition: config.h:17
Collects the data from the header of an xmc file XMC data format.
Definition: xmc.cpp:18
#define THROW_EXCEPTION(exception_type,...)
Definition: throw_error.h:16
TEST_CASE("parse valid header")
Definition: xmc.cpp:337