Calculate the sum of all elements in a submatrix in constant time
Given an M × N
integer matrix and two coordinates (p, q)
and (r, s)
representing top-left and bottom-right coordinates of a submatrix of it, calculate the sum of all elements present in the submatrix. Here, 0 <= p < r < M
and 0 <= q < s < N
.
For example,
[ 0 2 5 4 1 ]
[ 4 8 2 3 7 ]
[ 6 3 4 6 2 ]
[ 7 3 1 8 3 ]
[ 1 5 7 9 4 ]
(p, q) = (1, 1)
(r, s) = (3, 3)
Output: Sum is 38
Explanation:
The submatrix formed by coordinates (p, q), (p, s), (r, q), and (r, s) is shown below, having the sum of elements equal to 38.
[ 8 2 3 ]
[ 3 4 6 ]
[ 3 1 8 ]
Assume that m
such lookup calls are made to the matrix; the task is to achieve O(1) time lookups.
The idea is to preprocess the matrix. Take an auxiliary matrix sum[][]
, where sum[i][j]
will store the sum of elements in the matrix from (0, 0)
to (i, j)
. We can easily calculate the value of sum[i][j]
in constant time using the following relation:
The following diagram easily explains this relation. (Here greyed portion represents the sum of elements in the matrix from (0, 0)
to (i, j)
)
Now to calculate the sum of elements present in the submatrix formed by coordinates (p, q)
, (p, s)
, (r, q)
, and (r, s)
in constant time, we can directly apply the relation below:
The following diagram explains this relation. (Here the greyed portion represent the submatrix).
The algorithm can be implemented as follows in C++, Java, and Python:
C++
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
#include <iostream> #include <vector> using namespace std; vector<vector<int>> preprocess(vector<vector<int>> const &mat) { // `M × N` matrix int M = mat.size(); int N = mat[0].size(); // preprocess the matrix `mat` such that `sum[i][j]` stores // sum of elements in the matrix from (0, 0) to (i, j) vector<vector<int>> sum(M, vector<int>(N)); sum[0][0] = mat[0][0]; // preprocess the first row for (int j = 1; j < N; j++) { sum[0][j] = mat[0][j] + sum[0][j - 1]; } // preprocess the first column for (int i = 1; i < M; i++) { sum[i][0] = mat[i][0] + sum[i - 1][0]; } // preprocess the rest of the matrix for (int i = 1; i < M; i++) { for (int j = 1; j < N; j++) { sum[i][j] = mat[i][j] + sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1]; } } return sum; } // Calculate the sum of all elements in a submatrix in constant time int findSubmatrixSum(vector<vector<int>> const &mat, int p, int q, int r, int s) { // base case if (mat.size() == 0) { return 0; } // preprocess the matrix vector<vector<int>> sum = preprocess(mat); // `total` is `sum[r][s] - sum[r][q-1] - sum[p-1][s] + sum[p-1][q-1]` int total = sum[r][s]; if (q - 1 >= 0) { total -= sum[r][q - 1]; } if (p - 1 >= 0) { total -= sum[p - 1][s]; } if (p - 1 >= 0 && q - 1 >= 0) { total += sum[p - 1][q - 1]; } return total; } int main() { vector<vector<int>> mat = { { 0, 2, 5, 4, 1 }, { 4, 8, 2, 3, 7 }, { 6, 3, 4, 6, 2 }, { 7, 3, 1, 8, 3 }, { 1, 5, 7, 9, 4 } }; // (p, q) and (r, s) represent top-left and bottom-right // coordinates of the submatrix int p = 1, q = 1, r = 3, s = 3; // calculate the submatrix sum cout << findSubmatrixSum(mat, p, q, r, s); return 0; } |
Output:
38
Java
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
class Main { public static int[][] preprocess(int[][] mat) { // `M × N` matrix int M = mat.length; int N = mat[0].length; // preprocess the matrix `mat` such that `sum[i][j]` stores // sum of elements in the matrix from (0, 0) to (i, j) int[][] sum = new int[mat.length][mat[0].length]; sum[0][0] = mat[0][0]; // preprocess the first row for (int j = 1; j < mat[0].length; j++) { sum[0][j] = mat[0][j] + sum[0][j - 1]; } // preprocess the first column for (int i = 1; i < mat.length; i++) { sum[i][0] = mat[i][0] + sum[i - 1][0]; } // preprocess the rest of the matrix for (int i = 1; i < mat.length; i++) { for (int j = 1; j < mat[0].length; j++) { sum[i][j] = mat[i][j] + sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1]; } } return sum; } // Calculate the sum of all elements in a submatrix in constant time public static int findSubmatrixSum(int[][] mat, int p, int q, int r, int s) { // base case if (mat == null || mat.length == 0) { return 0; } // preprocess the matrix int[][] sum = preprocess(mat); /* `total` is `sum[r][s] - sum[r][q-1] - sum[p-1][s] + sum[p-1][q-1]` */ int total = sum[r][s]; if (q - 1 >= 0) { total -= sum[r][q - 1]; } if (p - 1 >= 0) { total -= sum[p - 1][s]; } if (p - 1 >= 0 && q - 1 >= 0) { total += sum[p - 1][q - 1]; } return total; } public static void main(String[] args) { int[][] mat = { { 0, 2, 5, 4, 1 }, { 4, 8, 2, 3, 7 }, { 6, 3, 4, 6, 2 }, { 7, 3, 1, 8, 3 }, { 1, 5, 7, 9, 4 } }; // (p, q) and (r, s) represent top-left and bottom-right // coordinates of the submatrix int p = 1, q = 1, r = 3, s = 3; // calculate the submatrix sum System.out.print(findSubmatrixSum(mat, p, q, r, s)); } } |
Output:
38
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
def preprocess(mat): # `M × N` matrix (M, N) = (len(mat), len(mat[0])) # preprocess the matrix `mat` such that `s[i][j]` stores # sum of elements in the matrix from (0, 0) to (i, j) s = [[0 for x in range(len(mat[0]))] for y in range(len(mat))] s[0][0] = mat[0][0] # preprocess the first row for j in range(1, len(mat[0])): s[0][j] = mat[0][j] + s[0][j - 1] # preprocess the first column for i in range(1, len(mat)): s[i][0] = mat[i][0] + s[i - 1][0] # preprocess the rest of the matrix for i in range(1, len(mat)): for j in range(1, len(mat[0])): s[i][j] = mat[i][j] + s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] return s # Calculate the sum of all elements in a submatrix in constant time def findSubmatrixSum(mat, p, q, r, s): # base case if not mat or not len(mat): return 0 # preprocess the matrix mat = preprocess(mat) # `total` is `mat[r][s] - mat[r][q-1] - mat[p-1][s] + mat[p-1][q-1]` total = mat[r][s] if q - 1 >= 0: total -= mat[r][q - 1] if p - 1 >= 0: total -= mat[p - 1][s] if p - 1 >= 0 and q - 1 >= 0: total += mat[p - 1][q - 1] return total if __name__ == '__main__': mat = [ [0, 2, 5, 4, 1], [4, 8, 2, 3, 7], [6, 3, 4, 6, 2], [7, 3, 1, 8, 3], [1, 5, 7, 9, 4] ] # (p, q) and (r, s) represent top-left and bottom-right # coordinates of the submatrix p = q = 1 r = s = 3 # calculate the submatrix sum print(findSubmatrixSum(mat, p, q, r, s)) |
Output:
38
This solution takes O(N2) time for an N × N
matrix, but we can do constant-time lookups any number of times once the matrix is preprocessed. In other words, if M
lookup calls are made to the matrix, then the naive solution takes O(M × N2) time, while the above solution takes only O(M + N2) time.
Exercise:
1. Given an M × N
integer matrix, find the sum of all K × K
submatrix
2. Given an M × N
integer matrix and a cell (i, j)
, find the sum of all matrix elements in constant time, except the elements present at row i
and column j
of the matrix.
Find maximum sum `K × K` submatrix in a given `M × N` matrix
Find the largest square submatrix which is surrounded by all 1’s
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)