Skip to content

Commit

Permalink
Merge pull request #174 from aznt00th/fix/hexagonalDistanceOne
Browse files Browse the repository at this point in the history
fix: Make adjacent nodes in hex have dist 1
  • Loading branch information
JustGlowing authored Nov 24, 2023
2 parents d594686 + 0c59791 commit 6d1f8ef
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions minisom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
power, exp, zeros, ones, arange, outer, meshgrid, dot,
logical_and, mean, cov, argsort, linspace, transpose,
einsum, prod, nan, sqrt, hstack, diff, argmin, multiply,
nanmean, nansum, tile, array_equal)
nanmean, nansum, tile, array_equal, isclose)
from numpy.linalg import norm
from collections import defaultdict, Counter
from warnings import warn
Expand Down Expand Up @@ -88,6 +88,8 @@ def asymptotic_decay(learning_rate, t, max_iter):


class MiniSom(object):
Y_HEX_CONV_FACTOR = (3.0 / 2.0) / sqrt(3)

def __init__(self, x, y, input_len, sigma=1.0, learning_rate=0.5,
decay_function=asymptotic_decay,
neighborhood_function='gaussian', topology='rectangular',
Expand Down Expand Up @@ -183,6 +185,7 @@ def euclidean(x, w):
self._yy = self._yy.astype(float)
if topology == 'hexagonal':
self._xx[::-2] -= 0.5
self._yy *= self.Y_HEX_CONV_FACTOR
if neighborhood_function in ['triangle']:
warn('triangle neighborhood function does not ' +
'take in account hexagonal topology')
Expand Down Expand Up @@ -569,9 +572,8 @@ def _topographic_error_hexagonal(self, data):
self._get_euclidean_coordinates_from_index(bmu[1])]
for bmu in b2mu_inds]
b2mu_coords = array(b2mu_coords)
b2mu_neighbors = [(bmu1 >= bmu2-1) & ((bmu1 <= bmu2+1))
b2mu_neighbors = [isclose(1, norm(bmu1 - bmu2))
for bmu1, bmu2 in b2mu_coords]
b2mu_neighbors = [neighbors.prod() for neighbors in b2mu_neighbors]
te = 1 - mean(b2mu_neighbors)
return te

Expand Down Expand Up @@ -641,6 +643,15 @@ def setUp(self):
self.som._weights = zeros((5, 5, 1)) # fake weights
self.som._weights[2, 3] = 5.0
self.som._weights[1, 1] = 2.0
self.hex_som = MiniSom(5, 5, 1, topology='hexagonal')
for i in range(5):
for j in range(5):
# checking weights normalization
assert_almost_equal(1.0, linalg.norm(
self.hex_som._weights[i, j]))
self.hex_som._weights = zeros((5, 5, 1)) # fake weights
self.hex_som._weights[2, 3] = 5.0
self.hex_som._weights[1, 1] = 2.0

def test_decay_function(self):
assert self.som._decay_function(1., 2., 3.) == 1./(1.+2./(3./2))
Expand Down Expand Up @@ -765,20 +776,24 @@ def test_topographic_error(self):
assert self.som.topographic_error([[5]]) == 0.0
assert self.som.topographic_error([[15]]) == 1.0

self.som.topology = 'hexagonal'
# 10 will have bmu_1 in (0, 4) and bmu_2 in (1, 3)
# which are in the same neighborhood on a hexagonal grid
self.som._weights[0, 4] = 10.0
self.som._weights[1, 3] = 9.0
def test_hexagonal_topographic_error(self):
self.hex_som._weights[2, 4] = 6.0
# # 15 will have bmu_1 in (4, 4) and bmu_2 in (0, 0)
# # which are not in the same neighborhood
self.hex_som._weights[4, 4] = 15.0
self.hex_som._weights[0, 0] = 14.
self.hex_som._weights[0, 4] = 10.0
self.hex_som._weights[1, 3] = 9.0
# 3 will have bmu_1 in (2, 0) and bmu_2 in (1, 1)
# which are in the same neighborhood on a hexagonal grid
self.som._weights[2, 0] = 3.0
assert self.som.topographic_error([[10]]) == 0.0
assert self.som.topographic_error([[3]]) == 0.0
self.hex_som._weights[2, 0] = 3.0
assert self.hex_som.topographic_error([[10]]) == 0.0
# (2,0) and (1,1) are not neighbours in hex,
# the neigbours of (2,0) are: (1,0), (2,1) and (3,0)
assert self.hex_som.topographic_error([[3]]) == 1.0
# True for both hexagonal and rectangular grids
assert self.som.topographic_error([[5]]) == 0.0
assert self.som.topographic_error([[15]]) == 1.0
self.som.topology = 'rectangular'
assert self.hex_som.topographic_error([[5]]) == 0.0
assert self.hex_som.topographic_error([[15]]) == 1.0

def test_quantization(self):
q = self.som.quantization(array([[4], [2]]))
Expand Down

0 comments on commit 6d1f8ef

Please sign in to comment.