问题:给定一个长度为 n (1 < n < 1000000) 的只含有0
和1
的字符串,求其中0
和1
数量相等的子字符串个数(不计空串)?子字符串指的是原串中下标为i
到j
(0 <= i <= j <= n-1)的字符组成的字符串。
解法一(枚举)
我们很容易想到,要使得0
和1
个数相等,这个字符串的长度一定为偶数。只需要枚举原字符串中长度为偶数的子串,然后检验是否满足条件即可。
int count_substring_slow(char *str) {
unsigned int n = strlen(str);
unsigned int ans = 0;
for (int len = 2; len <= n; ++len) {
for (int start = 0; start <= n - len; ++start) {
int ones = 0, zeros = 0;
for (int i = start; i < start + len; ++i) {
if (str[i] == '0') ++zeros;
else if (str[i] == '1') ++ones;
}
if (ones == zeros) ++ans;
}
}
return ans;
}
然而,枚举的时间复杂度过高,达到了\(O(n^3)\)级别,我们要想办法优化。
解法二(前缀和)
解法一时间复杂度高的其中一个原因是在统计子串字符数量时进行了过多重复的运算。举个例子,假如我们需要统计[i, j]的 0 和 1 的数量的时候其实可以利用[i, j – 1]的 0 和 1 的数量来进行增量计算。但是如果我们用 ones[i][j]
和 zeros[i][j]
来存储中间结果,不仅十分消耗空间,甚至可能超出了栈内存空间,导致运行时错误。
我们可以想到,可以直接将 0 和 1 求和,只要一个长度为 n 的字符串求得的结果为 n/2 即可,只需要一个数组 sum[i][j]
就可以了。但是在面对 1000000
这个数据量的时候,sum
数组仍然会超出限制。数据量限制了我们必须用一维数组。这时候,前缀和就派上了用场。
前缀和,顾名思义,就是记录一个数组前 n 项和的数组。要求得任意一个下标区间为 [i, j] (i > 0)的子数组的和,只需要计算 sum[j] - sum[i - 1]
即可。这样,我们可以先扫描一遍数组,计算出前缀和,然后计算长度为偶数的子区间的和是否等于 n / 2 即可。
int count_substring_sum(char *str) {
unsigned int n = strlen(str);
unsigned int ans = 0;
int tmp_sum = 0;
int sum[n]; // VLA
for (int i = 0; i < n; ++i) {
tmp_sum += (str[i] - '0');
sum[i] = tmp_sum;
}
for (int len = 2; len <= n; ++len) {
for (int start = 0; start <= n - len; ++start) {
if (start == 0 && sum[len - 1] == len / 2) {
++ans;
}
else if (sum[start + len - 1] - sum[start - 1] == len / 2) {
++ans;
}
}
}
return ans;
}
这种解法相较于直接枚举,少了一层循环,但是时间复杂度仍然有O(n^2),对于百万级别仍然是不可接受的,需要继续改进。
解法三
为了方便计算,我们可以定义 0 的价值为 -1,1 的价值为 1,只要我们计算得到的子区间的和为 0 即可。然而这仅仅只能少做了除法,没有根本上改变算法的复杂度。那么,我们还能从前缀和发现什么规律呢?
仔细观察发现,如果前缀和 sum[i] == sum[j],那么说明在(i, j]这一段的子数组区间和为 0。这样,我们只需要找到有多少对这样的 i 和 j 即可。如果我们采用扫描两次的方法,仍然效率不高,我们可以直接统计前缀和为 k 的数量 nk,那么我们只需要对所有这样的 k 所组成的组合数 Ck^2 进行求和,然后特别地再加上前缀和为 0 的数量,就是我们所要的答案。
由于 k 有可能小于 0,但最小只会是 -n,所以我们直接将 k 加上 n 进行哈希,就可以将其限制在数组范围内了。
#define score(ch) (ch == '1' ? 1 : -1)
int count_substring_fast(char *str) {
unsigned int n = strlen(str);
unsigned int ans = 0;
int *counter = new int[2 * n + 1](); // initialize
int tmp_sum = 0;
for (int i = 0; i < n; ++i) {
tmp_sum += score(str[i]);
if (tmp_sum == 0) ++ans;
counter[n + tmp_sum]++;
}
for (int i = 0; i < 2 * n + 1; ++i) {
if (counter[i]) ans += ((counter[i]) * (counter[i] - 1) / 2);
}
delete[] counter;
return ans;
}
算法则优化到了O(n)的线性时间复杂度级别。