Skip to content

Commit

Permalink
Merge pull request #1257 from novikov-alexander/alnovi/generic-cast
Browse files Browse the repository at this point in the history
fix: More generic array cast
  • Loading branch information
Oceania2018 authored Jun 22, 2024
2 parents 0392027 + def5774 commit 7fb73cd
Showing 1 changed file with 59 additions and 29 deletions.
88 changes: 59 additions & 29 deletions src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static NDArray MakeNdarray(TensorProto tensor)

T[] ExpandArrayToSize<T>(IList<T> src)
{
if(src.Count == 0)
if (src.Count == 0)
{
return new T[0];
}
Expand All @@ -77,7 +77,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
var first_elem = src[0];
var last_elem = src[src.Count - 1];
T[] res = new T[num_elements];
for(long i = 0; i < num_elements; i++)
for (long i = 0; i < num_elements; i++)
{
if (i < pre) res[i] = first_elem;
else if (i >= num_elements - after) res[i] = last_elem;
Expand Down Expand Up @@ -121,7 +121,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
}

if(values.size == 0)
if (values.size == 0)
{
return np.zeros(shape, tensor_dtype);
}
Expand All @@ -135,23 +135,47 @@ T[] ExpandArrayToSize<T>(IList<T> src)
TF_DataType.TF_QINT32
};

private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
private static Array ConvertArray<TOut>(Array inputArray, Func<object, TOut> converter)
{
var rows = inputArray.GetLength(0);
var cols = inputArray.GetLength(1);
var outputArray = new TOut[rows, cols];
if (inputArray == null)
throw new ArgumentNullException(nameof(inputArray));

for (var i = 0; i < rows; i++)
var elementType = typeof(TOut);
var lengths = new int[inputArray.Rank];
for (var i = 0; i < inputArray.Rank; i++)
{
for (var j = 0; j < cols; j++)
{
outputArray[i, j] = converter(inputArray[i, j]);
}
lengths[i] = inputArray.GetLength(i);
}

var outputArray = Array.CreateInstance(elementType, lengths);

FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0);

return outputArray;
}

private static void FillArray<TIn, TOut>(Array inputArray, Array outputArray, Func<TIn, TOut> converter, int[] indices, int dimension)
{
if (dimension == inputArray.Rank - 1)
{
for (int i = 0; i < inputArray.GetLength(dimension); i++)
{
indices[dimension] = i;
var inputValue = (TIn)inputArray.GetValue(indices);
var convertedValue = converter(inputValue);
outputArray.SetValue(convertedValue, indices);
}
}
else
{
for (int i = 0; i < inputArray.GetLength(dimension); i++)
{
indices[dimension] = i;
FillArray(inputArray, outputArray, converter, indices, dimension + 1);
}
}
}

/// <summary>
/// Create a TensorProto, invoked in graph mode
/// </summary>
Expand All @@ -171,24 +195,30 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
var origin_dtype = values.GetDataType();
if (dtype == TF_DataType.DtInvalid)
dtype = origin_dtype;
else if(origin_dtype != dtype)
else if (origin_dtype != dtype)
{
var new_system_dtype = dtype.as_system_dtype();

values = values switch

if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE)
{
if (values is Array arrayValues)
{
values = dtype switch
{
TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32),
TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle),
TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble),
_ => values,
};
} else
{
values = Convert.ChangeType(values, new_system_dtype);
}

} else
{
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
long[] longValues => values,
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[] floatValues => values,
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
float[,] float2DValues => values,
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[] doubleValues => values,
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
double[,] double2DValues => values,
_ => Convert.ChangeType(values, new_system_dtype),
};

}
dtype = values.GetDataType();
}

Expand Down Expand Up @@ -306,7 +336,7 @@ bool hasattr(Graph property, string attr)

if (tensor is EagerTensor eagerTensor)
{
if(tensor.dtype == tf.int64)
if (tensor.dtype == tf.int64)
return new Shape(tensor.ToArray<long>());
else
return new Shape(tensor.ToArray<int>());
Expand Down Expand Up @@ -481,7 +511,7 @@ bool hasattr(Graph property, string attr)
var d_ = new int[value.size];
foreach (var (index, d) in enumerate(value.ToArray<int>()))
d_[index] = d >= 0 ? d : -1;

ret = ret.merge_with(new Shape(d_));
}
return ret;
Expand Down

0 comments on commit 7fb73cd

Please sign in to comment.