Problem

You are given N edge-weighted trees T1, T2,…..TN​ having M1, M2 ….. MN nodes respectively. You also have an array A of N-1 positive integers.

You need to connect these NN trees using N-1N−1 edges such that each of these edges have weights equal to A_iAi​ for some i \in [1,N-1]i∈[1,N−1] and each of the A_iAi​ must be used exactly once. Note that after connecting these edges, a single large tree TT will form consisting of (M_1+M_2+\ldots+M_N)(M1​+M2​+…+MN​) nodes. Find the minimum possible sum of distances between each pair of nodes in the final tree TT.

Since this value can be large, output the value mod 998244353998244353.

Input Format

• The first line of input will contain a single integer NN, denoting the number of trees.
• Then, you are given the NN trees. The description of each tree is as follows:
• The first line contains the integer M_iMi​ (1 \le i \le N)(1≤iN), the number of nodes in ii-th tree.
• The next M_i-1Mi​−1 lines describe the edges of the ii-th tree. The jj-th of these M_i-1Mi​−1 lines contains 33 space-separated integers u_{ij}, v_{ij}uij​,vij​ and w_{ij}wij​, describing that there is an edge of weight w_{ij}wij​ connecting the nodes u_{ij}uij​ and v_{ij}vij​ in tree T_iTi​.
• The final line contains N-1N−1 space-separated integers, A_1, A_2,\ldots,A_{N-1}A1​,A2​,…,AN−1​.

Output Format

For each test case, output on a new line the minimum sum of distances between each pair of nodes in tree TT, modulo 998244353998244353.

Constraints

• 2 \leq N \leq 10^42≤N≤104
• 1 \leq M_i \leq 5\cdot 10^41≤Mi​≤5⋅104
• 1 \leq u_{ij}, v_{ij} \leq M_i1≤uij​,vij​≤Mi
• 1 \leq w_{ij} \leq 10^81≤wij​≤108
• 1 \leq A_i \leq 10^81≤Ai​≤108
• The sum (M_1+M_2+\ldots +M_N)(M1​+M2​+…+MN​) does not exceed 10^5105.

• Subtask 1 (30 points): N = 2N=2
• Subtask 2 (70 points): Original constraints.

Sample 1:

Input

2
3
1 2 1
2 3 3
4
1 2 2
1 3 1
1 4 3
1

Output

72

Explanation:

An optimal solution for the given test case looks as follows:

Tree of Trees Codechef Solution in CPP

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pb push_back
#define mx 200000
#define mp make_pair
#define pii pair<ll,ll>
#define fs first
#define sc second
#define mx 200000
vector<int>vc[mx],cost[mx];
bool deg[mx];
int aged[mx];
void clr(int n)
{
for(int i=1; i<=n; i++)
{
deg[i]=0;
vc[i].clear();
cost[i].clear();
}
}
int find(int n)
{
for(int i=1; i<=n; i++)
{
if(deg[i]==0)
return i;
}
assert(0);
}
ll base[]= {103,10007},mod[]= {784568360, 820925358};
ll base2[]= {137,277},mod2[]= {1000000007,1000000009};
ll base3[]= {701,41},mod3[]= {998244353,1012924417};
ll day[2][2][2],age[2][120][2],brac[2][2];
set<pii>s;
pii dfs(int v,bool x)
{
ll sum[2];
sum[0]=0;
sum[1]=0;
for(int i=0; i<vc[v].size(); i++)
{
int w=vc[v][i];
pii ps=dfs(w,x);
int c=cost[v][i];
ll val[2];
val[0]=ps.fs;
val[1]=ps.sc;
for(int j=0; j<2; j++)
{
val[j]*=base[j];
val[j]%=mod[j];
}
for(int j=0; j<2; j++)
{
val[j]+=(day[j][c][0]+day[j][c][1])%mod2[j];
val[j]%=mod2[j];
sum[j]+=val[j];
if(sum[j]>=mod2[j])
sum[j]%=mod2[j];
}
}
for(int j=0; j<2; j++)
{
sum[j]*=base2[j];
sum[j]%=mod2[j];
sum[j]+=(brac[j][0]+brac[j][1])%mod[j];
sum[j]%=mod[j];
int d=aged[v];
sum[j]+=(age[j][d][0]+age[j][d][1])%mod3[j];
sum[j]%=mod3[j];
}
if(x)
s.insert(mp(sum[0],sum[1]));
return mp(sum[0],sum[1]);
}
#define pb push_back
int main()
{
srand(time(NULL));
int i,j,k,l,m,n;
for(int i=0; i<101; i++)
{
for(int j=0; j<=1; j++)
{
age[j][i][0]=rand();
age[j][i][1]=rand();
}
}
for(int i=0; i<2; i++)
{
brac[i][0]=rand();
brac[i][1]=rand();
}
for(int i=0; i<2; i++)
{
for(int j=0; j<2; j++)
{
for(int k=0; k<2; k++)
{
day[i][j][k]=rand();
}
}
}
cin>>n;
char str[5];
for(int i=1; i<=n; i++)
{
scanf("%d",&aged[i]);
}
for(int i=1; i<n; i++)
{
scanf("%d%d%s",&l,&k,str);
vc[l].pb(k);
deg[k]=1;
int c=0;
if(str[0]=='E')
c=1;
cost[l].pb(c);
}
int root=find(n);
dfs(root,1);
int Q;
cin>>Q;
while(Q--)
{
scanf("%d",&n);
clr(n);
for(int i=1; i<=n; i++)
{
scanf("%d",&aged[i]);
}
for(int i=1; i<n; i++)
{
scanf("%d%d%s",&l,&k,str);
vc[l].pb(k);
deg[k]=1;
int c=0;
if(str[0]=='E')
c=1;
cost[l].pb(c);
}
int root2=find(n);
pii rs= dfs(root2,0);
if(s.find(rs)!=s.end())printf("YES\n");
else printf("NO\n");
}
}


