File Coverage

aks.c
Criterion Covered Total %
statement 55 164 33.5
branch 34 152 22.3
condition n/a
subroutine n/a
pod n/a
total 89 316 28.1


line stmt bran cond sub pod time code
1             #include
2             #include
3             #include
4             #include
5             #include
6              
7             /* The AKS primality algorithm for native integers.
8             *
9             * There are three versions here:
10             * V6 The v6 algorithm from the latest AKS paper.
11             * BORNEMANN Improvements from Bernstein, Voloch, and a clever r/s
12             * selection from Folkmar Bornemann. Similar to Bornemann's
13             * 2003 Pari/GP implementation
14             * BERN41 My implementation of theorem 4.1 from Bernstein's 2003 paper.
15             *
16             * Each one is orders of magnitude faster than the previous, and by default
17             * we use Bernstein 4.1 as it is by far the fastest.
18             *
19             * Note that AKS is very, very slow compared to other methods. It is, however,
20             * polynomial in log(N), and log-log performance graphs show nice straight
21             * lines for both implementations. However APR-CL and ECPP both start out
22             * much faster and the slope will be less for any sizes of N that we're
23             * interested in.
24             *
25             * For native 64-bit integers this is purely a coding exercise, as BPSW is
26             * a million times faster and gives proven results.
27             *
28             *
29             * When n < 2^(wordbits/2)-1, we can do a straightforward intermediate:
30             * r = (r + a * b) % n
31             * If n is larger, then these are replaced with:
32             * r = addmod( r, mulmod(a, b, n), n)
33             * which is a lot more work, but keeps us correct.
34             *
35             * Software that does polynomial convolutions followed by a modulo can be
36             * very fast, but will fail when n >= (2^wordbits)/r.
37             *
38             * This is all much easier in GMP.
39             *
40             * Copyright 2012-2016, Dana Jacobsen.
41             */
42              
43             #define SQRTN_SHORTCUT 1
44              
45             #define IMPL_V6 0 /* From the primality_v6 paper */
46             #define IMPL_BORNEMANN 0 /* From Bornemann's 2002 implementation */
47             #define IMPL_BERN41 1 /* From Bernstein's early 2003 paper */
48              
49             #include "ptypes.h"
50             #include "aks.h"
51             #define FUNC_isqrt 1
52             #define FUNC_gcd_ui 1
53             #include "util.h"
54             #include "cache.h"
55             #include "mulmod.h"
56             #include "factor.h"
57              
58             #if IMPL_BORNEMANN || IMPL_BERN41
59             /* We could use lgamma, but it isn't in MSVC and not in pre-C99. The only
60             * sure way to find if it is available is test compilation (ala autoconf).
61             * Instead, we'll just use our own implementation.
62             * See http://mrob.com/pub/ries/lanczos-gamma.html for alternates. */
63             static double lanczos_coef[8+1] =
64             { 0.99999999999980993, 676.5203681218851, -1259.1392167224028,
65             771.32342877765313, -176.61502916214059, 12.507343278686905,
66             -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7 };
67             static double log_sqrt_two_pi = 0.91893853320467274178;
68 552           static double log_gamma(double x)
69             {
70 552           double base = x + 7 + 0.5;
71 552           double sum = 0;
72             int i;
73 4968 100         for (i = 8; i >= 1; i--)
74 4416           sum += lanczos_coef[i] / (x + (double)i);
75 552           sum += lanczos_coef[0];
76 552           sum = log_sqrt_two_pi + log(sum/x) + ( (x+0.5)*log(base) - base );
77 552           return sum;
78             }
79             #undef lgamma
80             #define lgamma(x) log_gamma(x)
81             #endif
82              
83             #if IMPL_BERN41
84 184           static double log_binomial(UV n, UV k)
85             {
86 184           return log_gamma(n+1) - log_gamma(k+1) - log_gamma(n-k+1);
87             }
88 46           static double log_bern41_binomial(UV r, UV d, UV i, UV j, UV s)
89             {
90 92           return log_binomial( 2*s, i)
91 46           + log_binomial( d, i)
92 46           + log_binomial( 2*s-i, j)
93 46           + log_binomial( r-2-d, j);
94             }
95 46           static int bern41_acceptable(UV n, UV r, UV s)
96             {
97 46           double scmp = ceil(sqrt( (r-1)/3.0 )) * log(n);
98 46           UV d = (UV) (0.5 * (r-1));
99 46           UV i = (UV) (0.475 * (r-1));
100 46           UV j = i;
101 46 50         if (d > r-2) d = r-2;
102 46 50         if (i > d) i = d;
103 46 50         if (j > (r-2-d)) j = r-2-d;
104 46           return (log_bern41_binomial(r,d,i,j,s) >= scmp);
105             }
106             #endif
107              
108             #if 0
109             /* Naive znorder. Works well if limit is small. Note arguments. */
110             static UV order(UV r, UV n, UV limit) {
111             UV j;
112             UV t = 1;
113             for (j = 1; j <= limit; j++) {
114             t = mulmod(t, n, r);
115             if (t == 1)
116             break;
117             }
118             return j;
119             }
120             static void poly_print(UV* poly, UV r)
121             {
122             int i;
123             for (i = r-1; i >= 1; i--) {
124             if (poly[i] != 0)
125             printf("%lux^%d + ", poly[i], i);
126             }
127             if (poly[0] != 0) printf("%lu", poly[0]);
128             printf("\n");
129             }
130             #endif
131              
132 0           static void poly_mod_mul(UV* px, UV* py, UV* res, UV r, UV mod)
133             {
134             UV degpx, degpy;
135             UV i, j, pxi, pyj, rindex;
136              
137             /* Determine max degree of px and py */
138 0 0         for (degpx = r-1; degpx > 0 && !px[degpx]; degpx--) ; /* */
    0          
139 0 0         for (degpy = r-1; degpy > 0 && !py[degpy]; degpy--) ; /* */
    0          
140             /* We can sum at least j values at once */
141 0 0         j = (mod >= HALF_WORD) ? 0 : (UV_MAX / ((mod-1)*(mod-1)));
142              
143 0 0         if (j >= degpx || j >= degpy) {
    0          
144             /* res will be written completely, so no need to set */
145 0 0         for (rindex = 0; rindex < r; rindex++) {
146 0           UV sum = 0;
147 0           j = rindex;
148 0 0         for (i = 0; i <= degpx; i++) {
149 0 0         if (j <= degpy)
150 0           sum += px[i] * py[j];
151 0 0         j = (j == 0) ? r-1 : j-1;
152             }
153 0           res[rindex] = sum % mod;
154             }
155             } else {
156 0           memset(res, 0, r * sizeof(UV)); /* Zero result accumulator */
157 0 0         for (i = 0; i <= degpx; i++) {
158 0           pxi = px[i];
159 0 0         if (pxi == 0) continue;
160 0 0         if (mod < HALF_WORD) {
161 0 0         for (j = 0; j <= degpy; j++) {
162 0           pyj = py[j];
163 0 0         rindex = i+j; if (rindex >= r) rindex -= r;
164 0           res[rindex] = (res[rindex] + (pxi*pyj) ) % mod;
165             }
166             } else {
167 0 0         for (j = 0; j <= degpy; j++) {
168 0           pyj = py[j];
169 0 0         rindex = i+j; if (rindex >= r) rindex -= r;
170 0           res[rindex] = muladdmod(pxi, pyj, res[rindex], mod);
171             }
172             }
173             }
174             }
175 0           memcpy(px, res, r * sizeof(UV)); /* put result in px */
176 0           }
177 0           static void poly_mod_sqr(UV* px, UV* res, UV r, UV mod)
178             {
179             UV c, d, s, sum, rindex, maxpx;
180 0           UV degree = r-1;
181 0           int native_sqr = (mod > isqrt(UV_MAX/(2*r))) ? 0 : 1;
182              
183 0           memset(res, 0, r * sizeof(UV)); /* zero out sums */
184             /* Discover index of last non-zero value in px */
185 0 0         for (s = degree; s > 0; s--)
186 0 0         if (px[s] != 0)
187 0           break;
188 0           maxpx = s;
189             /* 1D convolution */
190 0 0         for (d = 0; d <= 2*degree; d++) {
191             UV *pp1, *pp2, *ppend;
192 0 0         UV s_beg = (d <= degree) ? 0 : d-degree;
193 0           UV s_end = ((d/2) <= maxpx) ? d/2 : maxpx;
194 0 0         if (s_end < s_beg) continue;
195 0           sum = 0;
196 0           pp1 = px + s_beg;
197 0           pp2 = px + d - s_beg;
198 0           ppend = px + s_end;
199 0 0         if (native_sqr) {
200 0 0         while (pp1 < ppend)
201 0           sum += 2 * *pp1++ * *pp2--;
202             /* Special treatment for last point */
203 0           c = px[s_end];
204 0 0         sum += (s_end*2 == d) ? c*c : 2*c*px[d-s_end];
205 0 0         rindex = (d < r) ? d : d-r; /* d % r */
206 0           res[rindex] = (res[rindex] + sum) % mod;
207             #if HAVE_UINT128
208             } else {
209 0           uint128_t max = ((uint128_t)1 << 127) - 1;
210 0           uint128_t c128, sum128 = 0;
211              
212 0 0         while (pp1 < ppend) {
213 0           c128 = ((uint128_t)*pp1++) * ((uint128_t)*pp2--);
214 0 0         if (c128 > max) c128 %= mod;
215 0           c128 <<= 1;
216 0 0         if (c128 > max) c128 %= mod;
217 0           sum128 += c128;
218 0 0         if (sum128 > max) sum128 %= mod;
219             }
220 0           c128 = px[s_end];
221 0 0         if (s_end*2 == d) {
222 0           c128 *= c128;
223             } else {
224 0           c128 *= px[d-s_end];
225 0 0         if (c128 > max) c128 %= mod;
226 0           c128 <<= 1;
227             }
228 0 0         if (c128 > max) c128 %= mod;
229 0           sum128 += c128;
230 0 0         if (sum128 > max) sum128 %= mod;
231 0 0         rindex = (d < r) ? d : d-r; /* d % r */
232 0           res[rindex] = ((uint128_t)res[rindex] + sum128) % mod;
233             #else
234             } else {
235             while (pp1 < ppend) {
236             UV p1 = *pp1++;
237             UV p2 = *pp2--;
238             sum = addmod(sum, mulmod(2, mulmod(p1, p2, mod), mod), mod);
239             }
240             c = px[s_end];
241             if (s_end*2 == d)
242             sum = addmod(sum, sqrmod(c, mod), mod);
243             else
244             sum = addmod(sum, mulmod(2, mulmod(c, px[d-s_end], mod), mod), mod);
245             rindex = (d < r) ? d : d-r; /* d % r */
246             res[rindex] = addmod(res[rindex], sum, mod);
247             #endif
248             }
249             }
250 0           memcpy(px, res, r * sizeof(UV)); /* put result in px */
251 0           }
252              
253 0           static UV* poly_mod_pow(UV* pn, UV power, UV r, UV mod)
254             {
255             UV *res, *temp;
256              
257 0 0         Newz(0, res, r, UV);
258 0 0         New(0, temp, r, UV);
259 0           res[0] = 1;
260              
261 0 0         while (power) {
262 0 0         if (power & 1) poly_mod_mul(res, pn, temp, r, mod);
263 0           power >>= 1;
264 0 0         if (power) poly_mod_sqr(pn, temp, r, mod);
265             }
266 0           Safefree(temp);
267 0           return res;
268             }
269              
270 0           static int test_anr(UV a, UV n, UV r)
271             {
272             UV* pn;
273             UV* res;
274             UV i;
275 0           int retval = 1;
276              
277 0 0         Newz(0, pn, r, UV);
278 0           a %= r;
279 0           pn[0] = a;
280 0           pn[1] = 1;
281 0           res = poly_mod_pow(pn, n, r, n);
282 0           res[n % r] = addmod(res[n % r], n - 1, n);
283 0           res[0] = addmod(res[0], n - a, n);
284              
285 0 0         for (i = 0; i < r; i++)
286 0 0         if (res[i] != 0)
287 0           retval = 0;
288 0           Safefree(res);
289 0           Safefree(pn);
290 0           return retval;
291             }
292              
293             /*
294             * Avanzi and Mihǎilescu, 2007
295             * http://www.uni-math.gwdg.de/preda/mihailescu-papers/ouraks3.pdf
296             * "As a consequence, one cannot expect the present variants of AKS to
297             * compete with the earlier primality proving methods like ECPP and
298             * cyclotomy." - conclusion regarding memory consumption
299             */
300 7           int is_aks_prime(UV n)
301             {
302 7           UV r, s, a, starta = 1;
303             int verbose;
304              
305 7 100         if (n < 2)
306 2           return 0;
307 5 100         if (n == 2)
308 1           return 1;
309              
310 4 50         if (is_power(n, 0))
311 0           return 0;
312              
313 4 50         if (n > 11 && ( !(n%2) || !(n%3) || !(n%5) || !(n%7) || !(n%11) )) return 0;
    50          
    50          
    50          
    50          
    50          
314             /* if (!is_prob_prime(n)) return 0; */
315              
316 4           verbose = _XS_get_verbose();
317             #if IMPL_V6
318             {
319             UV sqrtn = isqrt(n);
320             double log2n = log(n) / log(2); /* C99 has a log2() function */
321             UV limit = (UV) floor(log2n * log2n);
322              
323             if (verbose) { printf("# aks limit is %lu\n", (unsigned long) limit); }
324              
325             for (r = 2; r < n; r++) {
326             if ((n % r) == 0)
327             return 0;
328             #if SQRTN_SHORTCUT
329             if (r > sqrtn)
330             return 1;
331             #endif
332             if (znorder(n, r) > limit)
333             break;
334             }
335              
336             if (r >= n)
337             return 1;
338              
339             s = (UV) floor(sqrt(r-1) * log2n);
340             }
341             #endif
342             #if IMPL_BORNEMANN
343             {
344             UV fac[MPU_MAX_FACTORS+1];
345             UV slim;
346             double c1, c2, x;
347             double const t = 48;
348             double const t1 = (1.0/((t+1)*log(t+1)-t*log(t)));
349             double const dlogn = log(n);
350             r = next_prime( (UV) (t1*t1 * dlogn*dlogn) );
351             while (!is_primitive_root(n,r,1))
352             r = next_prime(r);
353              
354             slim = (UV) (2*t*(r-1));
355             c1 = lgamma(r-1);
356             c2 = dlogn * floor(sqrt(r));
357             { /* Binary search for first s in [1,slim] where x >= 0 */
358             UV i = 1;
359             UV j = slim;
360             while (i < j) {
361             s = i + (j-i)/2;
362             x = (lgamma(r-1+s) - c1 - lgamma(s+1)) / c2 - 1.0;
363             if (x < 0) i = s+1;
364             else j = s;
365             }
366             s = i-1;
367             }
368             s = (s+3) >> 1;
369             /* Bornemann checks factors up to (s-1)^2, we check to max(r,s) */
370             /* slim = (s-1)*(s-1); */
371             slim = (r > s) ? r : s;
372             if (verbose > 1) printf("# aks trial to %lu\n", slim);
373             if (trial_factor(n, fac, 2, slim) > 1)
374             return 0;
375             if (slim >= HALF_WORD || (slim*slim) >= n)
376             return 1;
377             }
378             #endif
379             #if IMPL_BERN41
380             {
381             UV slim, fac[MPU_MAX_FACTORS+1];
382 4           double const log2n = log(n) / log(2);
383             /* Tuning: Initial 'r' selection. Search limit for 's'. */
384 4 50         double const r0 = ((log2n > 32) ? 0.010 : 0.003) * log2n * log2n;
385 4 50         UV const rmult = (log2n > 32) ? 6 : 30;
386              
387 4 100         r = next_prime(r0 < 2 ? 2 : (UV)r0); /* r must be at least 3 */
388 20 100         while ( !is_primitive_root(n,r,1) || !bern41_acceptable(n,r,rmult*(r-1)) )
    100          
389 16           r = next_prime(r);
390              
391             { /* Binary search for first s in [1,slim] where conditions met */
392 4           UV bi = 1;
393 4           UV bj = rmult * (r-1);
394 38 100         while (bi < bj) {
395 34           s = bi + (bj-bi)/2;
396 34 100         if (!bern41_acceptable(n, r, s)) bi = s+1;
397 23           else bj = s;
398             }
399 4           s = bj;
400 4 50         if (!bern41_acceptable(n, r, s)) croak("AKS: bad s selected");
401             /* S goes from 2 to s+1 */
402 4           starta = 2;
403 4           s = s+1;
404             }
405             /* Check divisibility to s * (s-1) to cover both gcd conditions */
406 4           slim = s * (s-1);
407 4 50         if (verbose > 1) printf("# aks trial to %lu\n", (unsigned long)slim);
408 4 100         if (trial_factor(n, fac, 2, slim) > 1)
409 4           return 0;
410 2 50         if (slim >= HALF_WORD || (slim*slim) >= n)
    50          
411 2           return 1;
412             /* Check b^(n-1) = 1 mod n for b in [2..s] */
413 0 0         for (a = 2; a <= s; a++) {
414 0 0         if (powmod(a, n-1, n) != 1)
415 0           return 0;
416             }
417             }
418             #endif
419              
420 0 0         if (verbose) { printf("# aks r = %lu s = %lu\n", (unsigned long) r, (unsigned long) s); }
421              
422             /* Almost every composite will get recognized by the first test.
423             * However, we need to run 's' tests to have the result proven for all n
424             * based on the theorems we have available at this time. */
425 0 0         for (a = starta; a <= s; a++) {
426 0 0         if (! test_anr(a, n, r) )
427 0           return 0;
428 0 0         if (verbose>1) { printf("."); fflush(stdout); }
429             }
430 0 0         if (verbose>1) { printf("\n"); }
431 0           return 1;
432             }