Skip to content

Commit

Permalink
Tutorial: Neuralized K-Means
Browse files Browse the repository at this point in the history
- Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data
- I tried to adhere to guidelines
- That means: random data, random weights
- Code for real data and real weights in comments
- Runs on colab, did not test blender
- also adds the reference to docs/source/tutorial/index.rst
  • Loading branch information
jacobkauffmann committed Aug 18, 2023
1 parent f3b6ba2 commit 2bf6f52
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 0 deletions.
312 changes: 312 additions & 0 deletions docs/source/tutorial/deep-kmeans.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0661ff4-9f41-405c-8453-f009c31e6a0e",
"metadata": {},
"source": [
"## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3aef718-d2a0-4f30-9b91-b53f5b288299",
"metadata": {},
"outputs": [],
"source": [
"# for colab folks\n",
"# %pip install zennit"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239",
"metadata": {},
"outputs": [],
"source": [
"# Basic boilerplate code\n",
"from torchvision import datasets, transforms\n",
"from torchvision.models import vgg16\n",
"import torch\n",
"import numpy as np\n",
"\n",
"transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n",
"transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n",
"\n",
"transform = transforms.Compose([\n",
" transform_img,\n",
" transforms.ToTensor(),\n",
" transform_norm\n",
"])"
]
},
{
"cell_type": "markdown",
"id": "d73397bd-14a2-48ee-8c42-46d6b5104115",
"metadata": {},
"source": [
"### Real data and weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5b258b8-c670-473f-858e-2f8464863e29",
"metadata": {},
"outputs": [],
"source": [
"# uncomment this cell for an example on real data and real weights\n",
"### Data loading\n",
"# from torch.utils.data import SubsetRandomSampler, DataLoader\n",
"\n",
"# # Attention: the next row downloads a dataset into the current folder!\n",
"# dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n",
"\n",
"# categories = ['cougar_body', 'Leopards', 'wild_cat']\n",
"\n",
"# all_indices = []\n",
"# for category in categories:\n",
"# category_idx = dataset.categories.index(category)\n",
"# category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n",
"\n",
"# num_samples = min(7, len(category_indices))\n",
"\n",
"# selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n",
"# all_indices.extend(selected_indices)\n",
"\n",
"# sampler = SubsetRandomSampler(all_indices)\n",
"# loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n",
"\n",
"# # If this line throws a shape error, just run this cell again (some images in Caltech101 are grayscale)\n",
"# images, labels = next(iter(loader))\n",
"\n",
"### Feature extractor\n",
"# features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']"
]
},
{
"cell_type": "markdown",
"id": "be3a0e0d-afa0-4af1-b8c2-a3f6525dcb03",
"metadata": {},
"source": [
"### Random data and weights for online preview"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "736252a7-4445-408c-a946-163ea44903da",
"metadata": {},
"outputs": [],
"source": [
"# for zennit contribution guidelines\n",
"# some random data and weights\n",
"images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n",
"features = vgg16(weights=None).eval()._modules['features']"
]
},
{
"cell_type": "markdown",
"id": "e7f02b4d-1da8-44ea-a887-6413d150b355",
"metadata": {},
"source": [
"### The fun begins here\n",
"\n",
"We construct a feature map $\\phi$ from image space to feature space.\n",
"Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb",
"metadata": {},
"outputs": [],
"source": [
"from zennit.layer import Sum\n",
"\n",
"phi = torch.nn.Sequential(\n",
" features,\n",
" Sum((2,3))\n",
")\n",
"\n",
"Z = phi(images).detach()"
]
},
{
"cell_type": "markdown",
"id": "97b43d41-322a-483c-8506-93e3fa0a852d",
"metadata": {},
"source": [
"Use simple `scikit-learn.KMeans` on the features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3",
"metadata": {},
"outputs": [],
"source": [
"# initialize on class means\n",
"# because we have very few data points here\n",
"centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "309e1158-de08-4493-af07-32a592622a94",
"metadata": {},
"outputs": [],
"source": [
"### uncomment for real fun\n",
"# from sklearn.cluster import KMeans\n",
"# standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n",
"# centroids = standard_kmeans.cluster_centers_"
]
},
{
"cell_type": "markdown",
"id": "5d65f068-b651-4f87-81d4-54508b71c841",
"metadata": {},
"source": [
"Now build a deep clustering model that takes images as input and predicts the k-means assignments\n",
"\n",
"We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce2dbb2a-8a97-488d-9f88-25426881ee10",
"metadata": {},
"outputs": [],
"source": [
"from zennit.layer import Distance\n",
"\n",
"# it's not necessary, just looks a bit nicer\n",
"s = ((centroids**2).sum(-1, keepdims=True)**.5)\n",
"s = s / s.mean()\n",
"\n",
"model = torch.nn.Sequential(\n",
" phi,\n",
" Distance(torch.from_numpy(centroids / s).float())\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145",
"metadata": {},
"source": [
"### Enter zennit."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7",
"metadata": {},
"outputs": [],
"source": [
"# import zennit\n",
"from zennit.attribution import Gradient\n",
"from zennit.composites import EpsilonGammaBox\n",
"from zennit.image import imgify\n",
"from zennit.torchvision import VGGCanonizer\n",
"from zennit.canonizers import KMeansCanonizer\n",
"from zennit.composites import LayerMapComposite, MixedComposite\n",
"from zennit.layer import NeuralizedKMeans\n",
"from zennit.rules import ZPlus, Gamma\n",
"\n",
"def data2img(x):\n",
" return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aac5b8af-61cc-400b-a0fc-b036148104ad",
"metadata": {},
"outputs": [],
"source": [
"# compute cluster assignments and check if they are equal\n",
"# without the scaling trick above, the are definitely equal (trust me)\n",
"ypred = model(images).argmin(1)\n",
"assert (ypred.numpy() == standard_kmeans.predict(Z)).all()"
]
},
{
"cell_type": "markdown",
"id": "47e38917-b4ee-499f-ba9e-55cce7cb8163",
"metadata": {},
"source": [
"### Everything is ready.\n",
"\n",
"You can play around with the `beta` parameter in `KMeansCanonizer` and the `gamma` parameter in `Gamma`.\n",
"\n",
"`beta` is a contrast parameter. Keep `beta < 0`.\n",
"Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n",
"\n",
"The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n",
"In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n",
"\n",
"If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690",
"metadata": {},
"outputs": [],
"source": [
"canonizer = KMeansCanonizer(beta=-1e-12)\n",
"\n",
"low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n",
"\n",
"composite = MixedComposite([\n",
" EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n",
" LayerMapComposite([(NeuralizedKMeans, Gamma(gamma=1.))])\n",
"])\n",
"\n",
"with Gradient(model=model, composite=composite) as attributor:\n",
" for c in range(len(centroids)):\n",
" print(\"Cluster %d\"%c)\n",
" cluster_members = (ypred == c).nonzero()[:,0]\n",
" for i in cluster_members:\n",
" img = images[i].unsqueeze(0)\n",
" target = torch.eye(len(centroids))[[c]]\n",
" output, attribution = attributor(img, target)\n",
" relevance = attribution[0].sum(0)\n",
"\n",
" heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n",
" display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/source/tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:maxdepth: 1

image-classification-vgg-resnet
deep-kmeans
..
image-segmentation-with-unet
text-classification-with-tbd
Expand Down

0 comments on commit 2bf6f52

Please sign in to comment.