Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RaggedTensor #1186

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,5 +1139,18 @@ public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string n
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
return _op.output;
}

public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims")
{
if(ndims != -100)
{
if (axis >= 0 && axis < ndims) return axis;
else if (-ndims <= axis && axis < 0) return axis + ndims;
else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}");

} else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known.");
return axis;
}

}
}
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,38 @@ public static implicit operator RaggedTensor(Tensor tensor)
{
return tensor.Tag as RaggedTensor;
}
public Tensor nrows(TF_DataType out_type, string name = null)
{
tf_with(ops.name_scope(name, "RaggedNRows"), scope =>
{
return math_ops.cast(this._row_partition.nrows(), dtype: out_type);
});
return null;
}
public RaggedTensor row_lengths(int axis=-1, string name=null)
{
if (axis == 0) return this._row_partition.nrows();
if (axis == 1) return this._row_partition.row_lengths();
var values = (RaggedTensor)this._values;
axis = array_ops.get_positive_axis(
axis, this.shape.rank, ndims_name: "rank(this)");
if (axis == 0) return this.nrows(this._row_partition.GetDataType());
else if (axis == 1)
{
var splits = this._row_partition.row_splits;
return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)];

}
else if (this._values is RaggedTensor)
{
return values.row_lengths(axis - 1);
}
else
{
var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType());
return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) *
shape[axis - 1];
}
}
}
}
55 changes: 55 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using Serilog.Debugging;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
//using System.ComponentModel.DataAnnotations;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Framework;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -99,5 +104,55 @@ public static RowPartition from_row_splits(Tensor row_splits,
return new RowPartition(row_splits);
});
}

public static RowPartition from_row_lengths(Tensor row_lengths,
bool validate=true,
TF_DataType dtype = TF_DataType.TF_INT32,
TF_DataType dtype_hint= TF_DataType.TF_INT32)
{
row_lengths = _convert_row_partition(
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype);
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1));
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0);
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths);
}

public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype,
TF_DataType dtype_hint= TF_DataType.TF_INT64)
{
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name);
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64");
return partition;
}

public Tensor nrows()
{
/*Returns the number of rows created by this `RowPartition*/
if (this._nrows != null) return this._nrows;
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0);
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1;
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype);
}

public Tensor row_lengths()
{

if (this._row_splits != null)
{
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]);
return tf.constant(nrows_plus_one - 1);

}
if (this._row_lengths != null)
{
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]);
return tf.constant(nrows);
}
if(this._nrows != null)
{
return tensor_util.constant_value(this._nrows);
}
return tf.constant(-1);
}
}
}
26 changes: 26 additions & 0 deletions test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ManagedAPI
{
public class RaggedTensorTest :EagerModeTestBase
{
[TestMethod]
public void Test_from_row_lengths()
{
var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64));
var rp = RowPartition.from_row_lengths(row_lengths, validate: false);
var rp_row_lengths = rp.row_lengths();
var rp_nrows = rp.nrows();
Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]);

}
}
}
Loading