-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKdTree.java
More file actions
273 lines (256 loc) · 10.2 KB
/
KdTree.java
File metadata and controls
273 lines (256 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
import java.util.LinkedList;
public class KdTree { // set of points in unit square, implemented using 2d-tree; a generalization of a BST to two-dimensional keys
private static class Node {
private Point2D pt;
private RectHV rect;
private Node lessNode;
private Node greaterNode;
}
private int treeSize = 0;
private final Node root;
public KdTree() { // construct an empty set of points
root = new Node();
root.pt = null;
root.rect = new RectHV(0.0, 0.0, 1.0, 1.0);
}
public boolean isEmpty() { // is the set empty?
if (treeSize == 0) {
return true;
}
return false;
}
public int size() { // number of points in the set
return treeSize;
}
public void insert(Point2D p) { // add the point to the set (if it is not already in the set)
if (p == null) {
throw new IllegalArgumentException("Argument can not be null");
}
// Step 1
// insert logic; no need to set up ReactHV for each node
// best implemented by using private helper methods similar to BST.java
if (contains(p)) {
return;
}
if (root.pt == null) {
root.pt = p;
}
else {
Node newNode = new Node();
newNode.pt = p;
newNode.rect = null;
insertBST(root, newNode, true);
}
treeSize += 1;
}
private Node insertBST(Node curr, Node add, boolean isVertical) {
if (curr == null) {
return add;
}
if (isVertical) {
if (add.pt.x() < curr.pt.x()) {
curr.lessNode = insertBST(curr.lessNode, add, false);
if (curr.lessNode.rect == null) {
curr.lessNode.rect = new RectHV(curr.rect.xmin(), curr.rect.ymin(), curr.pt.x(), curr.rect.ymax());
}
}
else {
curr.greaterNode = insertBST(curr.greaterNode, add, false);
if (curr.greaterNode.rect == null) {
curr.greaterNode.rect = new RectHV(curr.pt.x(), curr.rect.ymin(), curr.rect.xmax(), curr.rect.ymax());
}
}
}
else {
if (add.pt.y() < curr.pt.y()) {
curr.lessNode = insertBST(curr.lessNode, add, true);
if (curr.lessNode.rect == null) {
curr.lessNode.rect = new RectHV(curr.rect.xmin(), curr.rect.ymin(), curr.rect.xmax(), curr.pt.y());
}
}
else {
curr.greaterNode = insertBST(curr.greaterNode, add, true);
if (curr.greaterNode.rect == null) {
curr.greaterNode.rect = new RectHV(curr.rect.xmin(), curr.pt.y(), curr.rect.xmax(), curr.rect.ymax());
}
}
}
return curr;
}
public boolean contains(Point2D p) { // does the set contain point p?
if (p == null) {
throw new IllegalArgumentException("Argument can not be null");
}
// Step 2
// use this to test that insert() was implemented properly
// best implemented by using private helper methods similar to BST.java
Node checkNode = new Node();
checkNode.pt = p;
if (root.pt == null) {
return false;
}
return containsBST(root, checkNode, true);
}
private boolean containsBST(Node curr, Node add, boolean isVertical) {
if (curr == null) {
return false;
}
if (curr.pt.x() == add.pt.x() && curr.pt.y() == add.pt.y()) {
return true;
}
boolean inLeft = false;
boolean inRight = false;
if (isVertical) {
if (add.pt.x() < curr.pt.x()) {
inLeft = containsBST(curr.lessNode, add, false);
}
else {
inRight = containsBST(curr.greaterNode, add, false);
}
}
else {
if (add.pt.y() < curr.pt.y()) {
inLeft = containsBST(curr.lessNode, add, true);
}
else {
inRight = containsBST(curr.greaterNode, add, true);
}
}
return inLeft || inRight;
}
public void draw() { // draw all points to standard draw
draw(root, true);
}
private void draw(Node currNode, boolean isVertical) {
// DF Traversal (inorder)
if (currNode == null) {
return;
}
draw(currNode.lessNode, !isVertical);
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
currNode.pt.draw();
if (isVertical) {
StdDraw.setPenColor(StdDraw.RED);
StdDraw.setPenRadius();
Point2D start = new Point2D(currNode.pt.x(), currNode.rect.ymin());
Point2D end = new Point2D(currNode.pt.x(), currNode.rect.ymax());
start.drawTo(end);
}
else {
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.setPenRadius();
Point2D start = new Point2D(currNode.rect.xmin(), currNode.pt.y());
Point2D end = new Point2D(currNode.rect.xmax(), currNode.pt.y());
start.drawTo(end);
}
draw(currNode.greaterNode, !isVertical);
}
public Iterable<Point2D> range(RectHV rect) { // all points that are inside the rectangle (or on the boundary)
if (rect == null) {
throw new IllegalArgumentException("Argument can not be null");
}
return findRange(root, rect, new LinkedList<>());
}
private LinkedList<Point2D> findRange(Node currNode, RectHV searchRect, LinkedList<Point2D> list) {
if (currNode == null || currNode.pt == null) {
return list;
}
if (searchRect.intersects(currNode.rect)) {
if (searchRect.contains(currNode.pt)) {
list.add(currNode.pt);
}
list = findRange(currNode.lessNode, searchRect, list);
list = findRange(currNode.greaterNode, searchRect, list);
}
return list;
}
public Point2D nearest(Point2D p) { // a nearest neighbor in the set to point p; null if the set is empty
if (p == null) {
throw new IllegalArgumentException("Argument can not be null");
}
return findNearest(root, p, new Point2D(Double.MAX_VALUE, Double.MAX_VALUE), Double.MAX_VALUE, true);
}
private Point2D findNearest(Node currNode, Point2D searchPoint, Point2D closestPoint, double closestSeen, boolean isVertical) {
if (currNode == null) {
return closestPoint;
}
if (currNode.pt == null) {
return null;
}
double currDistance = currNode.pt.distanceSquaredTo((searchPoint));
double searchCoordinate = 0.0;
double currentCoordinate = 0.0;
if (currDistance < closestSeen) {
closestSeen = currDistance;
closestPoint = currNode.pt;
}
if (isVertical) {
searchCoordinate = searchPoint.x();
currentCoordinate = currNode.pt.x();
}
else {
searchCoordinate = searchPoint.y();
currentCoordinate = currNode.pt.y();
}
if (searchCoordinate <= currentCoordinate) {
Point2D updatedCheck = findNearest(currNode.lessNode, searchPoint, closestPoint, closestSeen, !isVertical);
double leftCheck = updatedCheck.distanceSquaredTo(searchPoint);
if (leftCheck < closestSeen) {
closestPoint = updatedCheck;
closestSeen = leftCheck;
}
if(currNode.greaterNode != null && currNode.greaterNode.rect.distanceSquaredTo(searchPoint) < closestSeen){
updatedCheck = findNearest(currNode.greaterNode, searchPoint, closestPoint, closestSeen, !isVertical);
double rightCheck = updatedCheck.distanceSquaredTo(searchPoint);
if (rightCheck < closestSeen) {
closestPoint = updatedCheck;
}
}
}
else {
Point2D updatedCheck = findNearest(currNode.greaterNode, searchPoint, closestPoint, closestSeen, !isVertical);
double rightCheck = updatedCheck.distanceSquaredTo(searchPoint);
if (rightCheck <= closestSeen) {
closestPoint = updatedCheck;
closestSeen = rightCheck;
}
if(currNode.lessNode != null && currNode.lessNode.rect.distanceSquaredTo(searchPoint) < closestSeen){
updatedCheck = findNearest(currNode.lessNode, searchPoint, closestPoint, closestSeen, !isVertical);
double leftCheck = updatedCheck.distanceSquaredTo(searchPoint);
if (leftCheck < closestSeen) {
closestPoint = updatedCheck;
}
}
}
return closestPoint;
}
public static void main(String[] args) { // unit testing of the methods (optional)
/*
* KdTreeGenerator
* KdTreeVisualizer
* NearestNeighborVisualizer
*/
KdTree testPointTree = new KdTree();
testPointTree.insert(new Point2D(1.0, 0.75));
testPointTree.insert(new Point2D(0.75, 0.0));
testPointTree.insert(new Point2D(0.375, 0.12));
testPointTree.insert(new Point2D(0.375, 0.0));
testPointTree.insert(new Point2D(0.75, 0.125));
testPointTree.insert(new Point2D(1.0, 0.875));
testPointTree.insert(new Point2D(0.0, 0.125));
testPointTree.insert(new Point2D(0.25, 0.625));
testPointTree.insert(new Point2D(0.25, 1.0));
testPointTree.insert(new Point2D(0.5, 0.75));
testPointTree.insert(new Point2D(0.0, 1.0));
testPointTree.insert(new Point2D(0.375, 0.87));
testPointTree.insert(new Point2D(0.0, 0.5));
testPointTree.insert(new Point2D(0.25, 0.875));
testPointTree.insert(new Point2D(0.0, 0.625));
Point2D nearestPoint = testPointTree.nearest(new Point2D(0.375, 0.75));
System.out.printf("%f, %f", nearestPoint.x(), nearestPoint.y());
}
}