Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 08:55:35

0001 /*
0002  *  Copyright (c), 2017, Adrien Devresse <adrien.devresse@epfl.ch>
0003  *
0004  *  Distributed under the Boost Software License, Version 1.0.
0005  *    (See accompanying file LICENSE_1_0.txt or copy at
0006  *          http://www.boost.org/LICENSE_1_0.txt)
0007  *
0008  */
0009 #pragma once
0010 
0011 #include "../H5Easy.hpp"
0012 #include "H5Easy_misc.hpp"
0013 #include "H5Easy_scalar.hpp"
0014 
0015 #ifdef H5_USE_EIGEN
0016 
0017 namespace H5Easy {
0018 
0019 namespace detail {
0020 
0021 template <typename T>
0022 struct io_impl<T, typename std::enable_if<std::is_base_of<Eigen::DenseBase<T>, T>::value>::type> {
0023     // abbreviate row-major <-> col-major conversions
0024     template <typename S>
0025     struct types {
0026         using row_major = Eigen::Ref<
0027             const Eigen::Array<typename std::decay<T>::type::Scalar,
0028                                std::decay<T>::type::RowsAtCompileTime,
0029                                std::decay<T>::type::ColsAtCompileTime,
0030                                std::decay<T>::type::ColsAtCompileTime == 1 ? Eigen::ColMajor
0031                                                                            : Eigen::RowMajor,
0032                                std::decay<T>::type::MaxRowsAtCompileTime,
0033                                std::decay<T>::type::MaxColsAtCompileTime>,
0034             0,
0035             Eigen::InnerStride<1>>;
0036 
0037         using col_major =
0038             Eigen::Map<Eigen::Array<typename std::decay<T>::type::Scalar,
0039                                     std::decay<T>::type::RowsAtCompileTime,
0040                                     std::decay<T>::type::ColsAtCompileTime,
0041                                     std::decay<T>::type::ColsAtCompileTime == 1 ? Eigen::ColMajor
0042                                                                                 : Eigen::RowMajor,
0043                                     std::decay<T>::type::MaxRowsAtCompileTime,
0044                                     std::decay<T>::type::MaxColsAtCompileTime>>;
0045     };
0046 
0047     // return the shape of Eigen::DenseBase<T> object as size 1 or 2 "std::vector<size_t>"
0048     inline static std::vector<size_t> shape(const T& data) {
0049         if (std::decay<T>::type::RowsAtCompileTime == 1) {
0050             return {static_cast<size_t>(data.cols())};
0051         }
0052         if (std::decay<T>::type::ColsAtCompileTime == 1) {
0053             return {static_cast<size_t>(data.rows())};
0054         }
0055         return {static_cast<size_t>(data.rows()), static_cast<size_t>(data.cols())};
0056     }
0057 
0058     using EigenIndex = Eigen::DenseIndex;
0059 
0060     // get the shape of a "DataSet" as size 2 "std::vector<Eigen::Index>"
0061     template <class D>
0062     inline static std::vector<EigenIndex> shape(const File& file,
0063                                                 const std::string& path,
0064                                                 const D& dataset,
0065                                                 int RowsAtCompileTime) {
0066         std::vector<size_t> dims = dataset.getDimensions();
0067 
0068         if (dims.size() == 1 && RowsAtCompileTime == 1) {
0069             return std::vector<EigenIndex>{1u, static_cast<EigenIndex>(dims[0])};
0070         }
0071         if (dims.size() == 1) {
0072             return std::vector<EigenIndex>{static_cast<EigenIndex>(dims[0]), 1u};
0073         }
0074         if (dims.size() == 2) {
0075             return std::vector<EigenIndex>{static_cast<EigenIndex>(dims[0]),
0076                                            static_cast<EigenIndex>(dims[1])};
0077         }
0078 
0079         throw detail::error(file, path, "H5Easy::load: Inconsistent rank");
0080     }
0081 
0082     inline static DataSet dump(File& file,
0083                                const std::string& path,
0084                                const T& data,
0085                                const DumpOptions& options) {
0086         using row_major_type = typename types<T>::row_major;
0087         using value_type = typename std::decay<T>::type::Scalar;
0088         row_major_type row_major(data);
0089         DataSet dataset = initDataset<value_type>(file, path, shape(data), options);
0090         dataset.write_raw(row_major.data());
0091         if (options.flush()) {
0092             file.flush();
0093         }
0094         return dataset;
0095     }
0096 
0097     inline static T load(const File& file, const std::string& path) {
0098         DataSet dataset = file.getDataSet(path);
0099         std::vector<typename T::Index> dims = shape(file, path, dataset, T::RowsAtCompileTime);
0100         T data(dims[0], dims[1]);
0101         dataset.read_raw(data.data());
0102         if (data.IsVectorAtCompileTime || data.IsRowMajor) {
0103             return data;
0104         }
0105         using col_major = typename types<T>::col_major;
0106         return col_major(data.data(), dims[0], dims[1]);
0107     }
0108 
0109     inline static Attribute dumpAttribute(File& file,
0110                                           const std::string& path,
0111                                           const std::string& key,
0112                                           const T& data,
0113                                           const DumpOptions& options) {
0114         using row_major_type = typename types<T>::row_major;
0115         using value_type = typename std::decay<T>::type::Scalar;
0116         row_major_type row_major(data);
0117         Attribute attribute = initAttribute<value_type>(file, path, key, shape(data), options);
0118         attribute.write_raw(row_major.data());
0119         if (options.flush()) {
0120             file.flush();
0121         }
0122         return attribute;
0123     }
0124 
0125     inline static T loadAttribute(const File& file,
0126                                   const std::string& path,
0127                                   const std::string& key) {
0128         DataSet dataset = file.getDataSet(path);
0129         Attribute attribute = dataset.getAttribute(key);
0130         DataSpace dataspace = attribute.getSpace();
0131         std::vector<typename T::Index> dims = shape(file, path, dataspace, T::RowsAtCompileTime);
0132         T data(dims[0], dims[1]);
0133         attribute.read_raw(data.data());
0134         if (data.IsVectorAtCompileTime || data.IsRowMajor) {
0135             return data;
0136         }
0137         using col_major = typename types<T>::col_major;
0138         return col_major(data.data(), dims[0], dims[1]);
0139     }
0140 };
0141 
0142 }  // namespace detail
0143 }  // namespace H5Easy
0144 
0145 #endif  // H5_USE_EIGEN