Skip to content

Commit

Permalink
Neuralized K-Means
Browse files Browse the repository at this point in the history
Add KMeans layer

- receives cluster centroids and
- computes cluster assignments for given inputs

Neuralized K-Means

- documentation in numpydoc format
- pylint + flake8 stuff
- KMeansCanonizer
- NeuralizedKMeans layer
- LogMeanExpPool layer
- Distance layer
- Distance type
  • Loading branch information
jacobkauffmann committed Aug 17, 2023
1 parent c30e3cc commit cbb350e
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 1 deletion.
101 changes: 100 additions & 1 deletion src/zennit/canonizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
'''Functions to produce a canonical form of models fit for LRP'''
from abc import ABCMeta, abstractmethod

import copy
import torch

from .core import collect_leaves
from .types import Linear, BatchNorm, ConvolutionTranspose
from .types import Linear, BatchNorm, ConvolutionTranspose, Distance
from .layer import NeuralizedKMeans, LogMeanExpPool


class Canonizer(metaclass=ABCMeta):
Expand Down Expand Up @@ -329,3 +331,100 @@ def register(self):

def remove(self):
'''Remove this Canonizer. Nothing to do for a CompositeCanonizer.'''


class KMeansCanonizer(Canonizer):
'''Canonizer for k-means.
This canonizer replaces a :py:obj:`Distance` layer with power 2 with a :py:obj:`NeuralizedKMeans` layer followed by
a :py:obj:`LogMeanExpPool`
Parameters
----------
beta : float
stiffness of the :py:obj:`LogMeanExpPool` layer. Should be smaller than 0 in order to approximate the min
function. Default is -1.
Examples
--------
>>> from sklearn.cluster import KMeans
>>> centroids = KMeans(n_clusters=10).fit(X).cluster_centers_
>>> model = torch.nn.Sequential(Distance(torch.from_numpy(centroids).float(), power=2))
>>> cluster_assignment = model(x).argmin()
>>> canonizer = KMeansCanonizer(beta=-1.)
>>> with Gradient(model, canonizer=[canonizer]) as attributor:
>>> output, attribution = attributor(x, torch.eye(len(centroids))[[cluster_assignment]])
'''
def __init__(self, beta=-1.):
self.distance = None
self.distance_unchanged = None
self.beta = beta
self.parent_module = None
self.child_name = None

def apply(self, root_module):
'''Apply this canonizer recursively on all applicable modules.
Iterates over all modules of the root module and applies this canonizer to all :py:obj:`Distance` layers with
power 2.
Parameters
----------
root_module : :py:obj:`torch.nn.Module`
Root module containing a :py:obj:`Distance` layer with power 2 as a submodule.
'''
instances = []

for full_name, module in root_module.named_modules():
if isinstance(module, Distance) and module.power == 2:
instance = self.copy()
if '.' in full_name:
parent_name, child_name = full_name.rsplit('.', 1)
parent_module = getattr(root_module, parent_name)
else:
parent_module = root_module
child_name = full_name

instance.parent_module = parent_module
instance.child_name = child_name

instance.register(module)
instances.append(instance)

return instances

def register(self, distance_module):
'''Register the :py:obj:`Distance` layer and replace it with a :py:obj:`NeuralizedKMeans` layer followed by a
:py:obj:`LogMeanExpPool` layer.
compute :math:`w_{ck} = 2(\\mathbf{\\mu}_c - \\mathbf{\\mu}_k)` and :math:`b_{ck} = \\|\\mathbf{\\mu}_k\\|^2 -
\\|\\mathbf{\\mu}_c\\|^2`. Weights are stored in a tensor :math:`W \\in \\mathbb{R}^{K \\times (K - 1)
\\times D}` and biases in a vector :math:`b \\in \\mathbb{R}^{K \\times (K - 1)}`.
A :py:obj:`NeuralizedKMeans` layer is created with these weights and biases. The :py:obj:`LogMeanExpPool` layer
is created with the beta value supplied to the constructor.
Parameters
----------
distance_module : list of :py:obj:`Distance`
Distance layers to replace.
'''
self.distance = distance_module
self.distance_unchanged = copy.deepcopy(self.distance)

n_clusters, n_dims = self.distance.centroids.shape
mask = ~torch.eye(n_clusters, dtype=bool)
weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :])
weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims)
norms = torch.norm(self.distance.centroids, dim=-1)
bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1)
setattr(self.parent_module, self.child_name,
torch.nn.Sequential(NeuralizedKMeans(weight, bias),
LogMeanExpPool(self.beta)))

def remove(self):
"""Revert the changes introduced by this canonizer."""
setattr(self.parent_module, self.child_name, self.distance_unchanged)

def copy(self):
return KMeansCanonizer(self.beta)
129 changes: 129 additions & 0 deletions src/zennit/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,132 @@ def __init__(self, dim=-1):
def forward(self, input):
'''Computes the sum along a dimension.'''
return torch.sum(input, dim=self.dim)


class Distance(torch.nn.Module):
'''Compute pairwise distances between two sets of points.
Initialized with a set of centroids, this layer computes the pairwise distance between the input and the centroids.
Parameters
----------
centroids : :py:obj:`torch.Tensor`
shape (K, D) tensor of centroids
power : float
power to raise the distance to
Examples
--------
>>> centroids = torch.randn(10, 2)
>>> distance = Distance(centroids)
>>> x = torch.randn(100, 2)
>>> distance(x)
'''
def __init__(self, centroids, power=2):
super().__init__()
self.centroids = torch.nn.Parameter(centroids)
self.power = power

def forward(self, input):
'''Computes the pairwise distance between `input` and `self.centroids` and raises to the power `self.power`.
Parameters
----------
input : :py:obj:`torch.Tensor`
shape (N, D) tensor of points
Returns
-------
:py:obj:`torch.Tensor`
shape (N, K) tensor of distances
'''
distance = torch.cdist(input, self.centroids)**self.power
return distance


class NeuralizedKMeans(torch.nn.Module):
'''Compute the k-means discriminants for a set of points.
Technically, this is a tensor-matrix product with a bias.
Parameters
----------
weight : :py:obj:`torch.Tensor`
shape (K, K-1, D) tensor of weights
bias : :py:obj:`torch.Tensor`
shape (K, K-1) tensor of biases
Examples
--------
>>> weight = torch.randn(10, 9, 2)
>>> bias = torch.randn(10, 9)
>>> neuralized_kmeans = NeuralizedKMeans(weight, bias)
'''
def __init__(self, weight, bias):
super().__init__()
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias)

def forward(self, x):
'''Computes the tensor-matrix product of `x` and `self.weight` and adds `self.bias`.
Parameters
----------
x : :py:obj:`torch.Tensor`
shape (N, D) tensor of points
Returns
-------
:py:obj:`torch.Tensor`
shape (N, K, K-1) tensor of k-means discriminants
'''
x = torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias
return x


class LogMeanExpPool(torch.nn.Module):
'''Computes a log-mean-exp pool along an axis.
LogMeanExpPool computes :math:`\\frac{1}{\\beta} \\log \\frac{1}{N} \\sum_{i=1}^N \\exp(\\beta x_i)`
Parameters
----------
beta : float
stiffness of the pool. Positive values make the pool more like a max pool, negative values make the pool
more like a min pool. Default value is -1.
dim : int
dimension over which to pool
Examples
--------
>>> x = torch.randn(10, 2)
>>> pool = LogMeanExpPool()
>>> pool(x)
'''
def __init__(self, beta=1., dim=-1):
super().__init__()
self.dim = dim
self.beta = beta

def forward(self, input):
'''Computes the LogMeanExpPool of `input`.
If the input has shape (N1, N2, ..., Nk) and `self.dim` is `j`, then the output has shape
(N1, N2, ..., Nj-1, Nj+1, ..., Nk).
Parameters
----------
input : :py:obj:`torch.Tensor`
the input tensor
Returns
-------
:py:obj:`torch.Tensor`
the LogMeanExpPool of `input`
'''
n_dims = input.shape[self.dim]
return (torch.logsumexp(self.beta * input, dim=self.dim)
- torch.log(torch.tensor(n_dims, dtype=input.dtype))) / self.beta
9 changes: 9 additions & 0 deletions src/zennit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
'''Type definitions for convenience.'''
import torch

from .layer import Distance as DistanceLayer


class SubclassMeta(type):
'''Meta class to bundle multiple subclasses.'''
Expand Down Expand Up @@ -124,3 +126,10 @@ class Activation(metaclass=SubclassMeta):
torch.nn.modules.activation.Tanhshrink,
torch.nn.modules.activation.Threshold,
)


class Distance(metaclass=SubclassMeta):
'''Abstract base class that describes distance modules.'''
__subclass__ = (
DistanceLayer,
)

0 comments on commit cbb350e

Please sign in to comment.