In this post, we will solve HackerRank Tree Splitting Problem Solution.
Given a tree with vertices numbered from 1 to n. You need to process m queries. Each query represents a vertex number encoded in the following way:
Queries are encoded in the following way: Let. m, be the jth query and ans, be the
answer for the jth query where 1 <j< m and anso is always 0. Then vertex vjansj 1 mj. We are assure that v, is between 1 and n, and hasn’t been removed before.
Note: is the bitwise XOR operator.
For each query, first decode the vertex v and then perform the following:
- Print the size of the connected component containing v.
- Remove vertex v and all edges connected to v.
Input Format
The first line contains a single integer. n. denoting the number of vertices in the tree. Each line of the n-1 subsequent lines (where 0 < i < n) contains 2 space-separated integers describing the respective nodes, u, and u,, connected by edge i. The next line contains a single integer, m, denoting the number of queries. Each line j of the m subsequent lines contains a single integer, vertex number mj.
Output Format
For each query, print the size of the corresponding connected component on a new line.
Sample Input 0
3
1 2
1 3
3
1
1
2
Sample Output 0
3
1
1
Sample Input 1
4
1 2
1 3
1 4
4
3
6
2
6
Sample Output 1
4
3
2
1

Tree Splitting C Solution
#include <stdlib.h>
#include <stdio.h>
struct Set {
int count;
};
typedef struct Set Set;
struct node{
int number;
struct node * parent;
struct node * next;
struct node * prev;
struct node * first_child;
Set * set;
};
typedef struct node node;
void print_children(node * n){
node * child = n->first_child;
while(child){
printf("%d\n", child->number);
child = child->next;
}
}
void add_child(node * n, node * c){
node * cur = n->first_child;
n->first_child = c;
if (cur){
cur->prev = c;
c->next = cur;
}
}
void fill_children(node * root, node ** nodes, node ** result_nodes){
node * repr = nodes[root->number];
if(repr == 0){
return;
}
node * child = repr->first_child;
while(child){
if (result_nodes[child->number] != 0){
child = child->next;
continue;
}
node * c = calloc(1, sizeof(node));
c->number = child->number;
c->parent = root;
result_nodes[c->number] = c;
add_child(root, c);
fill_children(c, nodes, result_nodes);
child = child->next;
}
}
void compute_below(node * root, Set * set) {
if (set == 0) {
set = calloc(1, sizeof(set));
}
root->set = set;
set->count++;
node * child = root -> first_child;
while(child){
compute_below(child, set);
child = child->next;
}
}
void remove_node(node * item) {
// subtract_below(item, item->below+1);
int everyChild = item->parent != 0;
node * child = item->first_child;
int childCount = 0;
int toRemove = 1;
while (child) {
childCount++;
if (everyChild || childCount > 1) {
compute_below(child, 0);
toRemove += child->set->count;
}
child->parent = 0;
child = child->next;
}
item->set->count -= toRemove;
node * parent = item->parent;
if(parent){
if(parent->first_child == item){
parent->first_child = item->next;
}
if(item->next){
item->next->prev = item->prev;
}
if(item->prev){
item->prev->next = item->next;
}
}
}
int main(int argc, char **argv){
int n;
scanf("%d\n", &n);
int i = 0;
node ** nodes = calloc(n+1, sizeof(node *));
for(i = 0; i < n-1; i++){
int a,b;
scanf("%d %d\n", &a, &b);
node * node_a = nodes[a];
if(node_a == 0) {
node_a = calloc(1, sizeof(node));
node_a->number = a;
nodes[a] = node_a;
}
node * x = calloc(1, sizeof(node));
x->number = b;
add_child(node_a,x);
node * node_b = nodes[b];
if(node_b == 0){
node_b = calloc(1, sizeof(node));
node_b->number = b;
nodes[b] = node_b;
}
x = calloc(1, sizeof(node));
x->number = a;
add_child(node_b, x);
}
node * root = calloc(1, sizeof(node));
root->number = 1;
node ** result_nodes = calloc(n+1, sizeof(node *));
result_nodes[1] = root;
fill_children(root, nodes, result_nodes);
compute_below(result_nodes[1], 0);
int ans = 0;
int num_queries;
scanf("%d\n", &num_queries);
for(i = 0; i < num_queries; i++){
int m;
scanf("%d\n", &m);
int q = m^ans;
node * n = result_nodes[q];
ans = n->set->count;
printf("%d\n", ans);
remove_node(n);
}
return 0;
}
Tree Splitting C++ Solution
#include<cstdio>
#include<ctime>
#include<cstdlib>
#include<vector>
using namespace std;
int INF, N, M, nr, V, v[400009], ap[200009], t[200009];
bool scos[200009];
vector<int> muchii[200009];
struct nod{
int K, P, nr;
nod *l, *r, *t;
nod (int K, int P, int nr, nod *l, nod *r){
this->nr = nr;
this->l = l;
this->r = r;
this->P = P;
this->K = K;
this->t = 0;
}
} *adresa1[200009], *adresa2[200009], *nil, *R;
int Rand(){
return ((rand() % 32768) << 15) + (rand() % 32768) + 1;
}
void reup(nod *&n){
if(n->l != nil) n->l->t = n;
if(n->r != nil) n->r->t = n;
n->nr = n->l->nr + n->r->nr + 1;
}
void rot_left(nod *&n){
nod *t = n->l;
n->l = t->r, t->r = n;
t->t = n->t;
reup(n);
reup(t);
n = t;
}
void rot_right(nod *&n){
nod *t = n->r;
n->r = t->l, t->l = n;
t->t = n->t;
reup (n);
reup (t);
n = t;
}
void balance(nod *&n){
if(n->l->P > n->P)
rot_left(n);
else if(n->r->P > n->P)
rot_right(n);
}
void Insert(nod *&n, int Poz, int K, int P){
if(n == nil){
n = new nod(K, P, 1, nil, nil);
return;
}
if(n->l->nr >= Poz) Insert (n->l, Poz, K, P);
else Insert(n->r, Poz - n->l->nr - 1, K, P);
reup (n);
balance (n);
}
void Erase(nod *&n, int Poz){
if(n->l->nr >= Poz)
Erase(n->l, Poz);
else if(n->l->nr + 1 < Poz)
Erase(n->r, Poz - n->l->nr - 1);
else{
if(n->l == nil && n->r == nil){
delete n;
n = nil;
}
else{
if(n->l->P > n->r->P)
rot_left (n);
else
rot_right (n);
Erase(n, Poz);
}
}
if(n != nil) reup (n);
}
void join(nod *&R, nod *&Rl, nod *&Rr){
R = new nod(0, 0, Rl->nr + Rr->nr + 1, Rl, Rr);
Erase(R, Rl->nr + 1);
}
void split(nod *&R, nod *&Rl, nod *&Rr, int Poz){
Insert (R, Poz, 0, INF);
Rl = R->l, Rl->t = 0;
Rr = R->r, Rr->t = 0;
delete R, R = nil;
}
bool nu_radacina(nod *&n){
if(n->t != 0 && (n->t->l == n || n->t->r == n)) return 1;
return 0;
}
void det_root(nod *n, nod *&R){
while(1){
if (nu_radacina(n))
n = n->t;
else
break;
}
R = n;
}
int Get_Pos(nod *n){
int ans = n->l->nr + 1;
while (1){
if (nu_radacina(n)){
if(n->t->l == n) n = n->t;
else ans += n->t->l->nr + 1, n = n->t;
}
else break;
}
return ans;
}
void dfs(int nod){
ap[nod] = 1;
v[++nr] = nod;
for(vector<int>::iterator it = muchii[nod].begin(); it != muchii[nod].end(); it++)
if (ap[*it] == 0){
dfs(*it);
t[*it] = nod;
}
v[++nr] = -nod;
}
void build(nod *&n){
if(n == nil) return ;
if(n->K < 0)
adresa2[-n->K] = n;
else
adresa1[n->K] = n;
build(n->l);
build(n->r);
}
void afis(nod *&n){
if (n == nil) return ;
afis(n->l);
printf("%d ", n->K);
afis(n->r);
}
void Del(int a, int b){
if(t[b] == a){
int aux = a;
a = b;
b = aux;
}
nod *R, *R1, *R2, *R3, *R4;
det_root(adresa1[a], R);
int p1, p2;
p1 = Get_Pos(adresa1[a]);
p2 = Get_Pos(adresa2[a]);
split(R, R4, R3, p2);
split(R4, R1, R2, p1 - 1);
R2->t = 0;
join(R, R1, R3);
}
int main(){
srand(time(0));
INF = (1 << 30) + 7;
scanf("%d", &N);
for(int i = 1; i < N; i++){
int X, Y;
scanf("%d %d", &X, &Y);
muchii[X].push_back(Y);
muchii[Y].push_back(X);
}
dfs(1);
nil = new nod(0, 0, 0, 0, 0);
R = nil;
for(int i = 1; i <= nr; i++)
Insert (R, i - 1, v[i], Rand());
build(R);
int V = 0;
scanf("%d", &M);
for(int i = 1; i <= M; i++){
int x;
scanf("%d", &x), x ^= V;
nod *R;
det_root(adresa1[x], R);
printf("%d\n", R->nr / 2), scos[x] = 1, V = R->nr / 2;
for(auto it = muchii[x].begin (); it != muchii[x].end (); it ++)
if (!scos[*it]) Del(x, *it);
}
return 0;
}
Tree Splitting Java Solution
import java.io.*;
import java.util.*;
public class Solution {
static long x = 1;
// Xorshift random number generators
static long marsagliaXor32() {
x ^= x << 13;
x ^= x >> 17;
return x ^= x << 5;
}
static class Node {
int size = 1;
long pri = marsagliaXor32();
Node l = null;
Node r = null;
Node p = null;
Node mconcat() {
this.size = size(l) + 1 + size(r);
if (l != null) {
l.p = this;
}
if (r != null) {
r.p = this;
}
return this;
}
}
static int size(Node x) {
return x != null ? x.size : 0;
}
static Node root(Node x) {
while (x.p != null) {
x = x.p;
}
return x;
}
static long orderOf(Node x) {
long r = size(x.l);
while (x.p != null) {
if (x.p.r == x) {
r += size(x.p.l) + 1;
}
x = x.p;
}
return r;
}
static Node join(Node x, Node y) {
if (x == null) return y;
if (y == null) return x;
if (x.pri < y.pri) {
x.r = join(x.r, y);
return x.mconcat();
} else {
y.l = join(x, y.l);
return y.mconcat();
}
}
static long[] dep;
static List<Integer>[] es;
static Node[] pre;
static Node[] post;
static Node tr = null;
static class NodeDfs {
int u;
int p;
boolean start = true;
public NodeDfs(int u, int p) {
this.u = u;
this.p = p;
}
}
static void dfs(int u, int p) {
Deque<NodeDfs> queue = new LinkedList<>();
queue.add(new NodeDfs(u, p));
while (!queue.isEmpty()) {
NodeDfs node = queue.peek();
if (node.start) {
pre[node.u] = new Node();
tr = join(tr, pre[node.u]);
for (int v: es[node.u]) {
if (v != node.p) {
dep[v] = dep[node.u] + 1;
queue.push(new NodeDfs(v, node.u));
}
}
node.start = false;
} else {
post[node.u] = new Node();
tr = join(tr, post[node.u]);
queue.remove();
}
}
}
static Node[] split(Node x, long k, Node l, Node r) {
if (x == null) {
l = r = null;
} else {
long c = size(x.l) + 1;
if (k < c) {
Node[] res = split(x.l, k, l, x.l);
l = res[0];
x.l = res[1];
r = x;
} else {
Node[] res = split(x.r, k - c, x.r, r);
x.r = res[0];
r = res[1];
l = x;
}
x.mconcat();
x.p = null;
}
return new Node[] {l , r};
}
static void cut(int u, int v) {
if (dep[v] < dep[u]) {
int t = v;
v = u;
u = t;
}
long il = orderOf(pre[v]);
long ir = orderOf(post[v])+1;
Node y = root(pre[v]);
Node z = null;
Node[] res = split(y, ir, y, z);
y = res[0];
z = res[1];
Node x = null;
res = split(y, il, x, y);
x = res[0];
join(x, z);
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
dep = new long[n];
es = new List[n];
pre = new Node[n];
post = new Node[n];
for (int i = 0; i < n; i++) {
es[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken())-1;
int v = Integer.parseInt(st.nextToken())-1;
es[u].add(v);
es[v].add(u);
}
dfs(0, -1);
st = new StringTokenizer(br.readLine());
int queriesCount = Integer.parseInt(st.nextToken());
int result = 0;
for (int i = 0; i < queriesCount; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
u = (result ^ u) - 1;
result = size(root(pre[u])) / 2;
bw.write(String.valueOf(result));
if (i != queriesCount - 1) {
bw.write("\n");
for (int v: es[u]) {
cut(u, v);
}
}
}
bw.newLine();
bw.close();
br.close();
}
}