Skip to content

USACO 2016 February Contest Gold Division - Fenced In#

Problem link: here

Solution Author: Stefan Dascalescu

Problem Solution#

For this problem, we're given a grid and want to connect all the interiors to each other. If we think of each interior region as a point and add edges between points representing the length of the fence that separates them, then this problem becomes one of finding a minimum spanning tree on a graph. Since our graph has \(O(n∗m)\) vertices and each vertex has at most 4 edges, it also has \(O(n∗m)\) edges, meaning that we can simply run our favorite minimum spanning tree algorithm to solve the problem.

Source code#

The source code in C++ can be seen below.

#include <algorithm>
#include <fstream>
#include <iostream>
#include <vector>
using namespace std;

struct Item {
    int ID;
    int value;
    bool operator<(const Item &other) const { return value < other.value; }
};

struct DSet {
    vector<int> parent, height;
    DSet(int n) {
        parent.resize(n);
        height.resize(n);
        for (int i = 0; i < n; i++) {
            parent[i] = i;
            height[i] = 0;
        }
    }
    int find(int x) {
        if (parent[x] == x)
            return x;
        parent[x] = find(parent[x]);
        return parent[x];
    }
    bool unite(int a, int b) {
        int ra = find(a);
        int rb = find(b);
        if (ra == rb)
            return false;
        if (height[ra] > height[rb]) {
            parent[rb] = ra;
        } else if (height[ra] == height[rb]) {
            parent[rb] = ra;
            height[ra]++;
        } else {
            parent[ra] = rb;
        }
        return true;
    }
};

void solve() {
    ifstream fin("fencedin.in");
    int maxN, maxM, n, m;
    fin >> maxN >> maxM >> n >> m;

    vector<int> nList(n + 2), mList(m + 2);
    for (int i = 0; i < n; i++)
        fin >> nList[i];
    nList[n] = 0;
    nList[n + 1] = maxN;
    sort(nList.begin(), nList.end());

    for (int i = 0; i < m; i++)
        fin >> mList[i];
    mList[m] = 0;
    mList[m + 1] = maxM;
    sort(mList.begin(), mList.end());
    fin.close();

    vector<Item> allRC(n + m + 2);
    for (int i = 0; i <= n; i++)
        allRC[i] = {i, nList[i + 1] - nList[i]};
    for (int i = 0; i <= m; i++)
        allRC[n + 1 + i] = {n + 1 + i, mList[i + 1] - mList[i]};
    sort(allRC.begin(), allRC.end());

    DSet dj(n * (m + 1) + m * (n + 1));
    long long res = 0;
    int added = 0, index = 0;

    while (added < (n + 1) * (m + 1) - 1) {
        Item next = allRC[index];
        if (next.ID <= n) {
            int nVal = next.ID;
            for (int i = 0; i < m; i++) {
                if (dj.unite(nVal * (m + 1) + i, nVal * (m + 1) + i + 1)) {
                    res += next.value;
                    added++;
                }
            }
        } else {
            int mVal = next.ID - n - 1;
            for (int i = 0; i < n; i++) {
                if (dj.unite(i * (m + 1) + mVal, (i + 1) * (m + 1) + mVal)) {
                    res += next.value;
                    added++;
                }
            }
        }
        index++;
    }

    ofstream fout("fencedin.out");
    fout << res << "\n";
    fout.close();
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int t = 1;
    while (t--) {
        solve();
    }
    return 0;
}