summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--eval.h1
-rw-r--r--hash.c5
-rw-r--r--lib.c99
-rw-r--r--lib.h7
-rw-r--r--struct.c50
-rw-r--r--txr.161
6 files changed, 189 insertions, 34 deletions
diff --git a/eval.h b/eval.h
index 33d20816..05f9061f 100644
--- a/eval.h
+++ b/eval.h
@@ -27,6 +27,7 @@
extern val dwim_s, lambda_s, vector_lit_s, vec_list_s, list_s;
extern val hash_lit_s, hash_construct_s, struct_lit_s, qref_s;
extern val eval_error_s;
+extern val eq_s, eql_s, equal_s;
extern val last_form_evaled, last_form_expanded;
noreturn val eval_error(val form, val fmt, ...);
diff --git a/hash.c b/hash.c
index 77de4dea..e844d5f3 100644
--- a/hash.c
+++ b/hash.c
@@ -173,6 +173,11 @@ static cnum equal_hash(val obj)
case FLNUM:
return hash_double(obj->fl.n);
case COBJ:
+ if (obj->co.ops->equalsub) {
+ val sub = obj->co.ops->equalsub(obj);
+ if (sub)
+ return equal_hash(sub);
+ }
return obj->co.ops->hash(obj) & NUM_MAX;
case RNG:
return (equal_hash(obj->rn.from)
diff --git a/lib.c b/lib.c
index 22df9bc7..0dc1b543 100644
--- a/lib.c
+++ b/lib.c
@@ -2066,16 +2066,16 @@ val equal(val left, val right)
case NIL:
case CHR:
case NUM:
- return nil;
+ break;
case CONS:
case LCONS:
- if ((type(right) == CONS || type(right) == LCONS) &&
- equal(car(left), car(right)) &&
- equal(cdr(left), cdr(right)))
+ if (type(right) == CONS || type(right) == LCONS)
{
- return t;
+ if (equal(car(left), car(right)) && equal(cdr(left), cdr(right)))
+ return t;
+ return nil;
}
- return nil;
+ break;
case LIT:
switch (type(right)) {
case LIT:
@@ -2085,10 +2085,12 @@ val equal(val left, val right)
case LSTR:
lazy_str_force(right);
return equal(left, right->ls.prefix);
- default:
+ case COBJ:
break;
+ default:
+ return nil;
}
- return nil;
+ break;
case STR:
switch (type(right)) {
case LIT:
@@ -2098,14 +2100,16 @@ val equal(val left, val right)
case LSTR:
lazy_str_force(right);
return equal(left, right->ls.prefix);
- default:
+ case COBJ:
break;
+ default:
+ return nil;
}
- return nil;
+ break;
case SYM:
case PKG:
case ENV:
- return right == left ? t : nil;
+ break;
case FUN:
if (type(right) == FUN &&
left->f.functype == right->f.functype &&
@@ -2129,7 +2133,7 @@ val equal(val left, val right)
}
return nil;
}
- return nil;
+ break;
case VEC:
if (type(right) == VEC) {
cnum i, length;
@@ -2142,7 +2146,7 @@ val equal(val left, val right)
}
return t;
}
- return nil;
+ break;
case LSTR:
switch (type(right)) {
case LIT:
@@ -2150,31 +2154,57 @@ val equal(val left, val right)
case LSTR:
lazy_str_force(left);
return equal(left->ls.prefix, right);
- default:
+ case COBJ:
break;
+ default:
+ return nil;
}
return nil;
case BGNUM:
- if (type(right) == BGNUM && mp_cmp(mp(left), mp(right)) == MP_EQ)
- return t;
- return nil;
+ if (type(right) == BGNUM) {
+ if (mp_cmp(mp(left), mp(right)) == MP_EQ)
+ return t;
+ return nil;
+ }
+ break;
case FLNUM:
- if (type(right) == FLNUM && left->fl.n == right->fl.n)
- return t;
- return nil;
+ if (type(right) == FLNUM) {
+ if (left->fl.n == right->fl.n)
+ return t;
+ return nil;
+ }
+ break;
case RNG:
- if (type(right) == RNG &&
- equal(from(left), from(right)) &&
- equal(to(left), to(right)))
- return t;
- return nil;
+ if (type(right) == RNG) {
+ if (equal(from(left), from(right)) &&
+ equal(to(left), to(right)))
+ return t;
+ return nil;
+ }
+ break;
case COBJ:
+ if (left->co.ops->equalsub) {
+ val lsub = left->co.ops->equalsub(left);
+ if (lsub)
+ return equal(lsub, right);
+ }
+
if (type(right) == COBJ && left->co.ops == right->co.ops)
return left->co.ops->equal(left, right);
+
return nil;
}
- internal_error("unhandled case in equal function");
+ if (type(right) != COBJ)
+ return nil;
+
+ if (right->co.ops->equalsub) {
+ val rsub = right->co.ops->equalsub(right);
+ if (rsub)
+ return equal(left, rsub);
+ }
+
+ return nil;
}
alloc_bytes_t malloc_bytes;
@@ -3769,9 +3799,26 @@ val less(val left, val right)
if (left == right)
return nil;
+tail:
l_type = type(left);
r_type = type(right);
+ if (l_type == COBJ && left->co.ops->equalsub) {
+ val lsub = left->co.ops->equalsub(left);
+ if (lsub) {
+ left = lsub;
+ goto tail;
+ }
+ }
+
+ if (r_type == COBJ && right->co.ops->equalsub) {
+ val rsub = right->co.ops->equalsub(right);
+ if (rsub) {
+ right = rsub;
+ goto tail;
+ }
+ }
+
switch (less_tab[l_type][r_type]) {
case less_false:
return nil;
diff --git a/lib.h b/lib.h
index 3b10ce6d..7b33c784 100644
--- a/lib.h
+++ b/lib.h
@@ -219,10 +219,15 @@ struct cobj_ops {
void (*destroy)(val self);
void (*mark)(val self);
cnum (*hash)(val self);
+ val (*equalsub)(val self);
};
#define cobj_ops_init(equal, print, destroy, mark, hash) \
- { equal, print, destroy, mark, hash }
+ { equal, print, destroy, mark, hash, 0 }
+
+#define cobj_ops_init_ex(equal, print, destroy, mark, hash, \
+ equalsub) \
+ { equal, print, destroy, mark, hash, equalsub }
/* Default operations for above structure.
* Default equal is eq
diff --git a/struct.c b/struct.c
index d4521314..5ccd68b7 100644
--- a/struct.c
+++ b/struct.c
@@ -55,6 +55,7 @@ struct struct_type {
cnum id;
cnum nslots;
cnum nstslots;
+ cnum eqmslot;
val super;
struct struct_type *super_handle;
val slots;
@@ -232,7 +233,7 @@ val make_struct_type(val name, val super,
st->name = name;
st->id = c_num(id);
- st->nslots = st->nstslots = 0;
+ st->nslots = st->nstslots = st->eqmslot = 0;
st->slots = all_slots;
st->super = super;
st->stslot = 0;
@@ -746,8 +747,11 @@ val static_slot_set(val stype, val sym, val newval)
if (symbolp(sym)) {
loc ptr = lookup_static_slot(stype, st, sym);
- if (!nullocp(ptr))
+ if (!nullocp(ptr)) {
+ if (st->eqmslot == -1)
+ st->eqmslot = 0;
return set(ptr, newval);
+ }
}
no_such_slot(self, stype, sym);
@@ -764,6 +768,9 @@ val static_slot_ensure(val stype, val sym, val newval, val no_error_p)
no_error_p = default_bool_arg(no_error_p);
+ if (st->eqmslot == -1)
+ st->eqmslot = 0;
+
{
loc ptr = lookup_static_slot(stype, st, sym);
if (!nullocp(ptr))
@@ -1004,11 +1011,44 @@ static cnum struct_inst_hash(val obj)
return out;
}
+static val get_equal_method(val stype, struct struct_type *st)
+{
+ if (st->eqmslot == -1) {
+ return nil;
+ } else if (st->eqmslot) {
+ return st->stslot[st->eqmslot];
+ } else {
+ loc ptr = lookup_static_slot(stype, st, equal_s);
+ if (!nullocp(ptr)) {
+ st->eqmslot = valptr(ptr) - st->stslot;
+ return deref(ptr);
+ }
+ st->eqmslot = -1;
+ return nil;
+ }
+}
+
+static val struct_inst_equalsub(val obj)
+{
+ struct struct_inst *si = coerce(struct struct_inst *, obj->co.handle);
+ struct struct_type *st = coerce(struct struct_type *, si->type->co.handle);
+ val equal_method = get_equal_method(obj, st);
+ if (equal_method) {
+ val sub = funcall1(equal_method, obj);
+ if (nilp(sub)) {
+ uw_throwf(error_s, lit("equal method on type ~s returned nil"),
+ si->type, nao);
+ }
+ return sub;
+ }
+ return nil;
+}
+
static_def(struct cobj_ops struct_type_ops =
cobj_ops_init(eq, struct_type_print, struct_type_destroy,
struct_type_mark, cobj_hash_op))
static_def(struct cobj_ops struct_inst_ops =
- cobj_ops_init(struct_inst_equal, struct_inst_print,
- cobj_destroy_free_op, struct_inst_mark,
- struct_inst_hash))
+ cobj_ops_init_ex(struct_inst_equal, struct_inst_print,
+ cobj_destroy_free_op, struct_inst_mark,
+ struct_inst_hash, struct_inst_equalsub))
diff --git a/txr.1 b/txr.1
index d6b128ef..067e9b85 100644
--- a/txr.1
+++ b/txr.1
@@ -13607,8 +13607,10 @@ The
.code equal
function is less strict still than
.codn eql .
-In general, it recurses into some
-kinds of aggregate objects to perform a structural equivalence check.
+In general, it recurses into some kinds of aggregate objects to perform a
+structural equivalence check. For struct types, it also supports customization
+via equality substitution. See the Equality Substitution section under
+Structures.
Firstly, if
.meta left-obj
@@ -13830,6 +13832,15 @@ correct behavior when
is regarded as an empty list, since the empty list is lexicographically prior to
a nonempty list.
+If either argument is a structure for which the
+.code equal
+method is defined, the method is invoked on that argument, and the
+value returned is used in place of that argument for performing
+the comparison. Structures with no
+.code equal
+method cannot participate in a comparison, resulting in an error.
+See the Equality Substitution section under Structures.
+
Finally, if either of the arguments has a type other than the above discussed
types, the situation is an error.
@@ -18201,6 +18212,47 @@ is evaluated only once:
(mapcar s list) <--> (mapcar (meth s lambda) list)
.cble
+.NP* Equality Substitution
+
+Normally, two struct values are not considered the same under the
+.code equal
+function unless they are the same objects.
+
+However, if a method named
+.code equal
+is defined for a structure type, via a static slot, then instances of
+that structure type support
+.IR "equality substitution" .
+
+The
+.code equal
+method must take exactly one argument: the structure object.
+Moreover, the method must never return
+.codn nil .
+
+When a struct which supports equality substitution is compared using
+.codn equal ,
+.codn less
+or
+.codn greater ,
+its
+.code equal
+method is invoked, and the return value is used in place of that
+structure for the purposes of the comparison.
+
+The same applies when an struct is hashed using the
+.code hash-equal
+function, or implicitly by an
+.code :equal-hash
+hash tables.
+
+Note: if an
+.code equal
+method is defined or redefined with different semantics for a struct
+type whose instances have already been inserted as keys in an
+.code :equal-based
+hash table, searches for those keys will not work reliably.
+
.coNP Macro @ defstruct
.synb
.mets (defstruct >> { name | >> ( name << arg *)} < super
@@ -29473,6 +29525,11 @@ each produce the same integer hash value. In all other
circumstances, the hash values of two distinct objects are unrelated, and
may or may not be the same.
+Object of struct type may support custom hashing by way of defining
+an equality substitution via an
+.code equal
+method. See the Equality Substitution section under Structures.
+
.coNP Functions @, hash_keys @, hash_values @ hash_pairs and @ hash_alist
.synb
.mets (hash-keys << hash )