DiSMEC++
slice.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include <iostream>
7 #include <fstream>
8 #include "slice.h"
9 #include "data/data.h"
10 #include "io/numpy.h"
11 #include "io/common.h"
12 #include <spdlog/spdlog.h>
13 #include <spdlog/stopwatch.h>
14 
15 using namespace dismec;
16 namespace io = dismec::io;
17 
18 namespace {
19  DenseFeatures load_features(std::istream& features) {
20  if(io::is_npy(features)) {
21  return io::load_matrix_from_npy(features);
22  }
23 
24  std::string line_buffer;
25  std::getline(features, line_buffer);
26  io::MatrixHeader header = io::parse_header(line_buffer);
27  DenseFeatures target(header.NumRows, header.NumCols);
28 
29  for(int row = 0; row < header.NumRows; ++row) {
30  io::read_vector_from_text(features, target.row(row));
31  }
32  return target;
33  }
34 }
35 
36 dismec::MultiLabelData io::read_slice_dataset(std::istream& features, std::istream& labels) {
37  spdlog::stopwatch timer;
38  DenseFeatures feature_matrix = load_features(features);
39 
40  auto label_data = read_binary_matrix_as_lol(labels);
41 
42  if(label_data.NumRows != feature_matrix.rows()) {
43  THROW_ERROR("Mismatch between number of examples in feature file ({}) and in label file ({})",
44  feature_matrix.rows(), label_data.NumRows);
45  }
46 
47  spdlog::info("Finished loading dataset with {} examples in {:.3}s.", label_data.NumCols, timer);
48 
49  return MultiLabelData(std::move(feature_matrix), std::move(label_data.NonZeros));
50 }
51 
52 dismec::MultiLabelData io::read_slice_dataset(const std::filesystem::path& features, const std::filesystem::path& labels) {
53  std::fstream features_file(features, std::fstream::in);
54  if (!features_file.is_open()) {
55  throw std::runtime_error(fmt::format("Cannot open input file {}", features.c_str()));
56  }
57  std::fstream labels_file(labels, std::fstream::in);
58  if (!labels_file.is_open()) {
59  throw std::runtime_error(fmt::format("Cannot open input file {}", labels.c_str()));
60  }
61 
62  return read_slice_dataset(features_file, labels_file);
63 }
64 
65 
66 #include "doctest.h"
67 
68 using namespace dismec;
69 
70 TEST_CASE("small dataset") {
71  std::stringstream features;
72  std::stringstream labels;
73 
74  features.str("3 5\n"
75  "1.0 2.5 -1.0 3.5 4.4\n"
76  "-1.0 0.0 0.5 2.5 1.5\n"
77  "0.0 5.4\t 3.4 2.5 1.6\n");
78 
79  labels.str("3 3\n"
80  "1:1\n"
81  "0:1\n"
82  "0:1 2:1"
83  );
84 
85  auto ds = io::read_slice_dataset(features, labels);
86 
87  auto df = ds.get_features()->dense();
88  REQUIRE(df.rows() == 3);
89  REQUIRE(df.cols() == 5);
90  float true_features[] = {1.0, 2.5, -1.0, 3.5, 4.4, -1.0, 0.0, 0.5, 2.5, 1.5, 0.0, 5.4, 3.4, 2.5, 1.6};
91  for(int i = 0; i < df.size(); ++i) {
92  CHECK(df.coeff(i) == true_features[i]);
93  }
94 
95  // check the labels
96  const auto& l0 = ds.get_label_instances(label_id_t{0});
97  REQUIRE(l0.size() == 2);
98  CHECK(l0[0] == 1);
99  CHECK(l0[1] == 2);
100 
101  const auto& l1 = ds.get_label_instances(label_id_t{1});
102  REQUIRE(l1.size() == 1);
103  CHECK(l1[0] == 0);
104 
105  const auto& l2 = ds.get_label_instances(label_id_t{2});
106  REQUIRE(l2.size() == 1);
107  CHECK(l2[0] == 2);
108 }
Strong typedef for an int to signify a label id.
Definition: types.h:20
building blocks for io procedures that are used by multiple io subsystems
#define THROW_ERROR(...)
Definition: common.h:23
DenseFeatures load_features(std::istream &features)
Definition: slice.cpp:19
MultiLabelData read_slice_dataset(std::istream &features, std::istream &labels)
reads a dataset given in slice format.
Definition: slice.cpp:36
std::istream & read_vector_from_text(std::istream &stream, Eigen::Ref< DenseRealVector > data)
Reads the given vector as space-separated human-readable numbers.
Definition: common.cpp:37
MatrixHeader parse_header(const std::string &content)
Definition: common.cpp:49
LoLBinarySparse read_binary_matrix_as_lol(std::istream &source)
Definition: common.cpp:76
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
Definition: numpy.cpp:22
types::DenseRowMajor< real_t > load_matrix_from_npy(std::istream &source)
Loads a matrix from a numpy array.
Definition: numpy.cpp:342
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
TEST_CASE("small dataset")
Definition: slice.cpp:70
Collects the rows and columns parsed from a plain-text matrix file.
Definition: common.h:130