Skip to content

Commit

Permalink
Merge pull request #1189 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
feat: add the implementation of class_weight in model.fit
  • Loading branch information
Oceania2018 authored Oct 6, 2023
2 parents f16902d + a1c64ef commit 43c3705
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 10 deletions.
70 changes: 69 additions & 1 deletion src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;
using Tensorflow.Util;
using Tensorflow.Framework;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand All @@ -24,6 +26,7 @@ public class DataHandler
long _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
NDArray _sample_weight => args.SampleWeight;
IVariableV1 _steps_per_execution;

public DataHandler(DataHandlerArgs args)
Expand Down Expand Up @@ -75,10 +78,75 @@ public DataHandler(DataHandlerArgs args)
}

_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
_step_increment = _steps_per_execution_value - 1;
_insufficient_data = false;
_configure_dataset_and_inferred_steps(args.X, args.ClassWeight);
}

void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight)
{
if (_dataset == null)
{
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}

if (class_weight != null)
{
_dataset = _dataset.map(_make_class_weight_map_fn(class_weight));
}
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}


Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight)
{
var class_ids = class_weight.Keys.OrderBy(key => key).ToList();
var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1);
if (!class_ids.SequenceEqual(expected_class_ids))
{
throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+
$"than the number of classes, found {class_weight}");
}

var class_weight_list = new List<float>();
foreach (var class_id in class_ids)
{
class_weight_list.Add(class_weight[class_id]);
}
var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray());

Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) =>
{
var x = data[0];
var y = data[1];
var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight);

if (y.shape.rank > 2)
{
throw new ValueError("`class_weight` not supported for 3+ dimensional targets.");
}

var y_classes = smart_module.smart_cond(
y.shape.rank == 2 && y.shape[1] > 1,
() => math_ops.argmax(y, dimension: 1),
() => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64));

var cw = array_ops.gather(class_weight_tensor, y_classes);
if (sw != null)
{
cw = tf.cast(cw, sw.dtype);
cw *= sw;
}
else
{
sw = cw;
}
return new Tensors { x, y, sw };
};

return _class_weight_map_fn;
}

long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
Expand Down
13 changes: 11 additions & 2 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,20 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}

Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
{
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
Expand Down
11 changes: 4 additions & 7 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ public ICallback fit(NDArray x, NDArray y,
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}

// TODO(Wanglongzhi2001)
if (class_weight != null)
{
throw new NotImplementedException("class_weight is not implemented");
}

var data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
Expand All @@ -78,6 +72,7 @@ public ICallback fit(NDArray x, NDArray y,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -126,11 +121,12 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
{
X = new Tensors(x.ToArray()),
Y = y,
SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -174,6 +170,7 @@ public History fit(IDatasetV2 dataset,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down

0 comments on commit 43c3705

Please sign in to comment.