Codechef Link

This is a problem from the Codechef SnackDown Elimination Round 2017. During the contest, I read the statements, and immediately started coding a plain segment tree, thinking it would work, only to find that I couldn’t merge the left and right segments at around 10 LOC or so. Anyway, a few days ago, I decided to give the problem another try.

My first observation is that, for each l, if we can find the maximum r so that the pair (l, r) is permissible (i.e. a[l] <= a[l] ^ a[l + 1] <= ... <= a[l] ^ a[l + 1] ^ ... a[r]), then (l, k) is also permissible iff l <= k <= r. As such, let the maximum r for each l be maxDst[l], then the answer for the segment (l, r) is sum { min(maxDst[k], r) - k + 1 | l <= k <= r }. This query can be done O(log n) on a segment tree with a merge-split treap in each node by splitting each treap by maxDst[k] == r, and adding up sum { maxDst[k] - k + 1 } and sum { r - k + 1 } for the left and right side respectively. However, since no updates are needed, we can just use a merge-sort tree, with prefix sums for maxDst[k] - k + 1 and suffix sums for - k + 1 in each node. (The latter is summed with r * [NUMBER OF ELEMENTS] during the query.)

The problem is then transformed into solving the maxDst array. We can see that when k == 1, we have to calculate the relations of a[1], a[1] ^ a[2], ... , a[1] ^ a[2] ^ ... a[n], and when k == 2, the sequence turns into a[2], a[2] ^ a[3], .... Note that we can just drop the first element from the former and xor all the elements with a[1] to get the latter. In other words, we have to track the relations while supporting modifications that xor all the elements with a certain integer.

Looking into how xor works, we can observe that x <= y becomes x ^ k > y ^ k iff the most significant bit that differs between x and y is xored with 1. Thus, we can categorize the relations into 32 sets, and turn this into a modal where we have a sequence of 0s and 1s, standing for <= and > respectively. For each xor operation, we flip some certain sets of elements according to the binary representation of the number so that we can track the relations. Then, if we can query the sum of a segment, we can know whether the sequence is increasing by checking whether the sum is zero, thus solving maxDst via binary search.

In order to do the operations mentioned above, the first thing that came into mind are segment trees. However, since the modifications are non-continuous, we have to use multiple segment trees, one for each set, querying and modifying them separately. Unfortunately, this method is too slow and gets TLEs. On further inspection, it can be observed that the only modifications are the flipping, which are only done on the root. In other words, none of the nodes, except the roots, are modified. As such, we can turn to prefix sums and maintain a boolean array on whether a set is flipped. If the set is flipped, the query returns N - ans, where ans is the original return value, and N is the total number of elements in the segment, which can be determined by another prefix sum array.

To recap, the procedure is as follows:

  • Create an array of a[1], a[1] ^ a[2], ... a[1] ^ a[2] ^ ... a[n] and record the relations between every pair (a[i], a[i + 1]). (O(n))
  • Place each relation into two prefix sums (one for tracking the > count, and the other for tracking the number of elements) according to their binary representation (one of the 32 sets). (O(n))
  • Scan from left to right. For each l, binary search the minimum r so that the segment (l, r) has a sum > 0 on the prefix sum (querying the 32 sets separately). maxDst[l] can then be updated to r. Then, do the modification by xorring with a[l], i.e. flip the sets whose corresponding bit in a[l] is 1. (O(n * log n * log C))
  • Build a merge-sort tree. In each node, maintain not only maxDst[] but also the prefix sums of maxDst[k] - k + 1 and suffix sums of -k + 1. (O(n * log n))
  • For each query (l, r), search the segment on the merge-sort tree. For each node hit, binary search maxDst for r, and sum up the prefix sums left of r and suffix sums right of r together with r * [NUMBER OF ELEMENTS ON THE RIGHT]. (O(q * log n), since for each query, the number of binary searches are bounded by a constant.)

Total time complexity: O(n * log n * log C + q * log n)

Code is as follows:

#include <bits/stdc++.h>
using namespace std;

const int MAXN = 4E5 + 10;

int sum[32][MAXN][2], arr[MAXN], xorred[MAXN], maxDst[MAXN], n, queryR;
bool isRev[32] = { false };
vector<int> typeToIdx[32];
vector<int64_t> msPreSum[MAXN * 4], msSufSum[MAXN * 4];
vector<pair<int, int> > msTree[MAXN * 4];

bool queryMulti(int l, int r) {
	for (int i = 0; i < 32; i++) {
		int t = sum[i][r][0] - sum[i][l - 1][0];
		if (isRev[i]) {
			if (sum[i][r][1] - sum[i][l - 1][1] != t) return false;
		} else {
			if (t != 0) return false;
		}
	}
	return true;
}

void modifyMulti(int x) {
	for (int i = 0; i < 32; i++) {
		if (x & (1 << i)) isRev[i + 1] = !isRev[i + 1];
	}
}

void buildMs(int id, int l, int r) {
	if (l == r) {
		msTree[id].push_back({ maxDst[l], maxDst[l] - l + 1 });
		msPreSum[id] = { 0, maxDst[l] - l + 1 };
		msSufSum[id] = { l, 0 };
	} else {
		int m = (l + r) >> 1;
		buildMs(id << 1, l, m); buildMs(id << 1 | 1, m + 1, r);
		msTree[id].resize(msTree[id << 1].size() + msTree[id << 1 | 1].size());
		merge(msTree[id << 1].begin(), msTree[id << 1].end(),
			msTree[id << 1 | 1].begin(), msTree[id << 1 | 1].end(),
			msTree[id].begin());
		msPreSum[id].resize(msTree[id].size() + 1);
		for (int i = 1; i <= (int)msTree[id].size(); i++) {
			msPreSum[id][i] = msPreSum[id][i - 1] + msTree[id][i - 1].second;
		}
		msSufSum[id].resize(msTree[id].size() + 1);
		for (int i = msTree[id].size(); i >= 1; i--) {
			msSufSum[id][i - 1] = msSufSum[id][i] +
				(msTree[id][i - 1].first - msTree[id][i - 1].second + 1);
		}
	}
}

int64_t queryMs(int id, int l, int r, int qL, int qR) {
	if (qL <= l && r <= qR) {
		int idx = lower_bound(msTree[id].begin(), msTree[id].end(),
			make_pair(queryR, 0)) - msTree[id].begin();
		return msPreSum[id][idx] +
			(int64_t)(msTree[id].size() - idx) * (queryR + 1) - msSufSum[id][idx];
	} else {
		int m = (l + r) >> 1;
		int64_t t = 0;
		if (qL <= m) t += queryMs(id << 1, l, m, qL, min(m, qR));
		if (m < qR) t += queryMs(id << 1 | 1, m + 1, r, max(qL, m), qR);
		return t;
	}
}

int32_t main() {
#ifdef OJ_DEBUG
	freopen("out2", "r", stdin);
#endif
	std::cin.tie(0);
	std::ios_base::sync_with_stdio(0);
	cin >> n;
	for (int i = 1; i <= n; i++) cin >> arr[i];
	for (int i = 1; i <= n; i++) xorred[i] = xorred[i - 1] ^ arr[i];
	for (int i = 1; i < n; i++) {
		int j;
		for (j = 31; j >= 0; j--) {
			if ((xorred[i] & (1 << j)) != (xorred[i + 1] & (1 << j))) {
				typeToIdx[j + 1].push_back(i);
				if (xorred[i] > xorred[i + 1]) sum[j + 1][i][0]++;
				sum[j + 1][i][1]++;
				break;
			}
		}
		if (j < 0) {
			typeToIdx[0].push_back(i);
			sum[0][i][1]++;
		}
	}
	for (int i = 0; i < 32; i++) {
		for (int j = 1; j <= n; j++) {
			sum[i][j][0] += sum[i][j - 1][0];
			sum[i][j][1] += sum[i][j - 1][1];
		}
	}
	for (int i = 1; i < n; i++) {
		int l = i, r = n;
		while (l != r) {
			int m = (l + r) >> 1;
			if (queryMulti(i, m)) l = m + 1;
			else r = m;
		}
		maxDst[i] = l;
		modifyMulti(arr[i]);
	}
	maxDst[n] = n;
	buildMs(1, 1, n);
	int q;
	int64_t lastAns = 0;
	cin >> q;
	while (q--) {
		int l, r; cin >> l >> r;
		l = (l + lastAns) % n + 1;
		queryR = r = (r + lastAns) % n + 1;
		lastAns = queryMs(1, 1, n, l, r);
		cout << lastAns << '\n';
	}
}