DiSMEC++
model-io.h
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #ifndef DISMEC_MODEL_IO_H
7 #define DISMEC_MODEL_IO_H
8 
9 #include <Eigen/Dense>
10 #include <vector>
11 #include <memory>
12 #include <filesystem>
13 #include <optional>
14 #include <future>
15 #include "fwd.h"
16 #include "data/types.h"
17 #include <boost/iterator/iterator_adaptor.hpp>
18 
87 namespace dismec::io
88 {
89 
91  namespace model
92  {
94  using std::filesystem::path;
95 
99  enum class WeightFormat {
100  DENSE_TXT = 0,
101  SPARSE_TXT = 1,
102  DENSE_NPY = 2,
103  NULL_FORMAT = 3
105  };
106 
108  WeightFormat parse_weights_format(std::string_view name);
109  const char* to_string(WeightFormat format);
110 
111  struct SaveOption {
112  int Precision = 6;
113  double Culling = 0;
114  int SplitFiles = 4096;
116  };
117 
118  // \todo should we by default overwrite files, or refuse?
128  void save_model(const path& target_file, const std::shared_ptr<const Model>& model, SaveOption options);
129 
130 
131  std::shared_ptr<Model> load_model(path source);
132 
141  long Count;
142  std::string FileName;
144  };
145 
146 
152  public:
158  [[nodiscard]] long num_labels() const noexcept { return m_TotalLabels; }
159 
165  [[nodiscard]] long num_features() const noexcept { return m_NumFeatures; }
166 
167  protected:
168  PartialModelIO() = default;
169  ~PartialModelIO() = default;
170 
171  void read_metadata_file(const path& meta_file);
172 
173  long m_TotalLabels = -1;
174  long m_NumFeatures = -1;
175 
176  std::vector<WeightFileEntry> m_SubFiles;
177  using weight_file_iter_t = std::vector<WeightFileEntry>::const_iterator;
178 
184  void insert_sub_file(const WeightFileEntry& data);
185 
194  [[nodiscard]] weight_file_iter_t label_lower_bound(label_id_t pos) const;
195  };
196 
237  public:
252  PartialModelSaver(path target_file, SaveOption options, bool load_partial=false);
253 
270  std::future<WeightFileEntry> add_model(const std::shared_ptr<const Model>& model,
271  const std::optional<std::string>& file_path={});
272 
274 
281  void update_meta_file();
282 
291  void finalize();
292 
298  [[nodiscard]] std::pair<label_id_t, label_id_t> get_missing_weights() const;
299 
305  [[nodiscard]] bool any_weight_vector_for_interval(label_id_t begin, label_id_t end) const;
306 
307  private:
310  };
311 
318  public:
319  enum ESparseMode {
323  };
324 
333  explicit PartialModelLoader(path meta_file, ESparseMode mode=DEFAULT);
334 
338  const path& meta_file_path() const { return m_MetaFileName; }
339 
349  [[nodiscard]] std::shared_ptr<Model> load_model(label_id_t label_begin, label_id_t label_end) const;
350 
355  [[nodiscard]] std::shared_ptr<Model> load_model(int index) const;
356 
360  bool validate() const;
361 
363  [[nodiscard]] long num_weight_files() const;
364 
370  };
371 
372  [[nodiscard]] SubModelRangeSpec get_loading_range(label_id_t label_begin, label_id_t label_end) const;
373  private:
376  };
377  }
378 
379  using model::WeightFormat;
380  using model::SaveOption;
381  using model::save_model;
382  using model::load_model;
385 }
386 
387 #endif //DISMEC_MODEL_IO_H
This class is used as an implementation detail to capture the common code of PartialModelSaver and Pa...
Definition: model-io.h:151
weight_file_iter_t label_lower_bound(label_id_t pos) const
Gets an iterator into the weight-file list that points to the first element whose starting label is l...
Definition: model-io.cpp:116
void read_metadata_file(const path &meta_file)
Definition: model-io.cpp:93
void insert_sub_file(const WeightFileEntry &data)
Inserts a new sub-file entry into the metadata object.
Definition: model-io.cpp:123
std::vector< WeightFileEntry >::const_iterator weight_file_iter_t
Definition: model-io.h:177
std::vector< WeightFileEntry > m_SubFiles
Definition: model-io.h:176
long num_labels() const noexcept
Gets the total number of labels.
Definition: model-io.h:158
long num_features() const noexcept
Gets the total number of features.
Definition: model-io.h:165
This class allows loading only a subset of the weights of a large model.
Definition: model-io.h:317
std::shared_ptr< Model > load_model(label_id_t label_begin, label_id_t label_end) const
Loads part of the model.
Definition: model-io.cpp:373
long num_weight_files() const
Returns the number of availabel weight files.
Definition: model-io.cpp:395
bool validate() const
Validates that all weight files exist.
Definition: model-io.cpp:441
const path & meta_file_path() const
The path to the metadata file.
Definition: model-io.h:338
SubModelRangeSpec get_loading_range(label_id_t label_begin, label_id_t label_end) const
Definition: model-io.cpp:346
PartialModelLoader(path meta_file, ESparseMode mode=DEFAULT)
Create a new PartialModelLoader for the given metadata file.
Definition: model-io.cpp:341
Manage saving a model consisting of multiple partial models.
Definition: model-io.h:236
void finalize()
Checks that all weights have been written and updates the metadata file.
Definition: model-io.cpp:277
PartialModelSaver(path target_file, SaveOption options, bool load_partial=false)
Create a new PartialModelSaver.
Definition: model-io.cpp:160
bool any_weight_vector_for_interval(label_id_t begin, label_id_t end) const
Checks if there are any weight vectors for the given interval.
Definition: model-io.cpp:255
std::pair< label_id_t, label_id_t > get_missing_weights() const
Get an interval labels for which weights are missing.
Definition: model-io.cpp:292
std::future< WeightFileEntry > add_model(const std::shared_ptr< const Model > &model, const std::optional< std::string > &file_path={})
Adds the weights of a partial model asynchronously.
Definition: model-io.cpp:172
void update_meta_file()
Updates the metadata file.
Definition: model-io.cpp:234
Strong typedef for an int to signify a label id.
Definition: types.h:20
A model combines a set of weight with some meta-information about these weights.
Definition: model.h:63
Forward-declares types.
std::shared_ptr< Model > load_model(path source)
Definition: model-io.cpp:334
WeightFormat parse_weights_format(std::string_view name)
Gets the eighs.
Definition: model-io.cpp:84
void save_model(const path &target_file, const std::shared_ptr< const Model > &model, SaveOption options)
Saves a complete model to a file.
Definition: model-io.cpp:308
WeightFormat
Describes the format in which the weight data has been saved.
Definition: model-io.h:99
@ DENSE_TXT
Dense Text Format
@ SPARSE_TXT
Sparse Text Format
@ DENSE_NPY
Dense Numpy Format
const char * to_string(WeightFormat format)
Definition: model-io.cpp:89
WeightFormat Format
Format in which the weights will be saved.
Definition: model-io.h:115
double Culling
If saving in sparse mode, threshold below which weights will be omitted.
Definition: model-io.h:113
int Precision
Precision with which the labels will be saved.
Definition: model-io.h:112
int SplitFiles
Maximum number of weight vectors per file.
Definition: model-io.h:114
Collect the data about a weight file.
Definition: model-io.h:139