Skip to content

Commit

Permalink
Add none to aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Nov 23, 2024
1 parent e32a8ae commit 759aaf6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def __init__(
"cannot contain character `/`. "
f"Received: name={name}"
)
if aggregation not in ("mean", "sum", "only_first_replica"):
if aggregation not in ("none", "mean", "sum", "only_first_replica"):
raise ValueError(
"Invalid valid for argument `aggregation`. Expected "
"one of {'mean', 'sum', 'only_first_replica'}. "
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
f"Received: aggregation={aggregation}"
)
self.name = name
Expand Down
3 changes: 2 additions & 1 deletion keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ def _write_object_proto(self, proto, options):

def _map_aggregation(self, aggregation):
mapping = {
"none": tf.VariableAggregation.NONE,
"sum": tf.VariableAggregation.SUM,
"mean": tf.VariableAggregation.MEAN,
"only_first_replica": tf.VariableAggregation.ONLY_FIRST_REPLICA,
}
return mapping.get(aggregation, tf.VariableAggregation.NONE)
return mapping[aggregation]


def convert_to_tensor(x, dtype=None, sparse=None):
Expand Down
2 changes: 2 additions & 0 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def build(self, var_list):
shape=(),
dtype="int",
initializer=initializers.Zeros(),
aggregation="none",
name="step_counter",
)
self.dynamic_scale = self.add_variable(
shape=(),
dtype="float32",
initializer=initializers.Constant(self.initial_scale),
aggregation="none",
name="dynamic_scale",
)
self.inner_optimizer.build(var_list)
Expand Down

0 comments on commit 759aaf6

Please sign in to comment.