Tuesday, September 27, 2016

Segment Tree Sum with Lazy Update

public class SegmentTreeSum {
private static int[] arr = {1,9,2,3,5,8,90,1,2,88};
private static Node[] stArr;
private static int[] lazy;
private static void update(int root,int low,int high,int diff) {
if(root>=stArr.length) return;
Node n = stArr[root];

//no intersection
if(low > n.r || high < n.l) {
return;
}

if(lazy[root] > 0) {
nonLazyUpdate(root,lazy[root]);
}
//n is contained between low & high
if(n.l >= low && n.r <= high) {
int totalDiff = (n.r - n.l + 1)*diff;
n.val += totalDiff;
updateParents(root,totalDiff);
int left = root*2+1;
int right = root*2+2;
if(left < stArr.length) lazy[left] = diff;
if(right < stArr.length) lazy[right] = diff;
return;
}
//low <= n.r && high >= n.l
update(root*2+1,low,high,diff);
update(root*2+2,low,high,diff);
}
private static void nonLazyUpdate(int root, int diff) {
if(root >= stArr.length) return;
Node n = stArr[root];
if(n==null) return;
n.val += diff*(n.r - n.l + 1);
lazy[root] = 0;
int leftChild = 2*root+1;
int rightChild = 2*root+2;
if(leftChild < stArr.length && lazy[leftChild] > 0) nonLazyUpdate(leftChild, lazy[leftChild]);
if(rightChild < stArr.length && lazy[rightChild] > 0) nonLazyUpdate(rightChild, lazy[rightChild]);
nonLazyUpdate(leftChild, diff);
nonLazyUpdate(rightChild, diff);
}

private static void updateParents(int root, int totalDiff) {
if(root==0) return;
int parent = (root - 1)/2;
Node n = stArr[parent];
n.val += totalDiff;
updateParents(parent,totalDiff);
}

private static int query(int root,int low,int high) {
//System.out.println("query");
if(root>=stArr.length) return 0;
Node n = stArr[root];
if(lazy[root] > 0) {nonLazyUpdate(root, lazy[root]);}
//no intersection
if(low > n.r || high < n.l) {
return 0;
}
//n is contained between low & high
if(n.l >= low && n.r <= high) {
return n.val;
}
//low <= n.r && high >= n.l
int l_sum = query(root*2+1,low,high);
int r_sum = query(root*2+2,low,high);
int sum = l_sum + r_sum;
return sum;
}
private static void construct(int root,int low,int high) {
if(high<low) return;
if(root>=stArr.length) return;
Node n = new Node(low,high);
stArr[root] = n;
if(low==high) { n.val = arr[low]; return;}
int mid = low + (high-low)/2;
construct(2*root+1,low,mid);
construct(2*root+2,mid+1,high);
Node left = stArr[2*root+1];
Node right = stArr[2*root+2];
int lval  = left!=null?left.val:0;
int rval  = right!=null?right.val:0;
n.val = lval + rval;
}
public static void main(String[] args) {
int stlen = (int) Math.ceil(Math.log(arr.length*1l)/Math.log(2l)) + 1;
stArr = new Node[(int)( Math.pow(2, stlen) - 1)];
lazy = new int[stArr.length];
construct(0,0, arr.length-1);
System.out.println(query(0,0,4));
update(0,0,9,1);
update(0,0,4,1);
update(0,1,1,1);
System.out.println(query(0,0,4));
}
}

class NodeS {
int l;
int r;
int val;
public NodeS(int l,int r) {
this.l = l;
this.r = r;
}
}

No comments:

Blog Archive