Skip to content

Commit

Permalink
Add frozen minhash
Browse files Browse the repository at this point in the history
  • Loading branch information
fluhus committed Aug 18, 2024
1 parent d7785c9 commit 10db2dc
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
10 changes: 9 additions & 1 deletion heaps/heaps.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// This package provides better run speeds than the standard [heap] package.
package heaps

import "golang.org/x/exp/constraints"
import (
"golang.org/x/exp/constraints"
"golang.org/x/exp/slices"
)

// Heap is a generic heap.
type Heap[T any] struct {
Expand Down Expand Up @@ -145,3 +148,8 @@ func (h *Heap[T]) Fix(i int) {
}
}
}

// Clip removes unused capacity from the heap.
func (h *Heap[T]) Clip() {
h.a = slices.Clip(h.a)
}
34 changes: 32 additions & 2 deletions minhash/minhash.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func New[T constraints.Integer](k int) *MinHash[T] {
// already exist, and there are less than k elements lesser than x.
// Returns true if x was added and false if not.
func (mh *MinHash[T]) Push(x T) bool {
if mh.frozen() {
panic("called Push on a frozen MinHash")
}
mh.n++
if mh.h.Len() == mh.k && x >= mh.h.Head() {
// x is too large.
Expand Down Expand Up @@ -67,6 +70,9 @@ func (mh *MinHash[T]) N() int {

// View returns the underlying slice of values.
func (mh *MinHash[T]) View() []T {
if mh.frozen() {
return slices.Clone(mh.h.View())
}
return mh.h.View()
}

Expand Down Expand Up @@ -101,10 +107,10 @@ func (mh *MinHash[T]) UnmarshalJSON(b []byte) error {
// in min-hash terms.
func (mh *MinHash[T]) intersect(other *MinHash[T]) (int, int) {
a, b := mh.View(), other.View()
if !slices.IsSortedFunc(a, snm.CompareReverse) {
if !mh.frozen() && !slices.IsSortedFunc(a, snm.CompareReverse) {
panic("receiver is not sorted")
}
if !slices.IsSortedFunc(b, snm.CompareReverse) {
if !other.frozen() && !slices.IsSortedFunc(b, snm.CompareReverse) {
panic("other is not sorted")
}
intersection := 0
Expand Down Expand Up @@ -147,5 +153,29 @@ func (mh *MinHash[T]) SoftJaccard(other *MinHash[T]) float64 {
// Sort sorts the collection, making it ready for Jaccard calculation.
// The collection is still valid after calling Sort.
func (mh *MinHash[T]) Sort() {
if mh.frozen() {
panic("called Sort on a frozen MinHash " +
"(frozen instances are already sorted)")
}
slices.SortFunc(mh.h.View(), snm.CompareReverse)
}

// Frozen returns an immutable version of this instance.
// The original instance is unchanged.
//
// Frozen instances are sorted, take up less memory
// and calculate Jaccard faster.
// Calls to View are slower because the data is cloned.
func (mh *MinHash[T]) Frozen() *MinHash[T] {
h := heaps.Max[T]()
h.PushSlice(mh.View())
h.Clip()
slices.SortFunc(h.View(), snm.CompareReverse)
result := &MinHash[T]{h, nil, mh.k, mh.n}
return result
}

// Returns whether this minhash is frozen.
func (mh *MinHash[T]) frozen() bool {
return mh.s == nil
}
47 changes: 47 additions & 0 deletions minhash/minhash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,53 @@ func FuzzCollection(f *testing.F) {
})
}

func TestFrozen(t *testing.T) {
mh := New[int](3)
mh.Push(27872)
mh.Push(16978)
mh.Push(28696)
mh.Sort()

fr := mh.Frozen()
if !slices.Equal(mh.View(), fr.View()) {
t.Fatalf("View()=%v, want %v", fr.View(), mh.View())
}

mh2 := New[int](3)
mh.Push(27872)
mh.Push(16978)
mh.Push(28697)
mh2.Sort()

want := mh.Jaccard(mh2)
got := fr.Jaccard(mh2.Frozen())
if got != want {
t.Fatalf("Jaccard=%v, want %v", got, want)
}
}

func TestFrozen_modifySort(t *testing.T) {
mh := New[int](1)
mh.Push(27872)
mh = mh.Frozen()
defer func() {
recover()
}()
mh.Sort()
t.Fatalf(".Frozen().Sort() succeeded, want panic")
}

func TestFrozen_modifyPush(t *testing.T) {
mh := New[int](1)
mh.Push(27872)
mh = mh.Frozen()
defer func() {
recover()
}()
mh.Push(123)
t.Fatalf(".Frozen().Sort() succeeded, want panic")
}

func BenchmarkPush(b *testing.B) {
nums := rand.Perm(b.N)
mh := New[int](b.N)
Expand Down

0 comments on commit 10db2dc

Please sign in to comment.