Author: William A. Stein
1##################################################
2# ent.py -- Element Number Theory
3# (c) William Stein, 2004
4##################################################
5
6
7
8
9from random import randrange
10from math import log, sqrt
11
12
13
14
15##################################################
16## Greatest Common Divisors
17##################################################
18
19def gcd(a, b):                                        # (1)
20    """
21    Returns the greatest commond divisor of a and b.
22    Input:
23        a -- an integer
24        b -- an integer
25    Output:
26        an integer, the gcd of a and b
27    Examples:
28    >>> gcd(97,100)
29    1
30    >>> gcd(97 * 10**15, 19**20 * 97**2)              # (2)
31    97L
32    """
33    if a < 0:  a = -a
34    if b < 0:  b = -b
35    if a == 0: return b
36    if b == 0: return a
37    while b != 0:
38        (a, b) = (b, a%b)
39    return a
40
41
42
43##################################################
44## Enumerating Primes
45##################################################
46
47def primes(n):
48    """
49    Returns a list of the primes up to n, computed
50    using the Sieve of Eratosthenes.
51    Input:
52        n -- a positive integer
53    Output:
54        list -- a list of the primes up to n
55    Examples:
56    >>> primes(10)
57    [2, 3, 5, 7]
58    >>> primes(45)
59    [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43]
60    """
61    if n <= 1: return []
62    X = [i for i in range(3,n+1) if i%2 != 0]     # (1)
63    P = [2]                                       # (2)
64    sqrt_n = sqrt(n)                              # (3)
65    while len(X) > 0 and X[0] <= sqrt_n:          # (4)
66        p = X[0]                                  # (5)
67        P.append(p)                               # (6)
68        X = [a for a in X if a%p != 0]            # (7)
69    return P + X                                  # (8)
70
71
72
73
74##################################################
75## Integer Factorization
76##################################################
77
78def trial_division(n, bound=None):
79    """
80    Return the smallest prime divisor <= bound of the
81    positive integer n, or n if there is no such prime.
82    If the optional argument bound is omitted, then bound=n.
83    Input:
84        n -- a positive integer
85        bound - (optional) a positive integer
86    Output:
87        int -- a prime p<=bound that divides n, or n if
88               there is no such prime.
89    Examples:
90    >>> trial_division(15)
91    3
92    >>> trial_division(91)
93    7
94    >>> trial_division(11)
95    11
96    >>> trial_division(387833, 300)
97    387833
98    >>> # 300 is not big enough to split off a
99    >>> # factor, but 400 is.
100    >>> trial_division(387833, 400)
101    389
102    """
103    if n == 1: return 1
104    for p in [2, 3, 5]:
105        if n%p == 0: return p
106    if bound == None: bound = n
107    dif = [6, 4, 2, 4, 2, 4, 6, 2]
108    m = 7; i = 1
109    while m <= bound and m*m <= n:
110        if n%m == 0:
111            return m
112        m += dif[i%8]
113        i += 1
114    return n
115
116def factor(n):
117    """
118    Returns the factorization of the integer n as
119    a sorted list of tuples (p,e), where the integers p
120    are output by the split algorithm.
121    Input:
122        n -- an integer
123    Output:
124        list -- factorization of n
125    Examples:
126    >>> factor(500)
127    [(2, 2), (5, 3)]
128    >>> factor(-20)
129    [(2, 2), (5, 1)]
130    >>> factor(1)
131    []
132    >>> factor(2004)
133    [(2, 2), (3, 1), (167, 1)]
134    """
135    if n in [-1, 0, 1]: return []
136    if n < 0: n = -n
137    F = []
138    while n != 1:
139        p = trial_division(n)
140        e = 1
141        n /= p
142        while n%p == 0:
143            e += 1; n /= p
144        F.append((p,e))
145    F.sort()
146    return F
147
148def is_squarefree(n):
149    """
150    Returns True if and only if n is not divisible by the square of an integer > 1.
151    """
152    if n==0:
153        return False
154    for p, r in factor(n):
155        if r>1:
156            return False
157    return True
158
159
160##################################################
161## Linear Equations Modulo $n$
162##################################################
163
164def xgcd(a, b):
165    """
166    Returns g, x, y such that g = x*a + y*b = gcd(a,b).
167    Input:
168        a -- an integer
169        b -- an integer
170    Output:
171        g -- an integer, the gcd of a and b
172        x -- an integer
173        y -- an integer
174    Examples:
175    >>> xgcd(2,3)
176    (1, -1, 1)
177    >>> xgcd(10, 12)
178    (2, -1, 1)
179    >>> g, x, y = xgcd(100, 2004)
180    >>> print g, x, y
181    4 -20 1
182    >>> print x*100 + y*2004
183    4
184    """
185    if a == 0 and b == 0: return (0, 0, 1)
186    if a == 0: return (abs(b), 0, b/abs(b))
187    if b == 0: return (abs(a), a/abs(a), 0)
188    x_sign = 1; y_sign = 1
189    if a < 0: a = -a; x_sign = -1
190    if b < 0: b = -b; y_sign = -1
191    x = 1; y = 0; r = 0; s = 1
192    while b != 0:
193        (c, q) = (a%b, a/b)
194        (a, b, r, s, x, y) = (b, c, x-q*r, y-q*s, r, s)
195    return (a, x*x_sign, y*y_sign)
196
197def inversemod(a, n):
198    """
199    Returns the inverse of a modulo n, normalized to
200    lie between 0 and n-1.  If a is not coprime to n,
201    raise an exception (this will be useful later for
202    the elliptic curve factorization method).
203    Input:
204        a -- an integer coprime to n
205        n -- a positive integer
206    Output:
207        an integer between 0 and n-1.
208    Examples:
209    >>> inversemod(1,1)
210    0
211    >>> inversemod(2,5)
212    3
213    >>> inversemod(5,8)
214    5
215    >>> inversemod(37,100)
216    73
217    """
218    g, x, y = xgcd(a, n)
219    if g != 1:
220        raise ZeroDivisionError, (a,n)
221    assert g == 1, "a must be coprime to n."
222    return x%n
223
224def solve_linear(a,b,n):
225    """
226    If the equation ax = b (mod n) has a solution, return a
227    solution normalized to lie between 0 and n-1, otherwise
228    returns None.
229    Input:
230        a -- an integer
231        b -- an integer
232        n -- an integer
233    Output:
234        an integer or None
235    Examples:
236    >>> solve_linear(4, 2, 10)
237    8
238    >>> solve_linear(2, 1, 4) == None
239    True
240    """
241    g, c, _ = xgcd(a,n)                 # (1)
242    if b%g != 0: return None
243    return ((b/g)*c) % n
244
245def crt(a, b, m, n):
246    """
247    Return the unique integer between 0 and m*n - 1
248    that reduces to a modulo n and b modulo m, where
249    the integers m and n are coprime.
250    Input:
251        a, b, m, n -- integers, with m and n coprime
252    Output:
253        int -- an integer between 0 and m*n - 1.
254    Examples:
255    >>> crt(1, 2, 3, 4)
256    10
257    >>> crt(4, 5, 10, 3)
258    14
259    >>> crt(-1, -1, 100, 101)
260    10099
261    """
262    g, c, _ = xgcd(m, n)
263    assert g == 1, "m and n must be coprime."
264    return (a + (b-a)*c*m) % (m*n)
265
266
267##################################################
268## Computation of Powers
269##################################################
270
271def powermod(a, m, n):
272    """
273    The m-th power of a modulo n.
274    Input:
275        a -- an integer
276        m -- a nonnegative integer
277        n -- a positive integer
278    Output:
279        int -- an integer between 0 and n-1
280    Examples:
281    >>> powermod(2,25,30)
282    2
283    >>> powermod(19,12345,100)
284    99
285    """
286    assert m >= 0, "m must be nonnegative."   # (1)
287    assert n >= 1, "n must be positive."      # (2)
288    ans = 1
289    apow = a
290    while m != 0:
291        if m%2 != 0:
292            ans = (ans * apow) % n            # (3)
293        apow = (apow * apow) % n              # (4)
294        m /= 2
295    return ans % n
296
297
298##################################################
299## Finding a Primitive Root
300##################################################
301
302def primitive_root(p):
303    """
304    Returns first primitive root modulo the prime p.
305    (If p is not prime, this return value of this function
306    is not meaningful.)
307    Input:
308        p -- an integer that is assumed prime
309    Output:
310        int -- a primitive root modulo p
311    Examples:
312    >>> primitive_root(7)
313    3
314    >>> primitive_root(389)
315    2
316    >>> primitive_root(5881)
317    31
318    """
319    if p == 2: return 1
320    F = factor(p-1)
321    a = 2
322    while a < p:
323        generates = True
324        for q, _ in F:
325            if powermod(a, (p-1)/q, p) == 1:
326                generates = False
327                break
328        if generates: return a
329        a += 1
330    assert False, "p must be prime."
331
332
333##################################################
334## Determining Whether a Number is Prime
335##################################################
336
337def is_pseudoprime(n, bases = [2,3,5,7]):
338    """
339    Returns True if n is a pseudoprime to the given bases,
340    in the sense that n>1 and b**(n-1) = 1 (mod n) for each
341    elements b of bases, with b not a multiple of n, and
342    False otherwise.
343    Input:
344        n -- an integer
345        bases -- a list of integers
346    Output:
347        bool
348    Examples:
349    >>> is_pseudoprime(91)
350    False
351    >>> is_pseudoprime(97)
352    True
353    >>> is_pseudoprime(1)
354    False
355    >>> is_pseudoprime(-2)
356    True
357    >>> s = [x for x in range(10000) if is_pseudoprime(x)]
358    >>> t = primes(10000)
359    >>> s == t
360    True
361    >>> is_pseudoprime(29341) # first non-prime pseudoprime
362    True
363    >>> factor(29341)
364    [(13, 1), (37, 1), (61, 1)]
365    """
366    if n < 0: n = -n
367    if n <= 1: return False
368    for b in bases:
369        if b%n != 0 and powermod(b, n-1, n) != 1:
370            return False
371    return True
372
373
374def miller_rabin(n, num_trials=4):
375    """
376    True if n is likely prime, and False if n
377    is definitely not prime.  Increasing num_trials
378    increases the probability of correctness.
379    (One can prove that the probability that this
380    function returns True when it should return
381    False is at most (1/4)**num_trials.)
382    Input:
383        n -- an integer
384        num_trials -- the number of trials with the
385                      primality test.
386    Output:
387        bool -- whether or not n is probably prime.
388    Examples:
389    >>> miller_rabin(91)
390    False                         #rand
391    >>> miller_rabin(97)
392    True                          #rand
393    >>> s = [x for x in range(1000) if miller_rabin(x, 1)]
394    >>> t = primes(1000)
395    >>> print len(s), len(t)  # so 1 in 25 wrong
396    175 168                       #rand
397    >>> s = [x for x in range(1000) if miller_rabin(x)]
398    >>> s == t
399    True                          #rand
400    """
401    if n < 0: n = -n
402    if n in [2,3]: return True
403    if n <= 4: return False
404    m = n - 1
405    k = 0
406    while m%2 == 0:
407        k += 1; m /= 2
408    # Now n - 1 = (2**k) * m with m odd
409    for i in range(num_trials):
410        a = randrange(2,n-1)                  # (1)
411        apow = powermod(a, m, n)
412        if not (apow in [1, n-1]):
413            some_minus_one = False
414            for r in range(k-1):              # (2)
415                apow = (apow**2)%n
416                if apow == n-1:
417                    some_minus_one = True
418                    break                     # (3)
419        if (apow in [1, n-1]) or some_minus_one:
420            prob_prime = True
421        else:
422            return False
423    return True
424
425
426##################################################
427## The Diffie-Hellman Key Exchange
428##################################################
429
430def random_prime(num_digits, is_prime = miller_rabin):
431    """
432    Returns a random prime with num_digits digits.
433    Input:
434        num_digits -- a positive integer
435        is_prime -- (optional argment)
436                    a function of one argument n that
437                    returns either True if n is (probably)
438                    prime and False otherwise.
439    Output:
440        int -- an integer
441    Examples:
442    >>> random_prime(10)
443    8599796717L              #rand
444    >>> random_prime(40)
445    1311696770583281776596904119734399028761L  #rand
446    """
447    n = randrange(10**(num_digits-1), 10**num_digits)
448    if n%2 == 0: n += 1
449    while not is_prime(n): n += 2
450    return n
451
452def dh_init(p):
453    """
454    Generates and returns a random positive
455    integer n < p and the power 2^n (mod p).
456    Input:
457        p -- an integer that is prime
458    Output:
459        int -- a positive integer < p,  a secret
460        int -- 2^n (mod p), send to other user
461    Examples:
462    >>> p = random_prime(20)
463    >>> dh_init(p)
464    (15299007531923218813L, 4715333264598442112L)   #rand
465    """
466    n = randrange(2,p)
467    return n, powermod(2,n,p)
468
469def dh_secret(p, n, mpow):
470    """
471    Computes the shared Diffie-Hellman secret key.
472    Input:
473        p -- an integer that is prime
474        n -- an integer: output by dh_init for this user
475        mpow-- an integer: output by dh_init for other user
476    Output:
477        int -- the shared secret key.
478    Examples:
479    >>> p = random_prime(20)
480    >>> n, npow = dh_init(p)
481    >>> m, mpow = dh_init(p)
482    >>> dh_secret(p, n, mpow)
483    15695503407570180188L      #rand
484    >>> dh_secret(p, m, npow)
485    15695503407570180188L      #rand
486    """
487    return powermod(mpow,n,p)
488
489
490
491
492
493
494##################################################
495## Encoding Strings as Lists of Integers
496##################################################
497
498def str_to_numlist(s, bound):
499    """
500    Returns a sequence of integers between 0 and bound-1
501    that encodes the string s.   Randomization is included,
502    so the same string is very likely to encode differently
503    each time this function is called.
504    Input:
505        s -- a string
506        bound -- an integer >= 256
507    Output:
508        list -- encoding of s as a list of integers
509    Examples:
510    >>> str_to_numlist("Run!", 1000)
511    [82, 117, 110, 33]               #rand
512    >>> str_to_numlist("TOP SECRET", 10**20)
513    [4995371940984439512L, 92656709616492L]   #rand
514    """
515    assert bound >= 256, "bound must be at least 256."
516    n = int(log(bound) / log(256))          # (1)
517    salt = min(int(n/8) + 1, n-1)           # (2)
518    i = 0; v = []
519    while i < len(s):                       # (3)
520        c = 0; pow = 1
521        for j in range(n):                  # (4)
522            if j < salt:
523                c += randrange(1,256)*pow   # (5)
524            else:
525                if i >= len(s): break
526                c += ord(s[i])*pow          # (6)
527                i += 1
528            pow *= 256
529        v.append(c)
530    return v
531
532def numlist_to_str(v, bound):
533    """
534    Returns the string that the sequence v of
535    integers encodes.
536    Input:
537        v -- list of integers between 0 and bound-1
538        bound -- an integer >= 256
539    Output:
540        str -- decoding of v as a string
541    Examples:
542    >>> print numlist_to_str([82, 117, 110, 33], 1000)
543    Run!
544    >>> x = str_to_numlist("TOP SECRET MESSAGE", 10**20)
545    >>> print numlist_to_str(x, 10**20)
546    TOP SECRET MESSAGE
547    """
548    assert bound >= 256, "bound must be at least 256."
549    n = int(log(bound) / log(256))
550    s = ""
551    salt = min(int(n/8) + 1, n-1)
552    for x in v:
553        for j in range(n):
554            y = x%256
555            if y > 0 and j >= salt:
556                s += chr(y)
557            x /= 256
558    return s
559
560
561##################################################
562## The RSA Cryptosystem
563##################################################
564
565def rsa_init(p, q):
566    """
567    Returns defining parameters (e, d, n) for the RSA
568    cryptosystem defined by primes p and q.  The
569    primes p and q may be computed using the
570    random_prime functions.
571    Input:
572        p -- a prime integer
573        q -- a prime integer
574    Output:
575        Let m be (p-1)*(q-1).
576        e -- an encryption key, which is a randomly
577             chosen integer between 2 and m-1
578        d -- the inverse of e modulo eulerphi(p*q),
579             as an integer between 2 and m-1
580        n -- the product p*q.
581    Examples:
582    >>> p = random_prime(20); q = random_prime(20)
583    >>> print p, q
584    37999414403893878907L 25910385856444296437L #rand
585    >>> e, d, n = rsa_init(p, q)
586    >>> e
587    5                                           #rand
588    >>> d
589    787663591619054108576589014764921103213L    #rand
590    >>> n
591    984579489523817635784646068716489554359L    #rand
592    """
593    m = (p-1)*(q-1)
594    e = 3
595    while gcd(e, m) != 1: e += 1
596    d = inversemod(e, m)
597    return e, d, p*q
598
599def rsa_encrypt(plain_text, e, n):
600    """
601    Encrypt plain_text using the encrypt
602    exponent e and modulus n.
603    Input:
604        plain_text -- arbitrary string
605        e -- an integer, the encryption exponent
606        n -- an integer, the modulus
607    Output:
608        str -- the encrypted cipher text
609    Examples:
610    >>> e = 1413636032234706267861856804566528506075
611    >>> n = 2109029637390047474920932660992586706589
612    >>> rsa_encrypt("Run Nikita!", e, n)
613    [78151883112572478169375308975376279129L]    #rand
614    >>> rsa_encrypt("Run Nikita!", e, n)
615    [1136438061748322881798487546474756875373L]  #rand
616    """
617    plain = str_to_numlist(plain_text, n)
618    return [powermod(x, e, n) for x in plain]
619
620def rsa_decrypt(cipher, d, n):
621    """
622    Decrypt the cipher_text using the decryption
623    exponent d and modulus n.
624    Input:
625        cipher_text -- list of integers output
626                       by rsa_encrypt
627    Output:
628        str -- the unencrypted plain text
629    Examples:
630    >>> d = 938164637865370078346033914094246201579
631    >>> n = 2109029637390047474920932660992586706589
632    >>> msg1 = [1071099761433836971832061585353925961069]
633    >>> msg2 = [1336506586627416245118258421225335020977]
634    >>> rsa_decrypt(msg1, d, n)
635    'Run Nikita!'
636    >>> rsa_decrypt(msg2, d, n)
637    'Run Nikita!'
638    """
639    plain = [powermod(x, d, n) for x in cipher]
640    return numlist_to_str(plain, n)
641
642
643##################################################
644## Computing the Legendre Symbol
645##################################################
646
647def legendre(a, p):
648    """
649    Returns the Legendre symbol a over p, where
650    p is an odd prime.
651    Input:
652        a -- an integer
653        p -- an odd prime (primality not checked)
654    Output:
655        int: -1 if a is not a square mod p,
656              0 if gcd(a,p) is not 1
657              1 if a is a square mod p.
658    Examples:
659    >>> legendre(2, 5)
660    -1
661    >>> legendre(3, 3)
662    0
663    >>> legendre(7, 2003)
664    -1
665    """
666    assert p%2 == 1, "p must be an odd prime."
667    b = powermod(a, (p-1)/2, p)
668    if b == 1: return 1
669    elif b == p-1: return -1
670    return 0
671
672
673##################################################
674## In this section we implement the algorithm
675##################################################
676
677def sqrtmod(a, p):
678    """
679    Returns a square root of a modulo p.
680    Input:
681        a -- an integer that is a perfect
682             square modulo p (this is checked)
683        p -- a prime
684    Output:
685        int -- a square root of a, as an integer
686               between 0 and p-1.
687    Examples:
688    >>> sqrtmod(4, 5)              # p == 1 (mod 4)
689    3              #rand
690    >>> sqrtmod(13, 23)            # p == 3 (mod 4)
691    6              #rand
692    >>> sqrtmod(997, 7304723089)   # p == 1 (mod 4)
693    761044645L     #rand
694    """
695    a %= p
696    if p == 2: return a
697    assert legendre(a, p) == 1, "a must be a square mod p."
698    if p%4 == 3: return powermod(a, (p+1)/4, p)
699
700    def mul(x, y):   # multiplication in R       # (1)
701        return ((x[0]*y[0] + a*y[1]*x[1]) % p, \
702                (x[0]*y[1] + x[1]*y[0]) % p)
703    def pow(x, n):   # exponentiation in R       # (2)
704        ans = (1,0)
705        xpow = x
706        while n != 0:
707           if n%2 != 0: ans = mul(ans, xpow)
708           xpow = mul(xpow, xpow)
709           n /= 2
710        return ans
711
712    while True:
713        z = randrange(2,p)
714        u, v = pow((1,z), (p-1)/2)
715        if v != 0:
716            vinv = inversemod(v, p)
717            for x in [-u*vinv, (1-u)*vinv, (-1-u)*vinv]:
718                if (x*x)%p == a: return x%p
719            assert False, "Bug in sqrtmod."
720
721
722##################################################
723## Continued Fractions
724##################################################
725
726def convergents(v):
727    """
728    Returns the partial convergents of the continued
729    fraction v.
730    Input:
731        v -- list of integers [a0, a1, a2, ..., am]
732    Output:
733        list -- list [(p0,q0), (p1,q1), ...]
734                of pairs (pm,qm) such that the mth
735                convergent of v is pm/qm.
736    Examples:
737    >>> convergents([1, 2])
738    [(1, 1), (3, 2)]
739    >>> convergents([3, 7, 15, 1, 292])
740    [(3, 1), (22, 7), (333, 106), (355, 113), (103993, 33102)]
741    """
742    w = [(0,1), (1,0)]
743    for n in range(len(v)):
744        pn = v[n]*w[n+1][0] + w[n][0]
745        qn = v[n]*w[n+1][1] + w[n][1]
746        w.append((pn, qn))
747    del w[0]; del w[0]  # remove first entries of w
748    return w
749
750def contfrac_rat(numer, denom):
751    """
752    Returns the continued fraction of the rational
753    number numer/denom.
754    Input:
755        numer -- an integer
756        denom -- a positive integer coprime to num
757    Output
758        list -- the continued fraction [a0, a1, ..., am]
759                of the rational number num/denom.
760    Examples:
761    >>> contfrac_rat(3, 2)
762    [1, 2]
763    >>> contfrac_rat(103993, 33102)
764    [3, 7, 15, 1, 292]
765    """
766    assert denom > 0, "denom must be positive"
767    a = numer; b = denom
768    v = []
769    while b != 0:
770        v.append(a/b)
771        (a, b) = (b, a%b)
772    return v
773
774def contfrac_float(x):
775    """
776    Returns the continued fraction of the floating
777    point number x, computed using the continued
778    fraction procedure, and the sequence of partial
779    convergents.
780    Input:
781        x -- a floating point number (decimal)
782    Output:
783        list -- the continued fraction [a0, a1, ...]
784                obtained by applying the continued
785                fraction procedure to x to the
786                precision of this computer.
787        list -- the list [(p0,q0), (p1,q1), ...]
788                of pairs (pm,qm) such that the mth
789                convergent of continued fraction
790                is pm/qm.
791    Examples:
792    >>> v, w = contfrac_float(3.14159); print v
793    [3, 7, 15, 1, 25, 1, 7, 4]
794    >>> v, w = contfrac_float(2.718); print v
795    [2, 1, 2, 1, 1, 4, 1, 12]
796    >>> contfrac_float(0.3)
797    ([0, 3, 2, 1], [(0, 1), (1, 3), (2, 7), (3, 10)])
798    """
799    v = []
800    w = [(0,1), (1,0)] # keep track of convergents
801    start = x
802    while True:
803        a = int(x)                                  # (1)
804        v.append(a)
805        n = len(v)-1
806        pn = v[n]*w[n+1][0] + w[n][0]
807        qn = v[n]*w[n+1][1] + w[n][1]
808        w.append((pn, qn))
809        x -= a
810        if abs(start - float(pn)/float(qn)) == 0:    # (2)
811            del w[0]; del w[0]                       # (3)
812            return v, w
813        x = 1/x
814
815def sum_of_two_squares(p):
816    """
817    Uses continued fractions to efficiently compute
818    a representation of the prime p as a sum of
819    two squares.   The prime p must be 1 modulo 4.
820    Input:
821        p -- a prime congruent 1 modulo 4.
822    Output:
823        integers a, b such that p is a*a + b*b
824    Examples:
825    >>> sum_of_two_squares(5)
826    (1, 2)
827    >>> sum_of_two_squares(389)
828    (10, 17)
829    >>> sum_of_two_squares(86295641057493119033)
830    (789006548L, 9255976973L)
831    """
832    assert p%4 == 1, "p must be 1 modulo 4"
833    r = sqrtmod(-1, p)                                # (1)
834    v = contfrac_rat(-r, p)                           # (2)
835    n = int(sqrt(p))
836    for a, b in convergents(v):                       # (3)
837        c = r*b + p*a                                 # (4)
838        if -n <= c and c <= n: return (abs(b),abs(c))
839    assert False, "Bug in sum_of_two_squares."        # (5)
840
841
842##################################################
843## Arithmetic
844##################################################
845
847    """
848    Returns the sum of P1 and P2 on the elliptic
849    curve E.
850    Input:
851         E -- an elliptic curve over Z/pZ, given by a
852              triple of integers (a, b, p), with p odd.
853         P1 --a pair of integers (x, y) or the
854              string "Identity".
855         P2 -- same type as P1
856    Output:
857         R -- same type as P1
858    Examples:
859    >>> E = (1, 0, 7)   # y**2 = x**3 + x over Z/7Z
860    >>> P1 = (1, 3); P2 = (3, 3)
862    (3, 4)
863    >>> ellcurve_add(E, P1, (1, 4))
864    'Identity'
866    (3, 3)
867    """
868    a, b, p = E
869    assert p > 2, "p must be odd."
870    if P1 == "Identity": return P2
871    if P2 == "Identity": return P1
872    x1, y1 = P1; x2, y2 = P2
873    x1 %= p; y1 %= p; x2 %= p; y2 %= p
874    if x1 == x2 and y1 == p-y2: return "Identity"
875    if P1 == P2:
876        if y1 == 0: return "Identity"
877        lam = (3*x1**2+a) * inversemod(2*y1,p)
878    else:
879        lam = (y1 - y2) * inversemod(x1 - x2, p)
880    x3 = lam**2 - x1 - x2
881    y3 = -lam*x3 - y1 + lam*x1
882    return (x3%p, y3%p)
883
884def ellcurve_mul(E, m, P):
885    """
886    Returns the multiple m*P of the point P on
887    the elliptic curve E.
888    Input:
889        E -- an elliptic curve over Z/pZ, given by a
890             triple (a, b, p).
891        m -- an integer
892        P -- a pair of integers (x, y) or the
893             string "Identity"
894    Output:
895        A pair of integers or the string "Identity".
896    Examples:
897    >>> E = (1, 0, 7)
898    >>> P = (1, 3)
899    >>> ellcurve_mul(E, 5, P)
900    (1, 3)
901    >>> ellcurve_mul(E, 9999, P)
902    (1, 4)
903    """
904    assert m >= 0, "m must be nonnegative."
905    power = P
906    mP = "Identity"
907    while m != 0:
908        if m%2 != 0: mP = ellcurve_add(E, mP, power)
909        power = ellcurve_add(E, power, power)
910        m /= 2
911    return mP
912
913
914##################################################
915## Integer Factorization
916##################################################
917
918def lcm_to(B):
919    """
920    Returns the least common multiple of all
921    integers up to B.
922    Input:
923        B -- an integer
924    Output:
925        an integer
926    Examples:
927    >>> lcm_to(5)
928    60
929    >>> lcm_to(20)
930    232792560
931    >>> lcm_to(100)
932    69720375229712477164533808935312303556800L
933    """
934    ans = 1
935    logB = log(B)
936    for p in primes(B):
937        ans *= p**int(logB/log(p))
938    return ans
939
940def pollard(N, m):
941    """
942    Use Pollard's (p-1)-method to try to find a
943    nontrivial divisor of N.
944    Input:
945        N -- a positive integer
946        m -- a positive integer, the least common
947             multiple of the integers up to some
948             bound, computed using lcm_to.
949    Output:
950        int -- an integer divisor of n
951    Examples:
952    >>> pollard(5917, lcm_to(5))
953    61
954    >>> pollard(779167, lcm_to(5))
955    779167
956    >>> pollard(779167, lcm_to(15))
957    2003L
958    >>> pollard(187, lcm_to(15))
959    11
960    >>> n = random_prime(5)*random_prime(5)*random_prime(5)
961    >>> pollard(n, lcm_to(100))
962    315873129119929L     #rand
963    >>> pollard(n, lcm_to(1000))
964    3672986071L          #rand
965    """
966    for a in [2, 3]:
967        x = powermod(a, m, N) - 1
968        g = gcd(x, N)
969        if g != 1 and g != N:
970            return g
971    return N
972
973def randcurve(p):
974    """
975    Construct a somewhat random elliptic curve
976    over Z/pZ and a random point on that curve.
977    Input:
978        p -- a positive integer
979    Output:
980        tuple -- a triple E = (a, b, p)
981        P -- a tuple (x,y) on E
982    Examples:
983    >>> p = random_prime(20); p
984    17758176404715800329L    #rand
985    >>> E, P = randcurve(p)
986    >>> print E
987    (15299007531923218813L, 1, 17758176404715800329L)  #rand
988    >>> print P
989    (0, 1)
990    """
991    assert p > 2, "p must be > 2."
992    a = randrange(p)
993    while gcd(4*a**3 + 27, p) != 1:
994        a = randrange(p)
995    return (a, 1, p), (0,1)
996
997def elliptic_curve_method(N, m, tries=5):
998    """
999    Use the elliptic curve method to try to find a
1000    nontrivial divisor of N.
1001    Input:
1002        N -- a positive integer
1003        m -- a positive integer, the least common
1004             multiple of the integers up to some
1005             bound, computed using lcm_to.
1006        tries -- a positive integer, the number of
1007             different elliptic curves to try
1008    Output:
1009        int -- a divisor of n
1010    Examples:
1011    >>> elliptic_curve_method(5959, lcm_to(20))
1012    59L       #rand
1013    >>> elliptic_curve_method(10007*20011, lcm_to(100))
1014    10007L   #rand
1015    >>> p = random_prime(9); q = random_prime(9)
1016    >>> n = p*q; n
1017    117775675640754751L   #rand
1018    >>> elliptic_curve_method(n, lcm_to(100))
1019    117775675640754751L   #rand
1020    >>> elliptic_curve_method(n, lcm_to(500))
1021    117775675640754751L   #rand
1022    """
1023    for _ in range(tries):                     # (1)
1024        E, P = randcurve(N)                    # (2)
1025        try:                                   # (3)
1026            Q = ellcurve_mul(E, m, P)          # (4)
1027        except ZeroDivisionError, x:           # (5)
1028            g = gcd(x[0],N)                    # (6)
1029            if g != 1 or g != N: return g      # (7)
1030    return N
1031
1032
1033##################################################
1034## ElGamal Elliptic Curve Cryptosystem
1035##################################################
1036
1037def elgamal_init(p):
1038    """
1039    Constructs an ElGamal cryptosystem over Z/pZ, by
1040    choosing a random elliptic curve E over Z/pZ, a
1041    point B in E(Z/pZ), and a random integer n.  This
1042    function returns the public key as a 4-tuple
1043    (E, B, n*B) and the private key n.
1044    Input:
1045        p -- a prime number
1046    Output:
1047        tuple -- the public key as a 3-tuple
1048                 (E, B, n*B), where E = (a, b, p) is an
1049                 elliptic curve over Z/pZ, B = (x, y) is
1050                 a point on E, and n*B = (x',y') is
1051                 the sum of B with itself n times.
1052        int -- the private key, which is the pair (E, n)
1053    Examples:
1054    >>> p = random_prime(20); p
1055    17758176404715800329L    #rand
1056    >>> public, private = elgamal_init(p)
1057    >>> print "E =", public[0]
1058    E = (15299007531923218813L, 1, 17758176404715800329L)   #rand
1059    >>> print "B =", public[1]
1060    B = (0, 1)
1061    >>> print "nB =", public[2]
1062    nB = (5619048157825840473L, 151469105238517573L)   #rand
1063    >>> print "n =", private[1]
1064    n = 12608319787599446459    #rand
1065    """
1066    E, B = randcurve(p)
1067    n = randrange(2,p)
1068    nB = ellcurve_mul(E, n, B)
1069    return (E, B, nB), (E, n)
1070
1071def elgamal_encrypt(plain_text, public_key):
1072    """
1073    Encrypt a message using the ElGamal cryptosystem
1074    with given public_key = (E, B, n*B).
1075    Input:
1076       plain_text -- a string
1077       public_key -- a triple (E, B, n*B), as output
1078                     by elgamal_init.
1079    Output:
1080       list -- a list of pairs of points on E that
1081               represent the encrypted message
1082    Examples:
1083    >>> public, private = elgamal_init(random_prime(20))
1084    >>> elgamal_encrypt("RUN", public)
1085    [((6004308617723068486L, 15578511190582849677L), \ #rand
1086     (7064405129585539806L, 8318592816457841619L))]    #rand
1087    """
1088    E, B, nB = public_key
1089    a, b, p = E
1090    assert p > 10000, "p must be at least 10000."
1091    v = [1000*x for x in \
1092           str_to_numlist(plain_text, p/1000)]       # (1)
1093    cipher = []
1094    for x in v:
1095        while not legendre(x**3+a*x+b, p)==1:        # (2)
1096            x = (x+1)%p
1097        y = sqrtmod(x**3+a*x+b, p)                   # (3)
1098        P = (x,y)
1099        r = randrange(1,p)
1100        encrypted = (ellcurve_mul(E, r, B), \
1102        cipher.append(encrypted)
1103    return cipher
1104
1105def elgamal_decrypt(cipher_text, private_key):
1106    """
1107    Encrypt a message using the ElGamal cryptosystem
1108    with given public_key = (E, B, n*B).
1109    Input:
1110        cipher_text -- list of pairs of points on E output
1111                       by elgamal_encrypt.
1112    Output:
1113        str -- the unencrypted plain text
1114    Examples:
1115    >>> public, private = elgamal_init(random_prime(20))
1116    >>> v = elgamal_encrypt("TOP SECRET MESSAGE!", public)
1117    >>> print elgamal_decrypt(v, private)
1118    TOP SECRET MESSAGE!
1119    """
1120    E, n = private_key
1121    p = E[2]
1122    plain = []
1123    for rB, P_plus_rnB in cipher_text:
1124        nrB = ellcurve_mul(E, n, rB)
1125        minus_nrB = (nrB[0], -nrB[1])
1126        P = ellcurve_add(E, minus_nrB, P_plus_rnB)
1127        plain.append(P[0]/1000)
1128    return numlist_to_str(plain, p/1000)
1129
1130
1131##################################################
1132## Associativity of the Group Law
1133##################################################
1134
1135# The variable order is x1, x2, x3, y1, y2, y3, a, b
1136class Poly:                                     # (1)
1137    def __init__(self, d):                      # (2)
1138        self.v = dict(d)
1139    def __cmp__(self, other):                   # (3)
1140        self.normalize(); other.normalize()     # (4)
1141        if self.v == other.v: return 0
1142        return -1
1143
1144    def __add__(self, other):                   # (5)
1145        w = Poly(self.v)
1146        for m in other.monomials():
1147            w[m] += other[m]
1148        return w
1149    def __sub__(self, other):
1150        w = Poly(self.v)
1151        for m in other.monomials():
1152            w[m] -= other[m]
1153        return w
1154    def __mul__(self, other):
1155        if len(self.v) == 0 or len(other.v) == 0:
1156            return Poly([])
1157        m1 = self.monomials(); m2 = other.monomials()
1158        r = Poly([])
1159        for m1 in self.monomials():
1160            for m2 in other.monomials():
1161                z = [m1[i] + m2[i] for i in range(8)]
1162                r[z] += self[m1]*other[m2]
1163        return r
1164    def __neg__(self):
1165        v = {}
1166        for m in self.v.keys():
1167            v[m] = -self.v[m]
1168        return Poly(v)
1169    def __div__(self, other):
1170        return Frac(self, other)
1171
1172    def __getitem__(self, m):                   # (6)
1173        m = tuple(m)
1174        if not self.v.has_key(m): self.v[m] = 0
1175        return self.v[m]
1176    def __setitem__(self, m, c):
1177        self.v[tuple(m)] = c
1178    def __delitem__(self, m):
1179        del self.v[tuple(m)]
1180
1181    def monomials(self):                        # (7)
1182        return self.v.keys()
1183    def normalize(self):                        # (8)
1184        while True:
1185            finished = True
1186            for m in self.monomials():
1187                if self[m] == 0:
1188                    del self[m]
1189                    continue
1190                for i in range(3):
1191                    if m[3+i] >= 2:
1192                        finished = False
1193                        nx0 = list(m); nx0[3+i] -= 2;
1194                        nx0[7] += 1
1195                        nx1 = list(m); nx1[3+i] -= 2;
1196                        nx1[i] += 1; nx1[6] += 1
1197                        nx3 = list(m); nx3[3+i] -= 2;
1198                        nx3[i] += 3
1199                        c = self[m]
1200                        del self[m]
1201                        self[nx0] += c;
1202                        self[nx1] += c;
1203                        self[nx3] += c
1204                # end for
1205            # end for
1206            if finished: return
1207        # end while
1208
1209one = Poly({(0,0,0,0,0,0,0,0):1})               # (9)
1210
1211class Frac:                                     # (10)
1212    def __init__(self, num, denom=one):
1213        self.num = num; self.denom = denom
1214    def __cmp__(self, other):                   # (11)
1215        if self.num * other.denom == self.denom * other.num:
1216            return 0
1217        return -1
1218
1219    def __add__(self, other):                   # (12)
1220        return Frac(self.num*other.denom + \
1221                    self.denom*other.num,
1222                    self.denom*other.denom)
1223    def __sub__(self, other):
1224        return Frac(self.num*other.denom - \
1225                    self.denom*other.num,
1226                    self.denom*other.denom)
1227    def __mul__(self, other):
1228        return Frac(self.num*other.num, \
1229                    self.denom*other.denom)
1230    def __div__(self, other):
1231        return Frac(self.num*other.denom, \
1232                    self.denom*other.num)
1233    def __neg__(self):
1234        return Frac(-self.num,self.denom)
1235
1236def var(i):                                     # (14)
1237    v = [0,0,0,0,0,0,0,0]; v[i]=1;
1238    return Frac(Poly({tuple(v):1}))
1239
1240def prove_associative():                        # (15)
1241    x1 = var(0); x2 = var(1); x3 = var(2)
1242    y1 = var(3); y2 = var(4); y3 = var(5)
1243    a  = var(6); b  = var(7)
1244
1245    lambda12 = (y1 - y2)/(x1 - x2)
1246    x4       = lambda12*lambda12 - x1 - x2
1247    nu12     = y1 - lambda12*x1
1248    y4       = -lambda12*x4 - nu12
1249    lambda23 = (y2 - y3)/(x2 - x3)
1250    x5       = lambda23*lambda23 - x2 - x3
1251    nu23     = y2 - lambda23*x2
1252    y5       = -lambda23*x5 - nu23
1253    s1 = (x1 - x5)*(x1 - x5)*((y3 - y4)*(y3 - y4) \
1254                   - (x3 + x4)*(x3 - x4)*(x3 - x4))
1255    s2 = (x3 - x4)*(x3 - x4)*((y1 - y5)*(y1 - y5) \
1256                   - (x1 + x5)*(x1 - x5)*(x1 - x5))
1257    print "Associative?"
1258    print s1 == s2                              # (17)
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274##########################################################
1275# The following are all the examples not in functions.   #
1276##########################################################
1277
1278def examples():
1279    """
1280    >>> from ent import *
1281    >>> 7/5
1282    1
1283    >>> -2/3
1284    -1
1285    >>> 1.0/3
1286    0.33333333333333331
1287    >>> float(2)/3
1288    0.66666666666666663
1289    >>> 100**2
1290    10000
1291    >>> 10**20
1292    100000000000000000000L
1293    >>> range(10)            # range(n) is from 0 to n-1
1294    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1295    >>> range(3,10)          # range(a,b) is from a to b-1
1296    [3, 4, 5, 6, 7, 8, 9]
1297    >>> [x**2 for x in range(10)]
1298    [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
1299    >>> [x**2 for x in range(10) if x%4 == 1]
1300    [1, 25, 81]
1301    >>> [1,2,3] + [5,6,7]    # concatenation
1302    [1, 2, 3, 5, 6, 7]
1303    >>> len([1,2,3,4,5])     # length of a list
1304    5
1305    >>> x = [4,7,10,'gcd']   # mixing types is fine
1306    >>> x[0]                 # 0-based indexing
1307    4
1308    >>> x[3]
1309    'gcd'
1310    >>> x[3] = 'lagrange'    # assignment
1311    >>> x.append("fermat")   # append to end of list
1312    >>> x
1313    [4, 7, 10, 'lagrange', 'fermat']
1314    >>> del x[3]             # delete entry 3 from list
1315    >>> x
1316    [4, 7, 10, 'fermat']
1317    >>> v = primes(10000)
1318    >>> len(v)    # this is pi(10000)
1319    1229
1320    >>> len([x for x in v if x < 1000])   # pi(1000)
1321    168
1322    >>> len([x for x in v if x < 5000])   # pi(5000)
1323    669
1324    >>> x=(1, 2, 3)       # creation
1325    >>> x[1]
1326    2
1327    >>> (1, 2, 3) + (4, 5, 6)  # concatenation
1328    (1, 2, 3, 4, 5, 6)
1329    >>> (a, b) = (1, 2)        # assignment assigns to each member
1330    >>> print a, b
1331    1 2
1332    >>> for (c, d) in [(1,2), (5,6)]:
1333    ...     print c, d
1334    1 2
1335    5 6
1336    >>> x = 1, 2          # parentheses optional in creation
1337    >>> x
1338    (1, 2)
1339    >>> c, d = x          # parentheses also optional
1340    >>> print c, d
1341    1 2
1342    >>> P = [p for p in range(200000) if is_pseudoprime(p)]
1343    >>> Q = primes(200000)
1344    >>> R = [x for x in P if not (x in Q)]; print R
1345    [29341, 46657, 75361, 115921, 162401]
1346    >>> [n for n in R if is_pseudoprime(n,[2,3,5,7,11,13])]
1347    [162401]
1348    >>> factor(162401)
1349    [(17, 1), (41, 1), (233, 1)]
1350    >>> p = random_prime(50)
1351    >>> p
1352    13537669335668960267902317758600526039222634416221L #rand
1353    >>> n, npow = dh_init(p)
1354    >>> n
1355    8520467863827253595224582066095474547602956490963L  #rand
1356    >>> npow
1357    3206478875002439975737792666147199399141965887602L  #rand
1358    >>> m, mpow = dh_init(p)
1359    >>> m
1360    3533715181946048754332697897996834077726943413544L  #rand
1361    >>> mpow
1362    3465862701820513569217254081716392362462604355024L  #rand
1363    >>> dh_secret(p, n, mpow)
1364    12931853037327712933053975672241775629043437267478L #rand
1365    >>> dh_secret(p, m, npow)
1366    12931853037327712933053975672241775629043437267478L #rand
1367    >>> prove_associative()
1368    Associative?
1369    True
1370    >>> len(primes(10000))
1371    1229
1372    >>> 10000/log(10000)
1373    1085.73620476
1374    >>> powermod(3,45,100)
1375    43
1376    >>> inversemod(37, 112)
1377    109
1378    >>> powermod(102, 70, 113)
1379    98
1380    >>> powermod(99, 109, 113)
1381    60
1382    >>> P = primes(1000)
1383    >>> Q = [p for p in P if primitive_root(p) == 2]
1384    >>> print len(Q), len(P)
1385    67 168
1386    >>> P = primes(50000)
1387    >>> Q = [primitive_root(p) for p in P]
1388    >>> Q.index(37)
1389    3893
1390    >>> P[3893]
1391    36721
1392    >>> for n in range(97):
1393    ...     if powermod(5,n,97)==3: print n
1394    70
1395    >>> factor(5352381469067)
1396    [(141307, 1), (37877681L, 1)]
1397    >>> d=inversemod(4240501142039, (141307-1)*(37877681-1))
1398    >>> d
1399    5195621988839L
1400    >>> convergents([-3,1,1,1,1,3])
1401    [(-3, 1), (-2, 1), (-5, 2), (-7, 3), \
1402              (-12, 5), (-43, 18)]
1403    >>> convergents([0,2,4,1,8,2])
1404    [(0, 1), (1, 2), (4, 9), (5, 11), \
1405              (44, 97), (93, 205)]
1406    >>> import math
1407    >>> e = math.exp(1)
1408    >>> v, convs = contfrac_float(e)
1409    >>> [(a,b) for a, b in convs if \
1410           abs(e - a*1.0/b) < 1/(math.sqrt(5)*b**2)]
1411    [(3, 1), (19, 7), (193, 71), (2721, 1001),\
1412     (49171, 18089), (1084483, 398959),\
1413     (28245729, 10391023), (325368125, 119696244)]
1414    >>> factor(12345)
1415    [(3, 1), (5, 1), (823, 1)]
1416    >>> factor(729)
1417    [(3, 6)]
1418    >>> factor(5809961789)
1419    [(5809961789L, 1)]
1420    >>> 5809961789 % 4
1421    1L
1422    >>> sum_of_two_squares(5809961789)
1423    (51542L, 56155L)
1424    >>> N = [60 + s for s in range(-15,16)]
1425    >>> def is_powersmooth(B, x):
1426    ...     for p, e in factor(x):
1427    ...         if p**e > B: return False
1428    ...     return True
1429    >>> Ns = [x for x in N if is_powersmooth(20, x)]
1430    >>> print len(Ns), len(N), len(Ns)*1.0/len(N)
1431    14 31 0.451612903226
1432    >>> P = [x for x in range(10**12, 10**12+1000)\
1433             if miller_rabin(x)]
1434    >>> Ps = [x for x in P if \
1435             is_powersmooth(10000, x-1)]
1436    >>> print len(Ps), len(P), len(Ps)*1.0/len(P)
1437    2 37 0.0540540540541
1438
1439    """
1440
1441
1442if __name__ ==  '__main__':
1443    import doctest, sys
1444    doctest.testmod(sys.modules[__name__])
1445