This question is available on both Leetcode (No. 270) and Lintcode (No. 901). The question is not too hard, if we do not chase the best time complexity. For example, we can traversal all the nodes and book-keep the closest nodes.

The challenging is “Assume that the BST is balanced, could you solve it in less than O(n) runtime (where n = total nodes)?” Reasonably, the answer is yes. We can do it better than O(n), as O(logN + k). However, it took me quite a few days to find out the right solution.

When I got stuck with this problem, I tried to Google the other blogs. Unfortunately, among the pages I visited, some so-called O(logN + K) solutions are essentially O(N + k). For example, some solution prepare all the predecessors and successors in functions like getPredecessor(TreeNode root, double target, Stack precedessor), where the function fills the stack with all qualified nodes. Since nearly all nodes will be either predecessor or successor, these functions are O(N).

In fact, the key is how to get the predecessor and successor without visiting the whole tree. And the answer is a variable of iteration-based in-order traversal.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | from math import isclose """ Definition of TreeNode: class TreeNode: def __init__(self, val): self.val = val self.left, self.right = None, None """ def Previous(input_stack): # Reverse in-order travesal the binary tree and yield the previous node # (with smaller value) stack = input_stack * 1 # Deep copy current = None while stack or current: while current: stack.append(current) current = current.right current = stack.pop() yield current.val current = current.left def Next(input_stack): # In-order traversal the binary tree and yield the next node (with larger # value). stack = input_stack * 1 # Deep copy current = None while stack or current: while current: stack.append(current) current = current.left current = stack.pop() yield current.val current = current.right class Solution: """ @param root: the given BST @param target: the given target @param k: the given k @return: k values in the BST that are closest to the target """ def closestKValues(self, root, target, k): TreeNode.__repr__ = lambda x: str(x.val) # For debug purpose # Find the node with the closest value # If BST is balanced, the time is O(logN) stack = [] current = root min_diff = float("inf") min_diff_node = None while current: stack.append(current) if abs(current.val - target) < min_diff: min_diff = abs(current.val - target) min_diff_node = current if isclose(current.val, target): # Find the target break elif current.val > target: current = current.left else: current = current.right # Move back the node with the closest value. while stack[-1] != min_diff_node: stack.pop() # For the stack of in-order traversal, it should be decreasing. next_stack = [] for node in stack: if next_stack and node.val >= next_stack[-1].val: next_stack.pop() next_stack.append(node) # For the stack of reverse in-order traversal, it should be increasing. prev_stack = [] for node in stack: if prev_stack and node.val < prev_stack[-1].val: prev_stack.pop() prev_stack.append(node) next_generator = Next(next_stack) # The closest item should exist in just one side, not both next(next_generator) pre_generator = Previous(prev_stack) next_item = next(next_generator, None) pre_item = next(pre_generator, None) result = [] # Since it's guaranteed to have one result, it's not a must to have # (next_item != None or pre_item != None) # The time is O(k). The key point is the generator. We do NOT need to # traversal the whole trees. while len(result) < k and (next_item != None or pre_item != None): # Must use `next_item == None` instead of `next_item`, because it # might be 0. if next_item == None: result.append(pre_item) pre_item = next(pre_generator, None) elif pre_item == None: result.append(next_item) next_item = next(next_generator, None) elif abs(next_item - target) < abs(pre_item - target): result.append(next_item) next_item = next(next_generator, None) else: result.append(pre_item) pre_item = next(pre_generator, None) return result |