-
Notifications
You must be signed in to change notification settings - Fork 525
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1169 from lingbai-kong/ndarrayload
add: loading pickled npy file for imdb dataset loader
- Loading branch information
Showing
14 changed files
with
546 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace Tensorflow.NumPy.Pickle | ||
{ | ||
public class DTypePickleWarpper | ||
{ | ||
TF_DataType dtype { get; set; } | ||
public DTypePickleWarpper(TF_DataType dtype) | ||
{ | ||
this.dtype = dtype; | ||
} | ||
public void __setstate__(object[] args) { } | ||
public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper) | ||
{ | ||
return dTypeWarpper.dtype; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics.CodeAnalysis; | ||
using System.Text; | ||
using Razorvine.Pickle; | ||
|
||
namespace Tensorflow.NumPy.Pickle | ||
{ | ||
/// <summary> | ||
/// | ||
/// </summary> | ||
[SuppressMessage("ReSharper", "InconsistentNaming")] | ||
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] | ||
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")] | ||
class DtypeConstructor : IObjectConstructor | ||
{ | ||
public object construct(object[] args) | ||
{ | ||
var typeCode = (string)args[0]; | ||
TF_DataType dtype; | ||
if (typeCode == "b1") | ||
dtype = np.@bool; | ||
else if (typeCode == "i1") | ||
dtype = np.@byte; | ||
else if (typeCode == "i2") | ||
dtype = np.int16; | ||
else if (typeCode == "i4") | ||
dtype = np.int32; | ||
else if (typeCode == "i8") | ||
dtype = np.int64; | ||
else if (typeCode == "u1") | ||
dtype = np.ubyte; | ||
else if (typeCode == "u2") | ||
dtype = np.uint16; | ||
else if (typeCode == "u4") | ||
dtype = np.uint32; | ||
else if (typeCode == "u8") | ||
dtype = np.uint64; | ||
else if (typeCode == "f4") | ||
dtype = np.float32; | ||
else if (typeCode == "f8") | ||
dtype = np.float64; | ||
else if (typeCode.StartsWith("S")) | ||
dtype = np.@string; | ||
else if (typeCode.StartsWith("O")) | ||
dtype = np.@object; | ||
else | ||
throw new NotSupportedException(); | ||
return new DTypePickleWarpper(dtype); | ||
} | ||
} | ||
} |
53 changes: 53 additions & 0 deletions
53
src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Diagnostics.CodeAnalysis; | ||
using System.Text; | ||
using Razorvine.Pickle; | ||
using Razorvine.Pickle.Objects; | ||
|
||
namespace Tensorflow.NumPy.Pickle | ||
{ | ||
/// <summary> | ||
/// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if | ||
/// the objects are ints, etc. | ||
/// </summary> | ||
[SuppressMessage("ReSharper", "InconsistentNaming")] | ||
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] | ||
[SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")] | ||
public class MultiArrayConstructor : IObjectConstructor | ||
{ | ||
public object construct(object[] args) | ||
{ | ||
if (args.Length != 3) | ||
throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments."); | ||
|
||
var types = (ClassDictConstructor)args[0]; | ||
if (types.module != "numpy" || types.name != "ndarray") | ||
throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray"); | ||
|
||
var arg1 = (object[])args[1]; | ||
var dims = new int[arg1.Length]; | ||
for (var i = 0; i < arg1.Length; i++) | ||
{ | ||
dims[i] = (int)arg1[i]; | ||
} | ||
var shape = new Shape(dims); | ||
|
||
TF_DataType dtype; | ||
string identifier; | ||
if (args[2].GetType() == typeof(string)) | ||
identifier = (string)args[2]; | ||
else | ||
identifier = Encoding.UTF8.GetString((byte[])args[2]); | ||
switch (identifier) | ||
{ | ||
case "u": dtype = np.uint32; break; | ||
case "c": dtype = np.complex_; break; | ||
case "f": dtype = np.float32; break; | ||
case "b": dtype = np.@bool; break; | ||
default: throw new NotImplementedException($"Unsupported data type: {args[2]}"); | ||
} | ||
return new MultiArrayPickleWarpper(shape, dtype); | ||
} | ||
} | ||
} |
119 changes: 119 additions & 0 deletions
119
src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
using Newtonsoft.Json.Linq; | ||
using Serilog.Debugging; | ||
using System; | ||
using System.Collections; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace Tensorflow.NumPy.Pickle | ||
{ | ||
public class MultiArrayPickleWarpper | ||
{ | ||
public Shape reconstructedShape { get; set; } | ||
public TF_DataType reconstructedDType { get; set; } | ||
public NDArray reconstructedNDArray { get; set; } | ||
public Array reconstructedMultiArray { get; set; } | ||
public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype) | ||
{ | ||
reconstructedShape = shape; | ||
reconstructedDType = dtype; | ||
} | ||
public void __setstate__(object[] args) | ||
{ | ||
if (args.Length != 5) | ||
throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments."); | ||
|
||
var version = (int)args[0]; // version | ||
|
||
var arg1 = (object[])args[1]; | ||
var dims = new int[arg1.Length]; | ||
for (var i = 0; i < arg1.Length; i++) | ||
{ | ||
dims[i] = (int)arg1[i]; | ||
} | ||
var _ShapeLike = new Shape(dims); // shape | ||
|
||
TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType | ||
|
||
var F_continuous = (bool)args[3]; // F-continuous | ||
if (F_continuous) | ||
throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format."); | ||
|
||
var data = args[4]; // Data | ||
/* | ||
* If we ever need another pickle format, increment the version | ||
* number. But we should still be able to handle the old versions. | ||
*/ | ||
if (version < 0 || version > 4) | ||
throw new ValueError($"can't handle version {version} of numpy.dtype pickle"); | ||
|
||
// TODO: Implement the missing details and checks from the official Numpy C code here. | ||
// https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761 | ||
|
||
if (data.GetType() == typeof(ArrayList)) | ||
{ | ||
Reconstruct((ArrayList)data); | ||
} | ||
else | ||
throw new NotImplementedException(""); | ||
} | ||
private void Reconstruct(ArrayList arrayList) | ||
{ | ||
int ndim = 1; | ||
var subArrayList = arrayList; | ||
while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList)) | ||
{ | ||
subArrayList = (ArrayList)subArrayList[0]; | ||
ndim += 1; | ||
} | ||
var type = subArrayList[0].GetType(); | ||
if (type == typeof(int)) | ||
{ | ||
if (ndim == 1) | ||
{ | ||
int[] list = (int[])arrayList.ToArray(typeof(int)); | ||
Shape shape = new Shape(new int[] { arrayList.Count }); | ||
reconstructedMultiArray = list; | ||
reconstructedNDArray = new NDArray(list, shape); | ||
} | ||
if (ndim == 2) | ||
{ | ||
int secondDim = 0; | ||
foreach (ArrayList subArray in arrayList) | ||
{ | ||
secondDim = subArray.Count > secondDim ? subArray.Count : secondDim; | ||
} | ||
int[,] list = new int[arrayList.Count, secondDim]; | ||
for (int i = 0; i < arrayList.Count; i++) | ||
{ | ||
var subArray = (ArrayList?)arrayList[i]; | ||
if (subArray == null) | ||
throw new NullReferenceException(""); | ||
for (int j = 0; j < subArray.Count; j++) | ||
{ | ||
var element = subArray[j]; | ||
if (element == null) | ||
throw new NoNullAllowedException("the element of ArrayList cannot be null."); | ||
list[i, j] = (int)element; | ||
} | ||
} | ||
Shape shape = new Shape(new int[] { arrayList.Count, secondDim }); | ||
reconstructedMultiArray = list; | ||
reconstructedNDArray = new NDArray(list, shape); | ||
} | ||
if (ndim > 2) | ||
throw new NotImplementedException("can't handle ArrayList with more than two dimensions."); | ||
} | ||
else | ||
throw new NotImplementedException(""); | ||
} | ||
public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper) | ||
{ | ||
return arrayWarpper.reconstructedMultiArray; | ||
} | ||
public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper) | ||
{ | ||
return arrayWarpper.reconstructedNDArray; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.