Skip to content

Commit

Permalink
Merge pull request #1169 from lingbai-kong/ndarrayload
Browse files Browse the repository at this point in the history
add: loading pickled npy file for imdb dataset loader
  • Loading branch information
Oceania2018 authored Sep 9, 2023
2 parents 70d681c + f57a6fe commit 179c3f0
Show file tree
Hide file tree
Showing 14 changed files with 546 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System.Linq;
using System.Text;
using Tensorflow.Util;
using Razorvine.Pickle;
using Tensorflow.NumPy.Pickle;
using static Tensorflow.Binding;

namespace Tensorflow.NumPy
Expand Down Expand Up @@ -97,6 +99,13 @@ Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, i
return matrix;
}

Array ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape)
{
Stream stream = reader.BaseStream;
var unpickler = new Unpickler();
return (MultiArrayPickleWarpper)unpickler.load(stream);
}

public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
{
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);
Expand Down
16 changes: 12 additions & 4 deletions src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ public Array LoadMatrix(Stream stream)
Array matrix = Array.CreateInstance(type, shape);

//if (type == typeof(String))
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
return ReadValueMatrix(reader, matrix, bytes, type, shape);
//return ReadStringMatrix(reader, matrix, bytes, type, shape);

if (type == typeof(Object))
return ReadObjectMatrix(reader, matrix, shape);
else
{
return ReadValueMatrix(reader, matrix, bytes, type, shape);
}
}
}

Expand All @@ -37,7 +43,7 @@ public T Load<T>(Stream stream)
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{
// if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string)))
// return LoadJagged(stream) as T;
// return LoadJagged(stream) as T;
return LoadMatrix(stream) as T;
}

Expand Down Expand Up @@ -93,7 +99,7 @@ bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape
Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
{
isLittleEndian = IsLittleEndian(dtype);
bytes = Int32.Parse(dtype.Substring(2));
bytes = dtype.Length > 2 ? Int32.Parse(dtype.Substring(2)) : 0;

string typeCode = dtype.Substring(1);

Expand Down Expand Up @@ -121,6 +127,8 @@ Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
return typeof(Double);
if (typeCode.StartsWith("S"))
return typeof(String);
if (typeCode.StartsWith("O"))
return typeof(Object);

throw new NotSupportedException();
}
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ public class RandomizedImpl
public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x));

[AutoNumPy]
public void shuffle(NDArray x)
public void shuffle(NDArray x, int? seed = null)
{
var y = random_ops.random_shuffle(x);
var y = random_ops.random_shuffle(x, seed);
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
}

Expand Down
1 change: 1 addition & 0 deletions src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public class NDArrayConverter
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
=> nd.dtype switch
{
TF_DataType.TF_BOOL => Scalar<T>(*(bool*)nd.data),
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data),
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),
Expand Down
20 changes: 20 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.NumPy.Pickle
{
public class DTypePickleWarpper
{
TF_DataType dtype { get; set; }
public DTypePickleWarpper(TF_DataType dtype)
{
this.dtype = dtype;
}
public void __setstate__(object[] args) { }
public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper)
{
return dTypeWarpper.dtype;
}
}
}
52 changes: 52 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Razorvine.Pickle;

namespace Tensorflow.NumPy.Pickle
{
/// <summary>
///
/// </summary>
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
class DtypeConstructor : IObjectConstructor
{
public object construct(object[] args)
{
var typeCode = (string)args[0];
TF_DataType dtype;
if (typeCode == "b1")
dtype = np.@bool;
else if (typeCode == "i1")
dtype = np.@byte;
else if (typeCode == "i2")
dtype = np.int16;
else if (typeCode == "i4")
dtype = np.int32;
else if (typeCode == "i8")
dtype = np.int64;
else if (typeCode == "u1")
dtype = np.ubyte;
else if (typeCode == "u2")
dtype = np.uint16;
else if (typeCode == "u4")
dtype = np.uint32;
else if (typeCode == "u8")
dtype = np.uint64;
else if (typeCode == "f4")
dtype = np.float32;
else if (typeCode == "f8")
dtype = np.float64;
else if (typeCode.StartsWith("S"))
dtype = np.@string;
else if (typeCode.StartsWith("O"))
dtype = np.@object;
else
throw new NotSupportedException();
return new DTypePickleWarpper(dtype);
}
}
}
53 changes: 53 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Razorvine.Pickle;
using Razorvine.Pickle.Objects;

namespace Tensorflow.NumPy.Pickle
{
/// <summary>
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if
/// the objects are ints, etc.
/// </summary>
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")]
public class MultiArrayConstructor : IObjectConstructor
{
public object construct(object[] args)
{
if (args.Length != 3)
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments.");

var types = (ClassDictConstructor)args[0];
if (types.module != "numpy" || types.name != "ndarray")
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray");

var arg1 = (object[])args[1];
var dims = new int[arg1.Length];
for (var i = 0; i < arg1.Length; i++)
{
dims[i] = (int)arg1[i];
}
var shape = new Shape(dims);

TF_DataType dtype;
string identifier;
if (args[2].GetType() == typeof(string))
identifier = (string)args[2];
else
identifier = Encoding.UTF8.GetString((byte[])args[2]);
switch (identifier)
{
case "u": dtype = np.uint32; break;
case "c": dtype = np.complex_; break;
case "f": dtype = np.float32; break;
case "b": dtype = np.@bool; break;
default: throw new NotImplementedException($"Unsupported data type: {args[2]}");
}
return new MultiArrayPickleWarpper(shape, dtype);
}
}
}
119 changes: 119 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using Newtonsoft.Json.Linq;
using Serilog.Debugging;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.NumPy.Pickle
{
public class MultiArrayPickleWarpper
{
public Shape reconstructedShape { get; set; }
public TF_DataType reconstructedDType { get; set; }
public NDArray reconstructedNDArray { get; set; }
public Array reconstructedMultiArray { get; set; }
public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype)
{
reconstructedShape = shape;
reconstructedDType = dtype;
}
public void __setstate__(object[] args)
{
if (args.Length != 5)
throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments.");

var version = (int)args[0]; // version

var arg1 = (object[])args[1];
var dims = new int[arg1.Length];
for (var i = 0; i < arg1.Length; i++)
{
dims[i] = (int)arg1[i];
}
var _ShapeLike = new Shape(dims); // shape

TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType

var F_continuous = (bool)args[3]; // F-continuous
if (F_continuous)
throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format.");

var data = args[4]; // Data
/*
* If we ever need another pickle format, increment the version
* number. But we should still be able to handle the old versions.
*/
if (version < 0 || version > 4)
throw new ValueError($"can't handle version {version} of numpy.dtype pickle");

// TODO: Implement the missing details and checks from the official Numpy C code here.
// https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761

if (data.GetType() == typeof(ArrayList))
{
Reconstruct((ArrayList)data);
}
else
throw new NotImplementedException("");
}
private void Reconstruct(ArrayList arrayList)
{
int ndim = 1;
var subArrayList = arrayList;
while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList))
{
subArrayList = (ArrayList)subArrayList[0];
ndim += 1;
}
var type = subArrayList[0].GetType();
if (type == typeof(int))
{
if (ndim == 1)
{
int[] list = (int[])arrayList.ToArray(typeof(int));
Shape shape = new Shape(new int[] { arrayList.Count });
reconstructedMultiArray = list;
reconstructedNDArray = new NDArray(list, shape);
}
if (ndim == 2)
{
int secondDim = 0;
foreach (ArrayList subArray in arrayList)
{
secondDim = subArray.Count > secondDim ? subArray.Count : secondDim;
}
int[,] list = new int[arrayList.Count, secondDim];
for (int i = 0; i < arrayList.Count; i++)
{
var subArray = (ArrayList?)arrayList[i];
if (subArray == null)
throw new NullReferenceException("");
for (int j = 0; j < subArray.Count; j++)
{
var element = subArray[j];
if (element == null)
throw new NoNullAllowedException("the element of ArrayList cannot be null.");
list[i, j] = (int)element;
}
}
Shape shape = new Shape(new int[] { arrayList.Count, secondDim });
reconstructedMultiArray = list;
reconstructedNDArray = new NDArray(list, shape);
}
if (ndim > 2)
throw new NotImplementedException("can't handle ArrayList with more than two dimensions.");
}
else
throw new NotImplementedException("");
}
public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper)
{
return arrayWarpper.reconstructedMultiArray;
}
public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper)
{
return arrayWarpper.reconstructedNDArray;
}
}
}
4 changes: 3 additions & 1 deletion src/TensorFlowNET.Core/Numpy/Numpy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public partial class np
public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE;
public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX;
public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128;
public static readonly TF_DataType @string = TF_DataType.TF_STRING;
public static readonly TF_DataType @object = TF_DataType.TF_VARIANT;
#endregion

public static double nan => double.NaN;
Expand Down
1 change: 1 addition & 0 deletions src/TensorFlowNET.Core/Tensorflow.Binding.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.255" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
<PackageReference Include="Razorvine.Pickle" Version="1.4.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>

Expand Down
6 changes: 6 additions & 0 deletions src/TensorFlowNET.Core/tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using Razorvine.Pickle;
using Serilog;
using Serilog.Core;
using System.Reflection;
Expand All @@ -22,6 +23,7 @@ limitations under the License.
using Tensorflow.Eager;
using Tensorflow.Gradients;
using Tensorflow.Keras;
using Tensorflow.NumPy.Pickle;

namespace Tensorflow
{
Expand Down Expand Up @@ -98,6 +100,10 @@ public tensorflow()
"please visit https://github.com/SciSharp/TensorFlow.NET. If it still not work after installing the backend, please submit an " +
"issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}

// register numpy reconstructor for pickle
Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor());
Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor());
}

public string VERSION => c_api.StringPiece(c_api.TF_Version());
Expand Down
Loading

0 comments on commit 179c3f0

Please sign in to comment.