Remove all nodes from a linked list that matches a given key
Given a linked list, remove all nodes from it that match a given key.
For example,
Linked List: [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> null]
key: 4
Output: 2 -> 5 -> 7 -> 8 -> null
Input:
Linked List: [1 -> 1 -> 1 -> null]
key: 1
Output: null
1. Iterative Solution
The idea is to loop through the linked list’s nodes, and for each node, remove the next node if its value matches the specified key. To handle the input where the first node matches the provided key, we can use a dummy node as the start of the result list. This dummy node is a temporary node that was allocated in the stack, and initially points to the head of the linked list. When we are done, the next pointer to the dummy node points to the head of the resulting list.
Following is the C++, Java, and Python program demonstrating it:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#include <iostream> #include <vector> using namespace std; // A Linked List Node struct Node { int data; Node* next; }; // Utility function to push a new node with the given data onto the list's front void push(Node*& headRef, int data) { Node* node = new Node; node->data = data; node->next = headRef; headRef = node; } // Utility function to print a linked list void printList(Node* head) { Node* curr = head; while (curr) { cout << curr->data << " —> "; curr = curr->next; } cout << "nullptr"; } // Function to remove all nodes from a linked list that matches a given key Node* removeKey(Node* head, int key) { // base case: empty list if (!head) { return nullptr; } // construct a dummy node whose next pointer points to head of the linked list Node dummy; dummy.next = head; Node* curr = &dummy; while (curr->next) { // if the next node matches the given key, skip it if (curr->next->data == key) { curr->next = curr->next->next; } // otherwise, move to the next node else { curr = curr->next; } } // the next pointer of dummy node points to linked list head return dummy.next; } int main() { vector<int> nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> nullptr] Node* head = nullptr; for (int i = nums.size() - 1; i >= 0; i--) { push(head, nums[i]); } head = removeKey(head, key); printList(head); return 0; } |
Output:
2 -> 5 -> 7 -> 8 -> nullptr
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
// A Linked List Node class Node { int data; Node next; Node(int data, Node next) { this.data = data; this.next = next; } Node() {} } class Main { // Helper function to print a given linked list public static void printList(Node head) { Node ptr = head; while (ptr != null) { System.out.print(ptr.data + " —> "); ptr = ptr.next; } System.out.println("null"); } // Function to remove all nodes from a linked list that matches a given key public static Node removeKey(Node head, int key) { // base case: empty list if (head == null) { return null; } // construct a dummy node whose next pointer points to head of the linked list Node dummy = new Node(); dummy.next = head; Node curr = dummy; while (curr.next != null) { // if the next node matches the given key, skip it if (curr.next.data == key) { curr.next = curr.next.next; } // otherwise, move to the next node else { curr = curr.next; } } // the next pointer of dummy node points to linked list head return dummy.next; } public static void main(String[] args) { int[] nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> null] Node head = null; for (int i = nums.length - 1; i >= 0; i--) { head = new Node(nums[i], head); } head = removeKey(head, key); printList(head); } } |
Output:
2 -> 5 -> 7 -> 8 -> null
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# A Linked List Node class Node: def __init__(self, data=None, next=None): self.data = data self.next = next # Helper function to print a given linked list def printList(head): ptr = head while ptr: print(ptr.data, end=' —> ') ptr = ptr.next print('None') # Function to remove all nodes from a linked list that matches a given key def removeKey(head, key): # base case: empty list if head is None: return None # construct a dummy node whose next pointer points to head of the linked list dummy = Node() dummy.next = head curr = dummy while curr.next: # if the next node matches the given key, skip it if curr.next.data == key: curr.next = curr.next.next # otherwise, move to the next node else: curr = curr.next # the next pointer of dummy node points to linked list head return dummy.next if __name__ == '__main__': nums = [4, 2, 4, 4, 5, 4, 7, 8] key = 4 # construct a linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> None] head = None for i in reversed(nums): head = Node(i, head) head = removeKey(head, key) printList(head) |
Output:
2 -> 5 -> 7 -> 8 -> None
The time complexity of the above solution is O(n) and requires O(1) extra space. Here, n is the length of the linked list.
Here is a different approach that is structurally similar, but forgoes the use of a dummy head. Instead, it discards all nodes from the beginning of the linked list that match the specified key, before processing the remaining nodes.
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#include <iostream> #include <vector> using namespace std; // A Linked List Node struct Node { int data; Node* next; }; // Utility function to push a new node with the given data onto the list's front void push(Node*& headRef, int data) { Node* node = new Node; node->data = data; node->next = headRef; headRef = node; } // Utility function to print a linked list void printList(Node* head) { Node* curr = head; while (curr) { cout << curr->data << " —> "; curr = curr->next; } cout << "nullptr"; } // Function to remove all nodes from a linked list that matches a given key Node* removeKey(Node* head, int key) { // base case: empty list if (!head) { return nullptr; } // remove key at the beginning of the linked list while (head && head->data == key) { head = head->next; } Node* curr = head; while (curr && curr->next) { // if the next node matches the given key, skip it if (curr->next->data == key) { curr->next = curr->next->next; } // otherwise, move to the next node else { curr = curr->next; } } return head; } int main() { vector<int> nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> nullptr] Node* head = nullptr; for (int i = nums.size() - 1; i >= 0; i--) { push(head, nums[i]); } head = removeKey(head, key); printList(head); return 0; } |
Output:
2 -> 5 -> 7 -> 8 -> nullptr
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
// A Linked List Node class Node { int data; Node next; Node(int data, Node next) { this.data = data; this.next = next; } Node() {} } class Main { // Helper function to print a given linked list public static void printList(Node head) { Node ptr = head; while (ptr != null) { System.out.print(ptr.data + " —> "); ptr = ptr.next; } System.out.println("null"); } // Function to remove all nodes from a linked list that matches a given key public static Node removeKey(Node head, int key) { // base case: empty list if (head == null) { return null; } // remove key at the beginning of the linked list while (head != null && head.data == key) { head = head.next; } Node curr = head; while (curr != null && curr.next != null) { // if the next node matches the given key, skip it if (curr.next.data == key) { curr.next = curr.next.next; } // otherwise, move to the next node else { curr = curr.next; } } return head; } public static void main(String[] args) { int[] nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> null] Node head = null; for (int i = nums.length - 1; i >= 0; i--) { head = new Node(nums[i], head); } head = removeKey(head, key); printList(head); } } |
Output:
2 -> 5 -> 7 -> 8 -> null
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# A Linked List Node class Node: def __init__(self, data=None, next=None): self.data = data self.next = next # Helper function to print a given linked list def printList(head): ptr = head while ptr: print(ptr.data, end=' —> ') ptr = ptr.next print('None') # Function to remove all nodes from a linked list that matches a given key def removeKey(head, key): # base case: empty list if head is None: return None # remove key at the beginning of the linked list while head and head.data == key: head = head.next curr = head while curr and curr.next: # if the next node matches the given key, skip it if curr.next.data == key: curr.next = curr.next.next # otherwise, move to the next node else: curr = curr.next return head if __name__ == '__main__': nums = [4, 2, 4, 4, 5, 4, 7, 8] key = 4 # construct a linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> None] head = None for i in reversed(nums): head = Node(i, head) head = removeKey(head, key) printList(head) |
Output:
2 -> 5 -> 7 -> 8 -> None
2. Recursive Solution
The aforementioned iterative variant can easily be changed into a recursive one. The idea is to reach the end of the linked list by recursion and determine whether the current node’s value matches the supplied key. We return the next node if the current node’s data matches the supplied key; otherwise, we return the current node. Each node is connected to the remaining list as the recursion unfolds.
Following is the simple recursive implementation in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
#include <iostream> #include <vector> using namespace std; // A Linked List Node struct Node { int data; Node* next; }; // Utility function to push a new node with the given data onto the list's front void push(Node*& headRef, int data) { Node* node = new Node; node->data = data; node->next = headRef; headRef = node; } // Utility function to print a linked list void printList(Node* head) { Node* curr = head; while (curr) { cout << curr->data << " —> "; curr = curr->next; } cout << "nullptr"; } // Function to remove all nodes from a linked list that matches a given key Node* removeKey(Node* head, int key) { // base case if (!head) { return head; } // recur for the remaining list and link it to the current node head->next = removeKey(head->next, key); // return the next node if the current node's data matches the given key, // otherwise, return the current node return (head->data == key) ? head->next : head; } int main() { vector<int> nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> nullptr] Node* head = nullptr; for (int i = nums.size() - 1; i >= 0; i--) { push(head, nums[i]); } head = removeKey(head, key); printList(head); return 0; } |
Output:
2 -> 5 -> 7 -> 8 -> nullptr
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
// A Linked List Node class Node { int data; Node next; Node(int data, Node next) { this.data = data; this.next = next; } Node() {} } class Main { // Helper function to print a given linked list public static void printList(Node head) { Node ptr = head; while (ptr != null) { System.out.print(ptr.data + " —> "); ptr = ptr.next; } System.out.println("null"); } // Function to remove all nodes from a linked list that matches a given key public static Node removeKey(Node head, int key) { // base case if (head == null) { return null; } // recur for the remaining list and link it to the current node head.next = removeKey(head.next, key); // return the next node if the current node's data matches the given key, // otherwise, return the current node return (head.data == key) ? head.next : head; } public static void main(String[] args) { int[] nums = { 4, 2, 4, 4, 5, 4, 7, 8 }; int key = 4; // construct linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> null] Node head = null; for (int i = nums.length - 1; i >= 0; i--) { head = new Node(nums[i], head); } head = removeKey(head, key); printList(head); } } |
Output:
2 -> 5 -> 7 -> 8 -> null
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# A Linked List Node class Node: def __init__(self, data=None, next=None): self.data = data self.next = next # Helper function to print a given linked list def printList(head): ptr = head while ptr: print(ptr.data, end=' —> ') ptr = ptr.next print('None') # Function to remove all nodes from a linked list that matches a given key def removeKey(head, key): # base case: empty list if head is None: return None # recur for the remaining list and link it to the current node head.next = removeKey(head.next, key) # return the next node if the current node's data matches the given key, # otherwise, return the current node return head.next if head.data == key else head if __name__ == '__main__': nums = [4, 2, 4, 4, 5, 4, 7, 8] key = 4 # construct a linked list [4 -> 2 -> 4 -> 4 -> 5 -> 4 -> 7 -> 8 -> None] head = None for i in reversed(nums): head = Node(i, head) head = removeKey(head, key) printList(head) |
Output:
2 -> 5 -> 7 -> 8 -> None
The time complexity of the above solution is O(n), where n is the length of the linked list. The recursive version also need stack space proportional to the length of the list and are not advised for use in production.
Delete every `N` nodes in a linked list after skipping `M` nodes
Rearrange a linked list by separating odd nodes from even ones
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)