summaryrefslogtreecommitdiffstats
path: root/arith.c
diff options
context:
space:
mode:
authorKaz Kylheku <kaz@kylheku.com>2019-03-25 06:41:10 -0700
committerKaz Kylheku <kaz@kylheku.com>2019-03-25 06:41:10 -0700
commitc292315b720f9b5ba54206b71eeb98e2ee078f8f (patch)
tree7a6ef4c3a79e2648409829d06e5c1ed84d964072 /arith.c
parent789f20b3a472e41817d3977fed2203c2971058a8 (diff)
downloadtxr-c292315b720f9b5ba54206b71eeb98e2ee078f8f.tar.gz
txr-c292315b720f9b5ba54206b71eeb98e2ee078f8f.tar.bz2
txr-c292315b720f9b5ba54206b71eeb98e2ee078f8f.zip
Move numeric functions from lib.c to arith.c
* arith.c, lib.c (num, c_num, c_fixnum, bad_float, flo, c_flo, fixnump, bignump, integerp, floatp, numberp nary_op, nary_simple_op, plusv, minusv, mulv, divv, logandv, logiorv, gtv, ltv, gev, lev, numeqv, numneqv, sumv, prod, exptv, gcdv, lcmv): Function definitions moved from lib.c to arith.c. (nary_op_keyfun, unary_num, unary_arith, unary_int, sumv, prodv, rexpt, abso_self): Static functions moved from libmoved from lib.c to arith.c. (max): New macro in arith.c. (arith_init): Registrations moved from eval.c. * eval.c (eval_init): Registrations moved to arith.c
Diffstat (limited to 'arith.c')
-rw-r--r--arith.c445
1 files changed, 445 insertions, 0 deletions
diff --git a/arith.c b/arith.c
index 26a794c2..05814e37 100644
--- a/arith.c
+++ b/arith.c
@@ -27,6 +27,7 @@
#include <stdio.h>
#include <stdlib.h>
+#include <stddef.h>
#include <string.h>
#include <wctype.h>
#include <stdarg.h>
@@ -49,6 +50,8 @@
#include "txr.h"
#include "arith.h"
+#define max(a, b) ((a) > (b) ? (a) : (b))
+
#define TAG_PAIR(A, B) ((A) << TAG_SHIFT | (B))
#define NOOP(A, B)
#define CNUM_BIT ((int) sizeof (cnum) * CHAR_BIT)
@@ -3416,6 +3419,417 @@ static val flo_set_round_mode(val mode)
#endif
+val num(cnum n)
+{
+ return (n >= NUM_MIN && n <= NUM_MAX) ? num_fast(n) : bignum(n);
+}
+
+cnum c_num(val n)
+{
+ switch (type(n)) {
+ case CHR: case NUM:
+ return coerce(cnum, n) >> TAG_SHIFT;
+ case BGNUM:
+ if (mp_in_intptr_range(mp(n))) {
+ int_ptr_t out;
+ mp_get_intptr(mp(n), &out);
+ return out;
+ }
+ uw_throwf(error_s, lit("~s is out of allowed range [~s, ~s]"),
+ n, num(INT_PTR_MIN), num(INT_PTR_MAX), nao);
+ default:
+ type_mismatch(lit("~s is not an integer"), n, nao);
+ }
+}
+
+cnum c_fixnum(val num, val self)
+{
+ switch (type(num)) {
+ case CHR: case NUM:
+ return coerce(cnum, num) >> TAG_SHIFT;
+ default:
+ type_mismatch(lit("~a: ~s is not fixnum integer or character"),
+ self, num, nao);
+ }
+}
+
+#if HAVE_FPCLASSIFY
+INLINE int bad_float(double d)
+{
+ switch (fpclassify(d)) {
+ case FP_ZERO:
+ case FP_NORMAL:
+ case FP_SUBNORMAL:
+ return 0;
+ default:
+ return 1;
+ }
+}
+#else
+#define bad_float(d) (0)
+#endif
+
+val flo(double n)
+{
+ if (bad_float(n)) {
+ uw_throw(numeric_error_s, lit("out-of-range floating-point result"));
+ } else {
+ val obj = make_obj();
+ obj->fl.type = FLNUM;
+ obj->fl.n = n;
+ return obj;
+ }
+}
+
+double c_flo(val num, val self)
+{
+ type_check(self, num, FLNUM);
+ return num->fl.n;
+}
+
+val fixnump(val num)
+{
+ return (is_num(num)) ? t : nil;
+}
+
+val bignump(val num)
+{
+ return (type(num) == BGNUM) ? t : nil;
+}
+
+val integerp(val num)
+{
+ switch (tag(num)) {
+ case TAG_NUM:
+ return t;
+ case TAG_PTR:
+ if (num == nil)
+ return nil;
+ if (num->t.type == BGNUM)
+ return t;
+ /* fallthrough */
+ default:
+ return nil;
+ }
+}
+
+val floatp(val num)
+{
+ return (type(num) == FLNUM) ? t : nil;
+}
+
+val numberp(val num)
+{
+ switch (tag(num)) {
+ case TAG_NUM:
+ return t;
+ case TAG_PTR:
+ if (num == nil)
+ return nil;
+ if (num->t.type == BGNUM || num->t.type == FLNUM)
+ return t;
+ /* fallthrough */
+ default:
+ return nil;
+ }
+}
+
+val nary_op(val self, val (*bfun)(val, val),
+ val (*ufun)(val self, val),
+ struct args *args, val emptyval)
+{
+ val acc, next;
+ cnum index = 0;
+
+ if (!args_more(args, index))
+ return emptyval;
+
+ acc = args_get(args, &index);
+
+ if (!args_more(args, index))
+ return ufun(self, acc);
+
+ do {
+ next = args_get(args, &index);
+ acc = bfun(acc, next);
+ } while (args_more(args, index));
+
+ return acc;
+}
+
+static val nary_op_keyfun(val self, val (*bfun)(val, val),
+ val (*ufun)(val self, val),
+ struct args *args, val emptyval,
+ val keyfun)
+{
+ val acc, next;
+ cnum index = 0;
+
+ if (!args_more(args, index))
+ return emptyval;
+
+ acc = funcall1(keyfun, args_get(args, &index));
+
+ if (!args_more(args, index))
+ return ufun(self, acc);
+
+ do {
+ next = funcall1(keyfun, args_get(args, &index));
+ acc = bfun(acc, next);
+ } while (args_more(args, index));
+
+ return acc;
+}
+
+
+val nary_simple_op(val self, val (*bfun)(val, val),
+ struct args *args, val firstval)
+{
+ val acc = firstval, next;
+ cnum index = 0;
+
+ while (args_more(args, index)) {
+ next = args_get(args, &index);
+ acc = bfun(acc, next);
+ }
+
+ return acc;
+}
+
+static val unary_num(val self, val arg)
+{
+ if (!numberp(arg))
+ uw_throwf(error_s, lit("~a: ~s isn't a number"), self, arg, nao);
+ return arg;
+}
+
+static val unary_arith(val self, val arg)
+{
+ switch (type(arg)) {
+ case NUM:
+ case CHR:
+ case BGNUM:
+ case FLNUM:
+ return arg;
+ default:
+ uw_throwf(error_s, lit("~a: invalid argument ~s"), self, arg, nao);
+ }
+}
+
+static val unary_int(val self, val arg)
+{
+ if (!integerp(arg))
+ uw_throwf(error_s, lit("~a: ~s isn't an integer"), self, arg, nao);
+ return arg;
+}
+
+val plusv(struct args *nlist)
+{
+ return nary_op(lit("+"), plus, unary_arith, nlist, zero);
+}
+
+val minusv(val minuend, struct args *nlist)
+{
+ val acc = minuend, next;
+ cnum index = 0;
+
+ if (!args_more(nlist, index))
+ return neg(acc);
+
+ do {
+ next = args_get(nlist, &index);
+ acc = minus(acc, next);
+ } while (args_more(nlist, index));
+
+ return acc;
+}
+
+val mulv(struct args *nlist)
+{
+ return nary_op(lit("*"), mul, unary_num, nlist, one);
+}
+
+val divv(val dividend, struct args *nlist)
+{
+ val acc = dividend, next;
+ cnum index = 0;
+
+ if (!args_more(nlist, index))
+ return divi(one, acc);
+
+ do {
+ next = args_get(nlist, &index);
+ acc = divi(acc, next);
+ } while (args_more(nlist, index));
+
+ return acc;
+}
+
+val logandv(struct args *nlist)
+{
+ return nary_op(lit("logand"), logand, unary_int, nlist, negone);
+}
+
+val logiorv(struct args *nlist)
+{
+ return nary_op(lit("logior"), logior, unary_int, nlist, zero);
+}
+
+val gtv(val first, struct args *rest)
+{
+ cnum index = 0;
+
+ while (args_more(rest, index)) {
+ val elem = args_get(rest, &index);
+ if (!gt(first, elem))
+ return nil;
+ first = elem;
+ }
+
+ if (index == 0)
+ (void) unary_arith(lit(">"), first);
+
+ return t;
+}
+
+val ltv(val first, struct args *rest)
+{
+ cnum index = 0;
+
+ while (args_more(rest, index)) {
+ val elem = args_get(rest, &index);
+ if (!lt(first, elem))
+ return nil;
+ first = elem;
+ }
+
+ if (index == 0)
+ (void) unary_arith(lit("<"), first);
+
+ return t;
+}
+
+val gev(val first, struct args *rest)
+{
+ cnum index = 0;
+
+ while (args_more(rest, index)) {
+ val elem = args_get(rest, &index);
+ if (!ge(first, elem))
+ return nil;
+ first = elem;
+ }
+
+ if (index == 0)
+ (void) unary_arith(lit(">="), first);
+
+ return t;
+}
+
+val lev(val first, struct args *rest)
+{
+ cnum index = 0;
+
+ while (args_more(rest, index)) {
+ val elem = args_get(rest, &index);
+ if (!le(first, elem))
+ return nil;
+ first = elem;
+ }
+
+ if (index == 0)
+ (void) unary_arith(lit("<="), first);
+
+ return t;
+}
+
+val numeqv(val first, struct args *rest)
+{
+ cnum index = 0;
+
+ while (args_more(rest, index)) {
+ val elem = args_get(rest, &index);
+ if (!numeq(first, elem))
+ return nil;
+ first = elem;
+ }
+
+ if (index == 0)
+ (void) unary_arith(lit("="), first);
+
+ return t;
+}
+
+val numneqv(struct args *args)
+{
+ val i, j;
+ val list = args_get_list(args);
+
+ if (list && !cdr(list)) {
+ (void) unary_arith(lit("/="), car(list));
+ return t;
+ }
+
+ for (i = list; i; i = cdr(i))
+ for (j = cdr(i); j; j = cdr(j))
+ if (numeq(car(i), car(j)))
+ return nil;
+
+ return t;
+}
+
+static val sumv(struct args *nlist, val keyfun)
+{
+ return nary_op_keyfun(lit("+"), plus, unary_arith, nlist, zero, keyfun);
+}
+
+val sum(val seq, val keyfun)
+{
+ args_decl_list(args, ARGS_MIN, tolist(seq));
+ return if3(missingp(keyfun), plusv(args), sumv(args, keyfun));
+}
+
+static val prodv(struct args *nlist, val keyfun)
+{
+ return nary_op_keyfun(lit("*"), mul, unary_num, nlist, one, keyfun);
+}
+
+val prod(val seq, val keyfun)
+{
+ args_decl_list(args, ARGS_MIN, tolist(seq));
+ return if3(missingp(keyfun), mulv(args), prodv(args, keyfun));
+}
+
+static val rexpt(val right, val left)
+{
+ return expt(left, right);
+}
+
+val exptv(struct args *nlist)
+{
+ cnum nargs = args_count(nlist);
+ args_decl(rnlist, max(ARGS_MIN, nargs));
+ args_copy_reverse(rnlist, nlist, nargs);
+ return nary_op(lit("expt"), rexpt, unary_num, rnlist, one);
+}
+
+static val abso_self(val self, val arg)
+{
+ (void) self;
+ return abso(arg);
+}
+
+val gcdv(struct args *nlist)
+{
+ return nary_op(lit("gcd"), gcd, abso_self, nlist, zero);
+}
+
+val lcmv(struct args *nlist)
+{
+ return nary_op(lit("lcm"), lcm, abso_self, nlist, zero);
+}
+
+
void arith_init(void)
{
log2_init();
@@ -3457,6 +3871,24 @@ void arith_init(void)
reg_fun(intern(lit("abs"), user_package), func_n1(abso));
reg_fun(intern(lit("trunc"), user_package), func_n2o(trunc, 1));
reg_fun(intern(lit("mod"), user_package), func_n2(mod));
+ reg_fun(intern(lit("zerop"), user_package), func_n1(zerop));
+ reg_fun(intern(lit("nzerop"), user_package), func_n1(nzerop));
+ reg_fun(intern(lit("plusp"), user_package), func_n1(plusp));
+ reg_fun(intern(lit("minusp"), user_package), func_n1(minusp));
+ reg_fun(intern(lit("evenp"), user_package), func_n1(evenp));
+ reg_fun(intern(lit("oddp"), user_package), func_n1(oddp));
+ reg_fun(intern(lit("succ"), user_package), func_n1(succ));
+ reg_fun(intern(lit("ssucc"), user_package), func_n1(ssucc));
+ reg_fun(intern(lit("sssucc"), user_package), func_n1(sssucc));
+ reg_fun(intern(lit("pred"), user_package), func_n1(pred));
+ reg_fun(intern(lit("ppred"), user_package), func_n1(ppred));
+ reg_fun(intern(lit("pppred"), user_package), func_n1(pppred));
+ reg_fun(intern(lit(">"), user_package), func_n1v(gtv));
+ reg_fun(intern(lit("<"), user_package), func_n1v(ltv));
+ reg_fun(intern(lit(">="), user_package), func_n1v(gev));
+ reg_fun(intern(lit("<="), user_package), func_n1v(lev));
+ reg_fun(intern(lit("="), user_package), func_n1v(numeqv));
+ reg_fun(intern(lit("/="), user_package), func_n0v(numneqv));
reg_fun(intern(lit("wrap"), user_package), func_n3(wrap));
reg_fun(intern(lit("wrap*"), user_package), func_n3(wrap_star));
reg_fun(intern(lit("/"), user_package), func_n1v(divv));
@@ -3485,6 +3917,19 @@ void arith_init(void)
reg_fun(intern(lit("log2"), user_package), func_n1(logtwo));
reg_fun(intern(lit("exp"), user_package), func_n1(expo));
reg_fun(intern(lit("sqrt"), user_package), func_n1(sqroot));
+ reg_fun(intern(lit("logand"), user_package), func_n0v(logandv));
+ reg_fun(intern(lit("logior"), user_package), func_n0v(logiorv));
+ reg_fun(intern(lit("logxor"), user_package),
+ func_n2(if3(opt_compat && opt_compat <= 202, logxor_old, logxor)));
+ reg_fun(intern(lit("logtest"), user_package), func_n2(logtest));
+ reg_fun(intern(lit("lognot"), user_package), func_n2o(lognot, 1));
+ reg_fun(intern(lit("logtrunc"), user_package), func_n2(logtrunc));
+ reg_fun(intern(lit("sign-extend"), user_package), func_n2(sign_extend));
+ reg_fun(intern(lit("ash"), user_package), func_n2(ash));
+ reg_fun(intern(lit("bit"), user_package), func_n2(bit));
+ reg_fun(intern(lit("mask"), user_package), func_n0v(maskv));
+ reg_fun(intern(lit("width"), user_package), func_n1(width));
+ reg_fun(intern(lit("logcount"), user_package), func_n1(logcount));
reg_fun(intern(lit("cum-norm-dist"), user_package), func_n1(cum_norm_dist));
reg_fun(intern(lit("inv-cum-norm"), user_package), func_n1(inv_cum_norm));
reg_fun(intern(lit("n-choose-k"), user_package), func_n2(n_choose_k));