An interesting problem that I encountered recently: Given an array of numbers, find the 4 elements that sum up to a particular \(S\).

Easy enough, the way to solve it would be:

#!/usr/bin/env python

a = [0,8,12,13,16,23,26,30,37,39,55,59,61,62,63,70,89]
result = [(h,i,j,k) for h,ah in enumerate(a)
                    for i,ai in enumerate(a)
                    for j,aj in enumerate(a)
                    for k,ak in enumerate(a)
                    if ah+ai+aj+ak == 8+23+59+70 and h<i<j<k]
print result

Obviously, it works. But this is a \(O(n^4)\) algorithm.

So how to improve this? There could be two approach to solve this problem: One is to attack it directly like the above; another is to devise a dynamic programming algorithm. Due to the time constraint, I did the former. But actually the latter would be the better approach, as explained below.

Given an array of \(N\) elements, find the \(k\) elements that sum up to \(S\). If the array is sorted, and \(k=1\), the problem reduced to a binary search algorithm.

If \(k=2\), still assuming the array is sorted, trying to mimic the binary search, we start with the cursor (marked below) at the middle of the array.

[ ][ ][ ][ ][ ][x][y][ ][ ][ ][ ][ ]

Then, if the sum of the two elements is greater than \(S\), we must move some cursor to the left side; otherwise, to the right side. But since we can enforce the ordering that cursor x must at the left side of cursor \(y\), only cursor \(x\) can move to the left side and only cursor y can move to the right side. Assume a[x]+a[y] > S, we move the cursor x to the midway as follows:

[ ][ ][x]( )( )( )(y)( )( )( )( )( )

Then we can have a binary search of cursor y on the parentheses elements. If that fails to find a solution, we can, based on the a[x]+a[y] as depicted above, further move x to the midway, just like binary search.

The code is as follows:

#include <stdio.h>
#include <stdlib.h>

/* Sum of the elements in a that is pointed by the cursors */
int sum(int* a, int* cursor, int numCursor)
{
	int i;
	int s = 0;
	for (i = 0; i<numCursor; ++i) {
		s += a[cursor[i]];
	};
	return s;
}

/* Binary search from a[begin] to a[end], inclusive, for target */
int search(int* a, int begin, int end, int target)
{
	int i;
	/* Sanity check goes first */
	if (begin > end) return -255;
	/* Simple case: just one item */
	if (begin == end) return (a[begin] == target)? begin : -1;
	/* Recursive search */
	i = (begin+end)/2;
	if (a[i] == target) {
		return i;
	} else if (a[i] > target) {
		return search(a, begin, i-1, target);
	} else /* if (a[i] < target) */ {
		return search(a, i+1, end, target);
	}
}

/* Recursive binary search from a[begin] to a[end], inclusive, for k items sum up to target */
int cursorSearch(int* a, int begin, int end, int* cursor, int k, int target)
{ 
	int s, ret;
	int myVal, myMin, myMax, i;
	/* Sanity checks */
	if (begin > end) return -255;   /* array must of positive length */
	if (end-begin < k) return -255; /* array length must not shorter than cursor count */
	for (i=1; i<k; ++i) {           /* cursors must be adjacent */
		if (cursor[i] != cursor[i-1] + 1) {
			return -255;
		};
	};
	/* Terminal case: Only one cursor provided, reduced to simple binary search */
	if (k == 1) {
		ret = search(a, begin, end, target);
		if (ret >= begin && ret <= end) {
			cursor[0] = ret;
			return 0; /* Solution found */
		} else {
			return -1; /* Solution not found */
		}
	}
	/* General case: More than one cursor */
	s = sum(a, cursor, k);
	if (s == target) {
		return 0; /* Solution found */
	} else if (s > target) {
		/* Some cursor have to move down, but to keep the in-orderness,
		 * only cursor[0] can be moved */
		myMin = begin;
		myMax = cursor[0];
		/* Loop-based binary search */
		while (myMin <= myMax) {
			myVal = (myMin + myMax)/2;
			/* Fix cursor[0], and search for target minus a[cursor[0]] using the
			 * remaining cursors. The remaining cursors shall be invariant in this
			 * function if ret == 0. */
			ret = cursorSearch(a, myVal+1, end, cursor+1, k-1, target - a[myVal]);
			if (ret == 0) {
				cursor[0] = myVal;
				return 0;
			} else if (s - a[cursor[0]] + a[myVal] > target) {
				myMax = myVal - 1; /* Too large */
			} else /* if (s - a[cursor[0]] + a[myVal] < target) */ {
				myMin = myVal + 1; /* Too small */
			}
		};
		return -1; /* Solution not found */
	} else /* if (s < target) */ {
		/* Some cursor have to move up, but to keep the in-orderness,
		 * only cursor[k-1] can be moved */
		myMin = cursor[k-1];
		myMax = end;
		/* Loop-based binary search */
		while (myMin <= myMax) {
			myVal = (myMin + myMax)/2;
			/* Fix cursor[k-1], and search for target minus a[cursor[k-1]] using the
			 * remaining cursors. The remaining cursors shall be invariant in this
			 * function if ret == 0. */
			ret = cursorSearch(a, begin, myVal-1, cursor, k-1, target - a[myVal]);
			if (ret == 0) {
				cursor[k-1] = myVal;
				return 0;
			} else if (s - a[cursor[k-1]] + a[myVal] > target) {
				myMax = myVal - 1; /* Too large */
			} else /* if (s - a[cursor[k-1]] + a[myVal] < target) */ {
				myMin = myVal + 1; /* Too small */
			}
		};
		return -1; /* Solution not found */
	};
}

If the code works, it is a \(O((log N)^k)\) algorithm to find a solution. But I found it failed some of the unit tests. The problem is, if the solution is

[ ][ ][ ][ ][x][ ][ ][y][ ][ ][ ][ ]

but the cursors are at the position of

 [ ][ ][ ][ ][ ][x][y][ ][ ][ ][ ][ ]

The code above cannot find it. This means I cannot explore all the possible combinations to find a solution.

The pitfall is due to the fact that, while I am assuming a binary search, it is not indeed! When I fix cursor x and look for the position of cursor y to find a solution, the different possibilities of y can lead to a solution larger or smaller than \(S\). If different y always yield to a value larger than \(S\), we know we must lower x. But if it is sometimes larger and sometimes smaller, we have no clue on where to move x! So the only way to explore all the possibility is to do binary search only on one cursor, but for all the other cursors, we have to exhaust the list. This makes the algorithm \(O(N^{k-1}\log N)\).

In fact, binary search fails because binary search expects something is found exactly. Indeed, we can carry out the binary search in a different way to have the solutions found.

Let’s begin with \(k=2\) again, but with all the cursors at the front of the array:

[x][y][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ]

This combination is the smallest sum of two distinct elements in the array. If this smallest sum is still greater than \(S\), there would be no solution. Taking this fact, we can say the combination

[x][ ][ ][ ][ ][ ][ ][ ][y][ ][ ][ ]

If this combination is greater than \(S\), we must move y to a lower position or otherwise we cannot lower the sum of cursors to \(S\). But if this combination is smaller than \(S\), we can either raise y or raise x.

Therefore, we have another algorithm:

  • move cursor y, from its initial position of the front of the array, to the middle of the array in a binary search fashion, to look for the position that it yields a sum that is just smaller than \(S\)
  • Then, from that position, move y backward until its original location
  • In each step of the move, recursively find the solution of the subproblem that involves a partial array up to the position of y, and the cursors up to y, for the sum of \(S\) minus the element that y pointing to

Well said. The code for this is as follows:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Sum of the elements in a that is pointed by the cursors */
int sum(int* a, int* cursor, int numCursor)
{
	int i;
	int s = 0;
	for (i = 0; i<numCursor; ++i) {
		s += a[cursor[i]];
	};
	return s;
}

/* Binary search from a[begin] to a[end], inclusive, for target */
int search(int* a, int begin, int end, int target)
{
	int i;
	/* Sanity check goes first */
	if (begin > end) return -255;
	/* Simple case: just one item */
	if (begin == end) return (a[begin] == target)? begin : -1;
	/* Recursive search */
	i = (begin+end)/2;
	if (a[i] == target) {
		return i;
	} else if (a[i] > target) {
		return search(a, begin, i-1, target);
	} else /* if (a[i] < target) */ {
		return search(a, i+1, end, target);
	}
}

/* Recursive binary search from a[0] to a[n], inclusive, for k items sum up to target */
int recurSearch(int* a, int n, int* cursor, int k, int target)
{ 
	int s, ret;
	int myMin, myMax, i;
	char* bitmap;
	/* Sanity checks */
	if (n < 0) return -255;   /* array must of positive length */
	if (n < k-1) return -255; /* array length must not shorter than cursor count */
	/* Terminal case: Only one cursor provided, reduced to simple binary search */
	if (k == 1) {
		ret = search(a, 0, n, target);
		if (ret >= 0 && ret <= n) {
			cursor[0] = ret;
			return 0; /* Solution found */
		} else {
			return -1; /* Solution not found */
		}
	}
	/* General case: More than one cursor */
	bitmap = (char*)malloc(n+1); /* create bitmap */
	bzero(bitmap, n+1); /* initialize bitmap */
	for (i=0; i<k; ++i) cursor[i]=i; /* initialize cursors */
	s = sum(a, cursor, k);
	if (s == target) {
		ret = 0; /* Too easy: Solution found */
	} else if (s > target) {
		ret = -1; /* Solution cannot be found in this sorted array */
	} else if (s < target && n > k-1) {
		/* Some cursor have to move up */
		s = sum(a, cursor, k-1);
		myMin = k;
		myMax = n;
		/* Loop-based binary search for the marginal position between too big and too small */
		while (myMin <= myMax) {
			i = (myMin + myMax)/2;
			/* The easy edge cases */
			if (a[i] > target - s) { /* too big to be a solution */
				myMax = i-1;
				continue;
			} else if (a[i] == target - s) { /* Solution found by just moving the last cursor */
                                cursor[k-1] = i;
				ret = 0;
				break;
			};
			/* The case to recur */
			bitmap[i] = 1; /* Mark this position in the bitmap for future reference */
			ret = recurSearch(a, i-1, cursor, k-1, target-a[i]);
			if (ret == 0) { /* solution found in recursive search */
                                cursor[k-1] = i;
				break;
			};
			myMin = i + 1; /* Next try have to move up the lowerbound */
		};
		/* Perform linear search if solution is not yet found */
		if (sum(a, cursor, k) != target) {
			_printf("Linear search started on cursor %d\n",k-1);
			for (; i >= k; --i) {
				if (bitmap[i] == 1) continue; /* Skip those already examined */
				ret = recurSearch(a, i-1, cursor, k-1, target-a[i]);
				if (ret == 0) {
                                        cursor[k-1] = i;
                                        break; /* solution found in recursive search */
                                };
			};
		};
	};
	free(bitmap);
	return ret;
}

This code works. The full source that includes unit testing is here.

Now consider its efficiency, assuming the sorted array is free. The binary search on a cursor takes \(O(\log h)\) time, then it takes another \(O(m)\) time for linear search, where \(h\) is the length of the subarray trimmed by its following cursor and \(m\) is the length of the subarray as found by the binary search. Although we did not scan through the whole array for each cursor, we are close to doing that. The average complexity would be \(O(N^{k-1}\log N/k!)\).

Unsorted array

If the array is unsorted, and never to be sorted. Here is a method to find the solution using hash table. Since we are looking for 4 elements that make up \(S\), we can do 2+2 elements, like this

#!/usr/bin/env python

a = [0,8,12,13,16,23,26,30,37,39,55,59,61,62,63,70,89]
S = 8+23+59+70
hash = dict( (ii+jj, (i,j)) for i,ii in enumerate(a) for j,jj in enumerate(a) if i<j )
result = [hash[k]+hash[S-k] for k in hash.keys()
             if k < S/2 and                  # search only for half of the list
                S-k in hash and              # the key condition to make the result
                hash[k][-1] < hash[S-k][0]]  # limit result to only ascending order
print result

The hash is computed in \(O(N^2)\) time and the result is done by scanning the hash once, also \(O(N^2)\) time due to its size. So this is an \(O(N^2)\) algorithm to find the 4 elements in an unsorted array. Trivially, we can find the \(k\) elements that add up to S in \(O(N^{k/2})\) time using this method.