In this post, we will solve HackerRank Tree Pruning Problem Solution. A tree, t, has ʼn vertices numbered from 1 to n and is rooted at vertex 1. Each vertex i has an integer weight, w,, associated with it, and t's total weight is the sum of the weights of its nodes. A single remove operation removes the subtree rooted at some arbitrary vertex u from tree t.Given t, perform up to k remove operations so that the total weight of the remainingvertices in t is maximal. Then print t's maximal total weight on a new line. Note: Ift's total weight is already maximal, you may opt to remove 0 nodes.Input FormatThe first line contains two space-separated integers, n and k, respectively,The second line contains n space-separated integers describing the respective weights for each node in the tree, where the ith integer is the weight of the ¿th vertex. Each of the n-1 subsequent lines contains a pair of space-separated integers, u and v. describing an edge connecting vertex u to vertex v. Output Format Print a single integer denoting the largest total weight of t’s remaining vertices. Sample Input 5 2 1 1 -1 -1 -1 1 2 2 3 4 1 4 5 Sample Output 2 ExplanationWe perform 2 remove operations: Remove the subtree rooted at node 3. Losing this subtree’s -1 weight increases the tree’s total weight by 1. Remove the subtree rooted at node 4. Losing this subtree’s -2 weight increases the tree’s total weight by 2.The sum of our remaining positively-weighted nodes is 1+1=2, so we print 2 on a newline. HackerRank Tree Pruning Problem Solution Tree Pruning C Solution #include <stdio.h> #include <stdlib.h> typedef struct _node{ int x; struct _node *next; } node; void insert_edge(int x,int y); void dfs(int x); long long max(long long x,long long y); int a[100000],b[100000],size[100000],trace[100000]={0},NN=0; long long dp[100001][201]; node *table[100000]={0}; int main(){ int N,K,x,y,i,j; long long sum; scanf("%d%d",&N,&K); for(i=0;i<N;i++) scanf("%d",a+i); for(i=0;i<N-1;i++){ scanf("%d%d",&x,&y); insert_edge(x-1,y-1); } dfs(0); for(i=0;i<=K;i++) dp[0][i]=0; for(i=1,sum=0;i<=N;i++){ sum+=b[i-1]; for(j=0;j<=K;j++) dp[i][j]=sum; } for(i=1,sum=0;i<=N;i++) for(j=0;j<=K;j++){ if(j!=K) dp[i+size[i-1]-1][j+1]=max(dp[i+size[i-1]-1][j+1],dp[i-1][j]); dp[i][j]=max(dp[i][j],dp[i-1][j]+b[i-1]); } printf("%lld",dp[N][K]); return 0; } void insert_edge(int x,int y){ node *t; t=(node*)malloc(sizeof(node)); t->x=y; t->next=table[x]; table[x]=t; t=(node*)malloc(sizeof(node)); t->x=x; t->next=table[y]; table[y]=t; return; } void dfs(int x){ node *t; int i=NN; trace[x]=1; b[NN++]=a[x]; for(t=table[x];t;t=t->next) if(!trace[t->x]) dfs(t->x); size[i]=NN-i; return; } long long max(long long x,long long y){ return (x>y)?x:y; } Tree Pruning C++ Solution #include <bits/stdc++.h> using namespace std; typedef long long LL; LL INF = 1ll << 60; vector<int> eulerCircuit; vector<vector<int>> edges; int w[100010], f[100010], l[100010]; LL DP[200010][201]; void dfs(const int u, const int p) { eulerCircuit.push_back(u); for (auto &e : edges[u]) if (e != p) dfs(e, u), eulerCircuit.push_back(u); } // dfs LL solve(const int u, const int rem) { if (rem < 0) return -INF; if (u == (int)eulerCircuit.size() - 1) return 0; if (DP[u][rem] != -1) return DP[u][rem]; DP[u][rem] = solve(u + 1, rem) + ((f[eulerCircuit[u]] < f[eulerCircuit[u + 1]]) ? w[eulerCircuit[u + 1]] : 0ll); if (f[eulerCircuit[u]] < f[eulerCircuit[u + 1]]) DP[u][rem] = max(DP[u][rem], solve(u + l[eulerCircuit[u + 1]] - f[eulerCircuit[u + 1]] + 2, rem - 1)); return DP[u][rem] = DP[u][rem]; } // solve int main() { int n, k, n1, n2; scanf("%d %d", &n, &k); for (int i = 1; i <= n; i++) scanf("%d", &w[i]); edges.assign(n + 5, vector<int>()); for (int i = 0; i < n - 1; i++) { scanf("%d %d", &n1, &n2); edges[n1].push_back(n2); edges[n2].push_back(n1); } // for edges[0].push_back(1); dfs(0, -1); for (int i = 0; i < (int)eulerCircuit.size(); i++) { l[eulerCircuit[i]] = i + 1; if (!f[eulerCircuit[i]]) f[eulerCircuit[i]] = i + 1; } // for memset(DP, -1, sizeof(DP)); printf("%lld\n", solve(0, k)); return 0; } // main Tree Pruning C Sharp Solution using System; using System.Collections; using System.Collections.Generic; using System.Linq; class Solution{ static void Main(){ new Sol().Solve(); } } class Sol{ public void Solve(){ Queue<int> Q=new Queue<int>(); depth[root]=0; parent[root]=-1; Q.Enqueue(root); long minf=(long)-1e18; //深さと親を決める(Nが大きいのでdfsは回避) while(Q.Count>0){ int x=Q.Dequeue(); foreach(int l in E[x]){ if(l!=parent[x]){ depth[l]=depth[x]+1; parent[l]=x; Q.Enqueue(l); } } } int maxdepth=depth.Max(); List<int>[] Dep=new List<int>[maxdepth+1]; for(int i=0;i<=maxdepth;i++)Dep[i]=new List<int>(); for(int i=0;i<N;i++){ Dep[depth[i]].Add(i); } //各ノードの重みの合計を計算 long[] sum=new long[N]; for(int d=maxdepth;d>=0;d--){ for(int j=0;j<Dep[d].Count;j++){ sum[Dep[d][j]]=weight[Dep[d][j]]; foreach(int l in E[Dep[d][j]]){ if(l!=parent[Dep[d][j]]){ sum[Dep[d][j]]+=sum[l]; } } } } long[][] dp=new long[N][]; for(int i=0;i<N;i++){ dp[i]=new long[K+1]; for(int j=0;j<=K;j++)dp[i][j]=minf; } //深い側から親の方に順に配るdp for(int d=maxdepth;d>=0;d--){ for(int j=0;j<Dep[d].Count;j++){ int now=Dep[d][j]; dp[now][0]=0; foreach(int k in E[now]){ if(k!=parent[now]){ for(int ii=K;ii>=0;ii--){ if(dp[now][ii]==minf)continue; for(int jj=K;jj>=0;jj--){ if(dp[k][jj]==minf)continue; if(ii+jj>K)continue; dp[now][ii+jj]=Math.Max(dp[now][ii+jj],dp[now][ii]+dp[k][jj]); } } } } dp[now][1]=Math.Max(-sum[now],dp[now][1]); } } long Max=(long)-1e18; for(int i=0;i<=K;i++){ Max=Math.Max(Max,dp[root][i]); } Console.WriteLine(sum[root]+Max); } int N; int K; int root; List<int>[] E; int[] parent; int[] depth; long[] weight; public Sol(){ var d=ria(); N=d[0];K=d[1]; weight=rla(); E=new List<int>[N]; for(int i=0;i<N;i++)E[i]=new List<int>(); parent=new int[N]; root=0; depth=new int[N]; for(int i=0;i<N-1;i++){ var dd=ria(); E[dd[0]-1].Add(dd[1]-1); E[dd[1]-1].Add(dd[0]-1); } } static int[] ria(){return Array.ConvertAll(Console.ReadLine().Split(' '),e=>int.Parse(e));} static long[] rla(){return Array.ConvertAll(Console.ReadLine().Split(' '),e=>long.Parse(e));} } Tree Pruning Java Solution import java.util.*; import java.io.*; public class Main { public static void main(String[] args) throws IOException { FastScanner in = new FastScanner(System.in); PrintWriter out = new PrintWriter(System.out); new Main().run(in, out); out.close(); } int n; int K; List<Integer>[] adj; int[] w; void run(FastScanner in, PrintWriter out) { n = in.nextInt(); K = in.nextInt(); w = new int[n]; adj = new List[n]; for (int i = 0; i < n; i++) adj[i] = new ArrayList<>(); for (int i = 0; i < n; i++) w[i] = in.nextInt(); for (int i = 1; i < n; i++) { int u = in.nextInt()-1; int v = in.nextInt()-1; adj[u].add(v); adj[v].add(u); } long[] dp = go(0, -1); long max = Long.MIN_VALUE; for (int k = 0; k <= K; k++) { max = Math.max(max, dp[k]); } out.println(max); } long[] go(int u, int p) { long[][] dp = new long[2][K+1]; for (long[] d : dp) Arrays.fill(d, Long.MIN_VALUE); int flip = 0; dp[0][0] = w[u]; for (int v : adj[u]) { Arrays.fill(dp[flip^1], Long.MIN_VALUE); if (v == p) continue; long[] childDp = go(v, u); for (int k = 0; k <= K && dp[flip][k] != Long.MIN_VALUE; k++) { for (int pk = 0; pk+k <= K && childDp[pk] != Long.MIN_VALUE; pk++) { dp[flip^1][pk+k] = Math.max(dp[flip^1][pk+k], dp[flip][k] + childDp[pk]); } } flip = flip^1; } dp[flip][1] = Math.max(dp[flip][1], 0); return dp[flip]; } static class FastScanner { BufferedReader br; StringTokenizer st; public FastScanner(InputStream in) { br = new BufferedReader(new InputStreamReader(in)); st = null; } String next() { while (st == null || !st.hasMoreElements()) { try { st = new StringTokenizer(br.readLine()); } catch (IOException e) { e.printStackTrace(); } } return st.nextToken(); } int nextInt() { return Integer.parseInt(next()); } long nextLong() { return Long.parseLong(next()); } } } Tree Pruning JavaScript Solution Array.matrix = function(numrows, numcols, initial){ var arr = []; for (var i = 0; i < numrows; ++i){ var columns = []; for (var j = 0; j < numcols; ++j){ columns[j] = initial; } arr[i] = columns; } return arr; } function fillArray(arr, size, val) { for (var i = 0; i < size; i++) { arr[i] = val; } } function processData(input) { //Enter your code here var inputArray = input.split("\n"); var params = inputArray[0].split(" ").map(Number); var N = params[0]; var K = params[1]; var weights = inputArray[1].split(" ").map(Number); var sum = 0; for (var i = 0; i < N; i++) { sum += weights[i]; } //console.log(weights); var tree = {}; var res = 0; for (var i = 2; i<2+N-1; i++) { var edge = inputArray[i].split(" ").map(Number); if (tree.hasOwnProperty(edge[0])) tree[edge[0]].push(edge[1]); else { var children = []; children[0] = edge[1]; tree[edge[0]] = children; } if (tree.hasOwnProperty(edge[1])) { tree[edge[1]].push(edge[0]); } else { var children = []; children[0] = edge[0]; tree[edge[1]] = children; } } //console.log("hi"); //console.log(tree); console.log(treePruning(tree, N, K, sum, weights)); } function treePruning(tree, N, K, sum, weights) { var A = []; var W = []; var S = []; fillArray(A, N, 0); fillArray(S, N, 0); fillArray(W, N, 0); //var visited = {}; dfsTraversal(1, A, W, S, 0, weights, tree, [0]); var res = Array.matrix(1,N,0)[0]; var back = Array.matrix(1,N,0)[0]; for (var i = 1; i <= K; i++) { for (var j = N -1; j>=0; j--) { if (W[j]<0) { if (j+S[j]>=N) { res[j] = W[j]; } else { res[j] = W[j]+back[j+S[j]]; } } else { res[j] = j==N-1?0:res[j+1]; } res[j] = Math.min(res[j],j==N-1?0:res[j+1]); } var old = back; back = res; for (var k = 0; k < old.length; #!/bin/python3 import os import sys # # Complete the treePrunning function below. # from collections import defaultdict INF = -(1e15) def dfs(x, f, g, k, weights): dpc = [INF]*(k+1) dpc[0] = weights[x] for n in g[x]: if n == f: continue dpn = dfs(n, x, g, k, weights) dptmp = [INF]*(k+1) for i in range(k+1): if dpc[i] == INF: break for j in range(0, k-i+1): if dpn[j] == INF: break dptmp[i+j] = max(dptmp[i+j], dpc[i]+dpn[j]) if i+1 <= k: dptmp[i+1] = max(dptmp[i+1], dpc[i]) dpc = dptmp return dpc def treePrunning(k,weights,edges): g = defaultdict(list) for u, v in edges: g[u-1].append(v-1) g[v-1].append(u-1) dpn = dfs(0, -1, g, k, weights) return max(max(dpn),0) if __name__ == '__main__': fptr = open(os.environ['OUTPUT_PATH'], 'w') nk = input().split() n = int(nk[0]) k = int(nk[1]) weights = list(map(int, input().rstrip().split())) tree = [] for _ in range(n-1): tree.append(list(map(int, input().rstrip().split()))) result = treePrunning(k, weights, tree) fptr.write(str(result) + '\n') fptr.close()