PointKDTree.cs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. using System.Collections.Generic;
  2. namespace Pathfinding {
  3. using Pathfinding.Util;
  4. /// <summary>
  5. /// Represents a collection of GraphNodes.
  6. /// It allows for fast lookups of the closest node to a point.
  7. ///
  8. /// See: https://en.wikipedia.org/wiki/K-d_tree
  9. /// </summary>
  10. public class PointKDTree {
  11. // TODO: Make constant
  12. public const int LeafSize = 10;
  13. public const int LeafArraySize = LeafSize*2 + 1;
  14. Node[] tree = new Node[16];
  15. int numNodes = 0;
  16. readonly List<GraphNode> largeList = new List<GraphNode>();
  17. readonly Stack<GraphNode[]> arrayCache = new Stack<GraphNode[]>();
  18. static readonly IComparer<GraphNode>[] comparers = new IComparer<GraphNode>[] { new CompareX(), new CompareY(), new CompareZ() };
  19. struct Node {
  20. /// <summary>Nodes in this leaf node (null if not a leaf node)</summary>
  21. public GraphNode[] data;
  22. /// <summary>Split point along the <see cref="splitAxis"/> if not a leaf node</summary>
  23. public int split;
  24. /// <summary>Number of non-null entries in <see cref="data"/></summary>
  25. public ushort count;
  26. /// <summary>Axis to split along if not a leaf node (x=0, y=1, z=2)</summary>
  27. public byte splitAxis;
  28. }
  29. // Pretty ugly with one class for each axis, but it has been verified to make the tree around 5% faster
  30. class CompareX : IComparer<GraphNode> {
  31. public int Compare (GraphNode lhs, GraphNode rhs) { return lhs.position.x.CompareTo(rhs.position.x); }
  32. }
  33. class CompareY : IComparer<GraphNode> {
  34. public int Compare (GraphNode lhs, GraphNode rhs) { return lhs.position.y.CompareTo(rhs.position.y); }
  35. }
  36. class CompareZ : IComparer<GraphNode> {
  37. public int Compare (GraphNode lhs, GraphNode rhs) { return lhs.position.z.CompareTo(rhs.position.z); }
  38. }
  39. public PointKDTree() {
  40. tree[1] = new Node { data = GetOrCreateList() };
  41. }
  42. /// <summary>Add the node to the tree</summary>
  43. public void Add (GraphNode node) {
  44. numNodes++;
  45. Add(node, 1);
  46. }
  47. /// <summary>Rebuild the tree starting with all nodes in the array between index start (inclusive) and end (exclusive)</summary>
  48. public void Rebuild (GraphNode[] nodes, int start, int end) {
  49. if (start < 0 || end < start || end > nodes.Length)
  50. throw new System.ArgumentException();
  51. for (int i = 0; i < tree.Length; i++) {
  52. var data = tree[i].data;
  53. if (data != null) {
  54. for (int j = 0; j < LeafArraySize; j++) data[j] = null;
  55. arrayCache.Push(data);
  56. tree[i].data = null;
  57. }
  58. }
  59. numNodes = end - start;
  60. Build(1, new List<GraphNode>(nodes), start, end);
  61. }
  62. GraphNode[] GetOrCreateList () {
  63. // Note, the lists will never become larger than this initial capacity, so possibly they should be replaced by arrays
  64. return arrayCache.Count > 0 ? arrayCache.Pop() : new GraphNode[LeafArraySize];
  65. }
  66. int Size (int index) {
  67. return tree[index].data != null ? tree[index].count : Size(2 * index) + Size(2 * index + 1);
  68. }
  69. void CollectAndClear (int index, List<GraphNode> buffer) {
  70. var nodes = tree[index].data;
  71. var count = tree[index].count;
  72. if (nodes != null) {
  73. tree[index] = new Node();
  74. for (int i = 0; i < count; i++) {
  75. buffer.Add(nodes[i]);
  76. nodes[i] = null;
  77. }
  78. arrayCache.Push(nodes);
  79. } else {
  80. CollectAndClear(index*2, buffer);
  81. CollectAndClear(index*2 + 1, buffer);
  82. }
  83. }
  84. static int MaxAllowedSize (int numNodes, int depth) {
  85. // Allow a node to be 2.5 times as full as it should ideally be
  86. // but do not allow it to contain more than 3/4ths of the total number of nodes
  87. // (important to make sure nodes near the top of the tree also get rebalanced).
  88. // A node should ideally contain numNodes/(2^depth) nodes below it (^ is exponentiation, not xor)
  89. return System.Math.Min(((5 * numNodes) / 2) >> depth, (3 * numNodes) / 4);
  90. }
  91. void Rebalance (int index) {
  92. CollectAndClear(index, largeList);
  93. Build(index, largeList, 0, largeList.Count);
  94. largeList.ClearFast();
  95. }
  96. void EnsureSize (int index) {
  97. if (index >= tree.Length) {
  98. var newLeaves = new Node[System.Math.Max(index + 1, tree.Length*2)];
  99. tree.CopyTo(newLeaves, 0);
  100. tree = newLeaves;
  101. }
  102. }
  103. void Build (int index, List<GraphNode> nodes, int start, int end) {
  104. EnsureSize(index);
  105. if (end - start <= LeafSize) {
  106. var leafData = tree[index].data = GetOrCreateList();
  107. tree[index].count = (ushort)(end - start);
  108. for (int i = start; i < end; i++)
  109. leafData[i - start] = nodes[i];
  110. } else {
  111. Int3 mn, mx;
  112. mn = mx = nodes[start].position;
  113. for (int i = start; i < end; i++) {
  114. var p = nodes[i].position;
  115. mn = new Int3(System.Math.Min(mn.x, p.x), System.Math.Min(mn.y, p.y), System.Math.Min(mn.z, p.z));
  116. mx = new Int3(System.Math.Max(mx.x, p.x), System.Math.Max(mx.y, p.y), System.Math.Max(mx.z, p.z));
  117. }
  118. Int3 diff = mx - mn;
  119. var axis = diff.x > diff.y ? (diff.x > diff.z ? 0 : 2) : (diff.y > diff.z ? 1 : 2);
  120. nodes.Sort(start, end - start, comparers[axis]);
  121. int mid = (start+end)/2;
  122. tree[index].split = (nodes[mid-1].position[axis] + nodes[mid].position[axis] + 1)/2;
  123. tree[index].splitAxis = (byte)axis;
  124. Build(index*2 + 0, nodes, start, mid);
  125. Build(index*2 + 1, nodes, mid, end);
  126. }
  127. }
  128. void Add (GraphNode point, int index, int depth = 0) {
  129. // Move down in the tree until the leaf node is found that this point is inside of
  130. while (tree[index].data == null) {
  131. index = 2 * index + (point.position[tree[index].splitAxis] < tree[index].split ? 0 : 1);
  132. depth++;
  133. }
  134. // Add the point to the leaf node
  135. tree[index].data[tree[index].count++] = point;
  136. // Check if the leaf node is large enough that we need to do some rebalancing
  137. if (tree[index].count >= LeafArraySize) {
  138. int levelsUp = 0;
  139. // Search upwards for nodes that are too large and should be rebalanced
  140. // Rebalance the node above the node that had a too large size so that it can
  141. // move children over to the sibling
  142. while (depth - levelsUp > 0 && Size(index >> levelsUp) > MaxAllowedSize(numNodes, depth-levelsUp)) {
  143. levelsUp++;
  144. }
  145. Rebalance(index >> levelsUp);
  146. }
  147. }
  148. /// <summary>Closest node to the point which satisfies the constraint</summary>
  149. public GraphNode GetNearest (Int3 point, NNConstraint constraint) {
  150. GraphNode best = null;
  151. long bestSqrDist = long.MaxValue;
  152. GetNearestInternal(1, point, constraint, ref best, ref bestSqrDist);
  153. return best;
  154. }
  155. void GetNearestInternal (int index, Int3 point, NNConstraint constraint, ref GraphNode best, ref long bestSqrDist) {
  156. var data = tree[index].data;
  157. if (data != null) {
  158. for (int i = tree[index].count - 1; i >= 0; i--) {
  159. var dist = (data[i].position - point).sqrMagnitudeLong;
  160. if (dist < bestSqrDist && (constraint == null || constraint.Suitable(data[i]))) {
  161. bestSqrDist = dist;
  162. best = data[i];
  163. }
  164. }
  165. } else {
  166. var dist = (long)(point[tree[index].splitAxis] - tree[index].split);
  167. var childIndex = 2 * index + (dist < 0 ? 0 : 1);
  168. GetNearestInternal(childIndex, point, constraint, ref best, ref bestSqrDist);
  169. // Try the other one if it is possible to find a valid node on the other side
  170. if (dist*dist < bestSqrDist) {
  171. // childIndex ^ 1 will flip the last bit, so if childIndex is odd, then childIndex ^ 1 will be even
  172. GetNearestInternal(childIndex ^ 0x1, point, constraint, ref best, ref bestSqrDist);
  173. }
  174. }
  175. }
  176. /// <summary>Closest node to the point which satisfies the constraint</summary>
  177. public GraphNode GetNearestConnection (Int3 point, NNConstraint constraint, long maximumSqrConnectionLength) {
  178. GraphNode best = null;
  179. long bestSqrDist = long.MaxValue;
  180. // Given a found point at a distance of r world units
  181. // then any node that has a connection on which a closer point lies must have a squared distance lower than
  182. // d^2 < (maximumConnectionLength/2)^2 + r^2
  183. // Note: (x/2)^2 = (x^2)/4
  184. // Note: (x+3)/4 to round up
  185. long offset = (maximumSqrConnectionLength+3)/4;
  186. GetNearestConnectionInternal(1, point, constraint, ref best, ref bestSqrDist, offset);
  187. return best;
  188. }
  189. void GetNearestConnectionInternal (int index, Int3 point, NNConstraint constraint, ref GraphNode best, ref long bestSqrDist, long distanceThresholdOffset) {
  190. var data = tree[index].data;
  191. if (data != null) {
  192. var pointv3 = (UnityEngine.Vector3)point;
  193. for (int i = tree[index].count - 1; i >= 0; i--) {
  194. var dist = (data[i].position - point).sqrMagnitudeLong;
  195. // Note: the subtraction is important. If we used an addition on the RHS instead the result might overflow as bestSqrDist starts as long.MaxValue
  196. if (dist - distanceThresholdOffset < bestSqrDist && (constraint == null || constraint.Suitable(data[i]))) {
  197. // This node may contains the closest connection
  198. // Check all connections
  199. var conns = (data[i] as PointNode).connections;
  200. if (conns != null) {
  201. var nodePos = (UnityEngine.Vector3)data[i].position;
  202. for (int j = 0; j < conns.Length; j++) {
  203. // Find the closest point on the connection, but only on this node's side of the connection
  204. // This ensures that we will find the closest node with the closest connection.
  205. var connectionMidpoint = ((UnityEngine.Vector3)conns[j].node.position + nodePos) * 0.5f;
  206. float sqrConnectionDistance = VectorMath.SqrDistancePointSegment(nodePos, connectionMidpoint, pointv3);
  207. // Convert to Int3 space
  208. long sqrConnectionDistanceInt = (long)(sqrConnectionDistance*Int3.FloatPrecision*Int3.FloatPrecision);
  209. if (sqrConnectionDistanceInt < bestSqrDist) {
  210. bestSqrDist = sqrConnectionDistanceInt;
  211. best = data[i];
  212. }
  213. }
  214. }
  215. // Also check if the node itself is close enough.
  216. // This is important if the node has no connections at all.
  217. if (dist < bestSqrDist) {
  218. bestSqrDist = dist;
  219. best = data[i];
  220. }
  221. }
  222. }
  223. } else {
  224. var dist = (long)(point[tree[index].splitAxis] - tree[index].split);
  225. var childIndex = 2 * index + (dist < 0 ? 0 : 1);
  226. GetNearestConnectionInternal(childIndex, point, constraint, ref best, ref bestSqrDist, distanceThresholdOffset);
  227. // Try the other one if it is possible to find a valid node on the other side
  228. // Note: the subtraction is important. If we used an addition on the RHS instead the result might overflow as bestSqrDist starts as long.MaxValue
  229. if (dist*dist - distanceThresholdOffset < bestSqrDist) {
  230. // childIndex ^ 1 will flip the last bit, so if childIndex is odd, then childIndex ^ 1 will be even
  231. GetNearestConnectionInternal(childIndex ^ 0x1, point, constraint, ref best, ref bestSqrDist, distanceThresholdOffset);
  232. }
  233. }
  234. }
  235. /// <summary>Add all nodes within a squared distance of the point to the buffer.</summary>
  236. /// <param name="point">Nodes around this point will be added to the buffer.</param>
  237. /// <param name="sqrRadius">squared maximum distance in Int3 space. If you are converting from world space you will need to multiply by Int3.Precision:
  238. /// <code> var sqrRadius = (worldSpaceRadius * Int3.Precision) * (worldSpaceRadius * Int3.Precision); </code></param>
  239. /// <param name="buffer">All nodes will be added to this list.</param>
  240. public void GetInRange (Int3 point, long sqrRadius, List<GraphNode> buffer) {
  241. GetInRangeInternal(1, point, sqrRadius, buffer);
  242. }
  243. void GetInRangeInternal (int index, Int3 point, long sqrRadius, List<GraphNode> buffer) {
  244. var data = tree[index].data;
  245. if (data != null) {
  246. for (int i = tree[index].count - 1; i >= 0; i--) {
  247. var dist = (data[i].position - point).sqrMagnitudeLong;
  248. if (dist < sqrRadius) {
  249. buffer.Add(data[i]);
  250. }
  251. }
  252. } else {
  253. var dist = (long)(point[tree[index].splitAxis] - tree[index].split);
  254. // Pick the first child to enter based on which side of the splitting line the point is
  255. var childIndex = 2 * index + (dist < 0 ? 0 : 1);
  256. GetInRangeInternal(childIndex, point, sqrRadius, buffer);
  257. // Try the other one if it is possible to find a valid node on the other side
  258. if (dist*dist < sqrRadius) {
  259. // childIndex ^ 1 will flip the last bit, so if childIndex is odd, then childIndex ^ 1 will be even
  260. GetInRangeInternal(childIndex ^ 0x1, point, sqrRadius, buffer);
  261. }
  262. }
  263. }
  264. }
  265. }