考虑下数据范围:N = 5e10
所以a,b,c,d每个数都一定<= sqrt(N) < 2300
最简单的想法是三层for循环枚举a,b,c,然后计算出d,但是在AcWing上这一题的数据加强后会超时
1 //超时 2 #include <bits/stdc++.h> 3 using namespace std; 4 int main() { 5 int n; 6 cin >> n; 7 for (int a = 0; a * a <= n; a++) { 8 for (int b = a; a * a + b * b <= n; b++) { 9 for (int c = b; a * a + b * b + c * c <= n; c++) { 10 int t = n - a * a - b * b - c * c; 11 int d = sqrt(t); 12 if (d * d == t) { 13 cout << a << " " << b << " " << c << " " << d << endl; 14 return 0; 15 } 16 } 17 } 18 } 19 return 0; 20 }
按照2300的数据范围看时间复杂度,最多只能两层for循环枚举两个数
考虑用空间换时间
二分做法AC代码,时间复杂度n ^ 2 log n
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 2500010; 4 struct Sum { //用结构体存c ^ 2 + d ^ 2 和 c 和 d这三个数 5 int s, c, d; 6 } sum[N]; 7 bool cmp(Sum s1, Sum s2) { 8 if (s1.s != s2.s) { 9 return s1.s < s2.s; 10 } 11 if (s1.c != s2.c) { 12 return s1.c < s2.c; 13 } 14 return s1.d < s2.d; 15 } 16 int main() { 17 int n, m; //m表示所有组合的个数 18 cin >> n; 19 for (int c = 0; c * c <= n; c++) { //存下c和d的组合 20 for (int d = c; c * c + d * d <= n; d++) { 21 Sum t; 22 t.s = c * c + d * d; 23 t.c = c; 24 t.d = d; 25 sum[m++] = t; 26 //sum[m++] = {c * c + d * d, c, d}; //c++11的新特性 27 } 28 } 29 sort (sum, sum + m, cmp); 30 for (int a = 0; a * a <= n; a++) { 31 for (int b = a; a * a + b * b <= n; b++) { 32 int t = n - a * a - b * b; //差值 33 //找到 >= 这个差值的最小的一个数 34 int l = 0, r = m - 1; 35 while (l < r) { 36 int mid = l + r >> 1; 37 if (sum[mid].s >= t) { 38 r = mid; 39 } else { 40 l = mid + 1; 41 } 42 } 43 if (sum[l].s == t) { 44 cout << a << " " << b << " " << sum[l].c << " " << sum[l].d; 45 return 0; 46 } 47 } 48 } 49 return 0; 50 }
然后是unordered_map哈希做法,时间复杂度n ^ 2,结果正确但是超时了,不能AC,不过可以熟悉下unordered_map的用法
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef pair<int, int> PII; 4 unordered_map<int, PII> S; 5 int main() { 6 int n; 7 cin >> n; 8 for (int c = 0; c * c <= n; c++) { 9 for (int d = c; c * c + d * d <= n; d++) { 10 int t = c * c + d * d; 11 if (S.count(t) == 0) { //如果没有t的话 12 //S[t] = {c, d}; 13 S[t] = make_pair(c, d); 14 } 15 } 16 } 17 for (int a = 0; a * a <= n; a++) { 18 for (int b = a; a * a + b * b <= n; b++) { 19 int t = n - a * a - b * b; 20 if (S.count(t)) { 21 cout << a << " " << b << " " << S[t].first << " " << S[t].second << endl; 22 return 0; 23 } 24 } 25 } 26 return 0; 27 }