11 #include "spdlog/fmt/fmt.h"
12 #include "spdlog/spdlog.h"
17 constexpr
const char MAGIC[] = {
'\x93',
'N',
'U',
'M',
'P',
'Y',
'\x03',
'\x00'};
24 auto current_position = source.tellg();
25 if(
auto num_read = source.readsome(buffer,
sizeof(buffer)); num_read !=
sizeof(buffer)) {
26 THROW_ERROR(
"Error when trying to read magic bytes. Read on {} bytes.", num_read);
28 source.seekg(current_position);
29 return (std::memcmp(buffer,
MAGIC,
sizeof(buffer)) == 0);
34 std::size_t total_length =
sizeof(
MAGIC) +
sizeof(std::uint32_t) + description.size() + 1;
37 std::uint32_t header_length = description.size() + padding + 1;
38 target.sputn(
reinterpret_cast<const char*
>(&header_length),
sizeof(header_length));
39 target.sputn(description.data(), description.size());
40 for(
unsigned i = 0; i < padding; ++i) {
43 if(target.sputc(
'\n') !=
'\n') {
44 THROW_ERROR(
"Could not write terminating newline to npy header");
49 return fmt::format(R
"({{"descr": "{}", "fortran_order": {}, "shape": ({},)}})", dtype_desc, column_major ? "True" :
"False", size);
53 return fmt::format(R
"({{"descr": "{}", "fortran_order": {}, "shape": ({}, {})}})", dtype_desc, column_major ? "True" :
"False", rows, cols);
56 #define REGISTER_DTYPE(TYPE, STRING) \
58 const char* data_type_string<TYPE>() { \
73 auto read_raw = [&](
auto& target){
74 auto num_read = source.sgetn(
reinterpret_cast<char*
>(&target),
sizeof(target));
75 if(num_read !=
sizeof(target)) {
76 THROW_ERROR(
"Unexpected end of data while reading header length");
80 int major = source.sbumpc();
81 int minor = source.sbumpc();
83 if(major == 2 || major == 3) {
84 std::uint32_t header_length;
85 read_raw(header_length);
89 std::uint16_t short_header_length;
90 read_raw(short_header_length);
91 return short_header_length;
93 THROW_ERROR(
"Unknown npy file format version {}.{} -- {}", major, minor, source.pubseekoff(0, std::ios_base::cur, std::ios_base::in));
97 while(std::isspace(source[position]) != 0 && position <
ssize(source)) {
105 std::pair<std::string_view, std::string_view>
read_key_value(std::string_view source) {
106 auto source_end =
ssize(source);
110 if(position == source_end) {
114 char open_quote = source[position];
115 long key_start = position;
117 if(open_quote !=
'"' && open_quote !=
'\'') {
118 THROW_ERROR(
"Expected begin of string ' or \" for parsing dictionary key. Got {}.", open_quote);
121 std::size_t key_end = source.find(open_quote, key_start + 1);
122 if(key_end == std::string_view::npos) {
123 THROW_ERROR(
"Could not find matching closing quotation mark `{}` for key string", open_quote);
128 if(position == source_end) {
129 THROW_ERROR(
"Could not find : that separates key and value");
132 if(source[position] !=
':') {
133 THROW_ERROR(
"Expected : to separate key and value, got {}", source[position]);
137 if(position == source_end) {
141 const char openers[] = {
'"',
'\'',
'(',
'[',
'{'};
142 const char closers[] = {
'"',
'\'',
')',
']',
'}'};
148 long value_start = position;
149 char expect_close = 0;
150 while(position < source_end) {
151 char current = source[position];
152 if(expect_close == 0) {
155 if(current ==
',' || current ==
'}') {
156 return {{source.begin() + key_start + 1, key_end - key_start - 1},
157 {source.begin() + value_start, std::size_t(position - value_start)}};
168 for(
int i = 0; i <
to_long(
sizeof(openers)); ++i) {
169 if(openers[i] == current) {
170 expect_close = closers[i];
176 }
else if(current == expect_close) {
182 if(expect_close != 0) {
183 THROW_ERROR(
"Expected closing {}, but reached end of input", expect_close);
185 THROW_ERROR(
"Expected } or , to signal end of input");
189 view.remove_prefix(1);
193 bool has_descr =
false;
194 bool has_order =
false;
195 bool has_shape =
false;
196 for(
int i = 0; i < 3; ++i) {
200 auto value = kv.second;
201 view = view.substr(value.end() - view.begin() + 1);
204 if(value.front() !=
'\'' && value.front() !=
'"') {
205 THROW_ERROR(
"expected string for descr, got '{}'", value);
207 result.
DataType = value.substr(1, value.size() - 2);
209 }
else if (key ==
"fortran_order") {
210 if(value ==
"False" || value ==
"0") {
212 }
else if (value ==
"True" || value ==
"1") {
215 std::string val_str{value};
216 THROW_ERROR(
"unexpected value '{}' for fortran_order", val_str);
219 }
else if(key ==
"shape") {
220 if(value.at(0) !=
'(') {
221 THROW_ERROR(
"expected ( to start tuple for shape");
223 auto sep = value.find(
',');
224 if(sep == std::string::npos) {
228 const char* endptr =
nullptr;
231 if(errno != 0 || endptr == value.begin() + 1) {
232 THROW_ERROR(
"error while trying to parse number for size");
234 if(result.
Rows < 0) {
240 THROW_ERROR(
"error while trying to parse number for size");
243 if(result.
Cols < 0) {
249 std::string key_str{key};
254 bool closed_dict =
false;
255 for(
const auto& c : view) {
256 if(std::isspace(c) != 0)
continue;
257 if(c ==
'}' && !closed_dict) {
269 THROW_ERROR(
"Missing 'fortran_order' entry in dict");
281 std::array<char, MAGIC_SIZE> magic{};
284 if(magic[i] !=
MAGIC[i]) {
291 std::string header_buffer(header_length,
'\0');
292 if(
auto num_read = source.sgetn(header_buffer.data(), header_length); num_read != header_length) {
293 THROW_ERROR(
"Expected to read a header of size {}, but only got {} elements", header_length, num_read);
297 if(header_buffer.at(0) !=
'{') {
298 THROW_ERROR(
"Expected data description dict to start with '{{', got '{}'. Header is: {}", header_buffer.at(0), header_buffer );
301 if(header_buffer.back() !=
'\n') {
302 THROW_ERROR(
"Expected newline \\n at end of header \"{}\"", header_buffer );
312 if(header.DataType != io::data_type_string<T>()) {
313 THROW_ERROR(
"Unsupported data type {}", header.DataType);
315 if(header.ColumnMajor) {
316 THROW_ERROR(
"Currently, only row-major npy files can be read");
320 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> target(header.Rows, header.Cols);
321 for(
int row = 0; row < target.rows(); ++row) {
322 auto row_data = target.row(row);
323 io::binary_load(source, row_data.data(), row_data.data() + row_data.size());
330 void save_matrix_to_npy_imp(std::streambuf& target,
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
334 for(
int row = 0; row < matrix.rows(); ++row) {
335 const auto& row_data = matrix.row(row);
336 io::binary_dump(target, row_data.data(), row_data.data() + row_data.size());
343 return load_matrix_from_npy_imp<real_t>(*source.rdbuf());
347 std::ifstream file(path);
348 if(!file.is_open()) {
349 THROW_ERROR(
"Could not open file {} for reading.", path)
355 const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
360 const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
361 std::ofstream file(path);
362 if(!file.is_open()) {
363 THROW_ERROR(
"Could not open file {} for writing.", path)
373 std::stringstream target;
374 std::string description =
"{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }";
375 std::string ground_truth(
"\x93NUMPY\x03\x00\x74\x00\x00\x00{'descr': '<f8', 'fortran_order': False, 'shape': (3,), } \n", 128);
378 std::string new_str = target.str();
379 CHECK(new_str == ground_truth);
383 std::stringstream src;
385 SUBCASE(
"read valid length v2/3") {
386 src.str(std::string(
"\x03\x00s\x00\x00\x00",
MAGIC_SIZE));
388 CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 6);
391 SUBCASE(
"read valid length v1") {
392 src.str(std::string(
"\x01\x00s\x00\x00\x00",
MAGIC_SIZE));
394 CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 4);
397 SUBCASE(
"invalid version") {
398 src.str(std::string(
"\x04\x00s\x00\x00\x00",
MAGIC_SIZE));
402 SUBCASE(
"end of data") {
403 src.str(std::string(
"\x03\x00s\x00", 4));
424 SUBCASE(
"double quotes") {
425 input =
"{\"key\": value}";
430 SUBCASE(
"single quotes") {
431 input =
"{'key': value}";
436 SUBCASE(
"tuple value") {
437 input =
"{'key': (1, 2, 3)}";
442 SUBCASE(
"multiple entries") {
443 input =
"{'key': a, \"other key\": b}";
448 SUBCASE(
"nested quotes") {
449 input =
"{\"key_with'\": value}";
454 SUBCASE(
"quoted value") {
455 input =
"{'key': 'a value that contains } and \" and ) and ]'}";
457 value =
"'a value that contains } and \" and ) and ]'";
460 std::string_view dict_contents = input;
461 dict_contents.remove_prefix(1);
463 CHECK(got_key == key);
464 CHECK(got_value == value);
468 SUBCASE(
"f8 c order vector") {
469 auto data =
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }");
470 CHECK(data.ColumnMajor ==
false);
471 CHECK(data.Rows == 3);
472 CHECK(data.Cols == 0);
473 CHECK(data.DataType ==
"<f8");
475 SUBCASE(
"reordered") {
476 auto data =
parse_description(
"{'fortran_order': False, 'shape': (3,), 'descr': '<f8'}");
477 CHECK(data.ColumnMajor ==
false);
478 CHECK(data.Rows == 3);
479 CHECK(data.Cols == 0);
480 CHECK(data.DataType ==
"<f8");
482 SUBCASE(
"i4 f order matrix no trailing comma") {
483 auto data =
parse_description(
"{'descr': \"<i4\", 'fortran_order': 1, 'shape': (5 , 7)}");
484 CHECK(data.ColumnMajor ==
true);
485 CHECK(data.Rows == 5);
486 CHECK(data.Cols == 7);
487 CHECK(data.DataType ==
"<i4");
489 SUBCASE(
"f8 c order matrix no whitespace") {
490 auto data =
parse_description(
"{'descr':'<f8','fortran_order':0,'shape':(5,7)}");
491 CHECK(data.ColumnMajor ==
false);
492 CHECK(data.Rows == 5);
493 CHECK(data.Cols == 7);
494 CHECK(data.DataType ==
"<f8");
499 SUBCASE(
"wrong value") {
500 CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': Unknown, 'shape': (3,), }"));
501 CHECK_THROWS(
parse_description(
"{'descr': (5, 4), 'fortran_order': False, 'shape': (3,), }"));
502 CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': 8 }"));
503 CHECK_THROWS(
parse_description(
"{'descr': 5, 'fortran_order': False, 'shape': (3,) }"));
504 CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (3) }"));
505 CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (a,) }"));
508 SUBCASE(
"missing key") {
514 SUBCASE(
"unexpected key") {
515 CHECK_THROWS(
parse_description(
"{'descr':'<f8','fortran_order':0,'shape':(5,7), 'other': 'value'}"));
516 CHECK_THROWS(
parse_description(
"{'descr':'<f8','other': 'value', 'fortran_order':0,'shape':(5,7)}"));
521 CHECK(
io::make_npy_description(
"<f8",
false, 5) ==
"{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (5,)}");
522 CHECK(
io::make_npy_description(
">i4",
true, 17) ==
"{\"descr\": \">i4\", \"fortran_order\": True, \"shape\": (17,)}");
523 CHECK(
io::make_npy_description(
"<f8",
false, 7, 5) ==
"{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (7, 5)}");
527 std::ostringstream save_stream;
528 types::DenseRowMajor<real_t> matrix = types::DenseRowMajor<real_t>::Random(4, 5);
531 std::istringstream load_stream;
532 load_stream.str(save_stream.str());
535 CHECK( matrix == ref );
building blocks for io procedures that are used by multiple io subsystems
TEST_CASE("numpy header with given description")
constexpr const unsigned NPY_PADDING
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > load_matrix_from_npy_imp(std::streambuf &source)
void save_matrix_to_npy_imp(std::streambuf &target, const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > &matrix)
long skip_whitespace(std::string_view source, long position)
constexpr const int MAGIC_SIZE
std::uint32_t read_header_length(std::streambuf &source)
io::NpyHeaderData parse_description(std::string_view view)
std::pair< std::string_view, std::string_view > read_key_value(std::string_view source)
This function parses a single element from a python dict literal.
constexpr const char MAGIC[]
void binary_dump(std::streambuf &target, const T *begin, const T *end)
std::string make_npy_description(std::string_view dtype_desc, bool column_major, std::size_t size)
Creates a string with the data description dictionary for (1 dimensional) arrays.
REGISTER_DTYPE(float, "<f4")
void binary_load(std::streambuf &target, T *begin, T *end)
void write_npy_header(std::streambuf &target, std::string_view description)
Writes the header for a npy file.
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
types::DenseRowMajor< real_t > load_matrix_from_npy(std::istream &source)
Loads a matrix from a numpy array.
long parse_long(const char *string, const char **out)
void save_matrix_to_npy(std::ostream &source, const types::DenseRowMajor< real_t > &)
Saves a matrix to a numpy array.
NpyHeaderData parse_npy_header(std::streambuf &source)
Parses the header of the npy file given by source.
Main namespace in which all types, classes, and functions are defined.
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
constexpr long to_long(T value)
Convert the given value to long, throwing an error if the conversion is not possible.