Find k’th smallest node in a BST
Given a BST and a positive number k, find the k'th smallest node in it.
For example, the 4th smallest node in the following BST is 15, and the 6th smallest is 20. The 8th smallest node does not exist.

The idea is to traverse the BST in an inorder fashion since the inorder traversal visits the nodes of a BST in the sorted order. Maintain a counter along with recursion that keeps track of the visited nodes, and when that counter reaches k, return that node.
The algorithm can be implemented as follows in C, Java, and Python. The code is optimized to visit the right subtree only when the k'th smallest is not found in the left subtree.
C
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
#include <stdio.h> #include <stdlib.h> // Data structure to store a binary tree node struct Node { int data; struct Node *left, *right; }; // Function to create a new binary tree node having a given key struct Node* newNode(int key) { struct Node* node = (struct Node*)malloc(sizeof(struct Node)); node->data = key; node->left = node->right = NULL; return node; } // Recursive function to find the k'th smallest node // in the BST (using inorder traversal) struct Node* kthSmallest(struct Node* root, int* counter, int k) { // base case if (root == NULL) { return NULL; } // recur for the left subtree struct Node* left = kthSmallest(root->left, counter, k); // if k'th smallest node is found if (left != NULL) { return left; } // if the root is k'th smallest node if (++(*counter) == k) { return root; } // recur for the right subtree only if k'th smallest node is not found // in the right subtree return kthSmallest(root->right, counter, k); } // Function to find the k'th smallest node in a BST struct Node* findKthSmallest(struct Node* root, int k) { // counter to keep track of the total number of the visited nodes int counter = 0; // recursively find the k'th smallest node return kthSmallest(root, &counter, k); } int main(void) { /* Construct the following BST 15 / \ / \ 10 20 / \ / \ / \ / \ 8 12 16 25 */ struct Node* root = newNode(15); root->left = newNode(10); root->right = newNode(20); root->left->left = newNode(8); root->left->right = newNode(12); root->right->left = newNode(16); root->right->right = newNode(25); int k = 4; // find the k'th smallest node struct Node* result = findKthSmallest(root, k); if (result) { printf("%d'th smallest node is %d", k, result->data); } else { printf("%d'th smallest node does not exist.", k); } return 0; } |
Output:
4th smallest element is 15
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import java.util.concurrent.atomic.AtomicInteger; // A class to store a binary tree node class Node { int data; Node left, right; Node(int data) { this.data = data; } } class Main { // Recursive function to find the k'th smallest node // in the BST (using inorder traversal) public static Node kthSmallest(Node root, AtomicInteger counter, int k) { // base case if (root == null) { return null; } // recur for the left subtree Node left = kthSmallest(root.left, counter, k); // if k'th smallest node is found if (left != null) { return left; } // if the root is k'th smallest node if (counter.incrementAndGet() == k) { return root; } // recur for the right subtree only if k'th smallest node is not found // in the right subtree return kthSmallest(root.right, counter, k); } // Function to find the k'th smallest node in the BST public static Node findKthSmallest(Node root, int k) { // Counter to keep track of the total number of the visited nodes. // `AtomicInteger` is used here since `Integer` is passed by value in Java AtomicInteger counter = new AtomicInteger(0); // Recursively find the k'th smallest node return kthSmallest(root, counter, k); } public static void main(String[] args) { /* Construct the following BST 15 / \ / \ 10 20 / \ / \ / \ / \ 8 12 16 25 */ Node root = new Node(15); root.left = new Node(10); root.right = new Node(20); root.left.left = new Node(8); root.left.right = new Node(12); root.right.left = new Node(16); root.right.right = new Node(25); int k = 4; // find the k'th smallest node Node result = findKthSmallest(root, k); if (result != null) { System.out.printf("%d'th smallest node is %d", k, result.data); } else { System.out.printf("%d'th smallest node does not exist.", k); } } } |
Output:
4th smallest element is 15
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# A class to store a binary tree node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Recursive function to find the k'th smallest node # in the BST (using inorder traversal) def kthSmallest(root, counter, k): # base case if root is None: return None, counter # recur for the left subtree left, counter = kthSmallest(root.left, counter, k) # if k'th smallest node is found if left: return left, counter # if the root is k'th smallest node counter = counter + 1 if counter == k: return root, counter # recur for the right subtree only if k'th smallest node is not found # in the right subtree ret, counter = kthSmallest(root.right, counter, k) return ret, counter # Function to find the k'th smallest node in a BST def findKthSmallest(root, k): # counter to keep track of the total number of the visited nodes counter = 0 # recursively find the k'th smallest node return kthSmallest(root, counter, k)[0] if __name__ == '__main__': ''' Construct the following BST 15 / \ / \ 10 20 / \ / \ / \ / \ 8 12 16 25 ''' root = Node(15) root.left = Node(10) root.right = Node(20) root.left.left = Node(8) root.left.right = Node(12) root.right.left = Node(16) root.right.right = Node(25) k = 4 # find the k'th smallest node result = findKthSmallest(root, k) if result: print(f'{k}\'th smallest node is {result.data}') else: print(f'{k}\'th smallest node does not exist.') |
Output:
4th smallest element is 15
The time complexity of the above solution is O(n), where n is the size of the BST, and requires space proportional to the tree’s height for the call stack.
Exercise:
1. Modify the solution to find the k'th largest node in the BST (Check solution here).
2. Modify the solution to print the first k smallest nodes in the BST (Check solution here).
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)