1 条题解

  • 0
    @ 2023-11-8 21:13:41
    #include<bits/stdc++.h>
    #define int long long
    #define lson l,mid,rt<<1
    #define rson mid+1,r,rt<<1|1
    using namespace std;
    const int N=2e5+5;
    vector<int>g[N];
    int a[N];int n,m;
    
    int sz[N],dep[N],fa[N],son[N];
    void dfs1(int u,int father)
    {
    	sz[u]++;
    	dep[u]=dep[father]+1;
    	fa[u]=father;
    	for(auto v:g[u])
    		if(v!=father)
    		{
    			dfs1(v,u),sz[u]+=sz[v];
    			if(sz[v]>sz[son[u]])
    				son[u]=v;
    		}
    }
    int dfn[N],rev[N];int cnt;
    int top[N];
    void dfs2(int u,int tp)
    {
    	top[u]=tp;
    	dfn[u]=++cnt;
    	rev[cnt]=u;
    	if(son[u])
    		dfs2(son[u],tp);
    	for(auto v:g[u])
    		if(v!=fa[u]&&v!=son[u])
    			dfs2(v,v);
    }
    
    int lazy[N<<2];
    int tr_sum[N<<2],tr_max[N<<2],tr_min[N<<2];
    void push_down(int l,int r,int rt)
    {
    	if(lazy[rt])
    	{
    		int mid=(l+r)>>1;
    		lazy[rt<<1]+=lazy[rt];
    		lazy[rt<<1|1]+=lazy[rt];
    		tr_sum[rt<<1]+=(mid-l+1)*lazy[rt];
    		tr_sum[rt<<1|1]+=(r-mid)*lazy[rt];
    		tr_max[rt<<1]+=lazy[rt];
    		tr_max[rt<<1|1]+=lazy[rt];
    		tr_min[rt<<1]+=lazy[rt];
    		tr_min[rt<<1|1]+=lazy[rt];
    		lazy[rt]=0;
    	}
    }
    void push_up(int rt)
    {
    	tr_sum[rt]=tr_sum[rt<<1]+tr_sum[rt<<1|1];
    	tr_max[rt]=max(tr_max[rt<<1],tr_max[rt<<1|1]);
    	tr_min[rt]=min(tr_min[rt<<1],tr_min[rt<<1|1]);
    }
    void build(int l,int r,int rt)
    {
    	if(l==r)
    	{
    		tr_sum[rt]=a[rev[l]];
    		tr_max[rt]=a[rev[l]];
    		tr_min[rt]=a[rev[l]];
    		return;
    	}
    	int mid=(l+r)>>1;
    	build(lson);
    	build(rson);
    	push_up(rt);
    }
    void update(int l,int r,int rt,int l1,int r1,int x)
    {
    	if(l1<=l&&r<=r1)
    	{
    		tr_sum[rt]+=(r-l+1)*x;
    		tr_max[rt]+=x;
    		tr_min[rt]+=x;
    		lazy[rt]+=x;
    		return;
    	}
    	int mid=(l+r)>>1;
    	push_down(l,r,rt);
    	if(l1<=mid)
    		update(lson,l1,r1,x);
    	if(r1>mid)
    		update(rson,l1,r1,x);
    	push_up(rt);
    }
    int query_sum(int l,int r,int rt,int l1,int r1)
    {
    	if(l1<=l&&r<=r1)
    		return tr_sum[rt];
    	int mid=(l+r)>>1;
    	int res=0;
    	push_down(l,r,rt);
    	if(l1<=mid)
    		res+=query_sum(lson,l1,r1);
    	if(r1>mid)
    		res+=query_sum(rson,l1,r1);
    	return res;
    }
    int query_max(int l,int r,int rt,int l1,int r1)
    {
    	if(l1<=l&&r<=r1)
    		return tr_max[rt];
    	int mid=(l+r)>>1;
    	int res=-1e18;
    	push_down(l,r,rt);
    	if(l1<=mid)
    		res=max(res,query_max(lson,l1,r1));
    	if(r1>mid)
    		res=max(res,query_max(rson,l1,r1));
    	return res;
    }
    int query_min(int l,int r,int rt,int l1,int r1)
    {
    	if(l1<=l&&r<=r1)
    		return tr_min[rt];
    	int mid=(l+r)>>1;
    	int res=1e18;
    	push_down(l,r,rt);
    	if(l1<=mid)
    		res=min(res,query_min(lson,l1,r1));
    	if(r1>mid)
    		res=min(res,query_min(rson,l1,r1));
    	return res;
    }
    
    void cal_update(int u,int v,int k)
    {
    	while(top[u]!=top[v])
    	{
    		if(dep[top[u]]<dep[top[v]])
    			swap(u,v);
    		update(1,n,1,dfn[top[u]],dfn[u],k);
    		u=fa[top[u]];
    	}
    	if(dep[u]>dep[v])
    		swap(u,v);
    	update(1,n,1,dfn[u],dfn[v],k);
    }
    int cal_sum(int u,int v)
    {
    	int res=0;
    	while(top[u]!=top[v])
    	{
    		if(dep[top[u]]<dep[top[v]])
    			swap(u,v);
    		res+=query_sum(1,n,1,dfn[top[u]],dfn[u]);
    		u=fa[top[u]];
    	}
    	if(dep[u]>dep[v])
    		swap(u,v);
    	res+=query_sum(1,n,1,dfn[u],dfn[v]);
    	return res;
    }
    int cal_max(int u,int v)
    {
    	int res=-1e18;
    	while(top[u]!=top[v])
    	{
    		if(dep[top[u]]<dep[top[v]])
    			swap(u,v);
    		res=max(res,query_max(1,n,1,dfn[top[u]],dfn[u]));
    		u=fa[top[u]];
    	}
    	if(dep[u]>dep[v])
    		swap(u,v);
    	res=max(res,query_max(1,n,1,dfn[u],dfn[v]));
    	return res;
    }
    int cal_min(int u,int v)
    {
    	int res=1e18;
    	while(top[u]!=top[v])
    	{
    		if(dep[top[u]]<dep[top[v]])
    			swap(u,v);
    		res=min(res,query_min(1,n,1,dfn[top[u]],dfn[u]));
    		u=fa[top[u]];
    	}
    	if(dep[u]>dep[v])
    		swap(u,v);
    	res=min(res,query_min(1,n,1,dfn[u],dfn[v]));
    	return res;
    }
    signed main()
    {
    	
    	scanf("%lld%lld",&n,&m);
    	for(int i=1;i<=n;i++)
    		scanf("%lld",&a[i]);
    	for(int i=1;i<n;i++)
    	{
    		int u,v;
    		scanf("%lld%lld",&u,&v);
    		g[u].push_back(v);
    		g[v].push_back(u);
    	}
    	dfs1(1,0);
    	dfs2(1,1);
    	
    	build(1,n,1);
    	while(m--)
    	{
    		int op,x,y;
    		scanf("%lld%lld%lld",&op,&x,&y);
    		if(op==1)
    		{
    			int k;
    			scanf("%lld",&k);
    			cal_update(x,y,k);
    		}
    		if(op==2)
    			printf("%lld\n",cal_max(x,y));
    		if(op==3)
    			printf("%lld\n",cal_min(x,y));
    		if(op==4)
    			printf("%lld\n",cal_sum(x,y));
    	}
    	return 0;
    }
    
    • 1

    信息

    ID
    39
    时间
    1000ms
    内存
    256MiB
    难度
    9
    标签
    递交数
    14
    已通过
    4
    上传者