Find kth smallest value in a sorted matrix
Given a row-wise and column-wise sorted square matrix and a positive integer k
, find the kth smallest number in the matrix.
For example,
mat = [
[-3, 1, 3],
[-2, 2, 4],
[1, 3, 5]
]
k = 6
Output: 3
Explanation: The elements of the matrix in increasing order, are [-3, -2, 1, 1, 2, 3, 3, 4, 5]. The sixth smallest element is 3.
Input:
mat = [
[1, 3],
[2, 4]
]
k = 5
Output: None
Explanation: k is more than the number of elements in the matrix.
1. Using Min Heap
The idea is to build a min-heap from all the elements of the first row. Then, start a loop where, in each iteration, we remove the root from the min-heap and replace it with the next element from the same column of the matrix. After k
pop operations have been done on the min-heap, the last popped element contains the kth smallest element.
The algorithm can be implemented as follows in C++, Java, and Python. Note that we can also build the min-heap from elements of the first column and replace the root with the next element from the same row of the matrix. The remaining logic remains the same.
C++
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
#include <iostream> #include <vector> #include <queue> #include <climits> using namespace std; // Data structure to store a heap node struct Tuple { int i, j, value; }; // Comparison object to be used to order the min-heap struct comp { bool operator()(const Tuple &lhs, const Tuple &rhs) const { return lhs.value > rhs.value; } }; // Function to return the kth smallest value in a sorted matrix int findkthSmallestElement(vector<vector<int>> const &mat, int k) { // invalid input if (mat.size() == 0 || k <= 0 ) { return INT_MIN; } // create an empty min-heap priority_queue<Tuple, vector<Tuple>, comp> minHeap; // insert all elements of the first row in the min-heap for (int j = 0; j < mat.size(); j++) { minHeap.push({ 0, j, mat[0][j] }); } // loop k times or until the heap is empty while (k-- && !minHeap.empty()) { // remove root from the min-heap Tuple minvalue = minHeap.top(); minHeap.pop(); // if k pop operations have been performed on the min-heap, // the last popped element contains the kth smallest element if (k == 0) { return minvalue.value; } // replace the root with the next element from the same column of the matrix if (minvalue.i != mat.size() - 1) { minHeap.push({ minvalue.i + 1, minvalue.j, mat[minvalue.i + 1][minvalue.j] }); } } // we reach here if k is more than the number of elements in the matrix return INT_MIN; } int main() { vector<vector<int>> mat = { {-3, 1, 3}, {-2, 2, 4}, {1, 3, 5} }; int k = 6; cout << findkthSmallestElement(mat, k) << endl; return 0; } |
Output:
3
Java
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import java.util.PriorityQueue; // A class to store a heap node class Tuple implements Comparable<Tuple> { int i, j, value; public Tuple(int i, int j, int value) { this.i = i; this.j = j; this.value = value; } @Override public int compareTo(Tuple tuple) { return this.value - tuple.value; } } class Main { // Function to return the kth smallest value in a sorted matrix public static int findkthSmallestElement(int[][] mat, int k) { // invalid input if (mat.length == 0 || k <= 0 ) { return Integer.MIN_VALUE; } // create an empty min-heap PriorityQueue<Tuple> minHeap = new PriorityQueue<>(); // insert all elements of the first row in the min-heap for (int j = 0; j < mat.length; j++) { minHeap.add(new Tuple(0, j, mat[0][j])); } // loop k times or until the heap is empty while (k-- > 0 && !minHeap.isEmpty()) { // remove root from the min-heap Tuple minvalue = minHeap.poll(); // if k pop operations have been performed on the min-heap, // the last popped element contains the kth smallest element if (k == 0) { return minvalue.value; } // replace the root with the next element from the same column of the matrix if (minvalue.i != mat.length - 1) { minHeap.add(new Tuple(minvalue.i + 1, minvalue.j, mat[minvalue.i + 1][minvalue.j])); } } // we reach here if k is more than the number of elements in the matrix return Integer.MIN_VALUE; } public static void main(String[] args) { int[][] mat = { {-3, 1, 3}, {-2, 2, 4}, {1, 3, 5} }; int k = 6; System.out.println(findkthSmallestElement(mat, k)); } } |
Output:
3
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from heapq import heappop, heappush # A class to store a heap node class Tuple: def __init__(self, i, j, value): self.i = i self.j = j self.value = value # Override the `__lt__()` function to make it work with min-heap def __lt__(self, other): return self.value < other.value # Function to return the kth smallest value in a sorted matrix def findkthSmallestElement(mat, k): # invalid input if len(mat) == 0 or k <= 0: return # create an empty min-heap minHeap = [] # insert all elements of the first row in the min-heap for j in range(0, len(mat)): heappush(minHeap, Tuple(0, j, mat[0][j])) # loop k times or until the heap is empty while k and minHeap: k = k - 1 # remove root from the min-heap minvalue = heappop(minHeap) # if k pop operations have been performed on the min-heap, # the last popped element contains the kth smallest element if k == 0: return minvalue.value # replace the root with the next element from the same column of the matrix if minvalue.i != len(mat) - 1: heappush(minHeap, Tuple(minvalue.i + 1, minvalue.j, mat[minvalue.i + 1][minvalue.j])) if __name__ == '__main__': mat = [ [-3, 1, 3], [-2, 2, 4], [1, 3, 5] ] k = 6 print(findkthSmallestElement(mat, k)) |
Output:
3
The time complexity of the above solution is O(N2) for an N × N
matrix and requires O(k) extra space for heap data structure.
2. Using Binary Search
We can avoid using the extra space by using the binary search algorithm. The binary search typically works with a linear data structure that is sorted, we can modify it to work with a matrix that is sorted both row-wise and column-wise. We start with the search space [low, high]
where low
and high
initially point to the top-left and bottom-right corners of the matrix, respectively. Then, at each iteration of the binary search loop, we determine the mid-value and count the elements in the matrix that are less than or equal to the mid-element. We narrow the search space to [mid+1…high]
if the count is less than k
; otherwise, we narrow it to [low…mid-1]
. The loop terminates as soon as low
surpasses high
, and low
stores the kth smallest value in the matrix.
Following is the C++, Java, and Python implementation based on the idea:
C++
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
#include <iostream> #include <vector> using namespace std; // Function to count elements in the matrix that are less than or equal to the given value int findLessOrEqual(vector<vector<int>> const &mat, int val) { int n = mat.size(); // start at the bottom-left corner of the matrix int i = n-1, j = 0; int count = 0; // loop till (i, j) cross the matrix boundary while (i >= 0 && j < n) { // if the current element is more than the given value if (mat[i][j] > val) { i--; // move up (towards smaller values) } else { // if the current element is less than the specified value, // then all values above the current element must also be less count += (i + 1); j++; // move right (towards greater values) } } return count; } // Function to return the kth smallest value in a sorted matrix int findkthSmallestElement(vector<vector<int>> const &mat, int k) { int n = mat.size(); // invalid input if (n == 0 || k <= 0 ) { return INT_MIN; } // initialize low with the top-left element of the matrix int low = mat[0][0]; // initialize high with the bottom-right element of the matrix int high = mat[n-1][n-1]; // loop till the search space is exhausted while (low <= high) { // find the mid-value in the search space int mid = low + ((high - low) >> 1); // find the count of elements that is less than or equal to the mid element int count = findLessOrEqual(mat, mid); // if count is less than k, the kth smallest element exists in range [mid+1…high] if (count < k) { low = mid + 1; } // otherwise, kth smallest element exists in the range [low…mid-1] else { high = mid - 1; } } return low; } int main() { vector<vector<int>> mat = { {-3, 1, 3}, {-2, 2, 4}, {1, 3, 5} }; int k = 6; cout << findkthSmallestElement(mat, k) << endl; return 0; } |
Output:
3
Java
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
class Main { // Function to count elements in the matrix that are less than or equal to val public static int findLessOrEqual(int[][] mat, int val) { // start at the bottom-left corner of the matrix int i = mat.length - 1, j = 0; int count = 0; // loop till (i, j) cross the matrix boundary while (i >= 0 && j < mat.length) { // if the current element is more than the given value if (mat[i][j] > val) { i--; // move up (towards smaller values) } else { // if the current element is less than the specified value, // then all values above the current element must also be less count += (i + 1); j++; // move right (towards greater values) } } return count; } // Function to return the kth smallest value in a sorted matrix public static int findkthSmallestElement(int[][] mat, int k) { int n = mat.length; // invalid input if (n == 0 || k <= 0 ) { return Integer.MIN_VALUE; } // initialize low with the top-left element of the matrix int low = mat[0][0]; // initialize high with the bottom-right element of the matrix int high = mat[n-1][n-1]; // loop till the search space is exhausted while (low <= high) { // find the mid-value in the search space int mid = low + ((high - low) >> 1); // find the count of elements that is less than or equal to the mid element int count = findLessOrEqual(mat, mid); // if count is less than k, the kth smallest element exists in the // range [mid+1…high] if (count < k) { low = mid + 1; } // otherwise, kth smallest element exists in the range [low…mid-1] else { high = mid - 1; } } return low; } public static void main(String[] args) { { {-3, 1, 3}, {-2, 2, 4}, {1, 3, 5} }; int k = 6; System.out.println(findkthSmallestElement(mat, k)); } } |
Output:
3
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# Function to count elements in the matrix that are less than or equal to the given value def findLessOrEqual(mat, val): # start at the bottom-left corner of the matrix i, j = len(mat)-1, 0 count = 0 # loop till (i, j) cross the matrix boundary while i >= 0 and j < len(mat): # if the current element is more than the given value if mat[i][j] > val: i = i - 1 # move up (towards smaller values) else: # if the current element is less than the specified value, # then all values above the current element must also be less count += (i + 1) j = j + 1 # move right (towards greater values) return count # Function to return the kth smallest value in a sorted matrix def findkthSmallestElement(mat, k): n = len(mat) # invalid input if n == 0 or k <= 0: return # initialize low and high with top-left and bottom-right elements of the matrix low, high = mat[0][0], mat[n-1][n-1] # loop till the search space is exhausted while low <= high: # find the mid-value in the search space mid = low + ((high - low) >> 1) # find the count of elements that is less than or equal to the mid element count = findLessOrEqual(mat, mid) # if count is less than k, the kth smallest element exists in range [mid+1…high] if count < k: low = mid + 1 # otherwise, kth smallest element exists in the range [low…mid-1] else: high = mid - 1 return low if __name__ == '__main__': mat = [ [-3, 1, 3], [-2, 2, 4], [1, 3, 5] ] k = 6 print(findkthSmallestElement(mat, k)) |
Output:
3
The time complexity of the above solution is O(N.log(N2)) for an N × N
matrix and doesn’t require any extra space.
Count negative elements present in the sorted matrix in linear time
Report all occurrences of an element in a row-wise and column-wise sorted matrix in linear time
Find the area of the largest rectangle of 1’s in a binary matrix
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)