Goolge Kickstart 2018 round G problem A writeup
by bugnofree
Publish → 2018-11-29 Update → 2019-01-05

小插曲

自己实在是太菜了, 这道题陆陆续续看了好几天, 从 2018-11-05 看到 2018-11-07,到 8 号的时候提交通过, 本来想立即在博客上写一个题解报告, 但是中间有个小插曲.

在此之前, 我写博客都是在自己定义的 HTML 模板上来写文章,直接写 HTML 的一个好处是, 你可以最大程度的定义你想要的页面,但是事情一旦走到了极端, 你将会丧失掉其他的良好特性, 比如说易读性和修改的方便性.

我当然也尝试过 markdown, 但是不得不说, markdown 语法还是略显复杂,而且写出的文本不同的 markdonw 解析器解析出来的效果不同,其很大的一个根源在于 markdown 没有一个通用的准则.此外用 markdown 写出来的东西, 如果直接用文本来看,一点也不简洁, 看得人很眼花缭乱.

为此, 我从 2018-11-12 号开始了一个称之为  pigger  的项目, 用 golang 写的一个简单文本渲染器.这篇文章就是用 pigger 生成的, pigger 目前暂未开源, 当我认为其可以稳定使用了, 将会开源出来.这个小插曲先到此为止.

Problem A. Product Triplets

题目链接为 https://code.google.com/codejam/...,先来看一下题

这个题的意思是说给一组数, 计算这组数中有多少个三元组 (x, y, z) 使得 x * y = z.直白来讲就是说任取三个数, 只要这三个数中, 其中两个数的乘积等于另外一个数,那么这三个数就是满足条件的一个组合.这里就是让计算这样的组合有多少个.

初步分析

首先要仔细看边界条件, 每个数均为正整数, 范围为 [0, 2 * 10^5].因此两个数的最大的乘积就是 4 * 10^10. 因此我们要注意 int 型是否足够,int 占用 4 个字节, 所以其范围为  [-(2^31 - 1 + 1), 2^31 - 1] => [-2147483648, +2147483647] .也就是说在 -20 亿到 +20 亿之间, 而上面的最大乘积为 40 亿, 因此用 int 是要溢出的.

所以这种能直接从题目要求明显推理出来的东西, 不要再往这个坑里掉.

那么问题应该如何解决? 从直观的角度来看, 我们需要从给定的一组数中不重复的选择 3 个数(指的是位置不重复, 不是说数值不重复),然后判断是否满足给定条件. 从 N 个数中选 3 个数一共需要 N * (N - 1) * (N - 2) / 6 次,时间复杂度是 O(n^3). 粗略估计一下, 7000 ^3 = 343,000,000,000, 三千多亿次计算,反馈到实际的电脑上需要多少时间? 我们也粗略的算一下, 假设电脑的主频是 2.9GHz,也就是说 1 秒钟内有 29 亿个时钟周期, 假设每条指令的执行大概需要 4 个时钟周期,那么就是说 1 秒差不多执行 29 / 4 = 7.25 亿条指令. 而且一次计算往往是多条指令,大致可以看汇编后的指令数, 我们假设是 50 条指令, 那么需要 3430 * 50 / 7.25 / 3600 = 6.6 小时.才 7000 个数据, 就需要好几个小时, 不能忍, Google 给了两个测试样例, 小型数据样例跑一会儿是能通过的.但是大的测试样例, 就跑不动了.

核心思想

现在我们要来解决这个时间 bug, 为什么时间会耗费这么多, 任意选择三个数到底哪里不对?我们考虑一次计算中的三个数 (x, y, z), x * y = z 是我们的目标组合,思考一下这样的等式有什么样可能的组合.

如果能成功的得到以上三种情形, 基本上这个问题就解决了 80%.然而接下来依然需要仔细分析边界条件. 以下我使用 mp[i] 表示数值 i 重复出现的次数, 我们来看一看各种神坑.

各种神坑

第一种情况里面,  任意的两个数为 0 , 那么我们只需要从给的一组数中计算出 0 的个数,然后计算三元组数目. 可是这个里面有 bug, 任意两个数为 0, 也就是说第三个数可以为 0 也可以不为 0.如果两个为 0, 另一个为 r ≠0, 这种情况容易计算:  mp[0] * (mp[0] - 1) * mp[r] .如果另一个也为 0, 那么就必须另外计算:  mp[0] * (mp[0] - 1) * (mp[0] - 2) .不能都套用一个计算方式.

第二种情况里面,  一个数为 1, 其余两个数相等且不为 0 , 同样类似的方式,假设另外两个数均为 r ≠ 0, 计算三元组数目,似乎应该是这样子的: mp[1] * mp[r] * (mp[r] - 1) , 但是和第一种情况里面的类似, 这里也是有 bug 的,如果其余两个数的值都是 1, 那么正确的计算式子为  mp[1] * (mp[1] - 1) * (mp[1] - 2) .

当处理完前两种情况, 就只剩下了第三种情况, 如果当前要处理的数为 r,我们使用开平方的方法从 2 开始到 sqrt(r) 为止, 判断 r 是否为质数.如果 r 为质数, 那么对于 r, 在这种情况下, 没有满足的三元组.否则如果 r 为合数, 那么我们检测每一个因数, 假设为 k,检测 [r/k 是否存, 如果存在, 计算三元组数目.在实际编码的过程中, 我们需要处理 sqrt(r), C++ 里面返回的是一个浮点数,我们需要把它转换为整数, 这里就涉及到是向上取整, 还是向下取整.使用开平方判断一个数是否为质数是这样的思想:

如果一个整数 r(≥2) 是合数, 那么必然存在两个正整数p(≥2), q(≥2) 使得 p * q = r,
此时必然有 p ≤sqrt(r), q ≥sqrt(r).

我们考虑等于号时的情形, 此时 p = q = sqrt(r), 对于其他情形,p 和 q 必然分居  浮点数  sqrt(r) 的两侧. 注意这里我强调了浮点数,因此, 浮点数其下的整数必须是要进行检测的, 如果我们使用向上取整就得到了一个右不可达边界,这个右不可达边界是我自己取的一个名字, 类似于 STL 模板里面的  X.end()  方法,表示最右的一个不可访问的边界, 这个东西在 C 陷阱与缺陷里面有提到过类似的概念.因此这里我采用向上取整, 得到 右不可达边界 , 就完美解决了 sqrt(r) 该取那一侧的问题.

最后不要忘了处理 p = q = sqrt(r) 的情况.

坑踩完了吗? 没有!

最后这个坑严格来说不算是算法上的坑, 只能说是编程语言相关的.我说过用 mp[i] 表示数字 i 出现的次数,键最大是 2*10^5, 值是该数出现的次数,最多是 7000, 用 int 肯定能放下的啦, 于是随手写了一个  map<int, int> mp .一个 bug 就嘿嘿嘿的留在了这里面. 在 C/C++ 里面 int * int 始终是 int,而在计算 mp[i] * mp[j] * mp[k] 的时候, 三个 int 相乘, 结果依旧是 int,是有可能溢出的, 所以要换成  map<int, long long> mp .

编码实现

#include <map>
#include <set>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cmath>

using namespace std;

long long solve(vector<int> &num) {
    set<int> s;
    // the mp[i] must at least be long long since int * int = int but not long long
    map<int, long long> mp;
    for(auto it = num.begin(); it != num.end(); ++it) {
        // the time complexity of map.find() is O(logn)
        if(mp.find(*it) == mp.end()) mp[*it] = 0;
        mp[*it] += 1;
        // insert and sort
        s.insert(*it);
    }

    long long cnt = 0;
    for(auto it = s.begin(); it != s.end(); ++it) {
        int x = *it;

        // mp[0] * (mp[0] - 1) / 2 * (num.size() - 2); WRONG!
        //
        // for (0, 0, 0)
        if(x == 0) {
            cnt += mp[0] * (mp[0] - 1) * (mp[0] - 2) / 6;
            continue;
        // for (0, 0, x != 0)
        } else {
            cnt += mp[0] * (mp[0] - 1) / 2 * mp[x];
        }
        // for (1, 1, 1)
        if(x == 1) {
            cnt += mp[1] * (mp[1] - 1) * (mp[1] - 2) / 6;
            continue;
        // for (1, x != 1, x != 1)
        } else {
            cnt += mp[x] * (mp[x] - 1) / 2 * mp[1];
        }

        // Way 1: for (x, y, z)
        /*
         * int i = 2;
         * for(; i * i < x; ++i) {
         *     if(x % i == 0 && mp.count(i) && mp.count(x / i)) {
         *         cnt += mp[x] * mp[i] * mp [x / i];
         *     }
         * }
         * if(x == i * i) cnt += mp[i] * (mp[i] - 1) * mp[x] / 2;
         */

        // Way 2: for (x, y, z)
        // The loop count of 'i * i < k' is different with 'i < floor(sqrt(k))' in some case
        // Lets say k = 10 , the i in the former will has 2, 3, but the latter only has a 2
        int r = ceil(sqrt(x * 1.0)); // sqrt: O(sqrt(m))
        for(int i = 2; i < r; ++i) {
            if(x % i == 0 && mp.find(i) != mp.end() && mp.find(x / i) != mp.end()) {
                cnt += mp[i] * mp[x / i] * mp[x];
            }
        }
        if(r * r == x) {
            cnt += mp[r] * (mp[r] - 1) * mp[x] / 2;
        }
    }
    return cnt;
}

int main(int argc, char *argv[]) {
    // char input[] = "A-small-practice.in";
    char input[] = "A-large-practice.in";
    freopen(input, "r", stdin);

    int casecnt; cin >> casecnt;
    for(int i = 0; i < casecnt; ++i) {
        int numcnt; cin >> numcnt;
        vector<int> num;
        for(int j = 0; j < numcnt; ++j) {
            int ele; cin >> ele; num.push_back(ele);
        }
        long long satcnt = solve(num);
        cout << "Case #" << i + 1 << ": " << satcnt << endl;
    }
    return 0;
}

时间复杂度

假设 n 个数, 其最大值为 m, 那么时间复杂度为  O(n*sqrt(m)) .