DiSMEC++
test.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "io/xmc.h"
7 #include "io/common.h"
8 #include "data/data.h"
9 #include <vector>
10 #include "doctest.h"
11 using namespace dismec;
12 
13 
14 constexpr const char* TEST_FILE = \
15 R"(4 10 5
16 2,3 4:1.0 5:-0.5 8:0.25
17 0 2:1.0
18  6:-2.0 5:1.5
19 1, 2 3:-3.0
20 )";
21 
22 
26 TEST_CASE("xmc round trip") {
27  std::stringstream original_source;
28  original_source.str(TEST_FILE);
29 
30  SparseFeatures features(4, 10);
31  features.coeffRef(0, 4) = 1.0;
32  features.coeffRef(0, 5) = -0.5;
33  features.coeffRef(0, 8) = 0.25;
34  features.coeffRef(1, 2) = 1.0;
35  features.coeffRef(2, 6) = -2.0;
36  features.coeffRef(2, 5) = 1.5;
37  features.coeffRef(3, 3) = -3.0;
38 
39  std::vector<std::vector<long>> label_ex(5);
40  label_ex[0].push_back(1);
41  label_ex[1].push_back(3);
42  label_ex[2].push_back(0);
43  label_ex[3].push_back(0);
44 
45  MultiLabelData data(features, label_ex);
46 
47  std::stringstream canonical_save;
48  io::save_xmc_dataset(canonical_save, data);
49 
50  auto re_read = io::read_xmc_dataset(canonical_save, "test");
51  REQUIRE(re_read.get_features()->rows() == features.rows());
52  REQUIRE(re_read.get_features()->cols() == features.cols());
53  CHECK(types::DenseColMajor<real_t>(re_read.get_features()->sparse()) == types::DenseColMajor<real_t>(features));
54 
55  std::stringstream round_trip;
56  io::save_xmc_dataset(round_trip, re_read);
57 
58  CHECK(round_trip.str() == canonical_save.str());
59 }
building blocks for io procedures that are used by multiple io subsystems
constexpr const char * TEST_FILE
Definition: test.cpp:14
TEST_CASE("xmc round trip")
Definition: test.cpp:26
MultiLabelData read_xmc_dataset(const std::filesystem::path &source, IndexMode mode=IndexMode::ZERO_BASED)
Reads a dataset given in the extreme multilabel classification format.
Definition: xmc.cpp:216
void save_xmc_dataset(std::ostream &target, const MultiLabelData &data)
Saves the given dataset in XMC format.
Definition: xmc.cpp:294
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50