Skip to content

Commit

Permalink
Merge pull request #1186 from Beacontownfc/ragged
Browse files Browse the repository at this point in the history
Improve RaggedTensor
  • Loading branch information
Oceania2018 authored Sep 29, 2023
2 parents eb4c1f4 + 02bfb9a commit 15763df
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,5 +1139,18 @@ public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string n
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
return _op.output;
}

public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims")
{
if(ndims != -100)
{
if (axis >= 0 && axis < ndims) return axis;
else if (-ndims <= axis && axis < 0) return axis + ndims;
else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}");

} else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known.");
return axis;
}

}
}
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,38 @@ public static implicit operator RaggedTensor(Tensor tensor)
{
return tensor.Tag as RaggedTensor;
}
public Tensor nrows(TF_DataType out_type, string name = null)
{
tf_with(ops.name_scope(name, "RaggedNRows"), scope =>
{
return math_ops.cast(this._row_partition.nrows(), dtype: out_type);
});
return null;
}
public RaggedTensor row_lengths(int axis=-1, string name=null)
{
if (axis == 0) return this._row_partition.nrows();
if (axis == 1) return this._row_partition.row_lengths();
var values = (RaggedTensor)this._values;
axis = array_ops.get_positive_axis(
axis, this.shape.rank, ndims_name: "rank(this)");
if (axis == 0) return this.nrows(this._row_partition.GetDataType());
else if (axis == 1)
{
var splits = this._row_partition.row_splits;
return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)];

}
else if (this._values is RaggedTensor)
{
return values.row_lengths(axis - 1);
}
else
{
var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType());
return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) *
shape[axis - 1];
}
}
}
}
55 changes: 55 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using Serilog.Debugging;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
//using System.ComponentModel.DataAnnotations;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Framework;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -99,5 +104,55 @@ public static RowPartition from_row_splits(Tensor row_splits,
return new RowPartition(row_splits);
});
}

public static RowPartition from_row_lengths(Tensor row_lengths,
bool validate=true,
TF_DataType dtype = TF_DataType.TF_INT32,
TF_DataType dtype_hint= TF_DataType.TF_INT32)
{
row_lengths = _convert_row_partition(
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype);
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1));
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0);
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths);
}

public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype,
TF_DataType dtype_hint= TF_DataType.TF_INT64)
{
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name);
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64");
return partition;
}

public Tensor nrows()
{
/*Returns the number of rows created by this `RowPartition*/
if (this._nrows != null) return this._nrows;
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0);
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1;
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype);
}

public Tensor row_lengths()
{

if (this._row_splits != null)
{
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]);
return tf.constant(nrows_plus_one - 1);

}
if (this._row_lengths != null)
{
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]);
return tf.constant(nrows);
}
if(this._nrows != null)
{
return tensor_util.constant_value(this._nrows);
}
return tf.constant(-1);
}
}
}
26 changes: 26 additions & 0 deletions test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ManagedAPI
{
public class RaggedTensorTest :EagerModeTestBase
{
[TestMethod]
public void Test_from_row_lengths()
{
var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64));
var rp = RowPartition.from_row_lengths(row_lengths, validate: false);
var rp_row_lengths = rp.row_lengths();
var rp_nrows = rp.nrows();
Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]);

}
}
}

0 comments on commit 15763df

Please sign in to comment.