public class SegmentTreeSum {

private static int[] arr = {1,9,2,3,5,8,90,1,2,88};

private static Node[] stArr;

private static int[] lazy;

private static void update(int root,int low,int high,int diff) {

if(root>=stArr.length) return;

Node n = stArr[root];

//no intersection

if(low > n.r || high < n.l) {

return;

}

if(lazy[root] > 0) {

nonLazyUpdate(root,lazy[root]);

}

//n is contained between low & high

if(n.l >= low && n.r <= high) {

int totalDiff = (n.r - n.l + 1)*diff;

n.val += totalDiff;

updateParents(root,totalDiff);

int left = root*2+1;

int right = root*2+2;

if(left < stArr.length) lazy[left] = diff;

if(right < stArr.length) lazy[right] = diff;

return;

}

//low <= n.r && high >= n.l

update(root*2+1,low,high,diff);

update(root*2+2,low,high,diff);

}

private static void nonLazyUpdate(int root, int diff) {

if(root >= stArr.length) return;

Node n = stArr[root];

if(n==null) return;

n.val += diff*(n.r - n.l + 1);

lazy[root] = 0;

int leftChild = 2*root+1;

int rightChild = 2*root+2;

if(leftChild < stArr.length && lazy[leftChild] > 0) nonLazyUpdate(leftChild, lazy[leftChild]);

if(rightChild < stArr.length && lazy[rightChild] > 0) nonLazyUpdate(rightChild, lazy[rightChild]);

nonLazyUpdate(leftChild, diff);

nonLazyUpdate(rightChild, diff);

}

private static void updateParents(int root, int totalDiff) {

if(root==0) return;

int parent = (root - 1)/2;

Node n = stArr[parent];

n.val += totalDiff;

updateParents(parent,totalDiff);

}

private static int query(int root,int low,int high) {

//System.out.println("query");

if(root>=stArr.length) return 0;

Node n = stArr[root];

if(lazy[root] > 0) {nonLazyUpdate(root, lazy[root]);}

//no intersection

if(low > n.r || high < n.l) {

return 0;

}

//n is contained between low & high

if(n.l >= low && n.r <= high) {

return n.val;

}

//low <= n.r && high >= n.l

int l_sum = query(root*2+1,low,high);

int r_sum = query(root*2+2,low,high);

int sum = l_sum + r_sum;

return sum;

}

private static void construct(int root,int low,int high) {

if(high<low) return;

if(root>=stArr.length) return;

Node n = new Node(low,high);

stArr[root] = n;

if(low==high) { n.val = arr[low]; return;}

int mid = low + (high-low)/2;

construct(2*root+1,low,mid);

construct(2*root+2,mid+1,high);

Node left = stArr[2*root+1];

Node right = stArr[2*root+2];

int lval = left!=null?left.val:0;

int rval = right!=null?right.val:0;

n.val = lval + rval;

}

public static void main(String[] args) {

int stlen = (int) Math.ceil(Math.log(arr.length*1l)/Math.log(2l)) + 1;

stArr = new Node[(int)( Math.pow(2, stlen) - 1)];

lazy = new int[stArr.length];

construct(0,0, arr.length-1);

System.out.println(query(0,0,4));

update(0,0,9,1);

update(0,0,4,1);

update(0,1,1,1);

System.out.println(query(0,0,4));

}

}

class NodeS {

int l;

int r;

int val;

public NodeS(int l,int r) {

this.l = l;

this.r = r;

}

}