Skip to content

Commit

Permalink
Merge pull request #1180 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
fix: fix EarlyStopping
  • Loading branch information
Oceania2018 authored Sep 21, 2023
2 parents 725ec1e + f809f6e commit 3811e4e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
6 changes: 6 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,11 @@ public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string?

[AutoNumPy]
public static NDArray add(NDArray x, NDArray y) => new NDArray(math_ops.add(x, y));

[AutoNumPy]
public static NDArray greater(NDArray x, NDArray y) => new NDArray(tf.greater(x, y));

[AutoNumPy]
public static NDArray less(NDArray x, NDArray y) => new NDArray(tf.less(x, y));
}
}
64 changes: 42 additions & 22 deletions src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ public class EarlyStopping: ICallback
string _monitor;
string _mode;
bool _restore_best_weights;
List<IVariableV1>? _best_weights;
List<NDArray>? _best_weights;
CallbackParams _parameters;
Func<NDArray, NDArray, NDArray> _monitor_op;

public Dictionary<string, List<float>>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0,
Expand All @@ -38,17 +40,49 @@ public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", floa
_min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights;
_mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")

if (_mode != "auto" && _mode != "min" && _mode != "max")
{
Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode.");
_mode = "auto";
}

if (_mode == "min")
{
_monitor_op = np.less;
}
else if (_mode == "max")
{
_monitor_op = np.greater;
}
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
_monitor_op = np.greater;
}
else
{
_monitor_op = np.less;
}
}

if (_monitor_op == np.greater)
{
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
_min_delta *= 1;
}
else
{
_min_delta *= -1;
}
}
public void on_train_begin()
{
_wait = 0;
_stopped_epoch = 0;
_best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf;
_best_weights = null;
_best_epoch = 0;
_best = (float)np.Inf;
}

public void on_epoch_begin(int epoch)
Expand All @@ -74,7 +108,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.Weights;
_best_weights = _parameters.Model.get_weights();
}
_wait += 1;

Expand All @@ -83,7 +117,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
_best = current;
_best_epoch = epoch;
if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
_best_weights = _parameters.Model.get_weights();
// Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0;
Expand All @@ -99,7 +133,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
_parameters.Model.Weights = _best_weights;
_parameters.Model.set_weights(_best_weights);
}
}
}
Expand Down Expand Up @@ -131,21 +165,7 @@ float get_monitor_value(Dictionary<string, float> logs)
}
public bool _is_improvement(float monitor_value, float reference_value)
{
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
return _monitor_op(monitor_value - _min_delta, reference_value);
}

public void on_test_end(Dictionary<string, float> logs)
Expand Down

0 comments on commit 3811e4e

Please sign in to comment.