Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:13:46

0001 #ifndef BVH_V2_MINI_TREE_BUILDER_H
0002 #define BVH_V2_MINI_TREE_BUILDER_H
0003 
0004 #include "bvh/v2/sweep_sah_builder.h"
0005 #include "bvh/v2/binned_sah_builder.h"
0006 #include "bvh/v2/thread_pool.h"
0007 #include "bvh/v2/executor.h"
0008 
0009 #include <stack>
0010 #include <tuple>
0011 #include <algorithm>
0012 #include <optional>
0013 #include <numeric>
0014 #include <cassert>
0015 
0016 namespace bvh::v2 {
0017 
0018 /// Multi-threaded top-down builder that partitions primitives using a grid. Multiple instances
0019 /// of a single-threaded builder are run in parallel on that partition, generating many small
0020 /// trees. Finally, a top-level tree is built on these smaller trees to form the final BVH.
0021 /// This builder is inspired by
0022 /// "Rapid Bounding Volume Hierarchy Generation using Mini Trees", by P. Ganestam et al.
0023 template <typename Node, typename MortonCode = uint32_t>
0024 class MiniTreeBuilder {
0025     using Scalar = typename Node::Scalar;
0026     using Vec  = bvh::v2::Vec<Scalar, Node::dimension>;
0027     using BBox = bvh::v2::BBox<Scalar, Node::dimension>;
0028 
0029 public:
0030     struct Config : TopDownSahBuilder<Node>::Config {
0031         /// Flag that turns on/off mini-tree pruning.
0032         bool enable_pruning = true;
0033 
0034         /// Threshold on the area of a mini-tree node above which it is pruned, expressed in
0035         /// fraction of the area of bounding box around the entire set of primitives.
0036         Scalar pruning_area_ratio = static_cast<Scalar>(0.01);
0037 
0038         /// Minimum number of primitives per parallel task.
0039         size_t parallel_threshold = 1024;
0040 
0041         /// Log of the dimension of the grid used to split the workload horizontally.
0042         size_t log2_grid_dim = 4;
0043     };
0044 
0045     /// Starts building a BVH with the given primitive data. The build algorithm is multi-threaded,
0046     /// and runs on the given thread pool.
0047     BVH_ALWAYS_INLINE static Bvh<Node> build(
0048         ThreadPool& thread_pool,
0049         std::span<const BBox> bboxes,
0050         std::span<const Vec> centers,
0051         const Config& config = {})
0052     {
0053         MiniTreeBuilder builder(thread_pool, bboxes, centers, config);
0054         auto mini_trees = builder.build_mini_trees();
0055         if (config.enable_pruning)
0056             mini_trees = builder.prune_mini_trees(std::move(mini_trees));
0057         return builder.build_top_bvh(mini_trees);
0058     }
0059 
0060 private:
0061     friend struct BuildTask;
0062 
0063     struct Bin {
0064         std::vector<size_t> ids;
0065 
0066         BVH_ALWAYS_INLINE void add(size_t id) { ids.push_back(id); }
0067 
0068         BVH_ALWAYS_INLINE void merge(Bin&& other) {
0069             if (ids.empty())
0070                 ids = std::move(other.ids);
0071             else {
0072                 ids.insert(ids.end(), other.ids.begin(), other.ids.end());
0073                 other.ids.clear();
0074             }
0075         }
0076     };
0077 
0078     struct LocalBins {
0079         std::vector<Bin> bins;
0080 
0081         BVH_ALWAYS_INLINE Bin& operator [] (size_t i) { return bins[i]; }
0082         BVH_ALWAYS_INLINE const Bin& operator [] (size_t i) const { return bins[i]; }
0083 
0084         BVH_ALWAYS_INLINE void merge_small_bins(size_t threshold) {
0085             for (size_t i = 0; i < bins.size();) {
0086                 size_t j = i + 1;
0087                 for (; j < bins.size() && bins[j].ids.size() + bins[i].ids.size() <= threshold; ++j)
0088                     bins[i].merge(std::move(bins[j]));
0089                 i = j;
0090             }
0091         }
0092 
0093         BVH_ALWAYS_INLINE void remove_empty_bins() {
0094             bins.resize(std::remove_if(bins.begin(), bins.end(),
0095                 [] (const Bin& bin) { return bin.ids.empty(); }) - bins.begin());
0096         }
0097 
0098         BVH_ALWAYS_INLINE void merge(LocalBins&& other) {
0099             bins.resize(std::max(bins.size(), other.bins.size()));
0100             for (size_t i = 0, n = std::min(bins.size(), other.bins.size()); i < n; ++i)
0101                 bins[i].merge(std::move(other[i]));
0102         }
0103     };
0104 
0105     struct BuildTask {
0106         MiniTreeBuilder* builder;
0107         Bvh<Node>& bvh;
0108         std::vector<size_t> prim_ids;
0109 
0110         std::vector<BBox> bboxes;
0111         std::vector<Vec> centers;
0112 
0113         BuildTask(
0114             MiniTreeBuilder* builder,
0115             Bvh<Node>& bvh,
0116             std::vector<size_t>&& prim_ids)
0117             : builder(builder)
0118             , bvh(bvh)
0119             , prim_ids(std::move(prim_ids))
0120         {}
0121 
0122         BVH_ALWAYS_INLINE void run() {
0123             // Make sure that rebuilds produce the same BVH
0124             std::sort(prim_ids.begin(), prim_ids.end());
0125 
0126             // Extract bounding boxes and centers for this set of primitives
0127             bboxes.resize(prim_ids.size());
0128             centers.resize(prim_ids.size());
0129             for (size_t i = 0; i < prim_ids.size(); ++i) {
0130                 bboxes[i] = builder->bboxes_[prim_ids[i]];
0131                 centers[i] = builder->centers_[prim_ids[i]];
0132             }
0133 
0134             bvh = BinnedSahBuilder<Node>::build(bboxes, centers, builder->config_);
0135 
0136             // Permute primitive indices so that they index the proper set of primitives
0137             for (size_t i = 0; i < bvh.prim_ids.size(); ++i)
0138                 bvh.prim_ids[i] = prim_ids[bvh.prim_ids[i]];
0139         }
0140     };
0141 
0142     ParallelExecutor executor_;
0143     std::span<const BBox> bboxes_;
0144     std::span<const Vec> centers_;
0145     const Config& config_;
0146 
0147     BVH_ALWAYS_INLINE MiniTreeBuilder(
0148         ThreadPool& thread_pool,
0149         std::span<const BBox> bboxes,
0150         std::span<const Vec> centers,
0151         const Config& config)
0152         : executor_(thread_pool)
0153         , bboxes_(bboxes)
0154         , centers_(centers)
0155         , config_(config)
0156     {
0157         assert(bboxes.size() == centers.size());
0158     }
0159 
0160     std::vector<Bvh<Node>> build_mini_trees() {
0161         // Compute the bounding box of all centers
0162         auto center_bbox = executor_.reduce(0, bboxes_.size(), BBox::make_empty(),
0163             [this] (BBox& bbox, size_t begin, size_t end) {
0164                 for (size_t i = begin; i < end; ++i)
0165                     bbox.extend(centers_[i]);
0166             },
0167             [] (BBox& bbox, const BBox& other) { bbox.extend(other); });
0168 
0169         assert(config_.log2_grid_dim <= std::numeric_limits<MortonCode>::digits / Node::dimension);
0170         auto bin_count = size_t{1} << (config_.log2_grid_dim * Node::dimension);
0171         auto grid_dim = size_t{1} << config_.log2_grid_dim;
0172         auto grid_scale = Vec(static_cast<Scalar>(grid_dim)) * safe_inverse(center_bbox.get_diagonal());
0173         auto grid_offset = -center_bbox.min * grid_scale;
0174 
0175         // Place primitives in bins
0176         auto final_bins = executor_.reduce(0, bboxes_.size(), LocalBins {},
0177             [&] (LocalBins& local_bins, size_t begin, size_t end) {
0178                 local_bins.bins.resize(bin_count);
0179                 for (size_t i = begin; i < end; ++i) {
0180                     auto p = robust_max(fast_mul_add(centers_[i], grid_scale, grid_offset), Vec(0));
0181                     auto x = std::min(grid_dim - 1, static_cast<size_t>(p[0]));
0182                     auto y = std::min(grid_dim - 1, static_cast<size_t>(p[1]));
0183                     auto z = std::min(grid_dim - 1, static_cast<size_t>(p[2]));
0184                     local_bins[morton_encode(x, y, z) & (bin_count - 1)].add(i);
0185                 }
0186             },
0187             [&] (LocalBins& result, LocalBins&& other) { result.merge(std::move(other)); });
0188 
0189         // Note: Merging small bins will deteriorate the quality of the top BVH if there is no
0190         // pruning, since it will then produce larger mini-trees. For this reason, it is only enabled
0191         // when mini-tree pruning is enabled.
0192         if (config_.enable_pruning)
0193             final_bins.merge_small_bins(config_.parallel_threshold);
0194         final_bins.remove_empty_bins();
0195 
0196         // Iterate over bins to collect groups of primitives and build BVHs over them in parallel
0197         std::vector<Bvh<Node>> mini_trees(final_bins.bins.size());
0198         for (size_t i = 0; i < final_bins.bins.size(); ++i) {
0199             auto task = new BuildTask(this, mini_trees[i], std::move(final_bins[i].ids));
0200             executor_.thread_pool.push([task] (size_t) { task->run(); delete task; });
0201         }
0202         executor_.thread_pool.wait();
0203 
0204         return mini_trees;
0205     }
0206 
0207     std::vector<Bvh<Node>> prune_mini_trees(std::vector<Bvh<Node>>&& mini_trees) {
0208         // Compute the area threshold based on the area of the entire set of primitives
0209         auto avg_area = static_cast<Scalar>(0.);
0210         for (auto& mini_tree : mini_trees)
0211             avg_area += mini_tree.get_root().get_bbox().get_half_area();
0212         avg_area /= static_cast<Scalar>(mini_trees.size());
0213         auto threshold = avg_area * config_.pruning_area_ratio;
0214 
0215         // Cull nodes whose area is above the threshold
0216         std::stack<size_t> stack;
0217         std::vector<std::pair<size_t, size_t>> pruned_roots;
0218         for (size_t i = 0; i < mini_trees.size(); ++i) {
0219             stack.push(0);
0220             auto& mini_tree = mini_trees[i];
0221             while (!stack.empty()) {
0222                 auto node_id = stack.top();
0223                 auto& node = mini_tree.nodes[node_id];
0224                 stack.pop();
0225                 if (node.get_bbox().get_half_area() < threshold || node.is_leaf()) {
0226                     pruned_roots.emplace_back(i, node_id);
0227                 } else {
0228                     stack.push(node.index.first_id());
0229                     stack.push(node.index.first_id() + 1);
0230                 }
0231             }
0232         }
0233 
0234         // Extract the BVHs rooted at the previously computed indices
0235         std::vector<Bvh<Node>> pruned_trees(pruned_roots.size());
0236         executor_.for_each(0, pruned_roots.size(),
0237             [&] (size_t begin, size_t end) {
0238                 for (size_t i = begin; i < end; ++i) {
0239                     if (pruned_roots[i].second == 0)
0240                         pruned_trees[i] = std::move(mini_trees[pruned_roots[i].first]);
0241                     else
0242                         pruned_trees[i] = mini_trees[pruned_roots[i].first]
0243                             .extract_bvh(pruned_roots[i].second);
0244                 }
0245             });
0246         return pruned_trees;
0247     }
0248 
0249     Bvh<Node> build_top_bvh(std::vector<Bvh<Node>>& mini_trees) {
0250         // Build a BVH using the mini trees as leaves
0251         std::vector<Vec> centers(mini_trees.size());
0252         std::vector<BBox> bboxes(mini_trees.size());
0253         for (size_t i = 0; i < mini_trees.size(); ++i) {
0254             bboxes[i] = mini_trees[i].get_root().get_bbox();
0255             centers[i] = bboxes[i].get_center();
0256         }
0257 
0258         typename SweepSahBuilder<Node>::Config config = config_;
0259         config.max_leaf_size = config.min_leaf_size = 1; // Needs to have only one mini-tree in each leaf
0260         auto bvh = SweepSahBuilder<Node>::build(bboxes, centers, config);
0261 
0262         // Compute the offsets to apply to primitive and node indices
0263         std::vector<size_t> node_offsets(mini_trees.size());
0264         std::vector<size_t> prim_offsets(mini_trees.size());
0265         size_t node_count = bvh.nodes.size();
0266         size_t prim_count = 0;
0267         for (size_t i = 0; i < mini_trees.size(); ++i) {
0268             node_offsets[i] = node_count - 1; // Skip root node
0269             prim_offsets[i] = prim_count;
0270             node_count += mini_trees[i].nodes.size() - 1; // idem
0271             prim_count += mini_trees[i].prim_ids.size();
0272         }
0273 
0274         // Helper function to copy and fix the child/primitive index of a node
0275         auto copy_node = [&] (size_t i, Node& dst_node, const Node& src_node) {
0276             dst_node = src_node;
0277             dst_node.index.set_first_id(dst_node.index.first_id() +
0278                 (src_node.is_leaf() ? prim_offsets[i] : node_offsets[i]));
0279         };
0280 
0281         // Make the leaves of the top BVH point to the right internal nodes
0282         for (auto& node : bvh.nodes) {
0283             if (!node.is_leaf())
0284                 continue;
0285             assert(node.index.prim_count() == 1);
0286             size_t tree_id = bvh.prim_ids[node.index.first_id()];
0287             copy_node(tree_id, node, mini_trees[tree_id].get_root());
0288         }
0289 
0290         bvh.nodes.resize(node_count);
0291         bvh.prim_ids.resize(prim_count);
0292         executor_.for_each(0, mini_trees.size(),
0293             [&] (size_t begin, size_t end) {
0294                 for (size_t i = begin; i < end; ++i) {
0295                     auto& mini_tree = mini_trees[i];
0296 
0297                     // Copy the nodes of the mini tree with the offsets applied, without copying
0298                     // the root node (since it is already copied to the top-level part of the BVH).
0299                     for (size_t j = 1; j < mini_tree.nodes.size(); ++j)
0300                         copy_node(i, bvh.nodes[node_offsets[i] + j], mini_tree.nodes[j]);
0301 
0302                     std::copy(
0303                         mini_tree.prim_ids.begin(),
0304                         mini_tree.prim_ids.end(),
0305                         bvh.prim_ids.begin() + prim_offsets[i]);
0306                 }
0307             });
0308 
0309         return bvh;
0310     }
0311 };
0312 
0313 } // namespace bvh::v2
0314 
0315 #endif