forked from NVIDIA/NeMo-Curator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
semantic_dedup.py
651 lines (580 loc) · 25.1 KB
/
semantic_dedup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import time
from dataclasses import dataclass
from typing import List, Optional, Union
import cudf
import cupy as cp
import dask.bag as db
import dask.dataframe as dd
import dask_cudf
import numpy as np
import torch
import torch.nn as nn
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from cuml.dask.cluster import KMeans
from torch.nn import functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from nemo_curator.classifiers.base import _get_suggest_memory_for_classifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import SemDedupConfig
from nemo_curator.utils.distributed_utils import (
performance_report_if_with_ts_suffix,
write_to_disk,
)
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.semdedup_utils import (
assign_and_sort_clusters,
extract_dedup_data,
get_semantic_matches_per_cluster,
)
# Embedding Creation Module
@dataclass
class EmbeddingConfig:
model_name_or_path: str
max_seq_length: int = None
def __post_init__(self):
self.max_seq_length = AutoTokenizer.from_pretrained(
self.model_name_or_path
).model_max_length
# Gaurd against the HF bug
# which sets max_seq_length to max(int) for some models
if self.max_seq_length > 1e5:
self.max_seq_length = AutoConfig.from_pretrained(
self.model_name_or_path
).max_position_embeddings
class EmbeddingPytorchModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = AutoModel.from_pretrained(
config.model_name_or_path, config=self.config, force_download=False
)
def feature(self, input_ids, attention_mask):
with torch.autocast(device_type=input_ids.device.type):
embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)
return embeddings
@torch.no_grad()
def forward(self, batch):
feature = self.feature(batch["input_ids"], batch["attention_mask"])
return self._mean_pooling(feature, batch["attention_mask"])
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return F.normalize(sum_embeddings / sum_mask, dim=1)
class EmbeddingCrossFitModel(HFModel):
def __init__(
self,
config: EmbeddingConfig,
max_mem_gb: Optional[int] = None,
):
self.config = config
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()
super().__init__(self.config.model_name_or_path, max_mem_gb=max_mem_gb)
def load_model(self, device="cuda"):
model = EmbeddingPytorchModel(self.config)
model = model.to(device)
model.eval()
return model
def max_seq_length(self):
return self.config.max_seq_length
def load_config(self):
return AutoConfig.from_pretrained(self.config.model_name_or_path)
def load_tokenizer(self):
return AutoTokenizer.from_pretrained(self.config.model_name_or_path)
class EmbeddingCreator:
def __init__(
self,
embedding_model_name_or_path: str,
embedding_batch_size: int,
embedding_output_dir: str,
embedding_max_mem_gb: Optional[int] = None,
input_column: str = "text",
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
write_to_filename: bool = False,
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
):
"""
Initializes an EmbeddingCreator for generating embeddings using the specified model configurations.
Args:
embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
embedding_batch_size (int): Number of samples to process in each batch.
embedding_output_dir (str): Directory path where embeddings will be saved.
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
profile_dir (str): If specified directory to write dask profile. Default is None.
Attributes:
embeddings_config (EmbeddingConfig): Configuration for embeddings.
batch_size (int): Batch size for embedding generation.
logger (logging.Logger): Logger instance for the class.
embedding_output_dir (str): Output directory for embeddings.
input_column (str): Input column for data processing.
model (EmbeddingCrossFitModel): Model instance for embedding generation.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
"""
self.embeddings_config = EmbeddingConfig(
model_name_or_path=embedding_model_name_or_path,
)
self.batch_size = embedding_batch_size
self.logger = self._setup_logger(logger)
self.embedding_output_dir = embedding_output_dir
self.input_column = input_column
self.embedding_column = embedding_column
self.model = EmbeddingCrossFitModel(
self.embeddings_config, max_mem_gb=embedding_max_mem_gb
)
self.write_embeddings_to_disk = write_embeddings_to_disk
self.write_to_filename = write_to_filename
self.profile_dir = profile_dir
def _setup_logger(self, logger):
if isinstance(logger, str):
return create_logger(
rank=0,
name="compute-embeddings",
log_file=os.path.join(logger, "compute_embeddings.log"),
log_level=logging.INFO,
stdout=True,
)
else:
return logger
def create_embeddings(
self, ddf: dask_cudf.DataFrame, input_column="text"
) -> dask_cudf.DataFrame:
pipe = op.Sequential(
op.Tokenizer(
self.model,
cols=[input_column],
tokenizer_type="sentencepiece",
max_length=self.embeddings_config.max_seq_length,
),
op.Predictor(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col=self.embedding_column,
),
keep_cols=ddf.columns.tolist(),
)
return pipe(ddf)
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
t0 = time.time()
if self.write_embeddings_to_disk:
with performance_report_if_with_ts_suffix(
self.profile_dir, "embedding-creator"
):
embedding_ddf = self.create_embeddings(dataset.df, self.input_column)
write_to_disk(
embedding_ddf,
self.embedding_output_dir,
write_to_filename=self.write_to_filename,
output_type="parquet",
)
ddf = DocumentDataset(
dask_cudf.read_parquet(
self.embedding_output_dir, blocksize="2GB", aggregate_files=True
)
)
else:
ddf = DocumentDataset(embedding_ddf)
self.logger.info(
f"Time taken for Creating Embeddings : {time.time() - t0}"
+ (
f" and output written at {self.embedding_output_dir}"
if self.write_embeddings_to_disk
else ""
)
)
return ddf
### Clustering Module
def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
return df[embedding_col].list.leaves.values.reshape(len(df), -1)
def add_dist_to_cents(
df: "cudf.DataFrame", embedding_col: str, centroids: cp.ndarray
) -> "cudf.DataFrame":
embed_array = get_embedding_ar(df, embedding_col)
centroids_ar = centroids[df["nearest_cent"].values]
dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1))
df["dist_to_cent"] = dist_to_cents
return df
class ClusteringModel:
def __init__(
self,
id_column: str,
max_iter: int,
n_clusters: int,
clustering_output_dir: str,
embedding_col: str = "embeddings",
sim_metric: str = "cosine",
which_to_keep: str = "hard",
sort_clusters: bool = True,
kmeans_with_cos_dist: bool = False,
partition_size: str = "2gb",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
):
"""
Initializes the ClusteringModel with the provided settings for semantic clustering to help semantic deduplication.
Args:
id_column (str): Column name used as the identifier in the dataset.
max_iter (int): Maximum number of iterations for the clustering algorithm.
n_clusters (int): The number of clusters to form.
clustering_output_dir (str): Directory path where clustering results will be saved.
embedding_col (str): Column name where the embeddings are stored.
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
sort_clusters (bool): Whether to sort clusters, default is True.
kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
partition_size (str): The size of data partition to run kmeans with, default is "2gb".
logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
profile_dir (str): If specified directory to write dask profile. Default is None.
This constructor sets up the parameters required for clustering operations.
"""
self.id_col = id_column
self.max_iter = max_iter
self.n_clusters = n_clusters
self.clustering_output_dir = clustering_output_dir
self.embedding_col = embedding_col
self.sim_metric = sim_metric
self.keep_hard = which_to_keep == "hard"
self.kmeans_with_cos_dist = kmeans_with_cos_dist
self.partition_size = partition_size
self.sort_clusters = sort_clusters
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir
if not os.path.exists(self.clustering_output_dir):
expand_outdir_and_mkdir(self.clustering_output_dir)
else:
self.logger.warning(
f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten"
)
def _setup_logger(self, logger):
if isinstance(logger, str):
return create_logger(
rank=0,
name="SemanticClusterLevelDedup",
log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
log_level=logging.INFO,
stdout=True,
)
else:
return logger
def __call__(self, embeddings_dataset: DocumentDataset):
embeddings_df = embeddings_dataset.df
if self.embedding_col not in embeddings_df.columns:
raise ValueError(
f"Expected embedding column '{self.embedding_col}'"
f" to be in dataset. Only found columns {embeddings_df.columns}"
)
with performance_report_if_with_ts_suffix(self.profile_dir, "clustering-model"):
embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]
embeddings_df = embeddings_df.to_backend("pandas").persist()
embeddings_df = embeddings_df.repartition(
partition_size=self.partition_size
)
embeddings_df = embeddings_df.to_backend("cudf")
cupy_darr = embeddings_df.map_partitions(
get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
)
cupy_darr.compute_chunk_sizes()
t0 = time.time()
kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter)
self.logger.info("KMeans starting fit")
kmeans.fit(cupy_darr)
self.logger.info("KMeans fit complete")
self.logger.info(f"Time taken for KMeans Fit: {time.time() - t0}")
self.logger.info(
"Computing nearest centroids + distance to centers using kmeans.predict"
)
t0 = time.time()
nearest_cents = kmeans.predict(cupy_darr)
self.logger.info(f"Time taken for KMeans Predict: {time.time() - t0}")
t0 = time.time()
embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32)
del nearest_cents
meta_df = embeddings_df._meta.copy()
meta_df["dist_to_cent"] = cp.zeros(1)
embeddings_df = embeddings_df.map_partitions(
add_dist_to_cents,
embedding_col=self.embedding_col,
centroids=kmeans.cluster_centers_,
meta=meta_df,
)
embeddings_df = embeddings_df.reset_index(drop=True)
centroids = kmeans.cluster_centers_
kmeans_centroids_file = os.path.join(
self.clustering_output_dir, "kmeans_centroids.npy"
)
np.save(kmeans_centroids_file, centroids)
self.logger.info("Saving centroids complete")
del kmeans, cupy_darr, centroids
clustering_output_dir = os.path.join(
self.clustering_output_dir, "embs_by_nearest_center"
)
if os.path.exists(clustering_output_dir):
self.logger.warning(
f"Output directory {clustering_output_dir} already exists and will be overwritten"
)
shutil.rmtree(clustering_output_dir)
embeddings_df.to_parquet(
clustering_output_dir,
index=False,
partition_on="nearest_cent",
)
self.logger.info(
f"Time taken for Assigning distance to each embedding : {time.time() - t0} "
f"and output written at {clustering_output_dir}"
)
del embeddings_df
if self.sort_clusters:
assign_and_sort_clusters(
id_col=self.id_col,
kmeans_centroids_file=kmeans_centroids_file,
nearest_cent_dir=clustering_output_dir,
output_sorted_clusters_dir=os.path.join(
self.clustering_output_dir, "sorted"
),
embedding_col=self.embedding_col,
sim_metric=self.sim_metric,
keep_hard=self.keep_hard,
kmeans_with_cos_dist=self.kmeans_with_cos_dist,
cluster_ids=range(self.n_clusters),
logger=self.logger,
profile_dir=self.profile_dir,
)
fps = [
os.path.join(clustering_output_dir, file_name)
for file_name in os.listdir(clustering_output_dir)
]
embeddings_df = dd.from_map(cudf.read_parquet, fps)
return DocumentDataset(embeddings_df)
class SemanticClusterLevelDedup:
def __init__(
self,
n_clusters: int,
emb_by_clust_dir: str,
sorted_clusters_dir: str,
id_column: str,
id_column_type: str,
which_to_keep: str,
output_dir: str,
embedding_col: str = "embeddings",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
) -> None:
"""
Initialize the SemanticClusterLevelDedup class.
Args:
n_clusters (int): Number of clusters.
emb_by_clust_dir (str): Directory containing embeddings by cluster.
sorted_clusters_dir (str): Directory containing sorted clusters.
id_column (str): Column name for IDs.
id_column_type (str): Data type of the ID column.
which_to_keep (str): Strategy for which duplicate to keep.
output_dir (str): Directory to save output files.
embedding_col (str): Column where the embeddings are stored.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
profile_dir (str): If specified directory to write dask profile. Default is None.
"""
self.n_clusters = n_clusters
self.emb_by_clust_dir = emb_by_clust_dir
self.sorted_clusters_dir = sorted_clusters_dir
self.id_col = id_column
self.id_col_type = id_column_type
self.which_to_keep = which_to_keep
self.output_dir = output_dir
self.semdedup_pruning_tables_dir = os.path.join(
output_dir, "semdedup_pruning_tables"
)
self.computed_semantic_match_dfs = False
self.embedding_col = embedding_col
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir
def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
"""
Set up the logger.
Args:
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
Returns:
logging.Logger: Configured logger.
"""
if isinstance(logger, str):
return create_logger(
rank=0,
name="SemanticClusterLevelDedup",
log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
log_level=logging.INFO,
stdout=True,
)
else:
return logger
def compute_semantic_match_dfs(
self, eps_list: Optional[List[float]] = None
) -> None:
"""
Compute semantic match dataframes for clusters.
Args:
eps_list (Optional[List[float]]): List of epsilon values for clustering.
"""
if eps_list is None:
eps_list1 = [1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5, 1.0e-6]
eps_list2 = [0.1 + x * 0.005 for x in range(34)]
eps_list = eps_list1 + eps_list2
if os.path.exists(self.semdedup_pruning_tables_dir):
self.logger.info(
f"Removing existing directory {self.semdedup_pruning_tables_dir}"
)
shutil.rmtree(self.semdedup_pruning_tables_dir)
expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir)
t0 = time.time()
with performance_report_if_with_ts_suffix(
self.profile_dir, "semantic-match-compute"
):
tasks = db.from_sequence(
list(range(self.n_clusters)), npartitions=self.n_clusters
).map(
lambda cluster_id: get_semantic_matches_per_cluster(
cluster_id=cluster_id,
emb_by_clust_dir=self.emb_by_clust_dir,
sorted_clusters_dir=self.sorted_clusters_dir,
id_col=self.id_col,
id_col_type=self.id_col_type,
eps_list=eps_list,
output_dir=self.semdedup_pruning_tables_dir,
embedding_col=self.embedding_col,
which_to_keep=self.which_to_keep,
)
)
tasks.compute()
self.logger.info(
f"Time taken for Computing Semantic Matches : {time.time() - t0}"
)
self.computed_semantic_match_dfs = True
def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
"""
Extract deduplicated data based on epsilon value.
Args:
eps_to_extract (float): Epsilon threshold for extracting deduplicated data.
Returns:
DocumentDataset: Dataset containing deduplicated documents.
"""
if not self.computed_semantic_match_dfs:
raise ValueError(
"Run compute_semantic_match_dfs before calling extract_dedup_data"
)
output_summary_file = os.path.join(
self.output_dir, f"dedup_summary_{eps_to_extract}.csv"
)
output_parquet_path = os.path.join(
self.output_dir, f"unique_ids_{eps_to_extract}.parquet"
)
extract_dedup_data(
eps=eps_to_extract,
n_clusters=self.n_clusters,
id_col=self.id_col,
id_col_type=self.id_col_type,
sorted_clusters_dir=self.sorted_clusters_dir,
semdedup_pruning_tables_dir=self.semdedup_pruning_tables_dir,
output_summary_file=output_summary_file,
output_parquet_path=output_parquet_path,
logger=self.logger,
profile_dir=self.profile_dir,
)
fps = [
os.path.join(output_parquet_path, file_name)
for file_name in os.listdir(output_parquet_path)
]
return DocumentDataset.read_parquet(fps, backend="cudf")
class SemDedup:
def __init__(
self,
config: SemDedupConfig,
input_column: str = "text",
id_column: str = "id",
id_column_type: str = "int",
logger: Union[logging.Logger, str] = "./",
) -> None:
"""
Initialize the SemDedup class.
Args:
config (SemDedupConfig): Configuration for SemDedup.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
"""
self.config = config
self.logger = logger
cache_dir = config.cache_dir
self.embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=config.embedding_model_name_or_path,
embedding_batch_size=config.embedding_batch_size,
input_column=input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
logger=logger,
profile_dir=self.config.profile_dir,
)
self.clustering_model = ClusteringModel(
id_column=id_column,
max_iter=config.max_iter,
n_clusters=config.n_clusters,
clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc),
logger=logger,
profile_dir=self.config.profile_dir,
)
self.semantic_cluster_dedup = SemanticClusterLevelDedup(
n_clusters=config.n_clusters,
emb_by_clust_dir=os.path.join(
cache_dir, config.clustering_save_loc, "embs_by_nearest_center"
),
sorted_clusters_dir=os.path.join(
cache_dir, config.clustering_save_loc, "sorted"
),
id_column=id_column,
id_column_type=id_column_type,
which_to_keep=config.which_to_keep,
output_dir=os.path.join(cache_dir, config.clustering_save_loc),
logger=logger,
profile_dir=self.config.profile_dir,
)
self.eps_thresholds = config.eps_thresholds
self.eps_to_extract = config.eps_to_extract
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Execute the SemDedup process.
Args:
dataset (DocumentDataset): Input dataset for deduplication.
Returns:
DocumentDataset: Deduplicated dataset.
"""
embeddings_dataset = self.embedding_creator(dataset)
self.clustering_model(embeddings_dataset)
self.semantic_cluster_dedup.compute_semantic_match_dfs(self.eps_thresholds)
return self.semantic_cluster_dedup.extract_dedup_data(
eps_to_extract=self.eps_to_extract
)