summaryrefslogtreecommitdiffstats
path: root/rand.c
diff options
context:
space:
mode:
authorKaz Kylheku <kaz@kylheku.com>2011-12-23 21:57:33 -0800
committerKaz Kylheku <kaz@kylheku.com>2011-12-23 21:57:33 -0800
commitfb2f85cff8dab72c101c48f48c69ffc6ef35204d (patch)
treeffffed3ed085c9363f5795434346f2591e8850d2 /rand.c
parentf7a72e6e10d5f69ff9bb55ebf480e99e65727e6e (diff)
downloadtxr-fb2f85cff8dab72c101c48f48c69ffc6ef35204d.tar.gz
txr-fb2f85cff8dab72c101c48f48c69ffc6ef35204d.tar.bz2
txr-fb2f85cff8dab72c101c48f48c69ffc6ef35204d.zip
* arith.c (highest_bit): Changing to external linkage.
* arith.h (highest_bit): Declared. * rand.c (random): Rewrote using different algorithm which ensures even distribution, and avoids doing a bignum mod operation.
Diffstat (limited to 'rand.c')
-rw-r--r--rand.c67
1 files changed, 48 insertions, 19 deletions
diff --git a/rand.c b/rand.c
index c48ff191..073aa62e 100644
--- a/rand.c
+++ b/rand.c
@@ -167,41 +167,70 @@ val random(val state, val modulus)
random_state_s);
if (bignump(modulus)) {
mp_int *m = mp(modulus);
- int digits = USED(m);
- int bits = digits * MP_DIGIT_BIT;
- int bits_needed = bits + 32;
- int rands_needed = (bits_needed + 32 - 1) / 32;
+ int bits = mp_count_bits(m);
+ int rands_needed = (bits + 32 - 1) / 32;
+ int msb_rand_bits = bits % 32;
+ rand32_t msb_rand_mask = ((rand32_t) -1) >> (32 - msb_rand_bits);
val out = make_bignum();
mp_int *om = mp(out);
- int i, err;
- for (i = 0; i < rands_needed; i++) {
- rand32_t rnd = rand32(r);
+ for (;;) {
+ int i;
+ for (i = 0; i < rands_needed; i++) {
+ rand32_t rnd = rand32(r);
#if MP_DIGIT_SIZE >= 4
- if (i > 0)
- mp_mul_2d(om, 32, om);
- mp_add_d(om, rnd, om);
+ if (i > 0)
+ mp_mul_2d(om, 32, om);
+ else
+ rnd &= msb_rand_mask;
+ mp_add_d(om, rnd, om);
#else
- if (i > 0)
+ if (i > 0)
+ mp_mul_2d(om, 16, om);
+ else
+ rnd &= msb_rand_mask;
+ mp_add_d(om, rnd & 0xFFFF, om);
mp_mul_2d(om, 16, om);
- mp_add_d(om, rnd & 0xFFFF, om);
- mp_mul_2d(om, 16, om);
- mp_add_d(om, rnd >> 16, om);
+ mp_add_d(om, rnd >> 16, om);
#endif
+ }
+ if (mp_cmp(om, m) != MP_LT) {
+ mp_zero(om);
+ continue;
+ }
+ break;
}
- err = mp_mod(om, m, om);
- if (err != MP_OKAY)
- goto invalid;
+
return normalize(out);
} else if (fixnump(modulus)) {
cnum m = c_num(modulus);
+ int bits = highest_bit(m);
+#if SIZEOF_PTR >= 8
+ int rands_needed = (bits + 32 - 1) / 32;
+#endif
+ int msb_rand_bits = bits % 32;
+ rand32_t msb_rand_mask = ((rand32_t) -1) >> (32 - msb_rand_bits);
if (m <= 0)
goto invalid;
+ for (;;) {
+ cnum out = 0;
+ int i;
+
#if SIZEOF_PTR >= 8
- return num(((((cnum) rand32(r) & 0x7FFFFFFF) << 32) | rand32(r)) % m);
+ for (i = 0; i < rands_needed; i++) {
+ rand32_t rnd = rand32(r);
+ out <<= 32;
+ if (i == 0)
+ rnd &= msb_rand_mask;
+ out |= rnd;
+ }
#else
- return num(rand32(r) % m);
+ out = rand32(r) & msb_rand_mask;
#endif
+ if (out >= m)
+ continue;
+ return num(out);
+ }
}
invalid:
uw_throwf(numeric_error_s, lit("random: invalid modulus ~s"),