Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[desktop] Support rejected faces in clusters synced with remote #4051

Merged
merged 14 commits into from
Nov 18, 2024
41 changes: 35 additions & 6 deletions web/packages/new/photos/components/gallery/PeopleHeader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,23 @@ const suggestionsDialogReducer: React.Reducer<
// original assigned state.
updates.delete(item.id);
} else {
updates.set(item.id, value);
const update = (() => {
switch (value) {
case true:
// true corresponds to update "assign".
return "assign";
case false:
// false maps to different updates for suggestions
// vs choices.
return item.assigned === undefined
? "rejectSuggestion"
: "rejectSavedChoice";
case undefined:
// undefined means reset.
return "reset";
}
})();
updates.set(item.id, update);
}
return { ...state, updates };
}
Expand Down Expand Up @@ -764,7 +780,7 @@ const SuggestionOrChoiceList: React.FC<SuggestionOrChoiceListProps> = ({
</Stack>
{!item.fixed && (
<ToggleButtonGroup
value={fromItemValue(item, updates)}
value={itemValueFromUpdate(item, updates)}
exclusive
onChange={(_, v) => onUpdateItem(item, toItemValue(v))}
>
Expand All @@ -781,12 +797,25 @@ const SuggestionOrChoiceList: React.FC<SuggestionOrChoiceListProps> = ({
</List>
);

const fromItemValue = (item: SCItem, updates: PersonSuggestionUpdates) => {
const itemValueFromUpdate = (
item: SCItem,
updates: PersonSuggestionUpdates,
) => {
// Use the in-memory state if available. For choices, fallback to their
// original state.
const resolved = updates.has(item.id)
? updates.get(item.id)
: item.assigned;
const resolveUpdate = () => {
switch (updates.get(item.id)) {
case "assign":
return true;
case "rejectSavedChoice":
return false;
case "rejectSuggestion":
return false;
default:
return undefined;
}
};
const resolved = updates.has(item.id) ? resolveUpdate() : item.assigned;
return resolved ? "yes" : resolved === false ? "no" : undefined;
};

Expand Down
79 changes: 60 additions & 19 deletions web/packages/new/photos/services/ml/cluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import log from "@/base/log";
import type { EnteFile } from "@/media/file";
import { ensure } from "@/utils/ensure";
import { wait } from "@/utils/promise";
import { savedCGroups, updateOrCreateUserEntities } from "../user-entity";
import {
pullUserEntities,
savedCGroups,
updateOrCreateUserEntities,
} from "../user-entity";
import { savedFaceClusters, saveFaceClusters } from "./db";
import {
faceDirection,
Expand Down Expand Up @@ -97,14 +101,26 @@ export const _clusterFaces = async (

const sortedCGroups = cgroups.sort((a, b) => b.updatedAt - a.updatedAt);

// Extract the remote clusters.
clusters = clusters.concat(
// See: [Note: strict mode migration]
//
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
sortedCGroups.map((cg) => cg.data.assigned).flat(),
);
// Fill in clusters from remote cgroups, and also construct rejected lookup.
const rejectedClusterIDsForFaceID = new Map<string, Set<string>>();
for (const cgroup of sortedCGroups) {
if (cgroup.data.rejectedFaceIDs.length == 0) {
clusters = clusters.concat(cgroup.data.assigned);
} else {
const rejectedFaceIDs = new Set(cgroup.data.rejectedFaceIDs);
clusters = clusters.concat(
cgroup.data.assigned.map((cluster) => ({
...cluster,
faces: cluster.faces.filter((f) => !rejectedFaceIDs.has(f)),
})),
);
for (const faceID of rejectedFaceIDs) {
const s = rejectedClusterIDsForFaceID.get(faceID) ?? new Set();
cgroup.data.assigned.forEach(({ id }) => s.add(id));
rejectedClusterIDsForFaceID.set(faceID, s);
}
}
}

// Add on the clusters we have available locally.
clusters = clusters.concat(await savedFaceClusters());
Expand All @@ -130,10 +146,15 @@ export const _clusterFaces = async (
}
}

// IDs of the clusters which were modified. We use this information to
// determine which cgroups need to be updated on remote.
const modifiedClusterIDs = new Set<string>();

const state = {
faceIDToClusterID,
faceIDToClusterIndex,
clusters,
modifiedClusterIDs,
};

// Process the faces in batches, but keep an overlap between batches to
Expand All @@ -147,6 +168,7 @@ export const _clusterFaces = async (
await clusterBatchLinear(
faces.slice(offset, offset + batchSize),
state,
rejectedClusterIDsForFaceID,
({ completed }) =>
onProgress({ completed: offset + completed, total }),
);
Expand All @@ -155,7 +177,7 @@ export const _clusterFaces = async (
const t = `(${Date.now() - startTime} ms)`;
log.info(`Refreshed ${clusters.length} clusters from ${total} faces ${t}`);

return clusters;
return { clusters, modifiedClusterIDs };
};

/**
Expand Down Expand Up @@ -235,11 +257,13 @@ interface ClusteringState {
faceIDToClusterID: Map<string, string>;
faceIDToClusterIndex: Map<string, number>;
clusters: FaceCluster[];
modifiedClusterIDs: Set<string>;
}

const clusterBatchLinear = async (
batch: ClusterFace[],
state: ClusteringState,
rejectedClusterIDsForFaceID: Map<string, Set<string>>,
onProgress: (progress: ClusteringProgress) => void,
) => {
const [clusteredFaces, unclusteredFaces] = batch.reduce<
Expand Down Expand Up @@ -274,6 +298,8 @@ const clusterBatchLinear = async (
// If the face is already part of a cluster, then skip it.
if (state.faceIDToClusterID.has(fi.faceID)) continue;

const rejectedClusters = rejectedClusterIDsForFaceID.get(fi.faceID);

// Find the nearest neighbour among the previous faces in this batch.
let nnIndex: number | undefined;
let nnCosineSimilarity = 0;
Expand All @@ -286,11 +312,24 @@ const clusterBatchLinear = async (
// The vectors are already normalized, so we can directly use their
// dot product as their cosine similarity.
const csim = dotProduct(fi.embedding, fj.embedding);
if (csim <= nnCosineSimilarity) continue;

const threshold = fj.isBadFace ? 0.84 : 0.76;
if (csim > nnCosineSimilarity && csim >= threshold) {
nnIndex = j;
nnCosineSimilarity = csim;
if (csim < threshold) continue;

// Don't add the face back to a cluster it has been rejected from.
if (rejectedClusters) {
const cjx = state.faceIDToClusterIndex.get(fj.faceID);
if (cjx !== undefined) {
const cj = ensure(state.clusters[cjx]);
if (rejectedClusters.has(cj.id)) {
continue;
}
}
}

nnIndex = j;
nnCosineSimilarity = csim;
}

if (nnIndex !== undefined) {
Expand All @@ -304,6 +343,7 @@ const clusterBatchLinear = async (
state.faceIDToClusterID.set(fi.faceID, nnCluster.id);
state.faceIDToClusterIndex.set(fi.faceID, nnClusterIndex);
nnCluster.faces.push(fi.faceID);
state.modifiedClusterIDs.add(nnCluster.id);
} else {
// No neighbour within the threshold. Create a new cluster.
const clusterID = newClusterID();
Expand All @@ -313,6 +353,7 @@ const clusterBatchLinear = async (
state.faceIDToClusterID.set(fi.faceID, cluster.id);
state.faceIDToClusterIndex.set(fi.faceID, clusterIndex);
state.clusters.push(cluster);
state.modifiedClusterIDs.add(cluster.id);
}
}
};
Expand All @@ -326,6 +367,7 @@ const clusterBatchLinear = async (
*/
export const reconcileClusters = async (
clusters: FaceCluster[],
modifiedClusterIDs: Set<string>,
masterKey: Uint8Array,
) => {
// Index clusters by their ID for fast lookup.
Expand All @@ -337,12 +379,8 @@ export const reconcileClusters = async (
// Find the cgroups that have changed since we started.
const changedCGroups = cgroups
.map((cgroup) => {
for (const oldCluster of cgroup.data.assigned) {
// The clustering algorithm does not remove any existing faces, it
// can only add new ones to the cluster. So we can use the count as
// an indication if something changed.
const newCluster = ensure(clusterByID.get(oldCluster.id));
if (oldCluster.faces.length != newCluster.faces.length) {
for (const cluster of cgroup.data.assigned) {
if (modifiedClusterIDs.has(cluster.id)) {
return {
...cgroup,
data: {
Expand Down Expand Up @@ -375,4 +413,7 @@ export const reconcileClusters = async (
await saveFaceClusters(
clusters.filter(({ id }) => !isRemoteClusterID.has(id)),
);

// Refresh our local state if we'd updated remote.
if (changedCGroups.length) await pullUserEntities("cgroup", masterKey);
};
27 changes: 10 additions & 17 deletions web/packages/new/photos/services/ml/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,10 @@ export const addCGroup = async (name: string, cluster: FaceCluster) => {
/**
* Add a new cluster to an existing named person.
*
* If this cluster contains any faces that had previously been marked as not
* belonging to the person, then they will be removed from the rejected list and
* will get reassociated to the person.
*
* @param cgroup The existing cgroup underlying the person. This is the (remote)
* user entity that will get updated.
*
Expand All @@ -717,28 +721,17 @@ export const addCGroup = async (name: string, cluster: FaceCluster) => {
export const addClusterToCGroup = async (
cgroup: CGroup,
cluster: FaceCluster,
) =>
updateAssignedClustersForCGroup(
cgroup,
cgroup.data.assigned.concat([cluster]),
) => {
const clusterFaceIDs = new Set(cluster.faces);
const assigned = cgroup.data.assigned.concat([cluster]);
const rejectedFaceIDs = cgroup.data.rejectedFaceIDs.filter(
(id) => !clusterFaceIDs.has(id),
);

/**
* Update the clusters assigned to an existing named person.
*
* @param cgroup The existing cgroup underlying the person. This is the (remote)
* user entity that will get updated.
*
* @param cluster The new value of the face clusters assigned to this person.
*/
export const updateAssignedClustersForCGroup = async (
cgroup: CGroup,
assigned: FaceCluster[],
) => {
const masterKey = await masterKeyFromSession();
await updateOrCreateUserEntities(
"cgroup",
[{ ...cgroup, data: { ...cgroup.data, assigned } }],
[{ ...cgroup, data: { ...cgroup.data, assigned, rejectedFaceIDs } }],
masterKey,
);
return mlSync();
Expand Down
Loading