Thursday, 24 September 2015

Maximum sum submatrix

Problem


Given a matrix which contains positive and negative integers. Find the submatrix which has the maximum sum.

Brute force

There are total O(n^2) row range and O(n^2) column range. So total O(n^4) sub matrix. To find the sum of any sub matrix we need to do a O(n^2) operation. So the brute force algorithm will have a complexity of O(n^6).

Better Solution


We will create a same sized matrix to keep the vertical sum of the original matrix. For example verticalSum[i,j]=arr[0,j]+arr[1,j]+...+arr[i,j]. Now we will take the row range and move from left to right to find the maximum sum. There are O(n^2) row range possible. Row range will be a range from rowStart to rowEnd. Both of these variables can vary from 1 to n. So there are O(n^2) rowRange possible.  For any starting row we start with one row and move from left to right. Then we take 2 rows and move from left to right. While doing so, we maintain an array sum which maintains the vertical sum of the selected row range. This sum array and vertical sum matrix will help us find the sum for a new sub matrix in O(1) time.  So the total time complexity will be O(n^3). O(n^2) for row range, and we will have a O(column) operation inside that. So the total complexity will become O(n^3).


Code

public class LargestSumSubMatrix
{
 public static void main(String[] args)
 {
  int[][] arr =
  {
  { 1, -2, -7, 0 },
  { -6, 2, 9, 2 },
  { -4, -2, -1, 4 },
  { -1, -8, 0, -4 } };
  int[] leftRightTopBottom = new int[4];
  int maxsum = findMaximumSumSubMatrix(arr, leftRightTopBottom);
  System.out.println("max sum: " + maxsum);
  System.out.println("indices left right top bottom");
  for (int index : leftRightTopBottom)
   System.out.print(index + ",");
 }

 private static int findMaximumSumSubMatrix(int[][] arr,
   int[] leftTopRightBottom)
 {
  leftTopRightBottom[0] = 0;
  leftTopRightBottom[1] = 0;
  leftTopRightBottom[2] = 0;
  leftTopRightBottom[3] = 0;
  int rows = arr.length;
  int cols = arr[0].length;
  int[] sum = new int[cols];
  int[] pos = new int[cols];
  int localMax;
  int maxSum = arr[0][0];
  int[][] verticalSum = new int[rows][cols];

  for (int iRow = 0; iRow < rows; iRow++)
  {
   for (int jCol = 0; jCol < cols; jCol++)
   {
    if (jCol == 0)
    {
     verticalSum[jCol][iRow] = arr[jCol][iRow];
    } else
    {
     verticalSum[jCol][iRow] = arr[jCol][iRow]
       + verticalSum[jCol - 1][iRow];
    }
   }
  }

  for (int iRow = 0; iRow < rows; iRow++)
  {
   for (int k = iRow; k < rows; k++)
   {
    for (int index = 0; index < cols; index++)
    {
     sum[index] = 0;
     pos[index] = 0;
    }
    localMax = 0;
    int tmp = 0;
    if (iRow > 0)
    {
     tmp = verticalSum[iRow - 1][0];
    }
    sum[0] = verticalSum[k][0] - tmp;
    for (int j = 1; j < cols; j++)
    {
     tmp = 0;
     if (iRow > 0)
     {
      tmp = verticalSum[iRow - 1][j];
     }
     if (sum[j - 1] > 0)
     {
      sum[j] = sum[j - 1] + verticalSum[k][j] - tmp;
      pos[j] = pos[j - 1];
     } else
     {
      sum[j] = verticalSum[k][j] - tmp;
      pos[j] = j;
     }
     if (sum[j] > sum[localMax])
     {
      localMax = j;
     }
    }
    if (sum[localMax] > maxSum)
    {
     maxSum = sum[localMax];
     leftTopRightBottom[0] = pos[localMax];
     leftTopRightBottom[1] = localMax;
     leftTopRightBottom[2] = iRow;
     leftTopRightBottom[3] = k;
    }
   }
  }
  return maxSum;
 }
}

No comments:

Post a Comment