DiSMEC++
app.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "app.h"
7 #include "io/xmc.h"
8 #include "io/slice.h"
9 #include "data/data.h"
10 #include <spdlog/spdlog.h>
11 
12 using namespace dismec;
13 
14 void DataProcessing::setup_data_args(CLI::App& app) {
15  app.add_option("data-file", DataSetFile,
16  "The file from which the data will be loaded.")->required()->check(CLI::ExistingFile);
17 
18  app.add_flag("--xmc-one-based-index", OneBasedIndex,
19  "If this flag is given, then we assume that the input dataset in xmc format"
20  " has one-based indexing, i.e. the first label and feature are at index 1 (as opposed to the usual 0)");
21  AugmentForBias = app.add_flag("--augment-for-bias", Bias,
22  "If this flag is given, then all training examples will be augmented with an additional"
23  "feature of value 1 or the specified value.")->default_val(1.0);
24  app.add_flag("--normalize-instances", NormalizeInstances,
25  "If this flag is given, then the feature vectors of all instances are normalized to one.");
26  app.add_option("--transform", TransformData, "Apply a transformation to the features of the dataset.")->default_str("identity")
27  ->transform(CLI::Transformer(std::map<std::string, DatasetTransform>{
28  {"identity", DatasetTransform::IDENTITY},
29  {"log-one-plus", DatasetTransform::LOG_ONE_PLUS},
30  {"one-plus-log", DatasetTransform::ONE_PLUS_LOG},
31  {"sqrt", DatasetTransform::SQRT}
32  },CLI::ignore_case));
33 
34  app.add_option("--label-file", LabelFile, "For SLICE-type datasets, this specifies where the labels can be found")->check(CLI::ExistingFile);
35 
36 
37  auto* hash_option = app.add_flag("--hash-features", "If this Flag is given, then feature hashing is performed.");
38  auto* bucket_option = app.add_option("--hash-buckets", HashBuckets, "Number of buckets for each hash function when feature hashing is enabled.")
39  ->needs(hash_option)->check(CLI::PositiveNumber);
40  app.add_option("--hash-repeat", HashRepeats, "Number of hash functions to use for feature hashing.")
41  ->needs(hash_option)->default_val(32)->check(CLI::PositiveNumber);
42  app.add_option("--hash-seed", HashSeed, "Seed to use when feature hashing.")
43  ->needs(hash_option)->default_val(42);
44  hash_option->needs(bucket_option);
45 }
46 
47 std::shared_ptr<MultiLabelData> DataProcessing::load(int verbose) {
48  if(verbose >= 0) {
49  spdlog::info("Loading training data from file '{}'", DataSetFile);
50  }
51  auto data = std::make_shared<MultiLabelData>([&]() {
52  if(LabelFile.empty()) {
53  return read_xmc_dataset(DataSetFile, OneBasedIndex ? io::IndexMode::ONE_BASED : io::IndexMode::ZERO_BASED);
54  } else {
55  return io::read_slice_dataset(DataSetFile, LabelFile);
56  }
57  } ());
58 
59  if(HashBuckets > 0) {
60  if(!data->get_features()->is_sparse()) {
61  spdlog::error("Feature hashing is currently only implemented for sparse features.");
62  }
63  if(verbose >= 0) {
64  spdlog::info("Hashing features");
65  }
66  hash_sparse_features(data->edit_features()->sparse(), HashSeed, HashBuckets, HashRepeats);
67  }
68 
70  if(verbose >= 0)
71  spdlog::info("Applying data transformation");
73  }
74 
75  if(NormalizeInstances) {
76  if(verbose >= 0)
77  spdlog::info("Normalizing instances.");
78  normalize_instances(*data);
79  }
80 
81  if(!AugmentForBias->empty()) {
82  if(verbose >= 0)
83  spdlog::info("Appending bias features with value {}", Bias);
85  }
86 
87  if(verbose >= 0) {
88  if(data->get_features()->is_sparse()) {
89  double total = data->num_features() * data->num_examples();
90  auto nnz = data->get_features()->sparse().nonZeros();
91  spdlog::info("Processed feature matrix has {} rows and {} columns. Contains {} non-zeros ({:.3} %)", data->num_examples(),
92  data->num_features(), nnz, 100.0 * (nnz / total));
93  } else {
94  spdlog::info("Processed feature matrix has {} rows and {} columns", data->num_examples(),
95  data->num_features());
96  }
97  }
98 
99  return data;
100 }
101 
103  return AugmentForBias->count() > 0;
104 }
105 
std::shared_ptr< MultiLabelData > load(int verbose)
Definition: app.cpp:47
std::string DataSetFile
The file from which the dataset should be read.
Definition: app.h:23
bool augment_for_bias() const
Definition: app.cpp:102
void setup_data_args(CLI::App &app)
Definition: app.cpp:14
DatasetTransform TransformData
Definition: app.h:27
CLI::Option * AugmentForBias
Definition: app.h:28
bool NormalizeInstances
Definition: app.h:26
bool OneBasedIndex
Definition: app.h:25
std::string LabelFile
Definition: app.h:24
unsigned HashSeed
Definition: app.h:34
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
void normalize_instances(DatasetBase &data)
Definition: transform.cpp:88
void augment_features_with_bias(DatasetBase &data, real_t bias=1)
Definition: transform.cpp:25
void transform_features(DatasetBase &data, DatasetTransform transform)
Definition: transform.cpp:152
void hash_sparse_features(SparseFeatures &features, unsigned seed, int buckets, int repeats)
Definition: transform.cpp:183