Find K Pairs with Smallest Sums

May 9, 2026
Medium link_2

Problem Statement

You are given two integer arrays nums1 and nums2 sorted in non-decreasing order and an integer k.

Define a pair (u, v) which consists of one element from the first array and one element from the second array, return the k pairs (u1, v1), (u2, v2), …, (uk, vk) with the smallest sums.

Examples

  1. For arrays [1, 7, 11] and [2, 4, 6], the first 3 smallest sum pairs will be [[1, 2],[1, 4],[1, 6]].
  2. For arrays [1, 1, 2] and [1, 2, 3], the first 2 smallest sum pairs will be [[1, 1],[1, 1]].

Solution

Points of Interest

  1. The first and smallest sum pair will be (0, 0).
  2. If we are found m smallest pairs out of k pairs, with (u, v) index pair being the mthm^{th} element, then (m+1)th(m + 1)^{th} index pair will be either (u + 1, v) or (u, v + 1).

Algo

We will be using a modified version of Min Heap approach, called Priority Heap.

  1. Initialize an empty results array first.
  2. Initialize a lookup set to keep track of index pairs seen so far, ensuring that we don’t repeat any index. Add (0, 0), representing that we have by default visited this index pair.
  3. Next, initialize our min priority heap whose elements are (sum of pair, index of element 1, index of element 2).
  4. Run a loop till you have k elements in results array.
    1. Pop an element from priority heap, giving you sum and (i, j) index tuple.
    2. Add the element pair derived from (i, j) index tuple.
    3. Consider both (i + 1, j) and (i, j + 1). Add it to priority heap and visited.
  5. results is the desired output.

Code

import heapq

def kSmallestPairs(nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
    results = []
    visited = {(0, 0)}
    heap = [(nums1[0] + nums2[0], 0, 0)]
    while len(results) < k:
        _, i, j = heapq.heappop(heap)
        results.append([nums1[i], nums2[j]])

        if i + 1 < len(nums1):
            tup = (i + 1, j)
            if tup not in visited:
                heapq.heappush(heap, (nums1[i + 1] + nums2[j], i + 1, j))
                visited.add(tup)
        
        if j + 1 < len(nums2):
            tup = (i, j + 1)
            if tup not in visited:
                heapq.heappush(heap, (nums1[i] + nums2[j + 1], i, j + 1))
                visited.add(tup)
    
    return results

Complexity

Time Complexity=O(klogk)\textbf{Time Complexity} = O(k\log{k})

Space Complexity=O(k)\textbf{Space Complexity} = O(k)