Sharedwww / ent / ent.pyOpen in CoCalc
Author: William A. Stein
1
##################################################
2
# ent.py -- Element Number Theory
3
# (c) William Stein, 2004
4
##################################################
5
6
7
8
9
from random import randrange
10
from math import log, sqrt
11
12
13
14
15
##################################################
16
## Greatest Common Divisors
17
##################################################
18
19
def gcd(a, b): # (1)
20
"""
21
Returns the greatest commond divisor of a and b.
22
Input:
23
a -- an integer
24
b -- an integer
25
Output:
26
an integer, the gcd of a and b
27
Examples:
28
>>> gcd(97,100)
29
1
30
>>> gcd(97 * 10**15, 19**20 * 97**2) # (2)
31
97L
32
"""
33
if a < 0: a = -a
34
if b < 0: b = -b
35
if a == 0: return b
36
if b == 0: return a
37
while b != 0:
38
(a, b) = (b, a%b)
39
return a
40
41
42
43
##################################################
44
## Enumerating Primes
45
##################################################
46
47
def primes(n):
48
"""
49
Returns a list of the primes up to n, computed
50
using the Sieve of Eratosthenes.
51
Input:
52
n -- a positive integer
53
Output:
54
list -- a list of the primes up to n
55
Examples:
56
>>> primes(10)
57
[2, 3, 5, 7]
58
>>> primes(45)
59
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43]
60
"""
61
if n <= 1: return []
62
X = [i for i in range(3,n+1) if i%2 != 0] # (1)
63
P = [2] # (2)
64
sqrt_n = sqrt(n) # (3)
65
while len(X) > 0 and X[0] <= sqrt_n: # (4)
66
p = X[0] # (5)
67
P.append(p) # (6)
68
X = [a for a in X if a%p != 0] # (7)
69
return P + X # (8)
70
71
72
73
74
##################################################
75
## Integer Factorization
76
##################################################
77
78
def trial_division(n, bound=None):
79
"""
80
Return the smallest prime divisor <= bound of the
81
positive integer n, or n if there is no such prime.
82
If the optional argument bound is omitted, then bound=n.
83
Input:
84
n -- a positive integer
85
bound - (optional) a positive integer
86
Output:
87
int -- a prime p<=bound that divides n, or n if
88
there is no such prime.
89
Examples:
90
>>> trial_division(15)
91
3
92
>>> trial_division(91)
93
7
94
>>> trial_division(11)
95
11
96
>>> trial_division(387833, 300)
97
387833
98
>>> # 300 is not big enough to split off a
99
>>> # factor, but 400 is.
100
>>> trial_division(387833, 400)
101
389
102
"""
103
if n == 1: return 1
104
for p in [2, 3, 5]:
105
if n%p == 0: return p
106
if bound == None: bound = n
107
dif = [6, 4, 2, 4, 2, 4, 6, 2]
108
m = 7; i = 1
109
while m <= bound and m*m <= n:
110
if n%m == 0:
111
return m
112
m += dif[i%8]
113
i += 1
114
return n
115
116
def factor(n):
117
"""
118
Returns the factorization of the integer n as
119
a sorted list of tuples (p,e), where the integers p
120
are output by the split algorithm.
121
Input:
122
n -- an integer
123
Output:
124
list -- factorization of n
125
Examples:
126
>>> factor(500)
127
[(2, 2), (5, 3)]
128
>>> factor(-20)
129
[(2, 2), (5, 1)]
130
>>> factor(1)
131
[]
132
>>> factor(2004)
133
[(2, 2), (3, 1), (167, 1)]
134
"""
135
if n in [-1, 0, 1]: return []
136
if n < 0: n = -n
137
F = []
138
while n != 1:
139
p = trial_division(n)
140
e = 1
141
n /= p
142
while n%p == 0:
143
e += 1; n /= p
144
F.append((p,e))
145
F.sort()
146
return F
147
148
def is_squarefree(n):
149
"""
150
Returns True if and only if n is not divisible by the square of an integer > 1.
151
"""
152
if n==0:
153
return False
154
for p, r in factor(n):
155
if r>1:
156
return False
157
return True
158
159
160
##################################################
161
## Linear Equations Modulo $n$
162
##################################################
163
164
def xgcd(a, b):
165
"""
166
Returns g, x, y such that g = x*a + y*b = gcd(a,b).
167
Input:
168
a -- an integer
169
b -- an integer
170
Output:
171
g -- an integer, the gcd of a and b
172
x -- an integer
173
y -- an integer
174
Examples:
175
>>> xgcd(2,3)
176
(1, -1, 1)
177
>>> xgcd(10, 12)
178
(2, -1, 1)
179
>>> g, x, y = xgcd(100, 2004)
180
>>> print g, x, y
181
4 -20 1
182
>>> print x*100 + y*2004
183
4
184
"""
185
if a == 0 and b == 0: return (0, 0, 1)
186
if a == 0: return (abs(b), 0, b/abs(b))
187
if b == 0: return (abs(a), a/abs(a), 0)
188
x_sign = 1; y_sign = 1
189
if a < 0: a = -a; x_sign = -1
190
if b < 0: b = -b; y_sign = -1
191
x = 1; y = 0; r = 0; s = 1
192
while b != 0:
193
(c, q) = (a%b, a/b)
194
(a, b, r, s, x, y) = (b, c, x-q*r, y-q*s, r, s)
195
return (a, x*x_sign, y*y_sign)
196
197
def inversemod(a, n):
198
"""
199
Returns the inverse of a modulo n, normalized to
200
lie between 0 and n-1. If a is not coprime to n,
201
raise an exception (this will be useful later for
202
the elliptic curve factorization method).
203
Input:
204
a -- an integer coprime to n
205
n -- a positive integer
206
Output:
207
an integer between 0 and n-1.
208
Examples:
209
>>> inversemod(1,1)
210
0
211
>>> inversemod(2,5)
212
3
213
>>> inversemod(5,8)
214
5
215
>>> inversemod(37,100)
216
73
217
"""
218
g, x, y = xgcd(a, n)
219
if g != 1:
220
raise ZeroDivisionError, (a,n)
221
assert g == 1, "a must be coprime to n."
222
return x%n
223
224
def solve_linear(a,b,n):
225
"""
226
If the equation ax = b (mod n) has a solution, return a
227
solution normalized to lie between 0 and n-1, otherwise
228
returns None.
229
Input:
230
a -- an integer
231
b -- an integer
232
n -- an integer
233
Output:
234
an integer or None
235
Examples:
236
>>> solve_linear(4, 2, 10)
237
8
238
>>> solve_linear(2, 1, 4) == None
239
True
240
"""
241
g, c, _ = xgcd(a,n) # (1)
242
if b%g != 0: return None
243
return ((b/g)*c) % n
244
245
def crt(a, b, m, n):
246
"""
247
Return the unique integer between 0 and m*n - 1
248
that reduces to a modulo n and b modulo m, where
249
the integers m and n are coprime.
250
Input:
251
a, b, m, n -- integers, with m and n coprime
252
Output:
253
int -- an integer between 0 and m*n - 1.
254
Examples:
255
>>> crt(1, 2, 3, 4)
256
10
257
>>> crt(4, 5, 10, 3)
258
14
259
>>> crt(-1, -1, 100, 101)
260
10099
261
"""
262
g, c, _ = xgcd(m, n)
263
assert g == 1, "m and n must be coprime."
264
return (a + (b-a)*c*m) % (m*n)
265
266
267
##################################################
268
## Computation of Powers
269
##################################################
270
271
def powermod(a, m, n):
272
"""
273
The m-th power of a modulo n.
274
Input:
275
a -- an integer
276
m -- a nonnegative integer
277
n -- a positive integer
278
Output:
279
int -- an integer between 0 and n-1
280
Examples:
281
>>> powermod(2,25,30)
282
2
283
>>> powermod(19,12345,100)
284
99
285
"""
286
assert m >= 0, "m must be nonnegative." # (1)
287
assert n >= 1, "n must be positive." # (2)
288
ans = 1
289
apow = a
290
while m != 0:
291
if m%2 != 0:
292
ans = (ans * apow) % n # (3)
293
apow = (apow * apow) % n # (4)
294
m /= 2
295
return ans % n
296
297
298
##################################################
299
## Finding a Primitive Root
300
##################################################
301
302
def primitive_root(p):
303
"""
304
Returns first primitive root modulo the prime p.
305
(If p is not prime, this return value of this function
306
is not meaningful.)
307
Input:
308
p -- an integer that is assumed prime
309
Output:
310
int -- a primitive root modulo p
311
Examples:
312
>>> primitive_root(7)
313
3
314
>>> primitive_root(389)
315
2
316
>>> primitive_root(5881)
317
31
318
"""
319
if p == 2: return 1
320
F = factor(p-1)
321
a = 2
322
while a < p:
323
generates = True
324
for q, _ in F:
325
if powermod(a, (p-1)/q, p) == 1:
326
generates = False
327
break
328
if generates: return a
329
a += 1
330
assert False, "p must be prime."
331
332
333
##################################################
334
## Determining Whether a Number is Prime
335
##################################################
336
337
def is_pseudoprime(n, bases = [2,3,5,7]):
338
"""
339
Returns True if n is a pseudoprime to the given bases,
340
in the sense that n>1 and b**(n-1) = 1 (mod n) for each
341
elements b of bases, with b not a multiple of n, and
342
False otherwise.
343
Input:
344
n -- an integer
345
bases -- a list of integers
346
Output:
347
bool
348
Examples:
349
>>> is_pseudoprime(91)
350
False
351
>>> is_pseudoprime(97)
352
True
353
>>> is_pseudoprime(1)
354
False
355
>>> is_pseudoprime(-2)
356
True
357
>>> s = [x for x in range(10000) if is_pseudoprime(x)]
358
>>> t = primes(10000)
359
>>> s == t
360
True
361
>>> is_pseudoprime(29341) # first non-prime pseudoprime
362
True
363
>>> factor(29341)
364
[(13, 1), (37, 1), (61, 1)]
365
"""
366
if n < 0: n = -n
367
if n <= 1: return False
368
for b in bases:
369
if b%n != 0 and powermod(b, n-1, n) != 1:
370
return False
371
return True
372
373
374
def miller_rabin(n, num_trials=4):
375
"""
376
True if n is likely prime, and False if n
377
is definitely not prime. Increasing num_trials
378
increases the probability of correctness.
379
(One can prove that the probability that this
380
function returns True when it should return
381
False is at most (1/4)**num_trials.)
382
Input:
383
n -- an integer
384
num_trials -- the number of trials with the
385
primality test.
386
Output:
387
bool -- whether or not n is probably prime.
388
Examples:
389
>>> miller_rabin(91)
390
False #rand
391
>>> miller_rabin(97)
392
True #rand
393
>>> s = [x for x in range(1000) if miller_rabin(x, 1)]
394
>>> t = primes(1000)
395
>>> print len(s), len(t) # so 1 in 25 wrong
396
175 168 #rand
397
>>> s = [x for x in range(1000) if miller_rabin(x)]
398
>>> s == t
399
True #rand
400
"""
401
if n < 0: n = -n
402
if n in [2,3]: return True
403
if n <= 4: return False
404
m = n - 1
405
k = 0
406
while m%2 == 0:
407
k += 1; m /= 2
408
# Now n - 1 = (2**k) * m with m odd
409
for i in range(num_trials):
410
a = randrange(2,n-1) # (1)
411
apow = powermod(a, m, n)
412
if not (apow in [1, n-1]):
413
some_minus_one = False
414
for r in range(k-1): # (2)
415
apow = (apow**2)%n
416
if apow == n-1:
417
some_minus_one = True
418
break # (3)
419
if (apow in [1, n-1]) or some_minus_one:
420
prob_prime = True
421
else:
422
return False
423
return True
424
425
426
##################################################
427
## The Diffie-Hellman Key Exchange
428
##################################################
429
430
def random_prime(num_digits, is_prime = miller_rabin):
431
"""
432
Returns a random prime with num_digits digits.
433
Input:
434
num_digits -- a positive integer
435
is_prime -- (optional argment)
436
a function of one argument n that
437
returns either True if n is (probably)
438
prime and False otherwise.
439
Output:
440
int -- an integer
441
Examples:
442
>>> random_prime(10)
443
8599796717L #rand
444
>>> random_prime(40)
445
1311696770583281776596904119734399028761L #rand
446
"""
447
n = randrange(10**(num_digits-1), 10**num_digits)
448
if n%2 == 0: n += 1
449
while not is_prime(n): n += 2
450
return n
451
452
def dh_init(p):
453
"""
454
Generates and returns a random positive
455
integer n < p and the power 2^n (mod p).
456
Input:
457
p -- an integer that is prime
458
Output:
459
int -- a positive integer < p, a secret
460
int -- 2^n (mod p), send to other user
461
Examples:
462
>>> p = random_prime(20)
463
>>> dh_init(p)
464
(15299007531923218813L, 4715333264598442112L) #rand
465
"""
466
n = randrange(2,p)
467
return n, powermod(2,n,p)
468
469
def dh_secret(p, n, mpow):
470
"""
471
Computes the shared Diffie-Hellman secret key.
472
Input:
473
p -- an integer that is prime
474
n -- an integer: output by dh_init for this user
475
mpow-- an integer: output by dh_init for other user
476
Output:
477
int -- the shared secret key.
478
Examples:
479
>>> p = random_prime(20)
480
>>> n, npow = dh_init(p)
481
>>> m, mpow = dh_init(p)
482
>>> dh_secret(p, n, mpow)
483
15695503407570180188L #rand
484
>>> dh_secret(p, m, npow)
485
15695503407570180188L #rand
486
"""
487
return powermod(mpow,n,p)
488
489
490
491
492
493
494
##################################################
495
## Encoding Strings as Lists of Integers
496
##################################################
497
498
def str_to_numlist(s, bound):
499
"""
500
Returns a sequence of integers between 0 and bound-1
501
that encodes the string s. Randomization is included,
502
so the same string is very likely to encode differently
503
each time this function is called.
504
Input:
505
s -- a string
506
bound -- an integer >= 256
507
Output:
508
list -- encoding of s as a list of integers
509
Examples:
510
>>> str_to_numlist("Run!", 1000)
511
[82, 117, 110, 33] #rand
512
>>> str_to_numlist("TOP SECRET", 10**20)
513
[4995371940984439512L, 92656709616492L] #rand
514
"""
515
assert bound >= 256, "bound must be at least 256."
516
n = int(log(bound) / log(256)) # (1)
517
salt = min(int(n/8) + 1, n-1) # (2)
518
i = 0; v = []
519
while i < len(s): # (3)
520
c = 0; pow = 1
521
for j in range(n): # (4)
522
if j < salt:
523
c += randrange(1,256)*pow # (5)
524
else:
525
if i >= len(s): break
526
c += ord(s[i])*pow # (6)
527
i += 1
528
pow *= 256
529
v.append(c)
530
return v
531
532
def numlist_to_str(v, bound):
533
"""
534
Returns the string that the sequence v of
535
integers encodes.
536
Input:
537
v -- list of integers between 0 and bound-1
538
bound -- an integer >= 256
539
Output:
540
str -- decoding of v as a string
541
Examples:
542
>>> print numlist_to_str([82, 117, 110, 33], 1000)
543
Run!
544
>>> x = str_to_numlist("TOP SECRET MESSAGE", 10**20)
545
>>> print numlist_to_str(x, 10**20)
546
TOP SECRET MESSAGE
547
"""
548
assert bound >= 256, "bound must be at least 256."
549
n = int(log(bound) / log(256))
550
s = ""
551
salt = min(int(n/8) + 1, n-1)
552
for x in v:
553
for j in range(n):
554
y = x%256
555
if y > 0 and j >= salt:
556
s += chr(y)
557
x /= 256
558
return s
559
560
561
##################################################
562
## The RSA Cryptosystem
563
##################################################
564
565
def rsa_init(p, q):
566
"""
567
Returns defining parameters (e, d, n) for the RSA
568
cryptosystem defined by primes p and q. The
569
primes p and q may be computed using the
570
random_prime functions.
571
Input:
572
p -- a prime integer
573
q -- a prime integer
574
Output:
575
Let m be (p-1)*(q-1).
576
e -- an encryption key, which is a randomly
577
chosen integer between 2 and m-1
578
d -- the inverse of e modulo eulerphi(p*q),
579
as an integer between 2 and m-1
580
n -- the product p*q.
581
Examples:
582
>>> p = random_prime(20); q = random_prime(20)
583
>>> print p, q
584
37999414403893878907L 25910385856444296437L #rand
585
>>> e, d, n = rsa_init(p, q)
586
>>> e
587
5 #rand
588
>>> d
589
787663591619054108576589014764921103213L #rand
590
>>> n
591
984579489523817635784646068716489554359L #rand
592
"""
593
m = (p-1)*(q-1)
594
e = 3
595
while gcd(e, m) != 1: e += 1
596
d = inversemod(e, m)
597
return e, d, p*q
598
599
def rsa_encrypt(plain_text, e, n):
600
"""
601
Encrypt plain_text using the encrypt
602
exponent e and modulus n.
603
Input:
604
plain_text -- arbitrary string
605
e -- an integer, the encryption exponent
606
n -- an integer, the modulus
607
Output:
608
str -- the encrypted cipher text
609
Examples:
610
>>> e = 1413636032234706267861856804566528506075
611
>>> n = 2109029637390047474920932660992586706589
612
>>> rsa_encrypt("Run Nikita!", e, n)
613
[78151883112572478169375308975376279129L] #rand
614
>>> rsa_encrypt("Run Nikita!", e, n)
615
[1136438061748322881798487546474756875373L] #rand
616
"""
617
plain = str_to_numlist(plain_text, n)
618
return [powermod(x, e, n) for x in plain]
619
620
def rsa_decrypt(cipher, d, n):
621
"""
622
Decrypt the cipher_text using the decryption
623
exponent d and modulus n.
624
Input:
625
cipher_text -- list of integers output
626
by rsa_encrypt
627
Output:
628
str -- the unencrypted plain text
629
Examples:
630
>>> d = 938164637865370078346033914094246201579
631
>>> n = 2109029637390047474920932660992586706589
632
>>> msg1 = [1071099761433836971832061585353925961069]
633
>>> msg2 = [1336506586627416245118258421225335020977]
634
>>> rsa_decrypt(msg1, d, n)
635
'Run Nikita!'
636
>>> rsa_decrypt(msg2, d, n)
637
'Run Nikita!'
638
"""
639
plain = [powermod(x, d, n) for x in cipher]
640
return numlist_to_str(plain, n)
641
642
643
##################################################
644
## Computing the Legendre Symbol
645
##################################################
646
647
def legendre(a, p):
648
"""
649
Returns the Legendre symbol a over p, where
650
p is an odd prime.
651
Input:
652
a -- an integer
653
p -- an odd prime (primality not checked)
654
Output:
655
int: -1 if a is not a square mod p,
656
0 if gcd(a,p) is not 1
657
1 if a is a square mod p.
658
Examples:
659
>>> legendre(2, 5)
660
-1
661
>>> legendre(3, 3)
662
0
663
>>> legendre(7, 2003)
664
-1
665
"""
666
assert p%2 == 1, "p must be an odd prime."
667
b = powermod(a, (p-1)/2, p)
668
if b == 1: return 1
669
elif b == p-1: return -1
670
return 0
671
672
673
##################################################
674
## In this section we implement the algorithm
675
##################################################
676
677
def sqrtmod(a, p):
678
"""
679
Returns a square root of a modulo p.
680
Input:
681
a -- an integer that is a perfect
682
square modulo p (this is checked)
683
p -- a prime
684
Output:
685
int -- a square root of a, as an integer
686
between 0 and p-1.
687
Examples:
688
>>> sqrtmod(4, 5) # p == 1 (mod 4)
689
3 #rand
690
>>> sqrtmod(13, 23) # p == 3 (mod 4)
691
6 #rand
692
>>> sqrtmod(997, 7304723089) # p == 1 (mod 4)
693
761044645L #rand
694
"""
695
a %= p
696
if p == 2: return a
697
assert legendre(a, p) == 1, "a must be a square mod p."
698
if p%4 == 3: return powermod(a, (p+1)/4, p)
699
700
def mul(x, y): # multiplication in R # (1)
701
return ((x[0]*y[0] + a*y[1]*x[1]) % p, \
702
(x[0]*y[1] + x[1]*y[0]) % p)
703
def pow(x, n): # exponentiation in R # (2)
704
ans = (1,0)
705
xpow = x
706
while n != 0:
707
if n%2 != 0: ans = mul(ans, xpow)
708
xpow = mul(xpow, xpow)
709
n /= 2
710
return ans
711
712
while True:
713
z = randrange(2,p)
714
u, v = pow((1,z), (p-1)/2)
715
if v != 0:
716
vinv = inversemod(v, p)
717
for x in [-u*vinv, (1-u)*vinv, (-1-u)*vinv]:
718
if (x*x)%p == a: return x%p
719
assert False, "Bug in sqrtmod."
720
721
722
##################################################
723
## Continued Fractions
724
##################################################
725
726
def convergents(v):
727
"""
728
Returns the partial convergents of the continued
729
fraction v.
730
Input:
731
v -- list of integers [a0, a1, a2, ..., am]
732
Output:
733
list -- list [(p0,q0), (p1,q1), ...]
734
of pairs (pm,qm) such that the mth
735
convergent of v is pm/qm.
736
Examples:
737
>>> convergents([1, 2])
738
[(1, 1), (3, 2)]
739
>>> convergents([3, 7, 15, 1, 292])
740
[(3, 1), (22, 7), (333, 106), (355, 113), (103993, 33102)]
741
"""
742
w = [(0,1), (1,0)]
743
for n in range(len(v)):
744
pn = v[n]*w[n+1][0] + w[n][0]
745
qn = v[n]*w[n+1][1] + w[n][1]
746
w.append((pn, qn))
747
del w[0]; del w[0] # remove first entries of w
748
return w
749
750
def contfrac_rat(numer, denom):
751
"""
752
Returns the continued fraction of the rational
753
number numer/denom.
754
Input:
755
numer -- an integer
756
denom -- a positive integer coprime to num
757
Output
758
list -- the continued fraction [a0, a1, ..., am]
759
of the rational number num/denom.
760
Examples:
761
>>> contfrac_rat(3, 2)
762
[1, 2]
763
>>> contfrac_rat(103993, 33102)
764
[3, 7, 15, 1, 292]
765
"""
766
assert denom > 0, "denom must be positive"
767
a = numer; b = denom
768
v = []
769
while b != 0:
770
v.append(a/b)
771
(a, b) = (b, a%b)
772
return v
773
774
def contfrac_float(x):
775
"""
776
Returns the continued fraction of the floating
777
point number x, computed using the continued
778
fraction procedure, and the sequence of partial
779
convergents.
780
Input:
781
x -- a floating point number (decimal)
782
Output:
783
list -- the continued fraction [a0, a1, ...]
784
obtained by applying the continued
785
fraction procedure to x to the
786
precision of this computer.
787
list -- the list [(p0,q0), (p1,q1), ...]
788
of pairs (pm,qm) such that the mth
789
convergent of continued fraction
790
is pm/qm.
791
Examples:
792
>>> v, w = contfrac_float(3.14159); print v
793
[3, 7, 15, 1, 25, 1, 7, 4]
794
>>> v, w = contfrac_float(2.718); print v
795
[2, 1, 2, 1, 1, 4, 1, 12]
796
>>> contfrac_float(0.3)
797
([0, 3, 2, 1], [(0, 1), (1, 3), (2, 7), (3, 10)])
798
"""
799
v = []
800
w = [(0,1), (1,0)] # keep track of convergents
801
start = x
802
while True:
803
a = int(x) # (1)
804
v.append(a)
805
n = len(v)-1
806
pn = v[n]*w[n+1][0] + w[n][0]
807
qn = v[n]*w[n+1][1] + w[n][1]
808
w.append((pn, qn))
809
x -= a
810
if abs(start - float(pn)/float(qn)) == 0: # (2)
811
del w[0]; del w[0] # (3)
812
return v, w
813
x = 1/x
814
815
def sum_of_two_squares(p):
816
"""
817
Uses continued fractions to efficiently compute
818
a representation of the prime p as a sum of
819
two squares. The prime p must be 1 modulo 4.
820
Input:
821
p -- a prime congruent 1 modulo 4.
822
Output:
823
integers a, b such that p is a*a + b*b
824
Examples:
825
>>> sum_of_two_squares(5)
826
(1, 2)
827
>>> sum_of_two_squares(389)
828
(10, 17)
829
>>> sum_of_two_squares(86295641057493119033)
830
(789006548L, 9255976973L)
831
"""
832
assert p%4 == 1, "p must be 1 modulo 4"
833
r = sqrtmod(-1, p) # (1)
834
v = contfrac_rat(-r, p) # (2)
835
n = int(sqrt(p))
836
for a, b in convergents(v): # (3)
837
c = r*b + p*a # (4)
838
if -n <= c and c <= n: return (abs(b),abs(c))
839
assert False, "Bug in sum_of_two_squares." # (5)
840
841
842
##################################################
843
## Arithmetic
844
##################################################
845
846
def ellcurve_add(E, P1, P2):
847
"""
848
Returns the sum of P1 and P2 on the elliptic
849
curve E.
850
Input:
851
E -- an elliptic curve over Z/pZ, given by a
852
triple of integers (a, b, p), with p odd.
853
P1 --a pair of integers (x, y) or the
854
string "Identity".
855
P2 -- same type as P1
856
Output:
857
R -- same type as P1
858
Examples:
859
>>> E = (1, 0, 7) # y**2 = x**3 + x over Z/7Z
860
>>> P1 = (1, 3); P2 = (3, 3)
861
>>> ellcurve_add(E, P1, P2)
862
(3, 4)
863
>>> ellcurve_add(E, P1, (1, 4))
864
'Identity'
865
>>> ellcurve_add(E, "Identity", P2)
866
(3, 3)
867
"""
868
a, b, p = E
869
assert p > 2, "p must be odd."
870
if P1 == "Identity": return P2
871
if P2 == "Identity": return P1
872
x1, y1 = P1; x2, y2 = P2
873
x1 %= p; y1 %= p; x2 %= p; y2 %= p
874
if x1 == x2 and y1 == p-y2: return "Identity"
875
if P1 == P2:
876
if y1 == 0: return "Identity"
877
lam = (3*x1**2+a) * inversemod(2*y1,p)
878
else:
879
lam = (y1 - y2) * inversemod(x1 - x2, p)
880
x3 = lam**2 - x1 - x2
881
y3 = -lam*x3 - y1 + lam*x1
882
return (x3%p, y3%p)
883
884
def ellcurve_mul(E, m, P):
885
"""
886
Returns the multiple m*P of the point P on
887
the elliptic curve E.
888
Input:
889
E -- an elliptic curve over Z/pZ, given by a
890
triple (a, b, p).
891
m -- an integer
892
P -- a pair of integers (x, y) or the
893
string "Identity"
894
Output:
895
A pair of integers or the string "Identity".
896
Examples:
897
>>> E = (1, 0, 7)
898
>>> P = (1, 3)
899
>>> ellcurve_mul(E, 5, P)
900
(1, 3)
901
>>> ellcurve_mul(E, 9999, P)
902
(1, 4)
903
"""
904
assert m >= 0, "m must be nonnegative."
905
power = P
906
mP = "Identity"
907
while m != 0:
908
if m%2 != 0: mP = ellcurve_add(E, mP, power)
909
power = ellcurve_add(E, power, power)
910
m /= 2
911
return mP
912
913
914
##################################################
915
## Integer Factorization
916
##################################################
917
918
def lcm_to(B):
919
"""
920
Returns the least common multiple of all
921
integers up to B.
922
Input:
923
B -- an integer
924
Output:
925
an integer
926
Examples:
927
>>> lcm_to(5)
928
60
929
>>> lcm_to(20)
930
232792560
931
>>> lcm_to(100)
932
69720375229712477164533808935312303556800L
933
"""
934
ans = 1
935
logB = log(B)
936
for p in primes(B):
937
ans *= p**int(logB/log(p))
938
return ans
939
940
def pollard(N, m):
941
"""
942
Use Pollard's (p-1)-method to try to find a
943
nontrivial divisor of N.
944
Input:
945
N -- a positive integer
946
m -- a positive integer, the least common
947
multiple of the integers up to some
948
bound, computed using lcm_to.
949
Output:
950
int -- an integer divisor of n
951
Examples:
952
>>> pollard(5917, lcm_to(5))
953
61
954
>>> pollard(779167, lcm_to(5))
955
779167
956
>>> pollard(779167, lcm_to(15))
957
2003L
958
>>> pollard(187, lcm_to(15))
959
11
960
>>> n = random_prime(5)*random_prime(5)*random_prime(5)
961
>>> pollard(n, lcm_to(100))
962
315873129119929L #rand
963
>>> pollard(n, lcm_to(1000))
964
3672986071L #rand
965
"""
966
for a in [2, 3]:
967
x = powermod(a, m, N) - 1
968
g = gcd(x, N)
969
if g != 1 and g != N:
970
return g
971
return N
972
973
def randcurve(p):
974
"""
975
Construct a somewhat random elliptic curve
976
over Z/pZ and a random point on that curve.
977
Input:
978
p -- a positive integer
979
Output:
980
tuple -- a triple E = (a, b, p)
981
P -- a tuple (x,y) on E
982
Examples:
983
>>> p = random_prime(20); p
984
17758176404715800329L #rand
985
>>> E, P = randcurve(p)
986
>>> print E
987
(15299007531923218813L, 1, 17758176404715800329L) #rand
988
>>> print P
989
(0, 1)
990
"""
991
assert p > 2, "p must be > 2."
992
a = randrange(p)
993
while gcd(4*a**3 + 27, p) != 1:
994
a = randrange(p)
995
return (a, 1, p), (0,1)
996
997
def elliptic_curve_method(N, m, tries=5):
998
"""
999
Use the elliptic curve method to try to find a
1000
nontrivial divisor of N.
1001
Input:
1002
N -- a positive integer
1003
m -- a positive integer, the least common
1004
multiple of the integers up to some
1005
bound, computed using lcm_to.
1006
tries -- a positive integer, the number of
1007
different elliptic curves to try
1008
Output:
1009
int -- a divisor of n
1010
Examples:
1011
>>> elliptic_curve_method(5959, lcm_to(20))
1012
59L #rand
1013
>>> elliptic_curve_method(10007*20011, lcm_to(100))
1014
10007L #rand
1015
>>> p = random_prime(9); q = random_prime(9)
1016
>>> n = p*q; n
1017
117775675640754751L #rand
1018
>>> elliptic_curve_method(n, lcm_to(100))
1019
117775675640754751L #rand
1020
>>> elliptic_curve_method(n, lcm_to(500))
1021
117775675640754751L #rand
1022
"""
1023
for _ in range(tries): # (1)
1024
E, P = randcurve(N) # (2)
1025
try: # (3)
1026
Q = ellcurve_mul(E, m, P) # (4)
1027
except ZeroDivisionError, x: # (5)
1028
g = gcd(x[0],N) # (6)
1029
if g != 1 or g != N: return g # (7)
1030
return N
1031
1032
1033
##################################################
1034
## ElGamal Elliptic Curve Cryptosystem
1035
##################################################
1036
1037
def elgamal_init(p):
1038
"""
1039
Constructs an ElGamal cryptosystem over Z/pZ, by
1040
choosing a random elliptic curve E over Z/pZ, a
1041
point B in E(Z/pZ), and a random integer n. This
1042
function returns the public key as a 4-tuple
1043
(E, B, n*B) and the private key n.
1044
Input:
1045
p -- a prime number
1046
Output:
1047
tuple -- the public key as a 3-tuple
1048
(E, B, n*B), where E = (a, b, p) is an
1049
elliptic curve over Z/pZ, B = (x, y) is
1050
a point on E, and n*B = (x',y') is
1051
the sum of B with itself n times.
1052
int -- the private key, which is the pair (E, n)
1053
Examples:
1054
>>> p = random_prime(20); p
1055
17758176404715800329L #rand
1056
>>> public, private = elgamal_init(p)
1057
>>> print "E =", public[0]
1058
E = (15299007531923218813L, 1, 17758176404715800329L) #rand
1059
>>> print "B =", public[1]
1060
B = (0, 1)
1061
>>> print "nB =", public[2]
1062
nB = (5619048157825840473L, 151469105238517573L) #rand
1063
>>> print "n =", private[1]
1064
n = 12608319787599446459 #rand
1065
"""
1066
E, B = randcurve(p)
1067
n = randrange(2,p)
1068
nB = ellcurve_mul(E, n, B)
1069
return (E, B, nB), (E, n)
1070
1071
def elgamal_encrypt(plain_text, public_key):
1072
"""
1073
Encrypt a message using the ElGamal cryptosystem
1074
with given public_key = (E, B, n*B).
1075
Input:
1076
plain_text -- a string
1077
public_key -- a triple (E, B, n*B), as output
1078
by elgamal_init.
1079
Output:
1080
list -- a list of pairs of points on E that
1081
represent the encrypted message
1082
Examples:
1083
>>> public, private = elgamal_init(random_prime(20))
1084
>>> elgamal_encrypt("RUN", public)
1085
[((6004308617723068486L, 15578511190582849677L), \ #rand
1086
(7064405129585539806L, 8318592816457841619L))] #rand
1087
"""
1088
E, B, nB = public_key
1089
a, b, p = E
1090
assert p > 10000, "p must be at least 10000."
1091
v = [1000*x for x in \
1092
str_to_numlist(plain_text, p/1000)] # (1)
1093
cipher = []
1094
for x in v:
1095
while not legendre(x**3+a*x+b, p)==1: # (2)
1096
x = (x+1)%p
1097
y = sqrtmod(x**3+a*x+b, p) # (3)
1098
P = (x,y)
1099
r = randrange(1,p)
1100
encrypted = (ellcurve_mul(E, r, B), \
1101
ellcurve_add(E, P, ellcurve_mul(E,r,nB)))
1102
cipher.append(encrypted)
1103
return cipher
1104
1105
def elgamal_decrypt(cipher_text, private_key):
1106
"""
1107
Encrypt a message using the ElGamal cryptosystem
1108
with given public_key = (E, B, n*B).
1109
Input:
1110
cipher_text -- list of pairs of points on E output
1111
by elgamal_encrypt.
1112
Output:
1113
str -- the unencrypted plain text
1114
Examples:
1115
>>> public, private = elgamal_init(random_prime(20))
1116
>>> v = elgamal_encrypt("TOP SECRET MESSAGE!", public)
1117
>>> print elgamal_decrypt(v, private)
1118
TOP SECRET MESSAGE!
1119
"""
1120
E, n = private_key
1121
p = E[2]
1122
plain = []
1123
for rB, P_plus_rnB in cipher_text:
1124
nrB = ellcurve_mul(E, n, rB)
1125
minus_nrB = (nrB[0], -nrB[1])
1126
P = ellcurve_add(E, minus_nrB, P_plus_rnB)
1127
plain.append(P[0]/1000)
1128
return numlist_to_str(plain, p/1000)
1129
1130
1131
##################################################
1132
## Associativity of the Group Law
1133
##################################################
1134
1135
# The variable order is x1, x2, x3, y1, y2, y3, a, b
1136
class Poly: # (1)
1137
def __init__(self, d): # (2)
1138
self.v = dict(d)
1139
def __cmp__(self, other): # (3)
1140
self.normalize(); other.normalize() # (4)
1141
if self.v == other.v: return 0
1142
return -1
1143
1144
def __add__(self, other): # (5)
1145
w = Poly(self.v)
1146
for m in other.monomials():
1147
w[m] += other[m]
1148
return w
1149
def __sub__(self, other):
1150
w = Poly(self.v)
1151
for m in other.monomials():
1152
w[m] -= other[m]
1153
return w
1154
def __mul__(self, other):
1155
if len(self.v) == 0 or len(other.v) == 0:
1156
return Poly([])
1157
m1 = self.monomials(); m2 = other.monomials()
1158
r = Poly([])
1159
for m1 in self.monomials():
1160
for m2 in other.monomials():
1161
z = [m1[i] + m2[i] for i in range(8)]
1162
r[z] += self[m1]*other[m2]
1163
return r
1164
def __neg__(self):
1165
v = {}
1166
for m in self.v.keys():
1167
v[m] = -self.v[m]
1168
return Poly(v)
1169
def __div__(self, other):
1170
return Frac(self, other)
1171
1172
def __getitem__(self, m): # (6)
1173
m = tuple(m)
1174
if not self.v.has_key(m): self.v[m] = 0
1175
return self.v[m]
1176
def __setitem__(self, m, c):
1177
self.v[tuple(m)] = c
1178
def __delitem__(self, m):
1179
del self.v[tuple(m)]
1180
1181
def monomials(self): # (7)
1182
return self.v.keys()
1183
def normalize(self): # (8)
1184
while True:
1185
finished = True
1186
for m in self.monomials():
1187
if self[m] == 0:
1188
del self[m]
1189
continue
1190
for i in range(3):
1191
if m[3+i] >= 2:
1192
finished = False
1193
nx0 = list(m); nx0[3+i] -= 2;
1194
nx0[7] += 1
1195
nx1 = list(m); nx1[3+i] -= 2;
1196
nx1[i] += 1; nx1[6] += 1
1197
nx3 = list(m); nx3[3+i] -= 2;
1198
nx3[i] += 3
1199
c = self[m]
1200
del self[m]
1201
self[nx0] += c;
1202
self[nx1] += c;
1203
self[nx3] += c
1204
# end for
1205
# end for
1206
if finished: return
1207
# end while
1208
1209
one = Poly({(0,0,0,0,0,0,0,0):1}) # (9)
1210
1211
class Frac: # (10)
1212
def __init__(self, num, denom=one):
1213
self.num = num; self.denom = denom
1214
def __cmp__(self, other): # (11)
1215
if self.num * other.denom == self.denom * other.num:
1216
return 0
1217
return -1
1218
1219
def __add__(self, other): # (12)
1220
return Frac(self.num*other.denom + \
1221
self.denom*other.num,
1222
self.denom*other.denom)
1223
def __sub__(self, other):
1224
return Frac(self.num*other.denom - \
1225
self.denom*other.num,
1226
self.denom*other.denom)
1227
def __mul__(self, other):
1228
return Frac(self.num*other.num, \
1229
self.denom*other.denom)
1230
def __div__(self, other):
1231
return Frac(self.num*other.denom, \
1232
self.denom*other.num)
1233
def __neg__(self):
1234
return Frac(-self.num,self.denom)
1235
1236
def var(i): # (14)
1237
v = [0,0,0,0,0,0,0,0]; v[i]=1;
1238
return Frac(Poly({tuple(v):1}))
1239
1240
def prove_associative(): # (15)
1241
x1 = var(0); x2 = var(1); x3 = var(2)
1242
y1 = var(3); y2 = var(4); y3 = var(5)
1243
a = var(6); b = var(7)
1244
1245
lambda12 = (y1 - y2)/(x1 - x2)
1246
x4 = lambda12*lambda12 - x1 - x2
1247
nu12 = y1 - lambda12*x1
1248
y4 = -lambda12*x4 - nu12
1249
lambda23 = (y2 - y3)/(x2 - x3)
1250
x5 = lambda23*lambda23 - x2 - x3
1251
nu23 = y2 - lambda23*x2
1252
y5 = -lambda23*x5 - nu23
1253
s1 = (x1 - x5)*(x1 - x5)*((y3 - y4)*(y3 - y4) \
1254
- (x3 + x4)*(x3 - x4)*(x3 - x4))
1255
s2 = (x3 - x4)*(x3 - x4)*((y1 - y5)*(y1 - y5) \
1256
- (x1 + x5)*(x1 - x5)*(x1 - x5))
1257
print "Associative?"
1258
print s1 == s2 # (17)
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
##########################################################
1275
# The following are all the examples not in functions. #
1276
##########################################################
1277
1278
def examples():
1279
"""
1280
>>> from ent import *
1281
>>> 7/5
1282
1
1283
>>> -2/3
1284
-1
1285
>>> 1.0/3
1286
0.33333333333333331
1287
>>> float(2)/3
1288
0.66666666666666663
1289
>>> 100**2
1290
10000
1291
>>> 10**20
1292
100000000000000000000L
1293
>>> range(10) # range(n) is from 0 to n-1
1294
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1295
>>> range(3,10) # range(a,b) is from a to b-1
1296
[3, 4, 5, 6, 7, 8, 9]
1297
>>> [x**2 for x in range(10)]
1298
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
1299
>>> [x**2 for x in range(10) if x%4 == 1]
1300
[1, 25, 81]
1301
>>> [1,2,3] + [5,6,7] # concatenation
1302
[1, 2, 3, 5, 6, 7]
1303
>>> len([1,2,3,4,5]) # length of a list
1304
5
1305
>>> x = [4,7,10,'gcd'] # mixing types is fine
1306
>>> x[0] # 0-based indexing
1307
4
1308
>>> x[3]
1309
'gcd'
1310
>>> x[3] = 'lagrange' # assignment
1311
>>> x.append("fermat") # append to end of list
1312
>>> x
1313
[4, 7, 10, 'lagrange', 'fermat']
1314
>>> del x[3] # delete entry 3 from list
1315
>>> x
1316
[4, 7, 10, 'fermat']
1317
>>> v = primes(10000)
1318
>>> len(v) # this is pi(10000)
1319
1229
1320
>>> len([x for x in v if x < 1000]) # pi(1000)
1321
168
1322
>>> len([x for x in v if x < 5000]) # pi(5000)
1323
669
1324
>>> x=(1, 2, 3) # creation
1325
>>> x[1]
1326
2
1327
>>> (1, 2, 3) + (4, 5, 6) # concatenation
1328
(1, 2, 3, 4, 5, 6)
1329
>>> (a, b) = (1, 2) # assignment assigns to each member
1330
>>> print a, b
1331
1 2
1332
>>> for (c, d) in [(1,2), (5,6)]:
1333
... print c, d
1334
1 2
1335
5 6
1336
>>> x = 1, 2 # parentheses optional in creation
1337
>>> x
1338
(1, 2)
1339
>>> c, d = x # parentheses also optional
1340
>>> print c, d
1341
1 2
1342
>>> P = [p for p in range(200000) if is_pseudoprime(p)]
1343
>>> Q = primes(200000)
1344
>>> R = [x for x in P if not (x in Q)]; print R
1345
[29341, 46657, 75361, 115921, 162401]
1346
>>> [n for n in R if is_pseudoprime(n,[2,3,5,7,11,13])]
1347
[162401]
1348
>>> factor(162401)
1349
[(17, 1), (41, 1), (233, 1)]
1350
>>> p = random_prime(50)
1351
>>> p
1352
13537669335668960267902317758600526039222634416221L #rand
1353
>>> n, npow = dh_init(p)
1354
>>> n
1355
8520467863827253595224582066095474547602956490963L #rand
1356
>>> npow
1357
3206478875002439975737792666147199399141965887602L #rand
1358
>>> m, mpow = dh_init(p)
1359
>>> m
1360
3533715181946048754332697897996834077726943413544L #rand
1361
>>> mpow
1362
3465862701820513569217254081716392362462604355024L #rand
1363
>>> dh_secret(p, n, mpow)
1364
12931853037327712933053975672241775629043437267478L #rand
1365
>>> dh_secret(p, m, npow)
1366
12931853037327712933053975672241775629043437267478L #rand
1367
>>> prove_associative()
1368
Associative?
1369
True
1370
>>> len(primes(10000))
1371
1229
1372
>>> 10000/log(10000)
1373
1085.73620476
1374
>>> powermod(3,45,100)
1375
43
1376
>>> inversemod(37, 112)
1377
109
1378
>>> powermod(102, 70, 113)
1379
98
1380
>>> powermod(99, 109, 113)
1381
60
1382
>>> P = primes(1000)
1383
>>> Q = [p for p in P if primitive_root(p) == 2]
1384
>>> print len(Q), len(P)
1385
67 168
1386
>>> P = primes(50000)
1387
>>> Q = [primitive_root(p) for p in P]
1388
>>> Q.index(37)
1389
3893
1390
>>> P[3893]
1391
36721
1392
>>> for n in range(97):
1393
... if powermod(5,n,97)==3: print n
1394
70
1395
>>> factor(5352381469067)
1396
[(141307, 1), (37877681L, 1)]
1397
>>> d=inversemod(4240501142039, (141307-1)*(37877681-1))
1398
>>> d
1399
5195621988839L
1400
>>> convergents([-3,1,1,1,1,3])
1401
[(-3, 1), (-2, 1), (-5, 2), (-7, 3), \
1402
(-12, 5), (-43, 18)]
1403
>>> convergents([0,2,4,1,8,2])
1404
[(0, 1), (1, 2), (4, 9), (5, 11), \
1405
(44, 97), (93, 205)]
1406
>>> import math
1407
>>> e = math.exp(1)
1408
>>> v, convs = contfrac_float(e)
1409
>>> [(a,b) for a, b in convs if \
1410
abs(e - a*1.0/b) < 1/(math.sqrt(5)*b**2)]
1411
[(3, 1), (19, 7), (193, 71), (2721, 1001),\
1412
(49171, 18089), (1084483, 398959),\
1413
(28245729, 10391023), (325368125, 119696244)]
1414
>>> factor(12345)
1415
[(3, 1), (5, 1), (823, 1)]
1416
>>> factor(729)
1417
[(3, 6)]
1418
>>> factor(5809961789)
1419
[(5809961789L, 1)]
1420
>>> 5809961789 % 4
1421
1L
1422
>>> sum_of_two_squares(5809961789)
1423
(51542L, 56155L)
1424
>>> N = [60 + s for s in range(-15,16)]
1425
>>> def is_powersmooth(B, x):
1426
... for p, e in factor(x):
1427
... if p**e > B: return False
1428
... return True
1429
>>> Ns = [x for x in N if is_powersmooth(20, x)]
1430
>>> print len(Ns), len(N), len(Ns)*1.0/len(N)
1431
14 31 0.451612903226
1432
>>> P = [x for x in range(10**12, 10**12+1000)\
1433
if miller_rabin(x)]
1434
>>> Ps = [x for x in P if \
1435
is_powersmooth(10000, x-1)]
1436
>>> print len(Ps), len(P), len(Ps)*1.0/len(P)
1437
2 37 0.0540540540541
1438
1439
"""
1440
1441
1442
if __name__ == '__main__':
1443
import doctest, sys
1444
doctest.testmod(sys.modules[__name__])
1445