Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuralized K-Means #197

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
303 changes: 303 additions & 0 deletions docs/source/tutorial/deep-kmeans.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0661ff4-9f41-405c-8453-f009c31e6a0e",
"metadata": {},
"source": [
"## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3aef718-d2a0-4f30-9b91-b53f5b288299",
"metadata": {},
"outputs": [],
"source": [
"dummy = True\n",
"# for colab folks\n",
"# %pip install zennit\n",
"# dummy = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239",
"metadata": {},
"outputs": [],
"source": [
"# Basic boilerplate code\n",
"from torchvision import datasets, transforms\n",
"from torchvision.models import vgg16\n",
"import torch\n",
"import numpy as np\n",
"\n",
"transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n",
"transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n",
"\n",
"transform = transforms.Compose([\n",
" transform_img,\n",
" transforms.ToTensor(),\n",
" transform_norm\n",
"])"
]
},
{
"cell_type": "markdown",
"id": "d73397bd-14a2-48ee-8c42-46d6b5104115",
"metadata": {},
"source": [
"### Data and weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5b258b8-c670-473f-858e-2f8464863e29",
"metadata": {},
"outputs": [],
"source": [
"## Data loading\n",
"if dummy:\n",
" images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n",
" features = vgg16(weights=None).eval()._modules['features']\n",
"else:\n",
" from torch.utils.data import SubsetRandomSampler, DataLoader\n",
"\n",
" # Attention: the next row downloads a dataset into the current folder!\n",
" dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n",
"\n",
" categories = ['cougar_body', 'Leopards', 'wild_cat']\n",
"\n",
" all_indices = []\n",
" for category in categories:\n",
" category_idx = dataset.categories.index(category)\n",
" category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n",
"\n",
" num_samples = min(7, len(category_indices))\n",
"\n",
" selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n",
" all_indices.extend(selected_indices)\n",
"\n",
" sampler = SubsetRandomSampler(all_indices)\n",
" loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n",
"\n",
" try:\n",
" images, labels = next(iter(loader))\n",
" except Exception as e:\n",
" print(f\"Exception: {e}\\nSimply run the cell again.\")\n",
"\n",
" ## Feature extractor\n",
" features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']"
]
},
{
"cell_type": "markdown",
"id": "e7f02b4d-1da8-44ea-a887-6413d150b355",
"metadata": {},
"source": [
"### The fun begins here\n",
"\n",
"We construct a feature map $\\phi$ from image space to feature space.\n",
"Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb",
"metadata": {},
"outputs": [],
"source": [
"from zennit.layer import Sum\n",
"\n",
"phi = torch.nn.Sequential(\n",
" features,\n",
" Sum((2,3))\n",
")\n",
"\n",
"Z = phi(images).detach()"
]
},
{
"cell_type": "markdown",
"id": "97b43d41-322a-483c-8506-93e3fa0a852d",
"metadata": {},
"source": [
"Use simple `scikit-learn.KMeans` on the features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3",
"metadata": {},
"outputs": [],
"source": [
"# initialize on class means\n",
"# because we have very few data points here\n",
"centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "309e1158-de08-4493-af07-32a592622a94",
"metadata": {},
"outputs": [],
"source": [
"if not dummy:\n",
" from sklearn.cluster import KMeans\n",
" standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n",
" centroids = standard_kmeans.cluster_centers_"
]
},
{
"cell_type": "markdown",
"id": "5d65f068-b651-4f87-81d4-54508b71c841",
"metadata": {},
"source": [
"Now build a deep clustering model that takes images as input and predicts the k-means assignments\n",
"\n",
"We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce2dbb2a-8a97-488d-9f88-25426881ee10",
"metadata": {},
"outputs": [],
"source": [
"from zennit.layer import PairwiseCentroidDistance\n",
"\n",
"# it's not necessary, just looks a bit nicer\n",
"s = ((centroids**2).sum(-1, keepdims=True)**.5)\n",
"s = s / s.mean()\n",
"\n",
"model = torch.nn.Sequential(\n",
" phi,\n",
" PairwiseCentroidDistance(torch.from_numpy(centroids / s).float())\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145",
"metadata": {},
"source": [
"### Enter zennit."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7",
"metadata": {},
"outputs": [],
"source": [
"# import zennit\n",
"from zennit.attribution import Gradient\n",
"from zennit.composites import EpsilonGammaBox\n",
"from zennit.image import imgify\n",
"from zennit.torchvision import VGGCanonizer\n",
"from zennit.canonizers import KMeansCanonizer\n",
"from zennit.composites import LayerMapComposite, MixedComposite\n",
"from zennit.layer import NeuralizedKMeans, MinPool1d\n",
"from zennit.rules import ZPlus, Gamma, MinTakesMost1d\n",
"\n",
"def data2img(x):\n",
" return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aac5b8af-61cc-400b-a0fc-b036148104ad",
"metadata": {},
"outputs": [],
"source": [
"# compute cluster assignments and check if they are equal\n",
"# without the scaling trick above, the are definitely equal (trust me)\n",
"ypred = model(images).argmin(1)\n",
"# assert (ypred.numpy() == standard_kmeans.predict(Z)).all()"
]
},
{
"cell_type": "markdown",
"id": "47e38917-b4ee-499f-ba9e-55cce7cb8163",
"metadata": {},
"source": [
"### Everything is ready.\n",
"\n",
"You can play around with the `beta` parameter in `MinTakesMost1d` and the `gamma` parameter in `Gamma`.\n",
"\n",
"`beta` is a contrast parameter. Keep `beta < 0`.\n",
"Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n",
"\n",
"The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n",
"In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n",
"\n",
"If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"canonizer = KMeansCanonizer()\n",
"\n",
"low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n",
"\n",
"composite = MixedComposite([\n",
" EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n",
" LayerMapComposite([\n",
" (NeuralizedKMeans, Gamma(gamma=.0)),\n",
" (MinPool1d, MinTakesMost1d(beta=1e-6))\n",
" ])\n",
"])\n",
"\n",
"with Gradient(model=model, composite=composite) as attributor:\n",
" for c in range(len(centroids)):\n",
" print(\"Cluster %d\"%c)\n",
" cluster_members = (ypred == c).nonzero()[:,0]\n",
" for i in cluster_members:\n",
" img = images[i].unsqueeze(0)\n",
" target = torch.eye(len(centroids))[[c]]\n",
" output, attribution = attributor(img, target)\n",
" relevance = attribution[0].sum(0)\n",
"\n",
" heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n",
" display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/source/tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:maxdepth: 1

image-classification-vgg-resnet
deep-kmeans
..
image-segmentation-with-unet
text-classification-with-tbd
Expand Down
Loading
Loading