1 条题解

  • 0
    @ 2024-9-8 20:00:02

    并查集 代码:

    #include <bits/stdc++.h>
    using namespace std;
    const int maxn = 2e5 + 5;
    int n;
    
    struct Ufs {
    	int f[maxn], sz[maxn];
    	void init() {
    		for (int i = 1; i <= n; i++)
    			f[i] = i, sz[i] = 1;
    	}
    	int find(int x) {
    		return x == f[x] ? x : f[x] = find(f[x]);
    	}
    	void merge(int x, int y) {
    		int a = find(x), b = find(y);
    		if (a != b) {
    			f[b] = a;
    			sz[a] += sz[b];
    		}
    	}
    } tr[2];
    
    int main() {
    	scanf("%d", &n);
    	tr[0].init(), tr[1].init();
    	for (int i = 1; i < n; i++) {
    		int u, v, c;
    		scanf("%d%d%d", &u, &v, &c);
    		tr[c].merge(u, v);
    	}
    	long long ans = 0;
    	for (int u = 1; u <= n; u++) {
    		for (int i = 0; i < 2; i++)
    			if (tr[i].find(u) == u)
    				ans += (long long) tr[i].sz[u] * (tr[i].sz[u] - 1);
    		long long cnt0 = tr[0].sz[ tr[0].find(u) ];
    		long long cnt1 = tr[1].sz[ tr[1].find(u) ];
    		ans += (cnt0 - 1) * (cnt1 - 1);
    	}
    	printf("%lld\n", ans);
    	return 0;
    }
    
    
    • 1

    信息

    ID
    29
    时间
    1000ms
    内存
    256MiB
    难度
    10
    标签
    (无)
    递交数
    4
    已通过
    2
    上传者