Construct an ancestor matrix from a binary tree
Given a binary tree whose nodes are labeled from 0 to N-1, construct an N × N ancestor matrix. An ancestor matrix is a boolean matrix, whose cell (i, j) is true if i is an ancestor of j in the binary tree.
For example, consider the following binary tree:

The output should be the following ancestor matrix
0 0 0 0 0 1
0 0 0 0 0 0
1 0 1 0 0 0
1 1 1 1 0 1
0 0 0 0 0 0
The idea is to traverse the tree in a preorder fashion and keep track of ancestors in a container such as a set, a list, or an array. For each encountered node, mark its ancestors in the ancestor matrix using the ancestors’ container.
To keep the ancestors’ container updated, add the node to the ancestors’ list when it is visited and remove that node from the list of ancestors once its left and right subtrees are processed. Following is the C++, Java, and Python program that demonstrates it:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
#include <iostream> #include <vector> #include <unordered_set> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node* left, *right; // Constructor Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Recursive function to calculate the size of the binary tree int size(Node* root) { // base case if (root == nullptr) { return 0; } return size(root->left) + 1 + size(root->right); } // Traverse the tree in a preorder fashion and update the ancestors of // all nodes in the boolean ancestor matrix void constructAncestorMatrix(Node* root, unordered_set<Node*> &ancestors, vector<vector<bool>> &ancestorMatrix) { // base case if (root == nullptr) { return; } // update all ancestors of the current node for (Node* node: ancestors) { ancestorMatrix[node->data][root->data] = true; } // add the current node to the set of ancestors ancestors.insert(root); // recur for the left and right subtree constructAncestorMatrix(root->left, ancestors, ancestorMatrix); constructAncestorMatrix(root->right, ancestors, ancestorMatrix); // remove the current node from the set of ancestors since all // descendants of the current node are already processed ancestors.erase(root); } // Function to construct an ancestor matrix from a given binary tree vector<vector<bool>> constructAncestorMatrix(Node* root) { // calculate the size of the binary tree int n = size(root); // create an ancestor matrix of size `n × n`, initialized by false vector<vector<bool>> ancestorMatrix(n, vector<bool>(n)); // stores ancestors of a node unordered_set<Node*> ancestors; // construct the ancestor matrix constructAncestorMatrix(root, ancestors, ancestorMatrix); return ancestorMatrix; } int main() { /* Construct the following tree 4 / \ 3 1 / \ \ 2 0 5 */ Node* root = new Node(4); root->left = new Node(3); root->right = new Node(1); root->left->left = new Node(2); root->left->right = new Node(0); root->right->right = new Node(5); // construct the ancestor matrix vector<vector<bool>> ancestorMatrix = constructAncestorMatrix(root); // print the ancestor matrix for (auto const &row: ancestorMatrix) { for (auto val: row) { cout << val << " "; } cout << endl; } return 0; } |
Output:
0 0 0 0 0 0
0 0 0 0 0 1
0 0 0 0 0 0
1 0 1 0 0 0
1 1 1 1 0 1
0 0 0 0 0 0
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import java.util.Arrays; import java.util.HashSet; import java.util.Set; // A class to store a binary tree node class Node { int data; Node left, right; // Constructor Node(int data) { this.data = data; this.left = this.right = null; } } class Main { // Recursive function to calculate the size of the binary tree public static int size(Node root) { // base case if (root == null) { return 0; } return size(root.left) + 1 + size(root.right); } // Traverse the tree in a preorder fashion and update the ancestors of // all nodes in the boolean ancestor matrix public static void constructAncestorMatrix(Node root, Set<Node> ancestors, int[][] ancestorMatrix) { // base case if (root == null) { return; } // update all ancestors of the current node for (Node node: ancestors) { ancestorMatrix[node.data][root.data] = 1; } // add the current node to the set of ancestors ancestors.add(root); // recur for the left and right subtree constructAncestorMatrix(root.left, ancestors, ancestorMatrix); constructAncestorMatrix(root.right, ancestors, ancestorMatrix); // remove the current node from the set of ancestors since all // descendants of the current node are already processed ancestors.remove(root); } // Function to construct an ancestor matrix from a given binary tree public static int[][] constructAncestorMatrix(Node root) { // calculate the size of the binary tree int n = size(root); // create an ancestor matrix of size `n × n`, initialized by 0 int[][] ancestorMatrix = new int[n][n]; // stores ancestors of a node Set<Node> ancestors = new HashSet<>(); // construct the ancestor matrix constructAncestorMatrix(root, ancestors, ancestorMatrix); return ancestorMatrix; } public static void main(String[] args) { /* Construct the following tree 4 / \ 3 1 / \ \ 2 0 5 */ Node root = new Node(4); root.left = new Node(3); root.right = new Node(1); root.left.left = new Node(2); root.left.right = new Node(0); root.right.right = new Node(5); // construct the ancestor matrix int[][] ancestorMatrix = constructAncestorMatrix(root); // print the ancestor matrix for (int[] row: ancestorMatrix) { System.out.println(Arrays.toString(row)); } } } |
Output:
[0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 1]
[0, 0, 0, 0, 0, 0]
[1, 0, 1, 0, 0, 0]
[1, 1, 1, 1, 0, 1]
[0, 0, 0, 0, 0, 0]
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# A class to store a binary tree node class Node: # Constructor def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Recursive function to calculate the size of the binary tree def size(root): # base case if root is None: return 0 return size(root.left) + 1 + size(root.right) # Traverse the tree in a preorder fashion and update the ancestors of # all nodes in the ancestor matrix def constructAncestorMatrix(root, ancestors, ancestorMatrix): # base case if root is None: return # update all ancestors of the current node for node in ancestors: ancestorMatrix[node.data][root.data] = 1 # add the current node to the set of ancestors ancestors.add(root) # recur for the left and right subtree constructAncestorMatrix(root.left, ancestors, ancestorMatrix) constructAncestorMatrix(root.right, ancestors, ancestorMatrix) # remove the current node from the set of ancestors since all # descendants of the current node are already processed ancestors.remove(root) # Function to construct an ancestor matrix from a given binary tree def construct(root): # calculate the size of the binary tree n = size(root) # create an ancestor matrix of size `n × n`, initialized by 0 ancestorMatrix = [[0 for x in range(n)] for y in range(n)] # stores ancestors of a node ancestors = set() # construct the ancestor matrix constructAncestorMatrix(root, ancestors, ancestorMatrix) return ancestorMatrix if __name__ == '__main__': ''' Construct the following tree 4 / \ 3 1 / \ \ 2 0 5 ''' root = Node(4) root.left = Node(3) root.right = Node(1) root.left.left = Node(2) root.left.right = Node(0) root.right.right = Node(5) # construct the ancestor matrix ancestorMatrix = construct(root) # print the ancestor matrix for row in ancestorMatrix: print(row) |
Output:
[0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 1]
[0, 0, 0, 0, 0, 0]
[1, 0, 1, 0, 0, 0]
[1, 1, 1, 1, 0, 1]
[0, 0, 0, 0, 0, 0]
The time complexity of the above solution is O(N2) and requires O(N) extra space, where N is the size of the binary tree.
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)