Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Amsgrad #137

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified bench/dlrm_s_criteo_kaggle.sh
100755 → 100644
Empty file.
105 changes: 88 additions & 17 deletions dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@

from torch.optim.lr_scheduler import _LRScheduler


exc = getattr(builtins, "IOError", "FileNotFoundError")


class LRPolicyScheduler(_LRScheduler):
def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
self.num_warmup_steps = num_warmup_steps
Expand Down Expand Up @@ -180,19 +182,19 @@ def create_emb(self, m, ln):
# construct embedding operator
if self.qr_flag and n > self.qr_threshold:
EE = QREmbeddingBag(n, m, self.qr_collisions,
operation=self.qr_operation, mode="sum", sparse=True)
operation=self.qr_operation, mode="sum", sparse=self.sparse)
elif self.md_flag:
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
EE = PrEmbeddingBag(n, _m, base, sparse=self.sparse)
# use np initialization as below for consistency...
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m)
).astype(np.float32)
EE.embs.weight.data = torch.tensor(W, requires_grad=True)

else:
EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse)

# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
Expand Down Expand Up @@ -229,6 +231,8 @@ def __init__(
qr_threshold=200,
md_flag=False,
md_threshold=200,
emb_assignments=None,
sparse=True
):
super(DLRM_Net, self).__init__()

Expand All @@ -242,13 +246,15 @@ def __init__(

# save arguments
self.ndevices = ndevices
self.emb_assignments = emb_assignments
self.output_d = 0
self.parallel_model_batch_size = -1
self.parallel_model_is_not_prepared = True
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
self.sparse = sparse
# create variables for QR embedding if applicable
self.qr_flag = qr_flag
if self.qr_flag:
Expand Down Expand Up @@ -366,6 +372,25 @@ def sequential_forward(self, dense_x, lS_o, lS_i):

return z


def distribute_embs(self, ndevices):
# distribute embeddings (model parallelism)
t_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device(
"cuda:" + str(self.emb_assignments[k]))
emb.to(d)
t_list.append(emb.to(d))
self.emb_l = nn.ModuleList(t_list)


def distribute_model(self, device_ids, batch_size):
# replicate mlp (data parallelism)
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
self.parallel_model_batch_size = batch_size


def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
Expand All @@ -379,18 +404,11 @@ def parallel_forward(self, dense_x, lS_o, lS_i):

if self.parallel_model_is_not_prepared or self.sync_dense_params:
# replicate mlp (data parallelism)
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
self.parallel_model_batch_size = batch_size
self.distribute_model(device_ids, batch_size)

if self.parallel_model_is_not_prepared:
# distribute embeddings (model parallelism)
t_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
emb.to(d)
t_list.append(emb.to(d))
self.emb_l = nn.ModuleList(t_list)
self.distribute_embs(ndevices)
self.parallel_model_is_not_prepared = False

### prepare input (overwrite) ###
Expand All @@ -404,7 +422,8 @@ def parallel_forward(self, dense_x, lS_o, lS_i):
t_list = []
i_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
dev_id = self.emb_assignments
d = torch.device("cuda:" + str(self.emb_assignments[k]))
t_list.append(lS_o[k].to(d))
i_list.append(lS_i[k].to(d))
lS_o = t_list
Expand Down Expand Up @@ -529,6 +548,7 @@ def dash_separated_floats(value):
parser.add_argument("--qr-threshold", type=int, default=200)
parser.add_argument("--qr-operation", type=str, default="mult")
parser.add_argument("--qr-collisions", type=int, default=4)
parser.add_argument("--use-emb-distrib-heuristic", action='store_true', default=False)
# activations and loss
parser.add_argument("--activation-function", type=str, default="relu")
parser.add_argument("--loss-function", type=str, default="mse") # or bce or wbce
Expand Down Expand Up @@ -560,6 +580,7 @@ def dash_separated_floats(value):
The Terabyte dataset can be multiprocessed in an environment \
with more than 24 CPU cores and at least 1 TB of memory.")
# training
parser.add_argument("--solver", type=str, default="sgd")
parser.add_argument("--mini-batch-size", type=int, default=1)
parser.add_argument("--nepochs", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=0.01)
Expand All @@ -573,6 +594,7 @@ def dash_separated_floats(value):
# gpu
parser.add_argument("--use-gpu", action="store_true", default=False)
# debugging and profiling
parser.add_argument("--print-num-emb-params", action="store_true", default=False)
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
parser.add_argument("--test-mini-batch-size", type=int, default=-1)
Expand All @@ -598,6 +620,7 @@ def dash_separated_floats(value):
parser.add_argument("--lr-num-decay-steps", type=int, default=0)
args = parser.parse_args()


if args.mlperf_logging:
print('command line args: ', json.dumps(vars(args)))

Expand Down Expand Up @@ -722,7 +745,17 @@ def dash_separated_floats(value):
d0=m_spa,
round_dim=args.md_round_dims
).tolist()
print(m_spa)

if args.print_num_emb_params:
num_params = int(sum(torch.tensor(ln_emb) * torch.tensor(m_spa)))
if isinstance(m_spa, list):
_m_spa = torch.tensor(m_spa)
has_proj = _m_spa < max(_m_spa)
num_params += int(torch.sum(has_proj*_m_spa)*max(m_spa))
print(f"Num of params in embedding layer {num_params}")


# test prints (model arch)
if args.debug_mode:
print("model arch:")
Expand Down Expand Up @@ -757,7 +790,7 @@ def dash_separated_floats(value):

print("data (inputs and targets):")
for j, (X, lS_o, lS_i, T) in enumerate(train_ld):
# early exit if nbatches was set by the user and has been exceeded
# early exit if nbatches was set by the user and has been exceeded
if nbatches > 0 and j >= nbatches:
break

Expand All @@ -777,6 +810,33 @@ def dash_separated_floats(value):

ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1

if args.use_emb_distrib_heuristic:

def emb_distrib_heuristic(Rows, Dims, ndevices):
#inputs: 2 parallel lists (Rows, Dims) and an int (ndevices)
#Rows--list: i-th entry is # of rows in i-th emb table (int)
#Dims--list: i-th entry is dim of rows in i-th table (int)
#output: a list of balanced table assignments to device
num_params = torch.tensor(Rows) * torch.tensor(Dims)
cur_load = torch.zeros(ndevices)
assignments = [0]*len(Rows)
val, idx = torch.sort(num_params, descending=True)
for i,v in enumerate(val):
a = torch.argmin(cur_load)
assignments[idx[i]] = int(a)
cur_load[a] += v
return assignments

_m_spa = m_spa if isinstance(
m_spa, list) else [m_spa]*len(ln_emb)

assignments = emb_distrib_heuristic(ln_emb,_m_spa,ndevices)

else:
assignments = [k % ndevices for k in range(len(ln_emb))]
print(assignments)


### construct the neural network specified above ###
# WARNING: to obtain exactly the same initialization for
# the weights we need to start from the same random seed.
Expand All @@ -799,6 +859,8 @@ def dash_separated_floats(value):
qr_threshold=args.qr_threshold,
md_flag=args.md_flag,
md_threshold=args.md_threshold,
emb_assignments=assignments,
sparse=False if args.solver == 'amsgrad' else True
)
# test prints
if args.debug_mode:
Expand Down Expand Up @@ -828,10 +890,20 @@ def dash_separated_floats(value):

if not args.inference_only:
# specify the optimizer algorithm
optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate)
if args.solver == 'sgd':
optimizer = torch.optim.SGD(
dlrm.parameters(), lr=args.learning_rate)
elif args.solver == 'amsgrad':
optimizer = torch.optim.Adam(
dlrm.parameters(), lr=args.learning_rate, amsgrad=True)
else:
raise ValueError(
f'Solver {args.solver} is not supported. Select sgd or amsgrad')

lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
args.lr_num_decay_steps)


### main loop ###
def time_wrap(use_gpu):
if use_gpu:
Expand Down Expand Up @@ -1049,8 +1121,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
str_run_type = "inference" if args.inference_only else "training"
print(
"Finished {} it {}/{} of epoch {}, {:.2f} ms/it, ".format(
str_run_type, j + 1, nbatches, k, gT
)
str_run_type, j + 1, nbatches, k, gT)
+ "loss {:.6f}, accuracy {:3.3f} %".format(gL, gA * 100)
)
# Uncomment the line below to print out the total time with overhead
Expand Down
6 changes: 3 additions & 3 deletions tricks/md_embedding_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def alpha_power_rule(n, alpha, d0=None, B=None):


def pow_2_round(dims):
return 2 ** torch.round(torch.log2(dims.type(torch.float)))
return (2 ** torch.round(torch.log2(dims.type(torch.float)))).long()


class PrEmbeddingBag(nn.Module):
def __init__(self, num_embeddings, embedding_dim, base_dim):
def __init__(self, num_embeddings, embedding_dim, base_dim, sparse=False):
super(PrEmbeddingBag, self).__init__()
self.embs = nn.EmbeddingBag(
num_embeddings, embedding_dim, mode="sum", sparse=True)
num_embeddings, embedding_dim, mode="sum", sparse=sparse)
torch.nn.init.xavier_uniform_(self.embs.weight)
if embedding_dim < base_dim:
self.proj = nn.Linear(embedding_dim, base_dim, bias=False)
Expand Down