3 条题解

  • 0
    @ 2025-3-20 20:34:33

    2025/3/20

    #include <bits/stdc++.h>
    using namespace std;
    const int maxn = 2e5 + 5;
    
    int n, Q, a[maxn], dfn[maxn], id[maxn], idx, sz[maxn],
    	dep[maxn],	// dep[u]:节点u的深度
    	son[maxn],		// son[u]:节点u的重儿子
    	tp[maxn],		// tp[u]:节点u所在的那条重链的链顶节点
    	fa[maxn];		// fa[u]:节点u的父节点编号 
    	 
    long long tr[maxn<<2];
    vector<int> g[maxn];
    
    void push_up(int u) {
    	tr[u] = tr[u<<1] + tr[u<<1|1];
    }
    void build(int l, int r, int u) {
    	if (l == r) {
    		tr[u] = a[id[l]];
    		return;
    	}
    	int mid = (l + r) / 2;
    	build(l, mid, u<<1);
    	build(mid+1, r, u<<1|1);
    	push_up(u);
    }
    void update(int p, long long x, int l, int r, int u) {
    	if (l == r) {
    		tr[u] = x;
    		return;
    	}
    	int mid = (l + r) / 2;
    	(p <= mid) ? update(p, x, l, mid, u<<1) : update(p, x, mid+1, r, u<<1|1);
    	push_up(u);
    }
    long long query(int L, int R, int l, int r, int u) {
    	if (L <= l && r <= R)
    		return tr[u];
    	long long res = 0;
    	int mid = (l + r) / 2;
    	if (L <= mid) res += query(L, R, l, mid, u<<1);
    	if (R > mid) res += query(L, R, mid+1, r, u<<1|1);
    	return res;
    }
    
    void dfs1(int u, int p, int d) {
    	dep[u] = d;
    	fa[u] = p;
    	sz[u] = 1;
    	for (auto v : g[u]) {
    		if (v != p) {
    			dfs1(v, u, d+1);
    			sz[u] += sz[v];
    			if (sz[v] > sz[son[u]])
    				son[u] = v;
    		}
    	}
    }
    
    void dfs2(int u, int p, int top) {
    	tp[u] = top;
    	dfn[u] = ++idx;
    	id[idx] = u;
    	if (son[u])
    		dfs2(son[u], u, top);
    	for (auto v : g[u])
    		if (v != p && v != son[u])
    			dfs2(v, u, v);
    }
    
    long long cal(int x, int y) {
    	long long res = 0;
    	while (tp[x] != tp[y]) {
    		if (dep[tp[x]] < dep[tp[y]])
    			swap(x, y);
    		res += query(dfn[tp[x]], dfn[x], 1, n, 1);
    		x = fa[tp[x]];
    	}
    	if (dfn[x] > dfn[y])
    		swap(x, y);
    	res += query(dfn[x], dfn[y], 1, n, 1);
    	return res;
    }
    
    int main() {
    	scanf("%d%d", &n, &Q);
    	for (int i = 1; i <= n; i++)
    		scanf("%d", a+i);
    	for (int i = 1; i < n; i++) {
    		int u, v;
    		scanf("%d%d", &u, &v);
    		g[u].push_back(v);
    		g[v].push_back(u);
    	}
    	dfs1(1, -1, 0);
    	dfs2(1, -1, 1);
    	build(1, n, 1);
    	while (Q--) {
    		int op, x, y;
    		scanf("%d%d%d", &op, &x, &y);
    		if (op == 1) {
    			update(dfn[x], y, 1, n, 1);
    		}
    		else {
    			long long ans = cal(x, y);
    			printf("%lld\n", ans);
    		}
    	}
    	return 0;
    }
    
    • 0
      @ 2025-3-19 20:44:54

      2025/3/19

      #include <bits/stdc++.h>
      using namespace std;
      const int maxn = 2e5 + 5;
      
      int n, Q, a[maxn], dfn[maxn], idx,	// dfn[u]表示节点u的dfs序
      		id[maxn],			// id[x]表示dfs序为x的节点编号
      		sz[maxn],			// sz[u]: u所在的子树大小
      		dep[maxn],			// 节点u的深度
      		son[maxn],				// 节点u的重儿子编号
      		tp[maxn],				// 节点u所在的链的链顶节点编号
      		fa[maxn];				// fa[u]:节点u的父节点 
      vector<int> g[maxn];
       
      void dfs1(int u, int p, int d) {
      	dep[u] = d;
      	sz[u] = 1;
      	fa[u] = p;
      	for (auto v : g[u]) {
      		if (v != p) {
      			dfs1(v, u, d+1);
      			sz[u] += sz[v];
      			if (sz[v] > sz[son[u]])
      				son[u] = v;
      		}
      	}
      }
      
      void dfs2(int u, int p, int top) {
      	dfn[u] = ++idx;
      	id[idx] = u;
      	tp[u] = top;
      	if (son[u])
      		dfs2(son[u], u, top);
      	for (auto v : g[u])
      		if (v != p && v != son[u])
      			dfs2(v, u, v);
      }
      
      long long tr[maxn<<2];
      void push_up(int u) {
      	tr[u] = tr[u<<1] + tr[u<<1|1];
      }
      void build(int l, int r, int u) {
      	if (l == r) {
      		tr[u] = a[id[l]];
      		return;
      	}
      	int mid = (l + r) / 2;
      	build(l, mid, u<<1);
      	build(mid+1, r, u<<1|1);
      	push_up(u);
      }
      void update(int p, long long x, int l, int r, int u) {
      	if (l == r) {
      		tr[u] = x;
      		return;
      	}
      	int mid = (l + r) / 2;
      	(p <= mid) ? update(p, x, l, mid, u<<1) : update(p, x, mid+1, r, u<<1|1);
      	push_up(u);
      }
      long long query(int L, int R, int l, int r, int u) {
      	if (L <= l && r <= R)
      		return tr[u];
      	long long res = 0;
      	int mid = (l + r) / 2;
      	if (L <= mid) res += query(L, R, l, mid, u<<1);
      	if (R > mid) res += query(L, R, mid+1, r, u<<1|1);
      	return res;
      }
      
      long long cal(int x, int y) {
      	long long ans = 0;
      	while (tp[x] != tp[y]) {
      		if (dep[tp[x]] < dep[tp[y]])
      			swap(x, y);
      		ans += query(dfn[tp[x]], dfn[x], 1, n, 1);
      		x = fa[tp[x]];
      	}
      	assert(tp[x] == tp[y]);
      	if (dfn[x] > dfn[y])
      		swap(x, y);
      	ans += query(dfn[x], dfn[y], 1, n, 1);
      	return ans;
      }
      
      int main() {
      	scanf("%d%d", &n, &Q);
      	for (int i = 1; i <= n; i++)
      		scanf("%d", a+i);
      	for (int i = 1, u, v; i < n; i++) {
      		scanf("%d%d", &u, &v);
      		g[u].push_back(v);
      		g[v].push_back(u);
      	}
      	dfs1(1, -1, 0);
      	dfs2(1, -1, 1);
      	build(1, n, 1);
      	while (Q--) {
      		int op, x, y;
      		scanf("%d%d%d", &op, &x, &y);
      		if (op == 1) {
      			update(dfn[x], y, 1, n, 1);
      		} 
      		else {
      			long long ans = cal(x, y);
      			printf("%lld\n", ans);
      		}
      	}
      	return 0;
      }
      
      
      • 0
        @ 2023-11-8 20:20:23
        #include <bits/stdc++.h>
        using namespace std;
        const int maxn = 1e6 + 5;
        int n, m, rt, a[maxn], dfn[maxn], rev[maxn], sz[maxn], cnt, fa[maxn], dep[maxn], 
        						top[maxn],	// top[u] 表示 u 所在的重链的顶部节点 
        						son[maxn];	// son[u] 表示 u 的重儿子编号 
        vector<int> g[maxn];
        
        void dfs1(int u, int p) {
        	fa[u] = p;
        	dep[u] = dep[p] + 1;
        	sz[u] = 1;
        	for (auto v : g[u])
        		if (v != p) {
        			dfs1(v, u), sz[u] += sz[v];
        			if (sz[v] > sz[son[u]])
        				son[u] = v;
        		}
        }
        
        void dfs2(int u, int tp) {	// 当前这条重链的顶部节点 
        	top[u] = tp;
        	dfn[u] = ++cnt;
        	rev[cnt] = u;
        	if (son[u])
        		dfs2(son[u], tp);
        	for (auto v : g[u])
        		if (v != fa[u] && v != son[u])
        			dfs2(v, v);
        }
        
        // 线段树部分
        long long tr[maxn<<2];
        #define lson l, mid, rt<<1
        #define rson mid+1, r, rt<<1|1
        void push_up(int rt) {
        	tr[rt] = tr[rt<<1] + tr[rt<<1|1];
        }
        void build(int l, int r, int rt) {
        	if (l == r) {
        		tr[rt] = a[ rev[l] ];
        		return;
        	}
        	int mid = (l + r) / 2;
        	build(lson);
        	build(rson);
        	push_up(rt);
        }
        void update(int p, int x, int l, int r, int rt) {
        	if (l == r) {
        		tr[rt] = x;
        		return;
        	}
        	int mid = (l + r) / 2;
        	(p <= mid) ? update(p, x, lson) : update(p, x, rson);
        	push_up(rt);
        }
        long long query(int L, int R, int l, int r, int rt) {
        	if (L <= l && r <= R) return tr[rt];
        	long long res = 0;
        	int mid = (l + r) /2;
        	if (L <= mid) res += query(L, R, lson);
        	if (R > mid) res += query(L, R, rson);
        	return res;
        }
        
        long long cal(int u, int v) {
        	long long res = 0;
        	while (top[u] != top[v]) {
        		if (dep[top[u]] < dep[top[v]])
        			swap(u, v);
        		res += query(dfn[top[u]], dfn[u], 1, n, 1);
        		u = fa[top[u]];
        	}
        	if (dep[u] > dep[v])
        		swap(u, v);
        	res += query(dfn[u], dfn[v], 1, n, 1);
        	return res;
        }
        
        int main() {
        	scanf("%d%d", &n, &m);
        	for (int i = 1; i <= n; i++)
        		scanf("%d", a+i);
        	for (int i = 1; i < n; i++) {
        		int u, v;
        		scanf("%d%d", &u, &v);
        		g[u].push_back(v);
        		g[v].push_back(u);
        	}
        	dfs1(1, 0);
        	dfs2(1, 1);
        	build(1, n, 1);
        	while (m--) {
        		int op, x, y;
        		scanf("%d%d%d", &op, &x, &y);
        		if (op == 1)
        			update(dfn[x], y, 1, n, 1);
        		else 
        			printf("%lld\n", cal(x, y));
        	}
        	return 0;
        }
        
        • 1

        信息

        ID
        38
        时间
        2000ms
        内存
        512MiB
        难度
        7
        标签
        递交数
        24
        已通过
        8
        上传者