HackerRank Tree Pruning Problem Solution
In this post, we will solve HackerRank Tree Pruning Problem Solution.
A tree, t, has ʼn vertices numbered from 1 to n and is rooted at vertex 1. Each vertex i has an integer weight, w,, associated with it, and t’s total weight is the sum of the weights of its nodes. A single remove operation removes the subtree rooted at some arbitrary vertex u from tree t.
Given t, perform up to k remove operations so that the total weight of the remaining
vertices in t is maximal. Then print t’s maximal total weight on a new line. Note: Ift’s total weight is already maximal, you may opt to remove 0 nodes.
Input Format
The first line contains two space-separated integers, n and k, respectively,
The second line contains n space-separated integers describing the respective weights for each node in the tree, where the ith integer is the weight of the ¿th vertex. Each of the n-1 subsequent lines contains a pair of space-separated integers, u and v. describing an edge connecting vertex u to vertex v.
Output Format
Print a single integer denoting the largest total weight of t’s remaining vertices.
Sample Input
5 2
1 1 -1 -1 -1
1 2
2 3
4 1
4 5
Sample Output
2
Explanation
We perform 2 remove operations:
- Remove the subtree rooted at node 3. Losing this subtree’s -1 weight increases the tree’s total weight by 1.
- Remove the subtree rooted at node 4. Losing this subtree’s -2 weight increases the tree’s total weight by 2.
The sum of our remaining positively-weighted nodes is 1+1=2, so we print 2 on a new
line.
Tree Pruning C Solution
#include <stdio.h>
#include <stdlib.h>
typedef struct _node{
int x;
struct _node *next;
} node;
void insert_edge(int x,int y);
void dfs(int x);
long long max(long long x,long long y);
int a[100000],b[100000],size[100000],trace[100000]={0},NN=0;
long long dp[100001][201];
node *table[100000]={0};
int main(){
int N,K,x,y,i,j;
long long sum;
scanf("%d%d",&N,&K);
for(i=0;i<N;i++)
scanf("%d",a+i);
for(i=0;i<N-1;i++){
scanf("%d%d",&x,&y);
insert_edge(x-1,y-1);
}
dfs(0);
for(i=0;i<=K;i++)
dp[0][i]=0;
for(i=1,sum=0;i<=N;i++){
sum+=b[i-1];
for(j=0;j<=K;j++)
dp[i][j]=sum;
}
for(i=1,sum=0;i<=N;i++)
for(j=0;j<=K;j++){
if(j!=K)
dp[i+size[i-1]-1][j+1]=max(dp[i+size[i-1]-1][j+1],dp[i-1][j]);
dp[i][j]=max(dp[i][j],dp[i-1][j]+b[i-1]);
}
printf("%lld",dp[N][K]);
return 0;
}
void insert_edge(int x,int y){
node *t;
t=(node*)malloc(sizeof(node));
t->x=y;
t->next=table[x];
table[x]=t;
t=(node*)malloc(sizeof(node));
t->x=x;
t->next=table[y];
table[y]=t;
return;
}
void dfs(int x){
node *t;
int i=NN;
trace[x]=1;
b[NN++]=a[x];
for(t=table[x];t;t=t->next)
if(!trace[t->x])
dfs(t->x);
size[i]=NN-i;
return;
}
long long max(long long x,long long y){
return (x>y)?x:y;
}
Tree Pruning C++ Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL INF = 1ll << 60;
vector<int> eulerCircuit;
vector<vector<int>> edges;
int w[100010], f[100010], l[100010];
LL DP[200010][201];
void dfs(const int u, const int p) {
eulerCircuit.push_back(u);
for (auto &e : edges[u])
if (e != p)
dfs(e, u), eulerCircuit.push_back(u);
} // dfs
LL solve(const int u, const int rem) {
if (rem < 0)
return -INF;
if (u == (int)eulerCircuit.size() - 1)
return 0;
if (DP[u][rem] != -1)
return DP[u][rem];
DP[u][rem] =
solve(u + 1, rem) + ((f[eulerCircuit[u]] < f[eulerCircuit[u + 1]])
? w[eulerCircuit[u + 1]]
: 0ll);
if (f[eulerCircuit[u]] < f[eulerCircuit[u + 1]])
DP[u][rem] = max(DP[u][rem], solve(u + l[eulerCircuit[u + 1]] -
f[eulerCircuit[u + 1]] + 2,
rem - 1));
return DP[u][rem] = DP[u][rem];
} // solve
int main() {
int n, k, n1, n2;
scanf("%d %d", &n, &k);
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
edges.assign(n + 5, vector<int>());
for (int i = 0; i < n - 1; i++) {
scanf("%d %d", &n1, &n2);
edges[n1].push_back(n2);
edges[n2].push_back(n1);
} // for
edges[0].push_back(1);
dfs(0, -1);
for (int i = 0; i < (int)eulerCircuit.size(); i++) {
l[eulerCircuit[i]] = i + 1;
if (!f[eulerCircuit[i]])
f[eulerCircuit[i]] = i + 1;
} // for
memset(DP, -1, sizeof(DP));
printf("%lld\n", solve(0, k));
return 0;
} // main
Tree Pruning C Sharp Solution
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
class Solution{
static void Main(){
new Sol().Solve();
}
}
class Sol{
public void Solve(){
Queue<int> Q=new Queue<int>();
depth[root]=0;
parent[root]=-1;
Q.Enqueue(root);
long minf=(long)-1e18;
//深さと親を決める(Nが大きいのでdfsは回避)
while(Q.Count>0){
int x=Q.Dequeue();
foreach(int l in E[x]){
if(l!=parent[x]){
depth[l]=depth[x]+1;
parent[l]=x;
Q.Enqueue(l);
}
}
}
int maxdepth=depth.Max();
List<int>[] Dep=new List<int>[maxdepth+1];
for(int i=0;i<=maxdepth;i++)Dep[i]=new List<int>();
for(int i=0;i<N;i++){
Dep[depth[i]].Add(i);
}
//各ノードの重みの合計を計算
long[] sum=new long[N];
for(int d=maxdepth;d>=0;d--){
for(int j=0;j<Dep[d].Count;j++){
sum[Dep[d][j]]=weight[Dep[d][j]];
foreach(int l in E[Dep[d][j]]){
if(l!=parent[Dep[d][j]]){
sum[Dep[d][j]]+=sum[l];
}
}
}
}
long[][] dp=new long[N][];
for(int i=0;i<N;i++){
dp[i]=new long[K+1];
for(int j=0;j<=K;j++)dp[i][j]=minf;
}
//深い側から親の方に順に配るdp
for(int d=maxdepth;d>=0;d--){
for(int j=0;j<Dep[d].Count;j++){
int now=Dep[d][j];
dp[now][0]=0;
foreach(int k in E[now]){
if(k!=parent[now]){
for(int ii=K;ii>=0;ii--){
if(dp[now][ii]==minf)continue;
for(int jj=K;jj>=0;jj--){
if(dp[k][jj]==minf)continue;
if(ii+jj>K)continue;
dp[now][ii+jj]=Math.Max(dp[now][ii+jj],dp[now][ii]+dp[k][jj]);
}
}
}
}
dp[now][1]=Math.Max(-sum[now],dp[now][1]);
}
}
long Max=(long)-1e18;
for(int i=0;i<=K;i++){
Max=Math.Max(Max,dp[root][i]);
}
Console.WriteLine(sum[root]+Max);
}
int N;
int K;
int root;
List<int>[] E;
int[] parent;
int[] depth;
long[] weight;
public Sol(){
var d=ria();
N=d[0];K=d[1];
weight=rla();
E=new List<int>[N];
for(int i=0;i<N;i++)E[i]=new List<int>();
parent=new int[N];
root=0;
depth=new int[N];
for(int i=0;i<N-1;i++){
var dd=ria();
E[dd[0]-1].Add(dd[1]-1);
E[dd[1]-1].Add(dd[0]-1);
}
}
static int[] ria(){return Array.ConvertAll(Console.ReadLine().Split(' '),e=>int.Parse(e));}
static long[] rla(){return Array.ConvertAll(Console.ReadLine().Split(' '),e=>long.Parse(e));}
}
Tree Pruning Java Solution
import java.util.*;
import java.io.*;
public class Main {
public static void main(String[] args) throws IOException {
FastScanner in = new FastScanner(System.in);
PrintWriter out = new PrintWriter(System.out);
new Main().run(in, out);
out.close();
}
int n;
int K;
List<Integer>[] adj;
int[] w;
void run(FastScanner in, PrintWriter out) {
n = in.nextInt();
K = in.nextInt();
w = new int[n];
adj = new List[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n; i++) w[i] = in.nextInt();
for (int i = 1; i < n; i++) {
int u = in.nextInt()-1;
int v = in.nextInt()-1;
adj[u].add(v);
adj[v].add(u);
}
long[] dp = go(0, -1);
long max = Long.MIN_VALUE;
for (int k = 0; k <= K; k++) {
max = Math.max(max, dp[k]);
}
out.println(max);
}
long[] go(int u, int p) {
long[][] dp = new long[2][K+1];
for (long[] d : dp) Arrays.fill(d, Long.MIN_VALUE);
int flip = 0;
dp[0][0] = w[u];
for (int v : adj[u]) {
Arrays.fill(dp[flip^1], Long.MIN_VALUE);
if (v == p) continue;
long[] childDp = go(v, u);
for (int k = 0; k <= K && dp[flip][k] != Long.MIN_VALUE; k++) {
for (int pk = 0; pk+k <= K && childDp[pk] != Long.MIN_VALUE; pk++) {
dp[flip^1][pk+k] = Math.max(dp[flip^1][pk+k],
dp[flip][k] + childDp[pk]);
}
}
flip = flip^1;
}
dp[flip][1] = Math.max(dp[flip][1], 0);
return dp[flip];
}
static class FastScanner {
BufferedReader br;
StringTokenizer st;
public FastScanner(InputStream in) {
br = new BufferedReader(new InputStreamReader(in));
st = null;
}
String next() {
while (st == null || !st.hasMoreElements()) {
try {
st = new StringTokenizer(br.readLine());
} catch (IOException e) {
e.printStackTrace();
}
}
return st.nextToken();
}
int nextInt() {
return Integer.parseInt(next());
}
long nextLong() {
return Long.parseLong(next());
}
}
}
Tree Pruning JavaScript Solution
Array.matrix = function(numrows, numcols, initial){
var arr = [];
for (var i = 0; i < numrows; ++i){
var columns = [];
for (var j = 0; j < numcols; ++j){
columns[j] = initial;
}
arr[i] = columns;
}
return arr;
}
function fillArray(arr, size, val) {
for (var i = 0; i < size; i++) {
arr[i] = val;
}
}
function processData(input) {
//Enter your code here
var inputArray = input.split("\n");
var params = inputArray[0].split(" ").map(Number);
var N = params[0];
var K = params[1];
var weights = inputArray[1].split(" ").map(Number);
var sum = 0;
for (var i = 0; i < N; i++) {
sum += weights[i];
}
//console.log(weights);
var tree = {};
var res = 0;
for (var i = 2; i<2+N-1; i++) {
var edge = inputArray[i].split(" ").map(Number);
if (tree.hasOwnProperty(edge[0]))
tree[edge[0]].push(edge[1]);
else {
var children = [];
children[0] = edge[1];
tree[edge[0]] = children;
}
if (tree.hasOwnProperty(edge[1])) {
tree[edge[1]].push(edge[0]);
} else {
var children = [];
children[0] = edge[0];
tree[edge[1]] = children;
}
}
//console.log("hi");
//console.log(tree);
console.log(treePruning(tree, N, K, sum, weights));
}
function treePruning(tree, N, K, sum, weights) {
var A = [];
var W = [];
var S = [];
fillArray(A, N, 0);
fillArray(S, N, 0);
fillArray(W, N, 0);
//var visited = {};
dfsTraversal(1, A, W, S, 0, weights, tree, [0]);
var res = Array.matrix(1,N,0)[0];
var back = Array.matrix(1,N,0)[0];
for (var i = 1; i <= K; i++) {
for (var j = N -1; j>=0; j--) {
if (W[j]<0) {
if (j+S[j]>=N) {
res[j] = W[j];
} else {
res[j] = W[j]+back[j+S[j]];
}
} else {
res[j] = j==N-1?0:res[j+1];
}
res[j] = Math.min(res[j],j==N-1?0:res[j+1]);
}
var old = back;
back = res;
for (var k = 0; k < old.length; k++) {
old[k] = 0;
}
res = old;
}
return sum-back[0];
}
function dfsTraversal (node, A, W, S, visited, weights, tree, idx) {
var index = idx[0]++;
A[index] = node;
var len = 1;
var totalWeight = weights[node - 1];
for (var i = 0; i < tree[node].length; i++) {
if (tree[node][i] == visited)
continue;
var info = dfsTraversal(tree[node][i], A, W, S, node, weights, tree, idx);
len += info.len;
totalWeight += info.subtreeWeight;
}
W[index] = totalWeight;
S[index] = len;
return {len: len, subtreeWeight: totalWeight};
}
process.stdin.resume();
process.stdin.setEncoding("ascii");
_input = "";
process.stdin.on("data", function (input) {
_input += input;
});
process.stdin.on("end", function () {
processData(_input);
});
Tree Pruning Python Solution
#!/bin/python3
import os
import sys
#
# Complete the treePrunning function below.
#
from collections import defaultdict
INF = -(1e15)
def dfs(x, f, g, k, weights):
dpc = [INF]*(k+1)
dpc[0] = weights[x]
for n in g[x]:
if n == f:
continue
dpn = dfs(n, x, g, k, weights)
dptmp = [INF]*(k+1)
for i in range(k+1):
if dpc[i] == INF:
break
for j in range(0, k-i+1):
if dpn[j] == INF:
break
dptmp[i+j] = max(dptmp[i+j], dpc[i]+dpn[j])
if i+1 <= k:
dptmp[i+1] = max(dptmp[i+1], dpc[i])
dpc = dptmp
return dpc
def treePrunning(k,weights,edges):
g = defaultdict(list)
for u, v in edges:
g[u-1].append(v-1)
g[v-1].append(u-1)
dpn = dfs(0, -1, g, k, weights)
return max(max(dpn),0)
if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')
nk = input().split()
n = int(nk[0])
k = int(nk[1])
weights = list(map(int, input().rstrip().split()))
tree = []
for _ in range(n-1):
tree.append(list(map(int, input().rstrip().split())))
result = treePrunning(k, weights, tree)
fptr.write(str(result) + '\n')
fptr.close()