File indexing completed on 2025-09-17 09:13:46
0001 #ifndef BVH_V2_SWEEP_SAH_BUILDER_H
0002 #define BVH_V2_SWEEP_SAH_BUILDER_H
0003
0004 #include "bvh/v2/top_down_sah_builder.h"
0005
0006 #include <stack>
0007 #include <tuple>
0008 #include <algorithm>
0009 #include <optional>
0010 #include <numeric>
0011 #include <cassert>
0012
0013 namespace bvh::v2 {
0014
0015
0016
0017 template <typename Node>
0018 class SweepSahBuilder : public TopDownSahBuilder<Node> {
0019 using typename TopDownSahBuilder<Node>::Scalar;
0020 using typename TopDownSahBuilder<Node>::Vec;
0021 using typename TopDownSahBuilder<Node>::BBox;
0022
0023 using TopDownSahBuilder<Node>::build;
0024 using TopDownSahBuilder<Node>::config_;
0025 using TopDownSahBuilder<Node>::bboxes_;
0026
0027 public:
0028 using typename TopDownSahBuilder<Node>::Config;
0029
0030 BVH_ALWAYS_INLINE static Bvh<Node> build(
0031 std::span<const BBox> bboxes,
0032 std::span<const Vec> centers,
0033 const Config& config = {})
0034 {
0035 return SweepSahBuilder(bboxes, centers, config).build();
0036 }
0037
0038 protected:
0039 struct Split {
0040 size_t pos;
0041 Scalar cost;
0042 size_t axis;
0043 };
0044
0045 std::vector<bool> marks_;
0046 std::vector<Scalar> accum_;
0047 std::vector<size_t> prim_ids_[Node::dimension];
0048
0049 BVH_ALWAYS_INLINE SweepSahBuilder(
0050 std::span<const BBox> bboxes,
0051 std::span<const Vec> centers,
0052 const Config& config)
0053 : TopDownSahBuilder<Node>(bboxes, centers, config)
0054 {
0055 marks_.resize(bboxes.size());
0056 accum_.resize(bboxes.size());
0057 for (size_t axis = 0; axis < Node::dimension; ++axis) {
0058 prim_ids_[axis].resize(bboxes.size());
0059 std::iota(prim_ids_[axis].begin(), prim_ids_[axis].end(), 0);
0060 std::sort(prim_ids_[axis].begin(), prim_ids_[axis].end(), [&] (size_t i, size_t j) {
0061 return centers[i][axis] < centers[j][axis];
0062 });
0063 }
0064 }
0065
0066 std::vector<size_t>& get_prim_ids() override { return prim_ids_[0]; }
0067
0068 void find_best_split(size_t axis, size_t begin, size_t end, Split& best_split) {
0069 size_t first_right = begin;
0070
0071
0072 auto right_bbox = BBox::make_empty();
0073 for (size_t i = end - 1; i > begin;) {
0074 static constexpr size_t chunk_size = 32;
0075 size_t next = i - std::min(i - begin, chunk_size);
0076 auto right_cost = static_cast<Scalar>(0.);
0077 for (; i > next; --i) {
0078 right_bbox.extend(bboxes_[prim_ids_[axis][i]]);
0079 accum_[i] = right_cost = config_.sah.get_leaf_cost(i, end, right_bbox);
0080 }
0081
0082 if (right_cost > best_split.cost) {
0083 first_right = i;
0084 break;
0085 }
0086 }
0087
0088
0089 auto left_bbox = BBox::make_empty();
0090 for (size_t i = begin; i < first_right; ++i)
0091 left_bbox.extend(bboxes_[prim_ids_[axis][i]]);
0092 for (size_t i = first_right; i < end - 1; ++i) {
0093 left_bbox.extend(bboxes_[prim_ids_[axis][i]]);
0094 auto left_cost = config_.sah.get_leaf_cost(begin, i + 1, left_bbox);
0095 auto cost = left_cost + accum_[i + 1];
0096 if (cost < best_split.cost)
0097 best_split = Split { i + 1, cost, axis };
0098 else if (left_cost > best_split.cost)
0099 break;
0100 }
0101 }
0102
0103 BVH_ALWAYS_INLINE void mark_primitives(size_t axis, size_t begin, size_t split_pos, size_t end) {
0104 for (size_t i = begin; i < split_pos; ++i) marks_[prim_ids_[axis][i]] = true;
0105 for (size_t i = split_pos; i < end; ++i) marks_[prim_ids_[axis][i]] = false;
0106 }
0107
0108 std::optional<size_t> try_split(const BBox& bbox, size_t begin, size_t end) override {
0109
0110 auto leaf_cost = config_.sah.get_non_split_cost(begin, end, bbox);
0111 auto best_split = Split { (begin + end + 1) / 2, leaf_cost, 0 };
0112 for (size_t axis = 0; axis < Node::dimension; ++axis)
0113 find_best_split(axis, begin, end, best_split);
0114
0115
0116 if (best_split.cost >= leaf_cost) {
0117 if (end - begin <= config_.max_leaf_size)
0118 return std::nullopt;
0119
0120
0121
0122 best_split.pos = (begin + end + 1) / 2;
0123 best_split.axis = bbox.get_diagonal().get_largest_axis();
0124 }
0125
0126
0127
0128 mark_primitives(best_split.axis, begin, best_split.pos, end);
0129 for (size_t axis = 0; axis < Node::dimension; ++axis) {
0130 if (axis == best_split.axis)
0131 continue;
0132 std::stable_partition(
0133 prim_ids_[axis].begin() + begin,
0134 prim_ids_[axis].begin() + end,
0135 [&] (size_t i) { return marks_[i]; });
0136 }
0137
0138 return std::make_optional(best_split.pos);
0139 }
0140 };
0141
0142 }
0143
0144 #endif