USACO February 2006 Problem 'cell' Analysis

by Bruce Merry

This is really a brute force problem, which requires some pruning techniques to avoid taking forever. There are lots of ways such a search can be pruned. Here are two:

Bruce Merry's solution:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <limits.h>
#include <assert.h>

#define true 1
#define false 0
#define bool unsigned char

#define MAXN 26
#define MAXL 26
#define MAXD 50000 /* 5000 */
#define MAXW 10
#define MAXW0 (MAXL + 5)
#define MAXT 25000 /* was 2500000 */

const char *infile = "cell.in";
const char *outfile = "cell.out";

int N, L, D;
char Lchar;

char dict[MAXD][MAXW0];
char syms[MAXD][MAXW0];
int len[MAXD];
int last[MAXL];
int prev[MAXD][MAXL];

int total;
char first[MAXN + 1];
int best;
char bestfirst[MAXN + 1];

int trie[MAXT][MAXN + 1];
int T;

void readin() {
    FILE *f;
    int i, j, inx;

    f = fopen(infile, "r");
    assert(f);
    fscanf(f, "%d %d", &N, &L);
    assert(1 <= N && N <= MAXN);
    assert(1 <= L && L <= MAXL);

    memset(last, -1, sizeof(last));
    memset(prev, -1, sizeof(prev));

    fscanf(f, "%d", &D);
    assert(1 <= D && D <= MAXD);
    for (i = 0; i < D; i++) {
        fscanf(f, "%s", dict[i]);
        len[i] = strlen(dict[i]);
        for (j = 0; j < len[i]; j++) {
            inx = dict[i][j] - 'A';
            if (last[inx] != i) {
                prev[i][inx] = last[inx];
                last[inx] = i;
            }
        }
    }
    fclose(f);
}

int insert(char *word, int len) {
    int cur, nxt;
    int i;

    cur = 0;
    for (i = 0; i < len; i++) {
        nxt = trie[cur][(int) word[i]];
        if (nxt == -1) break;
        cur = nxt;
    }
    for (; i < len; i++) {
        trie[cur][(int) word[i]] = T;
        memset(trie[T], -1, sizeof(trie[T]));
        trie[T][N] = 0;
        cur = T;
        T++;
        assert(T <= MAXT);
    }
    switch (trie[cur][N]++) {
        case 0: return 1;
        case 1: return -1;
        default: return 0;
    }
}

int erase(char *word, int len) {
    int cur;
    int i;

    cur = 0;
    for (i = 0; i < len; i++) {
        cur = trie[cur][(int) word[i]];
        assert(cur != -1);
    }
    assert(trie[cur][N] > 0);
    switch (--trie[cur][N]) {
        case 0: return 1;
        case 1: return -1;
        default: return 0;
    }
}

void initialise() {
    int i, j;
    int map[256];
    char ch;

    for (i = 0; i < N; i++)
        for (ch = first[i]; ch < first[i + 1]; ch++)
            map[(int) ch] = i;

    T = 1;
    memset(trie[0], -1, sizeof(trie[0]));
    trie[0][N] = 0;
    total = 0;

    for (i = 0; i < D; i++) {
        for (j = 0; j < len[i]; j++)
            syms[i][j] = map[(int) dict[i][j]];
        total += insert(syms[i], len[i]);
    }
}

void adjust(char ch, char value) {
    int w, i;

    for (w = last[(int) ch - 'A']; w != -1; w = prev[w][(int) ch - 'A']) {
        total -= erase(syms[w], len[w]);
        for (i = 0; i < len[w]; i++)
            if (dict[w][i] == ch) syms[w][i] = value;
        total += insert(syms[w], len[w]);
    }
}

void recurse(int level) {
    if (level == N - 1) {
        first[level] = first[level - 1];
        initialise();
        for (first[level] = first[level - 1] + 1; first[level] < 'A' + L; first[level]++) {
            adjust(first[level] - 1, level - 1);
            if (total >= best) {
                best = total;
                memcpy(bestfirst, first, sizeof(bestfirst));
            }
        }
    }
    else
        for (first[level] = first[level - 1] + 1; first[level] < 'A' + L; first[level]++)
            recurse(level + 1);
}

void solve() {
    first[0] = 'A';
    first[N] = 'A' + L;
    if (N == 1) {
        initialise();
        best = total;
        bestfirst[0] = 'A';
        bestfirst[1] = 'A' + L;
    } else {
        best = -1;
        recurse(1);
    }
}

void writeout() {
    FILE *f;
    int i;
    char ch;

    f = fopen(outfile, "w");
    assert(f);
    fprintf(f, "%d\n", best);
    for (i = 0; i < N; i++)
    {
        for (ch = bestfirst[i]; ch < bestfirst[i + 1]; ch++)
                fprintf(f, "%c", ch);
        fprintf(f, "\n");
    }
    fclose(f);
}

int main(int argc, char **argv) {
    readin();
    solve();
    writeout();
    return 0;
}