Find all nodes at a given distance from leaf nodes in a binary tree
Given a binary tree, write an efficient algorithm to find all nodes present at a given distance from any leaf node. We need to find only those nodes that are present in the root-to-leaf path for that leaf.
For example, consider the following binary tree:

The nodes present at a distance of 2 from any leaf node are 15, 20
The nodes present at a distance of 3 from any leaf node is 15
The idea is to traverse the tree in a preorder fashion and use a list to store the current node’s ancestors in the preorder traversal. If we encounter a leaf node, print the ancestor present at a given distance from it. To avoid printing duplicates, insert the nodes into a set and print it later.
Following is the implementation of the idea in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
#include <iostream> #include <vector> #include <unordered_set> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to check if a given node is a leaf node or not bool isLeaf(Node* node) { return (node->left == nullptr && node->right == nullptr); } // Recursive function to find all nodes at a given distance from leaf nodes void leafNodeDistance(Node* node, vector<Node*> path, unordered_set<Node*> &set, int dist) { // base case: empty tree if (node == nullptr) { return; } // if a leaf node is found, insert the node at a distance `dist` from the // leaf node into the set if (isLeaf(node) && path.size() >= dist) { set.insert(path.at(path.size() - dist)); return; } // include the current node in the current path path.push_back(node); // recur for the left and right subtree leafNodeDistance(node->left, path, set, dist); leafNodeDistance(node->right, path, set, dist); } // Find all distinct nodes at a given distance from leaf nodes void leafNodeDistance(Node* node, int dist) { // vector to store root-to-leaf path vector<Node*> path; // create an empty set to store distinct nodes at a given // distance from leaf nodes unordered_set<Node*> set; // find all nodes leafNodeDistance(node, path, set, dist); // print output for (Node* node: set) { cout << node->data << " "; } } int main() { /* Construct the following tree 15 / \ / \ 10 20 / \ / \ 8 12 16 25 / 18 */ Node* root = new Node(15); root->left = new Node(10); root->right = new Node(20); root->left->left = new Node(8); root->left->right = new Node(12); root->right->left = new Node(16); root->right->right = new Node(25); root->right->left->left = new Node(18); int dist = 1; leafNodeDistance(root, dist); return 0; } |
Output:
10 16 20
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; // A class to store a binary tree node class Node { int data; Node left = null, right = null; Node(int data) { this.data = data; } } class Main { // Function to check if a given node is a leaf node or not public static boolean isLeaf(Node node) { return (node.left == null && node.right == null); } // Recursive function to find all nodes at a given distance from leaf nodes public static void leafNodeDistance(Node node, List<Node> path, Set<Node> set, int dist) { // base case: empty tree if (node == null) { return; } // if a leaf node is found, insert the node at a distance `dist` from the // leaf node into the set if (isLeaf(node) && path.size() >= dist) { set.add(path.get(path.size() - dist)); return; } // include the current node in the current path path.add(node); // recur for the left and right subtree leafNodeDistance(node.left, path, set, dist); leafNodeDistance(node.right, path, set, dist); // remove the current node from the current path path.remove(node); } // Find all distinct nodes at a given distance from leaf nodes public static void leafNodeDistance(Node node, int dist) { // list to store root-to-leaf path List<Node> path = new ArrayList<>(); // create an empty set to store distinct nodes at a given // distance from leaf nodes Set<Node> set = new HashSet<>(); // find all nodes leafNodeDistance(node, path, set, dist); // print output for (Node e: set) { System.out.print(e.data + " "); } } public static void main(String[] args) { /* Construct the following tree 15 / \ / \ 10 20 / \ / \ 8 12 16 25 / 18 */ Node root = new Node(15); root.left = new Node(10); root.right = new Node(20); root.left.left = new Node(8); root.left.right = new Node(12); root.right.left = new Node(16); root.right.right = new Node(25); root.right.left.left = new Node(18); int dist = 1; leafNodeDistance(root, dist); } } |
Output:
10 16 20
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# A class to store a binary tree node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Function to check if a given node is a leaf node or not def isLeaf(node): return node.left is None and node.right is None # Recursive function to find all nodes at a given distance from leaf nodes def leafNodeDistance(node, path, set, dist): # base case: empty tree if node is None: return # if a leaf node is found, insert the node at a distance `dist` from the # leaf node into the set if isLeaf(node) and len(path) >= dist: set.add(path[-dist]) return # include the current node in the current path path.append(node) # recur for the left and right subtree leafNodeDistance(node.left, path, set, dist) leafNodeDistance(node.right, path, set, dist) # remove the current node from the current path path.remove(node) # Find all distinct nodes at a given distance from leaf nodes def printLeafNodeDistance(node, dist): # list to store root-to-leaf path path = [] # create an empty set to store distinct nodes at a given # distance from leaf nodes s = set() # find all nodes leafNodeDistance(node, path, s, dist) # print output print([e.data for e in s]) if __name__ == '__main__': ''' Construct the following tree 15 / \ / \ 10 20 / \ / \ 8 12 16 25 / 18 ''' root = Node(15) root.left = Node(10) root.right = Node(20) root.left.left = Node(8) root.left.right = Node(12) root.right.left = Node(16) root.right.right = Node(25) root.right.left.left = Node(18) dist = 1 printLeafNodeDistance(root, dist) |
Output:
[16, 10, 20]
The time complexity of the above solution is O(n) and requires O(n) extra space, where n is the size of the binary tree.
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)