11 #include "spdlog/fmt/fmt.h" 
   12 #include "spdlog/spdlog.h" 
   17     constexpr 
const char MAGIC[] = {
'\x93', 
'N', 
'U', 
'M', 
'P', 
'Y', 
'\x03', 
'\x00'};
 
   24     auto current_position = source.tellg();
 
   25     if(
auto num_read = source.readsome(buffer, 
sizeof(buffer)); num_read != 
sizeof(buffer)) {
 
   26         THROW_ERROR(
"Error when trying to read magic bytes. Read on {} bytes.", num_read);
 
   28     source.seekg(current_position);
 
   29     return (std::memcmp(buffer, 
MAGIC, 
sizeof(buffer)) == 0);
 
   34     std::size_t total_length = 
sizeof(
MAGIC) + 
sizeof(std::uint32_t) + description.size() + 1;
 
   37     std::uint32_t header_length = description.size() + padding + 1;
 
   38     target.sputn(
reinterpret_cast<const char*
>(&header_length), 
sizeof(header_length));
 
   39     target.sputn(description.data(), description.size());
 
   40     for(
unsigned i = 0; i < padding; ++i) {
 
   43     if(target.sputc(
'\n') != 
'\n') {
 
   44         THROW_ERROR(
"Could not write terminating newline to npy header");
 
   49     return fmt::format(R
"({{"descr": "{}", "fortran_order": {}, "shape": ({},)}})", dtype_desc, column_major ? "True" : 
"False", size);
 
   53     return fmt::format(R
"({{"descr": "{}", "fortran_order": {}, "shape": ({}, {})}})", dtype_desc, column_major ? "True" : 
"False", rows, cols);
 
   56 #define REGISTER_DTYPE(TYPE, STRING)        \ 
   58 const char* data_type_string<TYPE>() {      \ 
   73         auto read_raw = [&](
auto& target){
 
   74             auto num_read = source.sgetn(
reinterpret_cast<char*
>(&target), 
sizeof(target));
 
   75             if(num_read != 
sizeof(target)) {
 
   76                 THROW_ERROR(
"Unexpected end of data while reading header length");
 
   80         int major = source.sbumpc();
 
   81         int minor = source.sbumpc();
 
   83         if(major == 2 || major == 3) {
 
   84             std::uint32_t header_length;
 
   85             read_raw(header_length);
 
   89             std::uint16_t short_header_length;
 
   90             read_raw(short_header_length);
 
   91             return short_header_length;
 
   93         THROW_ERROR(
"Unknown npy file format version {}.{} -- {}", major, minor, source.pubseekoff(0, std::ios_base::cur, std::ios_base::in));
 
   97         while(std::isspace(source[position]) != 0 && position < 
ssize(source)) {
 
  105     std::pair<std::string_view, std::string_view> 
read_key_value(std::string_view source) {
 
  106         auto source_end = 
ssize(source);
 
  110         if(position == source_end) {
 
  114         char open_quote = source[position];
 
  115         long key_start = position;
 
  117         if(open_quote != 
'"' && open_quote != 
'\'') {
 
  118             THROW_ERROR(
"Expected begin of string ' or \" for parsing dictionary key. Got {}.", open_quote);
 
  121         std::size_t key_end = source.find(open_quote, key_start + 1);
 
  122         if(key_end == std::string_view::npos) {
 
  123             THROW_ERROR(
"Could not find matching closing quotation mark `{}` for key string", open_quote);
 
  128         if(position == source_end) {
 
  129             THROW_ERROR(
"Could not find : that separates key and value");
 
  132         if(source[position] != 
':') {
 
  133             THROW_ERROR(
"Expected : to separate key and value, got {}", source[position]);
 
  137         if(position == source_end) {
 
  141         const char openers[] = {
'"', 
'\'', 
'(', 
'[', 
'{'};
 
  142         const char closers[] = {
'"', 
'\'', 
')', 
']', 
'}'};
 
  148         long value_start = position;
 
  149         char expect_close = 0;
 
  150         while(position < source_end) {
 
  151             char current = source[position];
 
  152             if(expect_close == 0) {
 
  155                 if(current == 
',' || current == 
'}') {
 
  156                     return {{source.begin() + key_start + 1, key_end - key_start - 1},
 
  157                             {source.begin() + value_start, std::size_t(position - value_start)}};
 
  168                         for(
int i = 0; i < 
to_long(
sizeof(openers)); ++i) {
 
  169                             if(openers[i] == current) {
 
  170                                 expect_close = closers[i];
 
  176             } 
else if(current == expect_close) {
 
  182         if(expect_close != 0) {
 
  183             THROW_ERROR(
"Expected closing {}, but reached end of input", expect_close);
 
  185         THROW_ERROR(
"Expected } or , to signal end of input");
 
  189         view.remove_prefix(1);
 
  193         bool has_descr = 
false;
 
  194         bool has_order = 
false;
 
  195         bool has_shape = 
false;
 
  196         for(
int i = 0; i < 3; ++i) {
 
  200             auto value = kv.second;
 
  201             view = view.substr(value.end() - view.begin() + 1);
 
  204                 if(value.front() != 
'\'' && value.front() != 
'"') {
 
  205                     THROW_ERROR(
"expected string for descr, got '{}'", value);
 
  207                 result.
DataType = value.substr(1, value.size() - 2);
 
  209             } 
else if (key == 
"fortran_order") {
 
  210                 if(value == 
"False" || value == 
"0") {
 
  212                 } 
else if (value == 
"True" || value == 
"1") {
 
  215                     std::string val_str{value};
 
  216                     THROW_ERROR(
"unexpected value '{}' for fortran_order", val_str);
 
  219             } 
else if(key == 
"shape") {
 
  220                 if(value.at(0) != 
'(') {
 
  221                     THROW_ERROR(
"expected ( to start tuple for shape");
 
  223                 auto sep = value.find(
',');
 
  224                 if(sep == std::string::npos) {
 
  228                 const char* endptr = 
nullptr;
 
  231                 if(errno != 0 || endptr == value.begin() + 1) {
 
  232                     THROW_ERROR(
"error while trying to parse number for size");
 
  234                 if(result.
Rows < 0) {
 
  240                     THROW_ERROR(
"error while trying to parse number for size");
 
  243                 if(result.
Cols < 0) {
 
  249                 std::string key_str{key};
 
  254         bool closed_dict = 
false;
 
  255         for(
const auto& c : view) {
 
  256             if(std::isspace(c) != 0) 
continue;
 
  257             if(c == 
'}' && !closed_dict) {
 
  269             THROW_ERROR(
"Missing 'fortran_order' entry in dict");
 
  281     std::array<char, MAGIC_SIZE> magic{};
 
  284         if(magic[i] != 
MAGIC[i]) {
 
  291     std::string header_buffer(header_length, 
'\0');
 
  292     if(
auto num_read = source.sgetn(header_buffer.data(), header_length); num_read != header_length) {
 
  293         THROW_ERROR(
"Expected to read a header of size {}, but only got {} elements", header_length, num_read);
 
  297     if(header_buffer.at(0) != 
'{') {
 
  298         THROW_ERROR(
"Expected data description dict to start with '{{', got '{}'. Header is: {}", header_buffer.at(0), header_buffer );
 
  301     if(header_buffer.back() != 
'\n') {
 
  302         THROW_ERROR(
"Expected newline \\n at end of header \"{}\"", header_buffer );
 
  312         if(header.DataType != io::data_type_string<T>()) {
 
  313             THROW_ERROR(
"Unsupported data type {}", header.DataType);
 
  315         if(header.ColumnMajor) {
 
  316             THROW_ERROR(
"Currently, only row-major npy files can be read");
 
  320         Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> target(header.Rows, header.Cols);
 
  321         for(
int row = 0; row < target.rows(); ++row) {
 
  322             auto row_data = target.row(row);
 
  323             io::binary_load(source, row_data.data(), row_data.data() + row_data.size());
 
  330     void save_matrix_to_npy_imp(std::streambuf& target, 
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
 
  334         for(
int row = 0; row < matrix.rows(); ++row) {
 
  335             const auto& row_data = matrix.row(row);
 
  336             io::binary_dump(target, row_data.data(), row_data.data() + row_data.size());
 
  343     return load_matrix_from_npy_imp<real_t>(*source.rdbuf());
 
  347     std::ifstream file(path);
 
  348     if(!file.is_open()) {
 
  349         THROW_ERROR(
"Could not open file {} for reading.", path)
 
  355                             const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
 
  360                             const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
 
  361     std::ofstream file(path);
 
  362     if(!file.is_open()) {
 
  363         THROW_ERROR(
"Could not open file {} for writing.", path)
 
  373     std::stringstream target;
 
  374     std::string description = 
"{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }";
 
  375     std::string ground_truth(
"\x93NUMPY\x03\x00\x74\x00\x00\x00{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }                                                          \n", 128);
 
  378     std::string new_str = target.str();
 
  379     CHECK(new_str == ground_truth);
 
  383     std::stringstream src;
 
  385     SUBCASE(
"read valid length v2/3") {
 
  386         src.str(std::string(
"\x03\x00s\x00\x00\x00", 
MAGIC_SIZE));
 
  388         CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 6);
 
  391     SUBCASE(
"read valid length v1") {
 
  392         src.str(std::string(
"\x01\x00s\x00\x00\x00", 
MAGIC_SIZE));
 
  394         CHECK(src.rdbuf()->pubseekoff(0, std::ios_base::cur, std::ios_base::in) == 4);
 
  397     SUBCASE(
"invalid version") {
 
  398         src.str(std::string(
"\x04\x00s\x00\x00\x00", 
MAGIC_SIZE));
 
  402     SUBCASE(
"end of data") {
 
  403         src.str(std::string(
"\x03\x00s\x00", 4));
 
  424     SUBCASE(
"double quotes") {
 
  425         input = 
"{\"key\": value}";
 
  430     SUBCASE(
"single quotes") {
 
  431         input = 
"{'key': value}";
 
  436     SUBCASE(
"tuple value") {
 
  437         input = 
"{'key': (1, 2, 3)}";
 
  442     SUBCASE(
"multiple entries") {
 
  443         input = 
"{'key': a, \"other key\": b}";
 
  448     SUBCASE(
"nested quotes") {
 
  449         input = 
"{\"key_with'\":  value}";
 
  454     SUBCASE(
"quoted value") {
 
  455         input = 
"{'key': 'a value that contains } and \" and ) and ]'}";
 
  457         value = 
"'a value that contains } and \" and ) and ]'";
 
  460     std::string_view dict_contents = input;
 
  461     dict_contents.remove_prefix(1);
 
  463     CHECK(got_key == key);
 
  464     CHECK(got_value == value);
 
  468     SUBCASE(
"f8 c order vector") {
 
  469         auto data = 
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (3,), }");
 
  470         CHECK(data.ColumnMajor == 
false);
 
  471         CHECK(data.Rows == 3);
 
  472         CHECK(data.Cols == 0);
 
  473         CHECK(data.DataType == 
"<f8");
 
  475     SUBCASE(
"reordered") {
 
  476         auto data = 
parse_description(
"{'fortran_order': False, 'shape': (3,), 'descr': '<f8'}");
 
  477         CHECK(data.ColumnMajor == 
false);
 
  478         CHECK(data.Rows == 3);
 
  479         CHECK(data.Cols == 0);
 
  480         CHECK(data.DataType == 
"<f8");
 
  482     SUBCASE(
"i4 f order matrix no trailing comma") {
 
  483         auto data = 
parse_description(
"{'descr': \"<i4\", 'fortran_order': 1, 'shape': (5 , 7)}");
 
  484         CHECK(data.ColumnMajor == 
true);
 
  485         CHECK(data.Rows == 5);
 
  486         CHECK(data.Cols == 7);
 
  487         CHECK(data.DataType == 
"<i4");
 
  489     SUBCASE(
"f8 c order matrix no whitespace") {
 
  490         auto data = 
parse_description(
"{'descr':'<f8','fortran_order':0,'shape':(5,7)}");
 
  491         CHECK(data.ColumnMajor == 
false);
 
  492         CHECK(data.Rows == 5);
 
  493         CHECK(data.Cols == 7);
 
  494         CHECK(data.DataType == 
"<f8");
 
  499     SUBCASE(
"wrong value") {
 
  500         CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': Unknown, 'shape': (3,), }"));
 
  501         CHECK_THROWS(
parse_description(
"{'descr': (5, 4), 'fortran_order': False, 'shape': (3,), }"));
 
  502         CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': 8 }"));
 
  503         CHECK_THROWS(
parse_description(
"{'descr': 5, 'fortran_order': False, 'shape': (3,) }"));
 
  504         CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (3) }"));
 
  505         CHECK_THROWS(
parse_description(
"{'descr': '<f8', 'fortran_order': False, 'shape': (a,) }"));
 
  508     SUBCASE(
"missing key") {
 
  514     SUBCASE(
"unexpected key") {
 
  515         CHECK_THROWS(
parse_description(
"{'descr':'<f8','fortran_order':0,'shape':(5,7), 'other': 'value'}"));
 
  516         CHECK_THROWS(
parse_description(
"{'descr':'<f8','other': 'value', 'fortran_order':0,'shape':(5,7)}"));
 
  521     CHECK(
io::make_npy_description(
"<f8", 
false, 5) == 
"{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (5,)}");
 
  522     CHECK(
io::make_npy_description(
">i4", 
true, 17) == 
"{\"descr\": \">i4\", \"fortran_order\": True, \"shape\": (17,)}");
 
  523     CHECK(
io::make_npy_description(
"<f8", 
false, 7, 5) == 
"{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (7, 5)}");
 
  527     std::ostringstream save_stream;
 
  528     types::DenseRowMajor<real_t> matrix = types::DenseRowMajor<real_t>::Random(4, 5);
 
  531     std::istringstream load_stream;
 
  532     load_stream.str(save_stream.str());
 
  535     CHECK( matrix == ref );
 
building blocks for io procedures that are used by multiple io subsystems
TEST_CASE("numpy header with given description")
constexpr const unsigned NPY_PADDING
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > load_matrix_from_npy_imp(std::streambuf &source)
void save_matrix_to_npy_imp(std::streambuf &target, const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > &matrix)
long skip_whitespace(std::string_view source, long position)
constexpr const int MAGIC_SIZE
std::uint32_t read_header_length(std::streambuf &source)
io::NpyHeaderData parse_description(std::string_view view)
std::pair< std::string_view, std::string_view > read_key_value(std::string_view source)
This function parses a single element from a python dict literal.
constexpr const char MAGIC[]
void binary_dump(std::streambuf &target, const T *begin, const T *end)
std::string make_npy_description(std::string_view dtype_desc, bool column_major, std::size_t size)
Creates a string with the data description dictionary for (1 dimensional) arrays.
REGISTER_DTYPE(float, "<f4")
void binary_load(std::streambuf &target, T *begin, T *end)
void write_npy_header(std::streambuf &target, std::string_view description)
Writes the header for a npy file.
bool is_npy(std::istream &target)
Check whether the stream is a npy file.
types::DenseRowMajor< real_t > load_matrix_from_npy(std::istream &source)
Loads a matrix from a numpy array.
long parse_long(const char *string, const char **out)
void save_matrix_to_npy(std::ostream &source, const types::DenseRowMajor< real_t > &)
Saves a matrix to a numpy array.
NpyHeaderData parse_npy_header(std::streambuf &source)
Parses the header of the npy file given by source.
Main namespace in which all types, classes, and functions are defined.
constexpr auto ssize(const C &c) -> std::common_type_t< std::ptrdiff_t, std::make_signed_t< decltype(c.size())>>
signed size free function. Taken from https://en.cppreference.com/w/cpp/iterator/size
constexpr long to_long(T value)
Convert the given value to long, throwing an error if the conversion is not possible.