File indexing completed on 2025-04-19 08:55:35
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "../H5Easy.hpp"
0012 #include "H5Easy_misc.hpp"
0013 #include "H5Easy_scalar.hpp"
0014
0015 #ifdef H5_USE_EIGEN
0016
0017 namespace H5Easy {
0018
0019 namespace detail {
0020
0021 template <typename T>
0022 struct io_impl<T, typename std::enable_if<std::is_base_of<Eigen::DenseBase<T>, T>::value>::type> {
0023
0024 template <typename S>
0025 struct types {
0026 using row_major = Eigen::Ref<
0027 const Eigen::Array<typename std::decay<T>::type::Scalar,
0028 std::decay<T>::type::RowsAtCompileTime,
0029 std::decay<T>::type::ColsAtCompileTime,
0030 std::decay<T>::type::ColsAtCompileTime == 1 ? Eigen::ColMajor
0031 : Eigen::RowMajor,
0032 std::decay<T>::type::MaxRowsAtCompileTime,
0033 std::decay<T>::type::MaxColsAtCompileTime>,
0034 0,
0035 Eigen::InnerStride<1>>;
0036
0037 using col_major =
0038 Eigen::Map<Eigen::Array<typename std::decay<T>::type::Scalar,
0039 std::decay<T>::type::RowsAtCompileTime,
0040 std::decay<T>::type::ColsAtCompileTime,
0041 std::decay<T>::type::ColsAtCompileTime == 1 ? Eigen::ColMajor
0042 : Eigen::RowMajor,
0043 std::decay<T>::type::MaxRowsAtCompileTime,
0044 std::decay<T>::type::MaxColsAtCompileTime>>;
0045 };
0046
0047
0048 inline static std::vector<size_t> shape(const T& data) {
0049 if (std::decay<T>::type::RowsAtCompileTime == 1) {
0050 return {static_cast<size_t>(data.cols())};
0051 }
0052 if (std::decay<T>::type::ColsAtCompileTime == 1) {
0053 return {static_cast<size_t>(data.rows())};
0054 }
0055 return {static_cast<size_t>(data.rows()), static_cast<size_t>(data.cols())};
0056 }
0057
0058 using EigenIndex = Eigen::DenseIndex;
0059
0060
0061 template <class D>
0062 inline static std::vector<EigenIndex> shape(const File& file,
0063 const std::string& path,
0064 const D& dataset,
0065 int RowsAtCompileTime) {
0066 std::vector<size_t> dims = dataset.getDimensions();
0067
0068 if (dims.size() == 1 && RowsAtCompileTime == 1) {
0069 return std::vector<EigenIndex>{1u, static_cast<EigenIndex>(dims[0])};
0070 }
0071 if (dims.size() == 1) {
0072 return std::vector<EigenIndex>{static_cast<EigenIndex>(dims[0]), 1u};
0073 }
0074 if (dims.size() == 2) {
0075 return std::vector<EigenIndex>{static_cast<EigenIndex>(dims[0]),
0076 static_cast<EigenIndex>(dims[1])};
0077 }
0078
0079 throw detail::error(file, path, "H5Easy::load: Inconsistent rank");
0080 }
0081
0082 inline static DataSet dump(File& file,
0083 const std::string& path,
0084 const T& data,
0085 const DumpOptions& options) {
0086 using row_major_type = typename types<T>::row_major;
0087 using value_type = typename std::decay<T>::type::Scalar;
0088 row_major_type row_major(data);
0089 DataSet dataset = initDataset<value_type>(file, path, shape(data), options);
0090 dataset.write_raw(row_major.data());
0091 if (options.flush()) {
0092 file.flush();
0093 }
0094 return dataset;
0095 }
0096
0097 inline static T load(const File& file, const std::string& path) {
0098 DataSet dataset = file.getDataSet(path);
0099 std::vector<typename T::Index> dims = shape(file, path, dataset, T::RowsAtCompileTime);
0100 T data(dims[0], dims[1]);
0101 dataset.read_raw(data.data());
0102 if (data.IsVectorAtCompileTime || data.IsRowMajor) {
0103 return data;
0104 }
0105 using col_major = typename types<T>::col_major;
0106 return col_major(data.data(), dims[0], dims[1]);
0107 }
0108
0109 inline static Attribute dumpAttribute(File& file,
0110 const std::string& path,
0111 const std::string& key,
0112 const T& data,
0113 const DumpOptions& options) {
0114 using row_major_type = typename types<T>::row_major;
0115 using value_type = typename std::decay<T>::type::Scalar;
0116 row_major_type row_major(data);
0117 Attribute attribute = initAttribute<value_type>(file, path, key, shape(data), options);
0118 attribute.write_raw(row_major.data());
0119 if (options.flush()) {
0120 file.flush();
0121 }
0122 return attribute;
0123 }
0124
0125 inline static T loadAttribute(const File& file,
0126 const std::string& path,
0127 const std::string& key) {
0128 DataSet dataset = file.getDataSet(path);
0129 Attribute attribute = dataset.getAttribute(key);
0130 DataSpace dataspace = attribute.getSpace();
0131 std::vector<typename T::Index> dims = shape(file, path, dataspace, T::RowsAtCompileTime);
0132 T data(dims[0], dims[1]);
0133 attribute.read_raw(data.data());
0134 if (data.IsVectorAtCompileTime || data.IsRowMajor) {
0135 return data;
0136 }
0137 using col_major = typename types<T>::col_major;
0138 return col_major(data.data(), dims[0], dims[1]);
0139 }
0140 };
0141
0142 }
0143 }
0144
0145 #endif