【算法笔记】并查集

Author Avatar
source. 1月 12, 2019
  • 在其它设备中阅读本文章

并查集是一种树型的数据结构,用于处理一些不交集(Disjoint Sets)的合并查询问题。

基本思想

用集合中的某个元素来代表这个集合,该元素称为集合的代表元,也称为根节点,用 par[x] 存储元素 x 的代表元。par 是 parent 的前三个字母。

路径压缩

为了加快查找速度,在查找 x 的根节点时将其到根节点路径上的所有点的根节点都变成同一个,即对于路径上的所有 x,把它们的根节点 par[x] 赋值成相同的。这样,经过优化后,每次操作平均时间复杂度就变成了 $O(\alpha(n))$ ,即阿克曼(Ackerman)函数的反函数复杂度,比 $O(\log n)$ 还小,实际应用中可粗略看成常数。

代码

root()

root() 函数用于:查找元素 x 的根节点(即代表元) par[x],同时执行路径压缩。

很多地方写成 find(),但是为了与 C++ 中 find() 函数区分,写成

int root(x)
{
    if(x == par[x]) return x;
    return par[x] = root(par[x]);    //递归查找根节点,并压缩路径
}

也可写成一行代码:

int root(x)
{
    return x==par[x]?x:par[x]=root(par[x]);
}

unite()

unite() 函数用于:合并两个集合。

void unite(int x, int y)
{
    int fx = root(x);
    int fy = root(y);
    if(fx != fy) par[fy] = fx;
}

init()

init() 函数用于初始化并查集,将每个元素的根节点设为它本身,即表示当前一个元素为一个集合,没有和其他节点相连。

void init(int n)
{
    for(int i = 1; i <= n; i++) par[i] = i;
}

same()

same() 函数用于判断两个点是否在同个集合内。代码写起来也很简单,直接判断他们的根节点是否相同。

bool same(int x, int y)
{
    return root(x) == root(y);
}

整合模板

void init(int n)
{
    for(int i = 1; i <= n; i++) par[i] = i;
}

int root(x)
{
    return x==par[x]?x:par[x]=root(par[x]);
}

void unite(int x, int y)
{
    int fx = root(x);
    int fy = root(y);
    if(fx != fy) par[fy] = fx;
}

bool same(int x, int y)
{
    return root(x) == root(y);
}

例题:畅通工程

HDOJ 1232 - 畅通工程

就是求有几个连接分块,若有 n 个相互不连接的分块,那么只需要再修 n-1 条路,就能使得所有节点都连通。使用并查集,可以用 par[i] == i 这个条件来看有几个相互不连接的分块。

#include<iostream>
using namespace std;
const int MAXN = 1005;
int par[MAXN];

int root(int x)     //查找根节点,同时实现了路径压缩
{
    return x==par[x]?x:par[x]=root(par[x]);
}

void unite(int x, int y)
{
    int root_x = root(x);
    int root_y = root(y);
    if(root_x != root_y)
    {
        par[root_y] = root_x;
    }
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n, m;   //点数、边数
    while(cin >> n && n)
    {
        for(int i = 1; i <= n; i++) par[i] = i;
        cin >> m;
        int ans = -1, a, b;
        for(int i = 0; i < m; i++)
        {
            cin >> a >> b;
            unite(a, b);
        }
        for(int i = 1; i <= n; i++) if(i == par[i]) ans++;
        cout << ans << '\n';
    }
    return 0;
}

例题:畅通工程(MST)

HDOJ 1863 - 畅通工程

Kruskal 算法求最小生成树(MST)。

贪心策略:权值最小且加入后不构成环的边优先选择。

因为要对所有边进行排序,所以适合稀疏图

#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
int par[105];

struct Edge
{
    int u, v, w;
    bool operator < (const Edge& c) const {return w < c.w;}
};

int root(int x)     //查找根节点,同时实现了路径压缩
{
    return x==par[x]?x:par[x]=root(par[x]);
}

void unite(int x, int y)
{
    int root_x = root(x);
    int root_y = root(y);
    if(root_x != root_y) par[root_y] = root_x;
}

int Kruskal(vector<Edge> edge, int n, int m)
{
    sort(edge.begin(), edge.end());             //所有边按权值升序排序
    for(int i = 1; i <= m; i++) par[i] = i;     //初始化并查集,每个点孤立
    int ans = 0, rest = m-1;                    //m个点,MST应有m-1条边
    for(int i = 0; i < n; i++)
    {
        if(root(edge[i].u) != root(edge[i].v))  //若不构成环,则选用这条边
        {
            unite(edge[i].u, edge[i].v);
            ans += edge[i].w;
            rest--;
            if(!rest) return ans;               //凑够m-1条边,直接返回
        }
    }
    return -1;  //将所有边遍历完还没返回,说明原来就非连通图,找不到MST
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n, m, ans;
    while(cin >> n >> m)
    {
        if(!n) break;
        vector<Edge> edge(n);
        for(int i = 0; i < n; i++) cin >> edge[i].u >> edge[i].v >> edge[i].w;
        ans = Kruskal(edge, n, m);
        if(ans == -1) cout << "?\n";
        else cout << ans << '\n';
    }
    return 0;
}

例题:The Suspects

POJ 1611 - The Suspects

#include<iostream>
#include<stdio.h>
using namespace std;
typedef long long ll;
const int maxn = 3e4 + 5;

int n, m, k, par[maxn], fst, t, ans;

void init(int x)
{
    for(int i = 0; i <= x; i++) par[i] = i;
}

int find(int x)
{
    if(x == par[x]) return x;
    return par[x] = find(par[x]);
}

int unite(int x, int y)
{
    int fx = find(x);
    int fy = find(y);
    if(fx != fy) par[fx] = fy;
}

int main()
{
    while(~scanf("%d %d", &n, &m) && (n||m))
    {
        init(n);
        while(m--)
        {
            scanf("%d %d", &k, &fst);
            for(int i = 1; i < k; i++)
            {
                scanf("%d", &t);
                unite(t, fst);
            }
        }
        ans = 1;
        t = find(0);
        for(int i = 1; i <= n; i++)
        {
            if(t == find(i)) ans++;
        }
        printf("%d\n", ans);
    }
    return 0;
}

例题:Ubiquitous Religions

POJ 2524 - Ubiquitous Religions

就是看有几个分块。

#include<iostream>
#include<stdio.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 5;

int n, m, k, par[maxn], a, b, ans;

void init(int x)
{
    for(int i = 0; i <= x; i++) par[i] = i;
}

int find(int x)
{
    if(x == par[x]) return x;
    return par[x] = find(par[x]);
}

int unite(int x, int y)
{
    int fx = find(x);
    int fy = find(y);
    if(fx != fy) par[fx] = fy;
}

int main()
{
    int n, m, cas = 0;
    while(~scanf("%d %d", &n, &m), (n||m))
    {
        cas++;
        init(n);
        ans = 0;
        while(m--)
        {
            scanf("%d %d", &a, &b);
            unite(a, b);
        }
        for(int i = 1; i <= n; i++)
        {
            if(i == par[i]) ans++;
        }
        printf("Case %d: %d\n", cas, ans);
    }
    return 0;
}

例题:食物链

POJ 1182 - 食物链

《挑战程序设计竞赛》上的思路很巧妙:

  • 3*n 大小的数组
  • 不管 i 到底是哪种动物,都直接记录三种可能

具体地:

对于每只动物 i 创建3个元素 i-A, i-B, i-C,并用这 3*n 个元素建立并查集。这个并查集维护如下信息:

  • i-x 表示 i 属于种类 x
  • 并查集里的每一个组表示组内所有元素代表的情况都同时发生或不发生

因此,对于每一条信息,只需要按照下面进行操作就可以了:

  1. 若 x 和 y 属于同一种类,则合并 x-Ay-Ax-By-Bx-Cy-C
  2. 若 x 吃 y,则合并 x-Ay-Bx-By-Cx-Cy-A

对于每次操作,要首先判断是否为假话。

特别精妙的事是:对于每次判断是否为假的时候,理应判断 2*3 种情况,但由于每次 3 种情况都执行了 unite() 操作,故只需要考虑两种情况:

  • 对于 D = 1,也就是断言为同类,则要判断是否存在捕食或被捕食关系
  • 对于 D = 2,也就是断言为捕食,则要判断是否存在同类或被捕食关系
#include<iostream>
#include<stdio.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 5;

int par[maxn*3];

void init(int x)
{
    for(int i = 0; i <= x; i++) par[i] = i;
}

int find(int x)
{
    return x==par[x]?x:par[x]=find(par[x]);
}

int unite(int x, int y)
{
    int fx = find(x);
    int fy = find(y);
    if(fx != fy) par[fx] = fy;
}

bool same(int x, int y)
{
    /*
    捕食关系:
    same(x, y+n)|same(x+n, y+2*n)|same(x+2*n, y)成立代表x吃y
    反捕食关系:
    same(x, y+2*n)|same(x+n, y)|same(x+2*n, y+n)成立代表y吃x
    */
    return find(x) == find(y);
}

int main()
{
    int n, k, d, x, y, ans = 0;
    scanf("%d %d", &n, &k);
    init(3*n);
    while(k--)
    {
        scanf("%d %d %d", &d, &x, &y);
        if(x > n || y > n)
        {
            ans++;
            continue;
        }
        if(d == 1)
        {
            //不能为捕食关系或反捕食关系
            if(same(x, y+n) || same(x, y+2*n)) ans++;
            else
            {
                unite(x, y);
                unite(x+n, y+n);
                unite(x+2*n, y+2*n);
            }
        }
        else
        {
            //不能为同类关系或反捕食关系
            if(same(x, y) || same(x, y+2*n)) ans++;
            else
            {
                unite(x, y+n);
                unite(x+n, y+2*n);
                unite(x+2*n, y);
            }
        }
    }
    printf("%d\n", ans);
    return 0;
}

例题:Agri-Net

POJ 1258 -Agri-Net

MST 问题,使用并查集实现 Kruskal 算法。

注意跳出的条件,若有 n 个顶点,则树的边数为 n-1

#include<iostream>
#include<stdio.h>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn = 100 + 5;

int par[maxn*maxn], t, ans;

struct edge
{
    int u, v, cost;
    edge(int uu, int vv, int cc):u(uu),v(vv),cost(cc){}
    edge(){}
    bool operator < (const edge& c) const
    {
        return cost < c.cost;
    }
}e[maxn*maxn];

void init(int x)
{
    for(int i = 0; i <= x; i++) par[i] = i;
}

int find(int x)
{
    if(x == par[x]) return x;
    return par[x] = find(par[x]);
}

int unite(int x, int y)
{
    int fx = find(x);
    int fy = find(y);
    if(fx != fy) par[fx] = fy;
}

bool same(int x, int y)
{
    return find(x) == find(y);
}

int main()
{
    int n, tc;
    while(~scanf("%d", &n))
    {
        int cnt = 0;
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < n; j++)
            {
                scanf("%d", &t);
                if(i < j) e[cnt++] = edge(i, j, t);
            }
        }
        sort(e, e+cnt);
        init(cnt);
        ans = 0;
        tc = 0;
        for(int i = 0; i < cnt; i++)
        {
            edge ce = e[i];
            if(!same(ce.u, ce.v))
            {
                unite(ce.u, ce.v);
                ans += ce.cost;
                tc++;
                if(tc == n-1) break;
            }
        }
        printf("%d\n", ans);
    }
    return 0;
}

参看

并查集详解(图文解说) - CSDN

POJ并查集的题目汇总 - CSDN

本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明来自 ComyDream
本文链接:http://comydream.github.io/2019/01/12/algorithm-union-find/