USACO CHN07 Problem 'baabo' Analysis

by Richard Peng

Since the pairings cannot intersect, this problem can be solved using DP with the state being the last 'pair' of accordionist and banjoist (denoted A and B from here on). This gives O(n^2) states and O(n^2) state transitions for each state as there are O(n^2) possible previous pairs, the state transition function is:

Best(i,j)=max{i1<i,j1<j|Best(i1,j1)-SQR(sum(A_i1...A_(i-1)))-SQR(sum(B_j1...B_(j-1)))}+A_i*B_j

where SQR is the square of a number.

This gives an O(n^4) solution, which we try to optimize all the way down to O(n^2). The first observation we make is we cannot skip cows in both A and B when we go from the current pair to the previous one. This because pairing any two from those skipped ones will increase the total profit. So this gives an O(n^3) solution, sufficient to get 60% of the points.

Now we try to get down to O(n^2) by optimizing the state transition using convexity. We first consider the case where no cows in A are skipped, or when i1=i-1. So we need to find:

max{j1<j|Best(i-1,j)-SQR(sum(B_j1...B_(j-1))}+A_i*B_j

If we let SB denote the partial sum of the B array, then our transition formula becomes:

max{j1<j|Best(i-1,j)-SQR(SB_(j-1)-SB(j1-1))}+A_i*B_j =max{j1<j|Best(i-1,j)-SQR(SB_(j1-1))+2*SB_(j1-1)*SB_(j-1)}+A_i*B_j+SQR(SB_(j-1))

We can ignore the last two terms since they're constant depending only on i and j. If we look at the first term, we're looking for the max of a linear combination of two values depending on j1 where the ratio is dependent on j.

Geometrically, this is equivalent to taking the y-intercept of a line with a given slope within a set of points. With some proof, it can be shown that only points on the convex hull matters. So we can insert values of (2*SB_j,Best(i-1,j)-SQR(SB_(j-1)))) into the convex hull (note SB_j is always increasing, so we can do this using a stack). And when we query, we can show the point where this value is minimized/maximized is always monotone in respect to x. This gives an algorithm that does this in O(1) amortized cost.

The case where nothing in B is skipped can be dealt with similarly. The only catch is that should we process the states in incremental i and then incremental j, we'll need to keep track of N+1 convex hulls, one for the previous 'row' and one for each of the columns.

This gives the desired O(N^2) algorithm. Code (by Richard Peng):

#include <cstdio>
#include <cstring>

#define MAXN 1200

int n;
double a[MAXN],b[MAXN],bes[MAXN][MAXN],ans;
double s1[MAXN],s2[MAXN];

double sqr(double x){return x*x;}

double hull[MAXN][MAXN][2];
int hullt[MAXN],p[MAXN];

void initialize(int id){
	hullt[id]=p[id]=0;
}

double crossp(double a[2],double b[2],double c[2]){
	return (b[0]-a[0])*(c[1]-a[1])-(c[0]-a[0])*(b[1]-a[1]);
}

void hulladd(int id,double x,double y){
	double point[2];
	point[0]=x;
	point[1]=y;
	if((hullt[id]>0)&&(x==hull[id][hullt[id]-1][0])){
		if(y>hull[id][hullt[id]-1][1]) hullt[id]--;
		else return;
	}
	while((hullt[id]>1)&&(crossp(point,hull[id][hullt[id]-1],hull[id][hullt[id]-2])<=0))
		hullt[id]--;
	hull[id][hullt[id]][0]=x;
	hull[id][hullt[id]][1]=y;
	p[id]<?=hullt[id];
	hullt[id]++;
}

double query(int id,double a){
	double tem,tem1;
	tem=hull[id][p[id]][0]*a+hull[id][p[id]][1];
	while((p[id]+1<hullt[id])&&((tem1=(hull[id][p[id]+1][0]*a+hull[id][p[id]+1][1]))>tem)){
		tem=tem1;
		p[id]++;
	}
	return tem;
}

int main(){
	int i,j,i1,j1;
freopen("mkpairs.in","r",stdin);
freopen("mkpairs.out","w",stdout);
	scanf("%d",&n);
	for(i=0;i<n;i++)	scanf("%lf",&a[i]);
	for(s1[0]=a[0],i=1;i<n;i++) s1[i]=s1[i-1]+a[i];
	for(i=0;i<n;i++)	scanf("%lf",&b[i]);
	for(s2[0]=b[0],i=1;i<n;i++) s2[i]=s2[i-1]+b[i];

	memset(bes,0,sizeof(bes));

	for(i=1;i<n;i++)
		initialize(i);
	for(ans=i=0;i<n;i++){
		initialize(0);
		for(j=0;j<n;j++){
			bes[i][j]=-((i==0)?0:sqr(s1[i-1]))-((j==0)?0:sqr(s2[j-1]));

			if(i>0){
				if(j>0){
					bes[i][j]>?=query(0,2*s2[j-1])-sqr(s2[j-1]);
					bes[i][j]>?=query(j,2*s1[i-1])-sqr(s1[i-1]);
				}
				hulladd(0,s2[j],bes[i-1][j]-sqr(s2[j]));
			}
		

			bes[i][j]+=a[i]*b[j];
			ans>?=bes[i][j]-sqr(s1[n-1]-s1[i])-sqr(s2[n-1]-s2[j]);
		}
		for(j=0;j+1<n;j++){
			hulladd(j+1,s1[i],bes[i][j]-sqr(s1[i]));
		}
	}
	printf("%0.0lf\n",ans);
	return 0;
}