DiSMEC++
data.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include <fstream>
7 #include "data.h"
8 #include "utils/conversion.h"
9 #include "spdlog/spdlog.h"
10 
11 using namespace dismec;
12 
14  return (get_labels(id)->array() == 1.0).count();
15 }
16 
18  return num_examples() - num_positives(id);
19 }
20 
21 std::shared_ptr<const BinaryLabelVector> DatasetBase::get_labels(label_id_t id) const {
22  // convert sparse to dense
23  auto label_vector = std::make_shared<BinaryLabelVector>(num_examples());
24  get_labels(id, *label_vector);
25  return label_vector;
26 }
27 
28 void BinaryData::get_labels(label_id_t i, Eigen::Ref<BinaryLabelVector> target) const {
29  if(i != label_id_t{0}) {
30  throw std::out_of_range("Binary problems only have a single class with id `0`");
31  }
32  target = *m_Labels;
33 }
34 
35 long BinaryData::num_labels() const noexcept {
36  return 1;
37 }
38 
39 std::shared_ptr<const GenericFeatureMatrix> DatasetBase::get_features() const {
40  return std::const_pointer_cast<const GenericFeatureMatrix>(m_Features);
41 }
42 
43 std::shared_ptr<GenericFeatureMatrix> DatasetBase::edit_features() {
44  return m_Features;
45 }
46 
47 
48 long DatasetBase::num_features() const noexcept {
49  return m_Features->cols();
50 }
51 
52 long DatasetBase::num_examples() const noexcept {
53  return m_Features->rows();
54 }
55 
56 DatasetBase::DatasetBase(SparseFeatures x) : m_Features(std::make_shared<GenericFeatureMatrix>(x.markAsRValue())) {}
57 DatasetBase::DatasetBase(DenseFeatures x) : m_Features(std::make_shared<GenericFeatureMatrix>(std::move(x))) {}
58 
59 long MultiLabelData::num_labels() const noexcept {
60  return ssize(m_Labels);
61 }
62 
63 void MultiLabelData::get_labels(label_id_t label, Eigen::Ref<BinaryLabelVector> target) const {
64  // convert sparse to dense
65  const auto& examples = m_Labels.at(label.to_index());
66  target.setConstant(-1);
67  for(const auto& ex : examples) {
68  target.coeffRef(ex) = 1;
69  }
70 }
71 
72 const std::vector<long>& MultiLabelData::get_label_instances(label_id_t label) const {
73  return m_Labels.at(label.to_index());
74 }
75 
77  return ssize(m_Labels.at(id.to_index()));
78 }
79 
81  return num_examples() - ssize(m_Labels.at(id.to_index()));
82 }
83 
85  if(end.to_index() < 0 || end.to_index() > num_labels()) {
86  end = label_id_t{static_cast<int_fast32_t>(m_Labels.size())};
87  }
88 
89  std::vector<std::vector<long>> sub_labels;
90  sub_labels.reserve(end-start);
91  std::move(begin(m_Labels) + start.to_index(), std::begin(m_Labels) + end.to_index(), std::back_inserter(sub_labels));
92  m_Labels = std::move(sub_labels);
93 }
std::shared_ptr< BinaryLabelVector > m_Labels
Definition: data.h:82
void get_labels(label_id_t i, Eigen::Ref< BinaryLabelVector > target) const override
Definition: data.cpp:28
long num_labels() const noexcept override
Definition: data.cpp:35
virtual long num_negatives(label_id_t id) const
Definition: data.cpp:17
long num_examples() const noexcept
Get the total number of instances, i.e. the number of rows in the feature matrix.
Definition: data.cpp:52
DatasetBase(const DatasetBase &)=default
std::shared_ptr< const BinaryLabelVector > get_labels(label_id_t id) const
Definition: data.cpp:21
virtual long num_positives(label_id_t id) const
Definition: data.cpp:13
std::shared_ptr< const GenericFeatureMatrix > get_features() const
get a shared pointer to the (immutable) feature data
Definition: data.cpp:39
std::shared_ptr< GenericFeatureMatrix > m_Features
Definition: data.h:60
long num_features() const noexcept
Get the total number of features, i.e. the number of columns in the feature matrix.
Definition: data.cpp:48
std::shared_ptr< GenericFeatureMatrix > edit_features()
get a shared pointer to mutable feature data. Use with care.
Definition: data.cpp:43
std::vector< std::vector< long > > m_Labels
Definition: data.h:111
long num_positives(label_id_t id) const override
Definition: data.cpp:76
long num_labels() const noexcept override
Definition: data.cpp:59
void get_labels(label_id_t label, Eigen::Ref< BinaryLabelVector > target) const override
Definition: data.cpp:63
const std::vector< long > & get_label_instances(label_id_t label) const
Definition: data.cpp:72
void select_labels(label_id_t start, label_id_t end)
Definition: data.cpp:84
long num_negatives(label_id_t id) const override
Definition: data.cpp:80
Strong typedef for an int to signify a label id.
Definition: types.h:20
constexpr T to_index() const
! Explicitly convert to an integer.
Definition: opaque_int.h:32
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
types::DenseRowMajor< real_t > DenseFeatures
Dense Feature Matrix in Row Major format.
Definition: matrix_types.h:58
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
Definition: conversion.h:42
types::SparseRowMajor< real_t > SparseFeatures
Sparse Feature Matrix in Row Major format.
Definition: matrix_types.h:50