diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index dedd8a9be..21d7279ca 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -768,11 +768,43 @@ private void MarshalRefClass(Class @class) { if (Context.Parameter.IsIndirect) { - Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); - Context.Before.WriteLineIndent( - $@"throw new global::System.ArgumentNullException(""{ - Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); - Context.Return.Write(paramInstance); + Method cctor = @class.HasNonTrivialCopyConstructor ? @class.Methods.First(c => c.IsCopyConstructor) : null; + if (cctor != null && cctor.IsGenerated) + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); + Context.Before.WriteLineIndent( + $@"throw new global::System.ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); + + var nativeClass = typePrinter.PrintNative(@class); + + var cctorName = CSharpSources.GetFunctionNativeIdentifier(Context.Context, cctor); + + var defaultValue = ""; + var TypePrinter = new CSharpTypePrinter(Context.Context); + var ExpressionPrinter = new CSharpExpressionPrinter(TypePrinter); + if (cctor.Parameters.Count > 1) + defaultValue = $", {ExpressionPrinter.VisitParameter(cctor.Parameters.Last())}"; + + Context.Before.WriteLine($"byte* __{Context.Parameter.Name}Memory = stackalloc byte[sizeof({nativeClass})];"); + Context.Before.WriteLine($"__IntPtr __{Context.Parameter.Name}Ptr = (__IntPtr)__{Context.Parameter.Name}Memory;"); + Context.Before.WriteLine($"{nativeClass}.{cctorName}(__{Context.Parameter.Name}Ptr, {Context.Parameter.Name}.__Instance{defaultValue});"); + Context.Return.Write($"__{Context.Parameter.Name}Ptr"); + + if (Context.Context.ParserOptions.IsItaniumLikeAbi && @class.HasNonTrivialDestructor) + { + Method dtor = @class.Destructors.FirstOrDefault(); + if (dtor != null) + { + // todo: virtual destructors? + Context.Cleanup.WriteLine($"{nativeClass}.dtor(__{Context.Parameter.Name}Ptr);"); + } + } + } + else + { + Context.Return.Write(paramInstance); + } } else { diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 30145f003..a102a87c8 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -3464,6 +3464,12 @@ public static string GetFunctionIdentifier(Function function) public string GetFunctionNativeIdentifier(Function function, bool isForDelegate = false) + { + return GetFunctionNativeIdentifier(Context, function, isForDelegate); + } + + public static string GetFunctionNativeIdentifier(BindingContext context, Function function, + bool isForDelegate = false) { var identifier = new StringBuilder(); @@ -3494,12 +3500,12 @@ public string GetFunctionNativeIdentifier(Function function, identifier.Append(Helpers.GetSuffixFor(specialization)); var internalParams = function.GatherInternalParams( - Context.ParserOptions.IsItaniumLikeAbi); + context.ParserOptions.IsItaniumLikeAbi); var overloads = function.Namespace.GetOverloads(function) .Where(f => (!f.Ignore || (f.OriginalFunction != null && !f.OriginalFunction.Ignore)) && (isForDelegate || internalParams.SequenceEqual( - f.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi), + f.GatherInternalParams(context.ParserOptions.IsItaniumLikeAbi), new MarshallingParamComparer()))).ToList(); var index = -1; if (overloads.Count > 1) diff --git a/tests/dotnet/CSharp/CSharp.Tests.cs b/tests/dotnet/CSharp/CSharp.Tests.cs index d35eb6ff7..01a911b53 100644 --- a/tests/dotnet/CSharp/CSharp.Tests.cs +++ b/tests/dotnet/CSharp/CSharp.Tests.cs @@ -1995,4 +1995,19 @@ public void TestPointerToClass() Assert.IsTrue(CSharp.CSharp.PointerToClass.IsDefaultInstance); Assert.IsTrue(CSharp.CSharp.PointerToClass.IsValid); } + + [Test] + public void TestCallByValueCopyConstructor() + { + using (var s = new CallByValueCopyConstructor()) + { + s.A = 500; + CSharp.CSharp.CallByValueCopyConstructorFunction(s); + Assert.That(s.A, Is.EqualTo(500)); + } + + Assert.That(CallByValueCopyConstructor.ConstructorCalls, Is.EqualTo(1)); + Assert.That(CallByValueCopyConstructor.CopyConstructorCalls, Is.EqualTo(1)); + Assert.That(CallByValueCopyConstructor.DestructorCalls, Is.EqualTo(2)); + } } diff --git a/tests/dotnet/CSharp/CSharp.cpp b/tests/dotnet/CSharp/CSharp.cpp index cffbff4bd..71e35935e 100644 --- a/tests/dotnet/CSharp/CSharp.cpp +++ b/tests/dotnet/CSharp/CSharp.cpp @@ -1791,3 +1791,29 @@ bool PointerTester::IsValid() } PointerTester* PointerToClass = &internalPointerTesterInstance; + +int CallByValueCopyConstructor::constructorCalls = 0; +int CallByValueCopyConstructor::destructorCalls = 0; +int CallByValueCopyConstructor::copyConstructorCalls = 0; + +CallByValueCopyConstructor::CallByValueCopyConstructor() +{ + a = 0; + constructorCalls++; +} + +CallByValueCopyConstructor::CallByValueCopyConstructor(const CallByValueCopyConstructor& other) +{ + a = other.a; + copyConstructorCalls++; +} + +CallByValueCopyConstructor::~CallByValueCopyConstructor() +{ + destructorCalls++; +} + +void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s) +{ + s.a = 99999; +} diff --git a/tests/dotnet/CSharp/CSharp.h b/tests/dotnet/CSharp/CSharp.h index 504dd7e2c..15df8463a 100644 --- a/tests/dotnet/CSharp/CSharp.h +++ b/tests/dotnet/CSharp/CSharp.h @@ -1603,3 +1603,16 @@ class DLL_API PointerTester }; DLL_API extern PointerTester* PointerToClass; + +struct DLL_API CallByValueCopyConstructor { + int a; + static int constructorCalls; + static int destructorCalls; + static int copyConstructorCalls; + + CallByValueCopyConstructor(); + ~CallByValueCopyConstructor(); + CallByValueCopyConstructor(const CallByValueCopyConstructor& other); +}; + +DLL_API void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s);