Range Sum Query - Mutable

307. Range Sum Query - Mutable

Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.

The update(i, val) function modifies nums by updating the element at index i to val.
Example:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9  
update(1, 2)  
sumRange(0, 2) -> 8  

Note:
The array is only modifiable by the update function.
You may assume the number of calls to update and sumRange function is distributed evenly.

class SegmentTreeNode {  
    public int start, end;
    public int sum;
    public SegmentTreeNode left;
    public SegmentTreeNode right;

    public SegmentTreeNode(int start, int end) {
        this.start = start;
        this.end = end;
        this.sum = 0;
        this.left = null;
        this.right = null;
    }
}

public class NumArray {  
    SegmentTreeNode root;

    public NumArray(int[] nums) {
        if (nums == null || nums.length == 0) {
            return;
        }

        root = buildTree(nums, 0, nums.length - 1);
    }

    SegmentTreeNode buildTree(int[] nums, int start, int end) {
        if (start > end) {
            return null;
        }

        SegmentTreeNode node = new SegmentTreeNode(start, end);

        if (start == end) {
            node.sum = nums[start];
        } else {
            int middle = start + (end - start) / 2;

            node.left = buildTree(nums, start, middle);
            node.right = buildTree(nums, middle + 1, end);
            node.sum = node.left.sum + node.right.sum;
        }

        return node;
    }

    void update(SegmentTreeNode node, int i, int val) {
        if (node == null) {
            return;
        }

        if (i < node.start || i > node.end) {
            return;
        }

        if (node.start == i && node.end == i) {
            node.sum = val;
            return;
        }

        int middle = node.start + (node.end - node.start) / 2;
        if (i <= middle) {
            update(node.left, i, val);
        } else {
            update(node.right, i, val);
        }
        node.sum = node.left.sum + node.right.sum;
    }

    void update(int i, int val) {
        update(root, i, val);
    }

    int query(SegmentTreeNode node, int start, int end) {
        if (node == null) {
            return 0;
        }

        if (start <= node.start && end >= node.end) {
            return node.sum;
        }

        int middle = node.start + (node.end - node.start) / 2;

        if (end <= middle) {
            return query(node.left, start, end);
        } else if (start > middle) {
            return query(node.right, start, end);
        } else {
            return query(node.left, start, middle) + query(node.right, middle + 1, end);
        }
    }

    public int sumRange(int i, int j) {
        return query(root, i, j);
    }
}

// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);

Hope this helps,
Michael