DiSMEC++
numpy.cpp
Go to the documentation of this file.
1 // Copyright (c) 2021, Aalto University, developed by Erik Schultheis
2 // All rights reserved.
3 //
4 // SPDX-License-Identifier: MIT
5 
6 #include "numpy.h"
7 #include "io/common.h"
8 #include <ostream>
9 #include <fstream>
10 #include <cstdint>
11 #include "spdlog/fmt/fmt.h"
12 #include "spdlog/spdlog.h"
13 
14 using namespace dismec;
15 
16 namespace {
17  constexpr const char MAGIC[] = {'\x93', 'N', 'U', 'M', 'P', 'Y', '\x03', '\x00'};
18  constexpr const int MAGIC_SIZE = 6;
19  constexpr const unsigned NPY_PADDING = 64u;
20 }
21 
22 bool io::is_npy(std::istream& source) {
23  char buffer[MAGIC_SIZE];
24  auto current_position = source.tellg();
25  if(auto num_read = source.readsome(buffer, sizeof(buffer)); num_read != sizeof(buffer)) {
26  THROW_ERROR("Error when trying to read magic bytes. Read on {} bytes.", num_read);
27  }
28  source.seekg(current_position);
29  return (std::memcmp(buffer, MAGIC, sizeof(buffer)) == 0);
30 }
31 
32 void io::write_npy_header(std::streambuf& target, std::string_view description) {
33  target.sputn(MAGIC, sizeof(MAGIC));
34  std::size_t total_length = sizeof(MAGIC) + sizeof(std::uint32_t) + description.size() + 1;
35  unsigned padding = NPY_PADDING - total_length % NPY_PADDING;
36 
37  std::uint32_t header_length = description.size() + padding + 1;
38  target.sputn(reinterpret_cast<const char*>(&header_length), sizeof(header_length));
39  target.sputn(description.data(), description.size());
40  for(unsigned i = 0; i < padding; ++i) {
41  target.sputc('\x20');
42  }
43  if(target.sputc('\n') != '\n') {
44  THROW_ERROR("Could not write terminating newline to npy header");
45  }
46 }
47 
48 std::string io::make_npy_description(std::string_view dtype_desc, bool column_major, std::size_t size) {
49  return fmt::format(R"({{"descr": "{}", "fortran_order": {}, "shape": ({},)}})", dtype_desc, column_major ? "True" : "False", size);
50 }
51 
52 std::string io::make_npy_description(std::string_view dtype_desc, bool column_major, std::size_t rows, std::size_t cols) {
53  return fmt::format(R"({{"descr": "{}", "fortran_order": {}, "shape": ({}, {})}})", dtype_desc, column_major ? "True" : "False", rows, cols);
54 }
55 
56 #define REGISTER_DTYPE(TYPE, STRING) \
57 template<> \
58 const char* data_type_string<TYPE>() { \
59  return STRING; \
60 }
61 
62 namespace dismec::io {
63  REGISTER_DTYPE(float, "<f4");
64  REGISTER_DTYPE(double, "<f8");
65  REGISTER_DTYPE(std::int32_t, "<i4");
66  REGISTER_DTYPE(std::int64_t, "<i8");
67  REGISTER_DTYPE(std::uint32_t, "<u4");
68  REGISTER_DTYPE(std::uint64_t, "<u8");
69 }
70 
71 namespace {
72  std::uint32_t read_header_length(std::streambuf& source) {
73  auto read_raw = [&](auto& target){
74  auto num_read = source.sgetn(reinterpret_cast<char*>(&target), sizeof(target));
75  if(num_read != sizeof(target)) {
76  THROW_ERROR("Unexpected end of data while reading header length");
77  }
78  };
79 
80  int major = source.sbumpc();
81  int minor = source.sbumpc();
82 
83  if(major == 2 || major == 3) {
84  std::uint32_t header_length;
85  read_raw(header_length);
86  return header_length;
87  }
88  if (major == 1) {
89  std::uint16_t short_header_length;
90  read_raw(short_header_length);
91  return short_header_length;
92  }
93  THROW_ERROR("Unknown npy file format version {}.{} -- {}", major, minor, source.pubseekoff(0, std::ios_base::cur, std::ios_base::in));
94  };
95 
96  long skip_whitespace(std::string_view source, long position) {
97  while(std::isspace(source[position]) != 0 && position < ssize(source)) {
98  ++position;
99  }
100 
101  return position;
102  }
103 
105  std::pair<std::string_view, std::string_view> read_key_value(std::string_view source) {
106  auto source_end = ssize(source);
107 
108  // skip all initial whitespace
109  long position = skip_whitespace(source, 0);
110  if(position == source_end) {
111  THROW_ERROR("received only whitespace");
112  }
113 
114  char open_quote = source[position];
115  long key_start = position;
116  // next, we should get a dictionary key. This is indicated by quotation marks
117  if(open_quote != '"' && open_quote != '\'') {
118  THROW_ERROR("Expected begin of string ' or \" for parsing dictionary key. Got {}.", open_quote);
119  }
120 
121  std::size_t key_end = source.find(open_quote, key_start + 1);
122  if(key_end == std::string_view::npos) {
123  THROW_ERROR("Could not find matching closing quotation mark `{}` for key string", open_quote);
124  }
125 
126  // next, we expect a colon to separate the value
127  position = skip_whitespace(source, to_long(key_end) + 1);
128  if(position == source_end) {
129  THROW_ERROR("Could not find : that separates key and value");
130  }
131 
132  if(source[position] != ':') {
133  THROW_ERROR("Expected : to separate key and value, got {}", source[position]);
134  }
135 
136  position = skip_whitespace(source, position + 1);
137  if(position == source_end) {
138  THROW_ERROR("Missing feature");
139  }
140 
141  const char openers[] = {'"', '\'', '(', '[', '{'};
142  const char closers[] = {'"', '\'', ')', ']', '}'};
143 
144  // to keep the code simple, we do not support nesting or escaping of
145  // delimiters. For the intended use case, that should be enough, but
146  // it means that this function cannot be used for general npy files
147 
148  long value_start = position;
149  char expect_close = 0;
150  while(position < source_end) {
151  char current = source[position];
152  if(expect_close == 0) {
153  // if we are not in a nested expression, the end of the current value is reached if we find a comma
154  // or closing brace
155  if(current == ',' || current == '}') {
156  return {{source.begin() + key_start + 1, key_end - key_start - 1},
157  {source.begin() + value_start, std::size_t(position - value_start)}};
158  }
159 
160  // if we are opening a nested expression, figure out which char we are waiting for next
161  switch (current) {
162  case '"':
163  case '\'':
164  case '(':
165  case '[':
166  case '{':
167  {
168  for(int i = 0; i < to_long(sizeof(openers)); ++i) {
169  if(openers[i] == current) {
170  expect_close = closers[i];
171  }
172  }
173  }
174  default: break;
175  }
176  } else if(current == expect_close) {
177  expect_close = 0;
178  }
179  ++position;
180  }
181 
182  if(expect_close != 0) {
183  THROW_ERROR("Expected closing {}, but reached end of input", expect_close);
184  }
185  THROW_ERROR("Expected } or , to signal end of input");
186  }
187 
188  io::NpyHeaderData parse_description(std::string_view view) {
189  view.remove_prefix(1);
190 
191  io::NpyHeaderData result;
192 
193  bool has_descr = false;
194  bool has_order = false;
195  bool has_shape = false;
196  for(int i = 0; i < 3; ++i) {
197  // can't use structured bindings here, because apparently they cannot be captured in the THROW_ERROR lambda
198  auto kv = read_key_value(view);
199  auto key = kv.first;
200  auto value = kv.second;
201  view = view.substr(value.end() - view.begin() + 1);
202 
203  if(key == "descr") {
204  if(value.front() != '\'' && value.front() != '"') {
205  THROW_ERROR("expected string for descr, got '{}'", value);
206  }
207  result.DataType = value.substr(1, value.size() - 2);
208  has_descr = true;
209  } else if (key == "fortran_order") {
210  if(value == "False" || value == "0") {
211  result.ColumnMajor = false;
212  } else if (value == "True" || value == "1") {
213  result.ColumnMajor = true;
214  } else {
215  std::string val_str{value};
216  THROW_ERROR("unexpected value '{}' for fortran_order", val_str);
217  }
218  has_order = true;
219  } else if(key == "shape") {
220  if(value.at(0) != '(') {
221  THROW_ERROR("expected ( to start tuple for shape");
222  }
223  auto sep = value.find(',');
224  if(sep == std::string::npos) {
225  THROW_ERROR("Expected comma in tuple definition");
226  }
227 
228  const char* endptr = nullptr;
229  errno = 0;
230  result.Rows = io::parse_long( value.begin() + 1, &endptr);
231  if(errno != 0 || endptr == value.begin() + 1) {
232  THROW_ERROR("error while trying to parse number for size");
233  }
234  if(result.Rows < 0) {
235  THROW_ERROR("Number of rows cannot be negative. Got {}", result.Rows);
236  }
237 
238  result.Cols = io::parse_long( value.begin() + sep + 1, &endptr);
239  if(errno != 0) {
240  THROW_ERROR("error while trying to parse number for size");
241  }
242 
243  if(result.Cols < 0) {
244  THROW_ERROR("Number of rows cannot be negative. Got {}", result.Cols);
245  }
246 
247  has_shape = true;
248  } else {
249  std::string key_str{key};
250  THROW_ERROR("unexpected key '{}'", key_str);
251  }
252  }
253 
254  bool closed_dict = false;
255  for(const auto& c : view) {
256  if(std::isspace(c) != 0) continue;
257  if(c == '}' && !closed_dict) {
258  closed_dict = true;
259  continue;
260  }
261  THROW_ERROR("Trailing '{}'", c);
262  }
263 
264  if(!has_descr) {
265  THROW_ERROR("Missing 'descr' entry in dict");
266  }
267 
268  if(!has_order) {
269  THROW_ERROR("Missing 'fortran_order' entry in dict");
270  }
271 
272  if(!has_shape) {
273  THROW_ERROR("Missing 'shape' entry in dict");
274  }
275 
276  return result;
277  }
278 }
279 #include <iostream>
280 io::NpyHeaderData io::parse_npy_header(std::streambuf& source) {
281  std::array<char, MAGIC_SIZE> magic{};
282  source.sgetn(magic.data(), MAGIC_SIZE);
283  for(int i = 0; i < MAGIC_SIZE; ++i) {
284  if(magic[i] != MAGIC[i]) {
285  THROW_ERROR("Magic bytes mismatch");
286  }
287  }
288 
289  std::uint32_t header_length = read_header_length(source);
290 
291  std::string header_buffer(header_length, '\0');
292  if(auto num_read = source.sgetn(header_buffer.data(), header_length); num_read != header_length) {
293  THROW_ERROR("Expected to read a header of size {}, but only got {} elements", header_length, num_read);
294  }
295 
296  // OK, now for the actual parsing of the dict.
297  if(header_buffer.at(0) != '{') {
298  THROW_ERROR("Expected data description dict to start with '{{', got '{}'. Header is: {}", header_buffer.at(0), header_buffer );
299  }
300 
301  if(header_buffer.back() != '\n') {
302  THROW_ERROR("Expected newline \\n at end of header \"{}\"", header_buffer );
303  }
304 
305  return parse_description(header_buffer);
306 }
307 
308 namespace {
309  template<class T>
310  Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> load_matrix_from_npy_imp(std::streambuf& source) {
311  auto header = io::parse_npy_header(source);
312  if(header.DataType != io::data_type_string<T>()) {
313  THROW_ERROR("Unsupported data type {}", header.DataType);
314  }
315  if(header.ColumnMajor) {
316  THROW_ERROR("Currently, only row-major npy files can be read");
317  }
318 
319  // load the matrix row-by-row, to make sure this works even if Eigen decides to include padding
320  Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> target(header.Rows, header.Cols);
321  for(int row = 0; row < target.rows(); ++row) {
322  auto row_data = target.row(row);
323  io::binary_load(source, row_data.data(), row_data.data() + row_data.size());
324  }
325 
326  return target;
327  }
328 
329  template<class T>
330  void save_matrix_to_npy_imp(std::streambuf& target, const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
332 
333  // save the matrix row-by-row, to make sure this works even if Eigen decides to include padding
334  for(int row = 0; row < matrix.rows(); ++row) {
335  const auto& row_data = matrix.row(row);
336  io::binary_dump(target, row_data.data(), row_data.data() + row_data.size());
337  }
338  }
339 
340 }
341 
342 Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> io::load_matrix_from_npy(std::istream& source) {
343  return load_matrix_from_npy_imp<real_t>(*source.rdbuf());
344 }
345 
346 Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> io::load_matrix_from_npy(const std::string& path) {
347  std::ifstream file(path);
348  if(!file.is_open()) {
349  THROW_ERROR("Could not open file {} for reading.", path)
350  }
351  return load_matrix_from_npy(file);
352 }
353 
354 void io::save_matrix_to_npy(std::ostream& source,
355  const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
356  save_matrix_to_npy_imp(*source.rdbuf(), matrix);
357 }
358 
359 void io::save_matrix_to_npy(const std::string& path,
360  const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
361  std::ofstream file(path);
362  if(!file.is_open()) {
363  THROW_ERROR("Could not open file {} for writing.", path)
364  }
365  return save_matrix_to_npy(file, matrix);
366 }
367 
368 
369 #include "doctest.h"
370 #include <sstream>
371 
372 TEST_CASE("numpy header with given description") {
373  std::stringstream target;
374  std::string description = "{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }";
375  std::string ground_truth("\x93NUMPY\x03\x00\x74\x00\x00\x00{'descr': '<f8', 'fortran_order': False, 'shape': (3,), } \n", 128);
376  io::write_npy_header(*target.rdbuf(), description);
377 
378  std::string new_str = target.str();
379  CHECK(new_str == ground_truth);
380 }
381 
382 TEST_CASE("header length test") {
383  std::stringstream src;
384 
385  SUBCASE("read valid length v2/3") {
386  src.str(std::string("\x03\x00s\x00\x00\x00", MAGIC_SIZE));
387  CHECK(read_header_length(*src.rdbuf()) == 's');
388  CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 6);
389  }
390 
391  SUBCASE("read valid length v1") {
392  src.str(std::string("\x01\x00s\x00\x00\x00", MAGIC_SIZE));
393  CHECK(read_header_length(*src.rdbuf()) == 's');
394  CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 4);
395  }
396 
397  SUBCASE("invalid version") {
398  src.str(std::string("\x04\x00s\x00\x00\x00", MAGIC_SIZE));
399  CHECK_THROWS(read_header_length(*src.rdbuf()));
400  }
401 
402  SUBCASE("end of data") {
403  src.str(std::string("\x03\x00s\x00", 4));
404  CHECK_THROWS(read_header_length(*src.rdbuf()));
405  }
406 }
407 
408 TEST_CASE("read key value error check") {
409  CHECK_THROWS(read_key_value(" "));
410  CHECK_THROWS(read_key_value(" key'"));
411  CHECK_THROWS(read_key_value("'key' "));
412  CHECK_THROWS(read_key_value("'key': "));
413  CHECK_THROWS(read_key_value("'key "));
414  CHECK_THROWS(read_key_value("'key' error:"));
415  CHECK_THROWS(read_key_value("'key': 'value"));
416  CHECK_THROWS(read_key_value("'key': (1, 2]"));
417 }
418 
419 TEST_CASE("read key value test") {
420  std::string input;
421  std::string key;
422  std::string value;
423 
424  SUBCASE("double quotes") {
425  input = "{\"key\": value}";
426  key = "key";
427  value = "value";
428  }
429 
430  SUBCASE("single quotes") {
431  input = "{'key': value}";
432  key = "key";
433  value = "value";
434  }
435 
436  SUBCASE("tuple value") {
437  input = "{'key': (1, 2, 3)}";
438  key = "key";
439  value = "(1, 2, 3)";
440  }
441 
442  SUBCASE("multiple entries") {
443  input = "{'key': a, \"other key\": b}";
444  key = "key";
445  value = "a";
446  }
447 
448  SUBCASE("nested quotes") {
449  input = "{\"key_with'\": value}";
450  key = "key_with'";
451  value = "value";
452  }
453 
454  SUBCASE("quoted value") {
455  input = "{'key': 'a value that contains } and \" and ) and ]'}";
456  key = "key";
457  value = "'a value that contains } and \" and ) and ]'";
458  }
459 
460  std::string_view dict_contents = input;
461  dict_contents.remove_prefix(1);
462  auto [got_key, got_value] = read_key_value(dict_contents);
463  CHECK(got_key == key);
464  CHECK(got_value == value);
465 }
466 
467 TEST_CASE("parse description -- valid") {
468  SUBCASE("f8 c order vector") {
469  auto data = parse_description("{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }");
470  CHECK(data.ColumnMajor == false);
471  CHECK(data.Rows == 3);
472  CHECK(data.Cols == 0);
473  CHECK(data.DataType == "<f8");
474  }
475  SUBCASE("reordered") {
476  auto data = parse_description("{'fortran_order': False, 'shape': (3,), 'descr': '<f8'}");
477  CHECK(data.ColumnMajor == false);
478  CHECK(data.Rows == 3);
479  CHECK(data.Cols == 0);
480  CHECK(data.DataType == "<f8");
481  }
482  SUBCASE("i4 f order matrix no trailing comma") {
483  auto data = parse_description("{'descr': \"<i4\", 'fortran_order': 1, 'shape': (5 , 7)}");
484  CHECK(data.ColumnMajor == true);
485  CHECK(data.Rows == 5);
486  CHECK(data.Cols == 7);
487  CHECK(data.DataType == "<i4");
488  }
489  SUBCASE("f8 c order matrix no whitespace") {
490  auto data = parse_description("{'descr':'<f8','fortran_order':0,'shape':(5,7)}");
491  CHECK(data.ColumnMajor == false);
492  CHECK(data.Rows == 5);
493  CHECK(data.Cols == 7);
494  CHECK(data.DataType == "<f8");
495  }
496 }
497 
498 TEST_CASE("parse description -- errors") {
499  SUBCASE("wrong value") {
500  CHECK_THROWS(parse_description("{'descr': '<f8', 'fortran_order': Unknown, 'shape': (3,), }"));
501  CHECK_THROWS(parse_description("{'descr': (5, 4), 'fortran_order': False, 'shape': (3,), }"));
502  CHECK_THROWS(parse_description("{'descr': '<f8', 'fortran_order': False, 'shape': 8 }"));
503  CHECK_THROWS(parse_description("{'descr': 5, 'fortran_order': False, 'shape': (3,) }"));
504  CHECK_THROWS(parse_description("{'descr': '<f8', 'fortran_order': False, 'shape': (3) }"));
505  CHECK_THROWS(parse_description("{'descr': '<f8', 'fortran_order': False, 'shape': (a,) }"));
506  }
507 
508  SUBCASE("missing key") {
509  CHECK_THROWS(parse_description("{'fortran_order':0,'shape':(5,7)}"));
510  CHECK_THROWS(parse_description("{'descr':'<f8','shape':(5,7)}"));
511  CHECK_THROWS(parse_description("{'descr':'<f8','fortran_order':0"));
512  }
513 
514  SUBCASE("unexpected key") {
515  CHECK_THROWS(parse_description("{'descr':'<f8','fortran_order':0,'shape':(5,7), 'other': 'value'}"));
516  CHECK_THROWS(parse_description("{'descr':'<f8','other': 'value', 'fortran_order':0,'shape':(5,7)}"));
517  }
518 }
519 
520 TEST_CASE("make description") {
521  CHECK(io::make_npy_description("<f8", false, 5) == "{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (5,)}");
522  CHECK(io::make_npy_description(">i4", true, 17) == "{\"descr\": \">i4\", \"fortran_order\": True, \"shape\": (17,)}");
523  CHECK(io::make_npy_description("<f8", false, 7, 5) == "{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (7, 5)}");
524 }
525 
526 TEST_CASE("save/load round trip") {
527  std::ostringstream save_stream;
528  types::DenseRowMajor<real_t> matrix = types::DenseRowMajor<real_t>::Random(4, 5);
529  io::save_matrix_to_npy(save_stream, matrix);
530 
531  std::istringstream load_stream;
532  load_stream.str(save_stream.str());
533  auto ref = io::load_matrix_from_npy(load_stream);
534 
535  CHECK( matrix == ref );
536 }
building blocks for io procedures that are used by multiple io subsystems
#define THROW_ERROR(...)
Definition: common.h:23
TEST_CASE("numpy header with given description")
Definition: numpy.cpp:372
constexpr const unsigned NPY_PADDING
Definition: numpy.cpp:19
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > load_matrix_from_npy_imp(std::streambuf &source)
Definition: numpy.cpp:310
void save_matrix_to_npy_imp(std::streambuf &target, const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > &matrix)
Definition: numpy.cpp:330
long skip_whitespace(std::string_view source, long position)
Definition: numpy.cpp:96
constexpr const int MAGIC_SIZE
Definition: numpy.cpp:18
std::uint32_t read_header_length(std::streambuf &source)
Definition: numpy.cpp:72
io::NpyHeaderData parse_description(std::string_view view)
Definition: numpy.cpp:188
std::pair< std::string_view, std::string_view > read_key_value(std::string_view source)
This function parses a single element from a python dict literal.
Definition: numpy.cpp:105
constexpr const char MAGIC[]
Definition: numpy.cpp:17
void binary_dump(std::streambuf &target, const T *begin, const T *end)
Definition: common.h:110
std::string make_npy_description(std::string_view dtype_desc, bool column_major, std::size_t size)
Creates a string with the data description dictionary for (1 dimensional) arrays.
Definition: numpy.cpp:48
REGISTER_DTYPE(float, "<f4")
void binary_load(std::streambuf &target, T *begin, T *end)
Definition: common.h:120
void write_npy_header(std::streambuf &target, std::string_view description)
Writes the header for a npy file.
Definition: numpy.cpp:32
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
Definition: numpy.cpp:22
types::DenseRowMajor< real_t > load_matrix_from_npy(std::istream &source)
Loads a matrix from a numpy array.
Definition: numpy.cpp:342
long parse_long(const char *string, const char **out)
Definition: common.h:34
void save_matrix_to_npy(std::ostream &source, const types::DenseRowMajor< real_t > &)
Saves a matrix to a numpy array.
NpyHeaderData parse_npy_header(std::streambuf &source)
Parses the header of the npy file given by source.
Definition: numpy.cpp:280
Main namespace in which all types, classes, and functions are defined.
Definition: app.h:15
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
Definition: conversion.h:42
constexpr long to_long(T value)
Convert the given value to long, throwing an error if the conversion is not possible.
Definition: conversion.h:14
Contains the data of the header of a npy file with an array that has at most 2 dimensions.
Definition: numpy.h:56
long Rows
The number of rows in the data.
Definition: numpy.h:59
std::string DataType
The data type descr
Definition: numpy.h:57
bool ColumnMajor
Whether the data is column major (Fortran)
Definition: numpy.h:58