Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the implementation of sample_weight in model.fit #1187

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -16,5 +17,7 @@ public class DataAdapterArgs: IKerasConfig
public int Worker { get; set; }
public bool UseMultiprocessing { get; set; }
public IModel Model { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -18,5 +19,7 @@ public class DataHandlerArgs: IKerasConfig
public bool UseMultiprocessing { get; set; } = false;
public IModel Model { get; set; }
public IVariableV1 StepsPerExecution { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}
11 changes: 9 additions & 2 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine;

Expand All @@ -22,8 +23,10 @@ ICallback fit(NDArray x, NDArray y,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -35,8 +38,10 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -63,6 +68,8 @@ void load_weights(string filepath,
Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,

int steps = -1,
int max_queue_size = 10,
int workers = 1,
Expand Down
66 changes: 66 additions & 0 deletions src/TensorFlowNET.Core/Util/Data.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Tensorflow.NumPy;

namespace Tensorflow.Util
{
/// <summary>
/// ValidationDataPack is used to pass validation data to fit method.
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays.
/// </summary>
public class ValidationDataPack
{
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;

public ValidationDataPack((NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
}

public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public void Deconstruct(out NDArray val_x, out NDArray val_y)
{
val_x = this.val_x;
val_y = this.val_y;
}

public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{
val_x = this.val_x;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}
}
}
59 changes: 59 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand Down Expand Up @@ -34,9 +35,67 @@ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
return (x, y);
}

public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight)
{
for (int i = 0; i < x.Length; i++)
{
if (x[i].shape.ndim == 1)
x[i] = array_ops.expand_dims(x[i], axis: -1);
}
for (int i = 0; i < y.Length; i++)
{
if (y[i].shape.ndim == 1)
y[i] = array_ops.expand_dims(y[i], axis: -1);
}
for (int i = 0; i < sample_weight.Length; i++)
{
if (sample_weight[i].shape.ndim == 1)
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1);
}
return (x, y, sample_weight);
}

public virtual bool ShouldRecreateIterator()
{
return true;
}

public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]);
}

return ((train_x, train_y, sample_weight), validation_data);
}

public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);
return ((train_x, train_y, sample_weight), validation_data);
}
}
}
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand All @@ -28,6 +29,7 @@ public class DataHandler
public DataHandler(DataHandlerArgs args)
{
this.args = args;

if (args.StepsPerExecution == null)
{
_steps_per_execution = tf.Variable(1L);
Expand All @@ -48,6 +50,7 @@ public DataHandler(DataHandlerArgs args)
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
SampleWeight = args.SampleWeight,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
Expand Down
2 changes: 2 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public interface IDataAdapter
IDatasetV2 GetDataset();
int GetSize();
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight);

bool ShouldRecreateIterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null;
num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
Expand All @@ -37,6 +37,8 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
inputs.AddRange(args.X);
if (args.Y != null)
inputs.AddRange(args.Y);
if (sample_weight_tensor != null)
inputs.Add(sample_weight_tensor);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}
Expand Down Expand Up @@ -94,8 +96,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)

public override bool ShouldRecreateIterator() => false;

void _process_tensorlike()
Tensor _process_tensorlike(NDArray sample_weights)
{
return tf.convert_to_tensor(sample_weights);
}
}
}
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Engine/LossesContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ public LossesContainer(ILossFunc losses, string[] output_names = null)
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
public Tensor Call(Tensor y_true, Tensor y_pred)
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build(y_pred);
var loss_value = _losses.Call(y_true, y_pred);
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
var loss_metric_value = loss_value;
var batch_dim = array_ops.shape(y_true)[0];

Expand Down
19 changes: 14 additions & 5 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public partial class Model
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -51,6 +52,7 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -140,7 +142,8 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
test_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
Expand All @@ -149,17 +152,23 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
var outputs = data.Length == 2 ?
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
test_step(
data_handler,
new Tensors(data.Take(x_size).ToArray()),
new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
Expand Down
Loading
Loading