#include <stdio.h>
#include <assert.h>
#include <stdlib.h>
#include <time.h>
#define ran(low,high) ((rand()%((high)-(low)+1))+(low))
#define NELEMS(x) ((sizeof(x)) / (sizeof((x)[0])))
void makePrimes(int primes[], int num)
{
primes[0] = 2;
for(int p = 3, cnt = 1; cnt < num; p += 2)
{
for(int k = 0; ; ++k)
{
div_t dt = div(p, primes[k]);
if(dt.rem == 0) break;
if(dt.quot <= primes[k]) { primes[cnt++] = p; break; }
}
}
}
void rand_pq(int& p, int& q)
{
static int primes[1024*8];
static int low, high;
if(low == 0 && high == 0)
{
makePrimes(primes, NELEMS(primes));
for(int i = 0; i < NELEMS(primes); ++i)
{
if(!low && primes[i] > 10000) { low = i; continue; }
if(!high && primes[i] > 40000) { high = i; break; }
}
srand(time(NULL));
}
int r1 = ran(low,high);
int r2 = ran(low,high);
while(r2 == r1) r2 = ran(low,high);
p = primes[r1];
q = primes[r2];
}
int xgcd(int a, int b, int& x, int& y)
{
if(b == 0) { x = 1, y = 0; return a; }
int d = xgcd(b, a%b, x, y);
int t = x-(a/b)*y; x = y; y = t;
return d;
}
void rand_ed(int p, int q, int& e, int &d)
{
int n = (p-1)*(q-1);
int x, y;
for(e = 37; ; e += 2)
{
if(xgcd(e, n, x, y) == 1) break;
}
while(x > 0) x -= n;
while(x < 0) x += n;
d = x;
}
unsigned mul_mod(unsigned u, unsigned v, unsigned z)
{
if((u*v)/u == v) return (u*v)%z;
unsigned u0, v0, w0;
unsigned u1, v1, w1, w2, t;
u0 = u & 0xFFFF; u1 = u >> 16;
v0 = v & 0xFFFF; v1 = v >> 16;
w0 = u0*v0;
t = u1*v0 + (w0 >> 16);
w1 = t & 0xFFFF;
w2 = t >> 16;
w1 = u0*v1 + w1;
unsigned x = u1*v1 + w2 + (w1 >> 16);
unsigned y = u*v;
for (int i = 1; i <= 32; i++)
{
t = (int)x >> 31;
x = (x << 1) | (y >> 31);
y <<= 1;
if((x|t) >= z) { x -= z; y++; }
}
return x;
}
unsigned pow_mod(unsigned a, unsigned p, unsigned n)
{
unsigned k = 1;
while(p > 1)
{
if(p&1) k = mul_mod(k, a, n);
a = mul_mod(a, a, n); p >>= 1;
}
return mul_mod(k, a, n);
}
void rsa_rand(int& e, int& d, int& n)
{
int p, q;
rand_pq(p, q);
rand_ed(p, q, e, d);
n = p*q;
}
int rsa_encryp(int m, int ed, int n)
{
assert(m > 0 && m < n);
return pow_mod(m, ed, n);
}
int main(void)
{
int e, d, n;
int m = 119;
printf("RSA加密系统.\n\n");
rsa_rand(e, d, n);
printf("秘钥: e = %d, d = %d, n = %d\n\n", e, d, n);
while(1)
{
printf("\n输入数据M(0<M<n): ");
if(scanf("%d", &m) != 1) break;
if(m <= 0 || m >= n) break;
printf("加密前: %d\n", m);
printf("加密后: %d\n", m = rsa_encryp(m, e, n));
printf("解密后: %d\n", m = rsa_encryp(m, d, n));
}
return 0;
}