diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 873579e4..6e5024ef 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -67,7 +67,7 @@ public static NDArray MakeNdarray(TensorProto tensor) T[] ExpandArrayToSize(IList src) { - if(src.Count == 0) + if (src.Count == 0) { return new T[0]; } @@ -77,7 +77,7 @@ T[] ExpandArrayToSize(IList src) var first_elem = src[0]; var last_elem = src[src.Count - 1]; T[] res = new T[num_elements]; - for(long i = 0; i < num_elements; i++) + for (long i = 0; i < num_elements; i++) { if (i < pre) res[i] = first_elem; else if (i >= num_elements - after) res[i] = last_elem; @@ -121,7 +121,7 @@ T[] ExpandArrayToSize(IList src) $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); } - if(values.size == 0) + if (values.size == 0) { return np.zeros(shape, tensor_dtype); } @@ -135,23 +135,47 @@ T[] ExpandArrayToSize(IList src) TF_DataType.TF_QINT32 }; - private static TOut[,] ConvertArray2D(TIn[,] inputArray, Func converter) + private static Array ConvertArray(Array inputArray, Func converter) { - var rows = inputArray.GetLength(0); - var cols = inputArray.GetLength(1); - var outputArray = new TOut[rows, cols]; + if (inputArray == null) + throw new ArgumentNullException(nameof(inputArray)); - for (var i = 0; i < rows; i++) + var elementType = typeof(TOut); + var lengths = new int[inputArray.Rank]; + for (var i = 0; i < inputArray.Rank; i++) { - for (var j = 0; j < cols; j++) - { - outputArray[i, j] = converter(inputArray[i, j]); - } + lengths[i] = inputArray.GetLength(i); } + var outputArray = Array.CreateInstance(elementType, lengths); + + FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0); + return outputArray; } + private static void FillArray(Array inputArray, Array outputArray, Func converter, int[] indices, int dimension) + { + if (dimension == inputArray.Rank - 1) + { + for (int i = 0; i < inputArray.GetLength(dimension); i++) + { + indices[dimension] = i; + var inputValue = (TIn)inputArray.GetValue(indices); + var convertedValue = converter(inputValue); + outputArray.SetValue(convertedValue, indices); + } + } + else + { + for (int i = 0; i < inputArray.GetLength(dimension); i++) + { + indices[dimension] = i; + FillArray(inputArray, outputArray, converter, indices, dimension + 1); + } + } + } + /// /// Create a TensorProto, invoked in graph mode /// @@ -171,24 +195,30 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T var origin_dtype = values.GetDataType(); if (dtype == TF_DataType.DtInvalid) dtype = origin_dtype; - else if(origin_dtype != dtype) + else if (origin_dtype != dtype) { var new_system_dtype = dtype.as_system_dtype(); - - values = values switch + + if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE) + { + if (values is Array arrayValues) + { + values = dtype switch + { + TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32), + TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle), + TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble), + _ => values, + }; + } else + { + values = Convert.ChangeType(values, new_system_dtype); + } + + } else { - long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), - long[] longValues => values, - float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), - float[] floatValues => values, - float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), - float[,] float2DValues => values, - double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), - double[] doubleValues => values, - double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle), - double[,] double2DValues => values, - _ => Convert.ChangeType(values, new_system_dtype), - }; + + } dtype = values.GetDataType(); } @@ -306,7 +336,7 @@ bool hasattr(Graph property, string attr) if (tensor is EagerTensor eagerTensor) { - if(tensor.dtype == tf.int64) + if (tensor.dtype == tf.int64) return new Shape(tensor.ToArray()); else return new Shape(tensor.ToArray()); @@ -481,7 +511,7 @@ bool hasattr(Graph property, string attr) var d_ = new int[value.size]; foreach (var (index, d) in enumerate(value.ToArray())) d_[index] = d >= 0 ? d : -1; - + ret = ret.merge_with(new Shape(d_)); } return ret;