Segment Treeを使用して部分和を計算する。
Segment Tree とは
Segment Tree
はバイナリツリー(二分木)を使用して区間やあるリストの範囲を貯蔵するに使われるデータ構造である。Segment Tree
は実装するリストの長さに対していつも平衡二分探索木である。なので探索に最悪O(lgN)
がかかる。Segment Tree
を使用すれば、クエリーを用いてある特定の区間の部分和をO(lgN + K)
で求めることができる。K
とは求められる区間の部分区間の個数を指す。
Segment Tree | 数式 |
---|---|
要素の数 | |
ツリーの深さ | |
ツリーのメモリ | |
ツリーの空間複雑度 | |
生成の時間複雑度 | |
部分和の時間複雑度 |
になる。
Segment Tree
はリストが長く、そしてリストの要素の数が固定されて頻繁にリストの要素が値が切り替わるときに有用になる。
コード
Segment Tree 実装
- Segment Treeの実装に注意する点は、インプットとして入れたリストの要素の位置はツリーの葉部分であることだ。
- そして葉が生成される位置は、ツリーの左から寄せて生成するのではなく再帰に沿って仮想の完全二分木の葉の部分かそれとも葉の親部分を葉として扱ってそれらからの一つに葉が生成される。なので
Segment Tree
を生成する時には、空白の葉ノードが必要となる。
using TSegmentTree = std::vector<int>; using TValueList = std::vector<int>; TSegmentTree::value_type MakeSegmentTree( const TValueList& valueList, TSegmentTree& segmentTree, const unsigned nodeIndex, const unsigned valueStart, const unsigned valueEnd) { if (valueStart == valueEnd) { segmentTree[nodeIndex] = valueList[valueStart]; } else { const auto sep = (valueStart + valueEnd) / 2; segmentTree[nodeIndex] = MakeSegmentTree(valueList, segmentTree, 2 * nodeIndex + 1, valueStart, sep) + MakeSegmentTree(valueList, segmentTree, 2 * nodeIndex + 2, sep + 1, valueEnd); } return segmentTree[nodeIndex]; }
部分和の関数実装
- 部分和を実装する関数を作る時には、区間の範囲だけでなくツリーの範囲まで入れないとならない。何故なら再帰によって既に計算された部分区間の和が入った親ノードを選ぶかないかを決定しなければならない為だ。
- もし再帰で今求められる区間が部分ツリーが持っている葉の要素のインデックスの範囲の相応していたら、今指している親ノードの部分和の値を返す。
- それとも区間が外されている場合にはまた掘り下げて葉まで下る。
- 区間が範囲の完全に外していたらその部分ツリーでは値を求む必要がないのでを返す。
[[nodiscard]] TSegmentTree::value_type GetPartialSum( const TSegmentTree& segmentTree, const unsigned targetStart, const unsigned targetEnd, const unsigned treeStart, const unsigned treeEnd, const unsigned id = 0) { // Special case if (targetEnd < treeStart || treeEnd < targetStart) { return 0; } if (targetStart <= treeStart && treeEnd <= targetEnd) { return segmentTree[id]; } // General case const auto sep = (treeStart + treeEnd) / 2; return GetPartialSum(segmentTree, targetStart, targetEnd, treeStart, sep, 2 * id + 1) + GetPartialSum(segmentTree, targetStart, targetEnd, sep + 1, treeEnd, 2 * id + 2); }
値切り替えの関数実装
- 値を切り替える場合にも再帰を積極的に使う。
- ただし、切り替えをで終わらせてまた
Constant Factor
を減らすためには今切り替えをする値に対し、既存の値との差を求めてその値が入る予定の葉までの親ノードに入っている部分和の値を新たに演算しなければいけない。
void UpdateValue( TSegmentTree& segmentTree, const unsigned valueId, const int difference, const unsigned treeStart, const unsigned treeEnd, const unsigned treeId = 0) { // Special case if (valueId < treeStart || valueId > treeEnd) { return; } // General case segmentTree[treeId] += difference; if (treeStart == treeEnd) { return; } const auto sep = (treeStart + treeEnd) / 2; UpdateValue(segmentTree, valueId, difference, treeStart, sep, 2 * treeId + 1); UpdateValue(segmentTree, valueId, difference, sep + 1, treeEnd, 2 * treeId + 2); }
結果
// 最初のインプットリスト 5 6 8 3 1 3 1 4 2 4 // Segment Treeの各ノードの値 37 23 14 19 4 8 6 11 8 3 1 4 4 2 4 5 6 0 0 0 0 0 0 3 1 0 0 0 0 0 0 // 部分和の結果リスト Start : 0 ~ End : 0 = Sum : 5 Start : 0 ~ End : 1 = Sum : 11 ... Start : 0 ~ End : 7 = Sum : 31 Start : 0 ~ End : 8 = Sum : 33 Start : 0 ~ End : 9 = Sum : 37 Start : 1 ~ End : 9 = Sum : 32 Start : 2 ~ End : 9 = Sum : 26 ... Start : 8 ~ End : 9 = Sum : 6 Start : 9 ~ End : 9 = Sum : 4 // 新たな値リスト 17 -34 -59 71 86 25 83 33 53 -89 // 更新したSegment Treeの各ノードの値 186 81 105 -76 157 141 -36 -17 -59 71 86 108 33 53 -89 17 -34 0 0 0 0 0 0 25 83 0 0 0 0 0 0
コード全文
#include <cstdio> #include <vector> #include <algorithm> #include <utility> #include "../../Common/CheckTime.hpp" using TSegmentTree = std::vector<int>; using TValueList = std::vector<int>; TSegmentTree::value_type MakeSegmentTree( const TValueList& valueList, TSegmentTree& segmentTree, const unsigned nodeIndex, const unsigned valueStart, const unsigned valueEnd) { if (valueStart == valueEnd) { segmentTree[nodeIndex] = valueList[valueStart]; return segmentTree[nodeIndex]; } else { const auto sep = (valueStart + valueEnd) / 2; segmentTree[nodeIndex] = MakeSegmentTree(valueList, segmentTree, 2 * nodeIndex + 1, valueStart, sep) + MakeSegmentTree(valueList, segmentTree, 2 * nodeIndex + 2, sep + 1, valueEnd); return segmentTree[nodeIndex]; } } [[nodiscard]] TSegmentTree::value_type GetPartialSum( const TSegmentTree& segmentTree, const unsigned targetStart, const unsigned targetEnd, const unsigned treeStart, const unsigned treeEnd, const unsigned id = 0) { // Special case if (targetEnd < treeStart || treeEnd < targetStart) { return 0; } if (targetStart <= treeStart && treeEnd <= targetEnd) { return segmentTree[id]; } // General case const auto sep = (treeStart + treeEnd) / 2; return GetPartialSum(segmentTree, targetStart, targetEnd, treeStart, sep, 2 * id + 1) + GetPartialSum(segmentTree, targetStart, targetEnd, sep + 1, treeEnd, 2 * id + 2); } void UpdateValue( TSegmentTree& segmentTree, const unsigned valueId, const int difference, const unsigned treeStart, const unsigned treeEnd, const unsigned treeId = 0) { // Special case if (valueId < treeStart || valueId > treeEnd) { return; } // General case segmentTree[treeId] += difference; if (treeStart == treeEnd) { return; } const auto sep = (treeStart + treeEnd) / 2; UpdateValue(segmentTree, valueId, difference, treeStart, sep, 2 * treeId + 1); UpdateValue(segmentTree, valueId, difference, sep + 1, treeEnd, 2 * treeId + 2); } int main() { // Make valueList; std::vector<int> valueList = {}; for (int i = 0; i < 10; ++i) { valueList.emplace_back(neu::test::GetRandomValue<int>(1, 10)); } // Print valueList; for (const auto& value : valueList) { std::printf("%d ", value); } std::puts(""); // Make segment Tree; const auto height = static_cast<unsigned>(std::ceil(std::log2(valueList.size()))) + 1; std::vector<int> segmentTree((1 << height) - 1, 0); MakeSegmentTree(valueList, segmentTree, 0, 0, valueList.size() - 1); // Pritn segment Tree. for (const auto& value : segmentTree) { std::printf("%d ", value); } std::puts(""); // Get Sum from segmentTree rapidly O(lgN) unsigned start = 0; unsigned end = 0; while (end <= valueList.size() - 1) { std::printf("Start : %u ~ End : %u = Sum : %d\n", start, end, GetPartialSum(segmentTree, start, end, 0, valueList.size() - 1) ); ++end; } end--; while (start <= end) { std::printf("Start : %u ~ End : %u = Sum : %d\n", start, end, GetPartialSum(segmentTree, start, end, 0, valueList.size() - 1) ); ++start; } // Update value and get overall sum. std::vector<int> diff(valueList.size()); for (auto i = 0u; i < valueList.size(); ++i) { const auto newValue = (neu::test::GetRandomValue<int>(-100, 100)); diff[i] = newValue - valueList[i]; valueList[i] = newValue; } // Print valueList; for (const auto& value : valueList) { std::printf("%d ", value); } std::puts(""); for (auto i = 0u; i < valueList.size(); ++i) { UpdateValue(segmentTree, i, diff[i], 0, valueList.size() - 1); } // Print segment Tree. for (const auto& value : segmentTree) { std::printf("%d ", value); } std::puts(""); };
参照リンク
Segment tree - Wikipedia 세그먼트 트리(Segment Tree) 요약 정리 (C++) – Jang