From d79158aaa69db55a8d7e99f3fd046bd67c5422ef Mon Sep 17 00:00:00 2001 From: "Simon G." Date: Mon, 25 Nov 2024 12:57:25 +0100 Subject: [PATCH] - allow async onCreate method by introducing async resolve --- LightweightIocContainer.sln.DotSettings | 11 ++- .../Factories/TypedFactory.cs | 11 ++- .../GenericMethodCaller.cs | 28 ++++++ .../Interfaces/IIocResolver.cs | 15 +++ .../Registrations/Fluent/IOnCreate.cs | 3 + LightweightIocContainer/IocContainer.cs | 86 ++++++++++++++++- .../LightweightIocContainer.xml | 50 ++++++++++ .../Registrations/RegistrationBase.cs | 2 +- .../Registrations/TypedRegistration.cs | 24 +++++ LightweightIocContainer/TypeExtension.cs | 10 ++ .../AsyncFactoryTest.cs | 96 +++++++++++++++++++ Test.LightweightIocContainer/OnCreateTest.cs | 12 +++ 12 files changed, 339 insertions(+), 9 deletions(-) create mode 100644 Test.LightweightIocContainer/AsyncFactoryTest.cs diff --git a/LightweightIocContainer.sln.DotSettings b/LightweightIocContainer.sln.DotSettings index ebbf1ca..fb8d842 100644 --- a/LightweightIocContainer.sln.DotSettings +++ b/LightweightIocContainer.sln.DotSettings @@ -4,16 +4,21 @@ True False 200 - Author: $USER_NAME$ -Created: $CREATED_YEAR$-$CREATED_MONTH$-$CREATED_DAY$ -Copyright(c) $CREATED_YEAR$ SimonG. All Rights Reserved. + Author: ${User.Name} +Created: ${File.CreatedYear}-${File.CreatedMonth}-${File.CreatedDay} +Copyright(c) ${File.CreatedYear} SimonG. All Rights Reserved. <Policy Inspect="True" Prefix="" Suffix="" Style="AaBb" /> <Policy Inspect="True" Prefix="" Suffix="" Style="AA_BB" /> <Policy Inspect="True" Prefix="_" Suffix="" Style="aaBb" /> + <Policy><Descriptor Staticness="Static" AccessRightKinds="Private" Description="Static readonly fields (private)"><ElementKinds><Kind Name="READONLY_FIELD" /></ElementKinds></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="_" Suffix="" Style="aaBb" /></Policy> + <Policy><Descriptor Staticness="Any" AccessRightKinds="Private" Description="Constant fields (private)"><ElementKinds><Kind Name="CONSTANT_FIELD" /></ElementKinds></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB" /></Policy> + <Policy><Descriptor Staticness="Any" AccessRightKinds="Any" Description="Methods"><ElementKinds><Kind Name="METHOD" /></ElementKinds></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AaBb" /></Policy> + True True True True True + True True True True \ No newline at end of file diff --git a/LightweightIocContainer/Factories/TypedFactory.cs b/LightweightIocContainer/Factories/TypedFactory.cs index 8a63b01..9f88148 100644 --- a/LightweightIocContainer/Factories/TypedFactory.cs +++ b/LightweightIocContainer/Factories/TypedFactory.cs @@ -51,7 +51,7 @@ public class TypedFactory : TypedFactoryBase, ITypedFactory< FieldBuilder helperFieldBuilder = typeBuilder.DefineField("_helper", typeof(FactoryHelper), FieldAttributes.Private | FieldAttributes.InitOnly); //add ctor - ConstructorBuilder constructorBuilder = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.HasThis, new[] {typeof(IocContainer), typeof(FactoryHelper)}); + ConstructorBuilder constructorBuilder = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.HasThis, [typeof(IocContainer), typeof(FactoryHelper)]); ILGenerator constructorGenerator = constructorBuilder.GetILGenerator(); constructorGenerator.Emit(OpCodes.Ldarg_0); constructorGenerator.Emit(OpCodes.Ldarg_1); @@ -129,13 +129,18 @@ public class TypedFactory : TypedFactoryBase, ITypedFactory< generator.EmitCall(OpCodes.Call, emptyArray, null); } - generator.EmitCall(OpCodes.Call, typeof(FactoryHelper).GetMethod(nameof(FactoryHelper.ConvertPassedNull), new[] { typeof(MethodBase), typeof(object?[]) })!, null); + generator.EmitCall(OpCodes.Call, typeof(FactoryHelper).GetMethod(nameof(FactoryHelper.ConvertPassedNull), [typeof(MethodBase), typeof(object?[])])!, null); generator.Emit(OpCodes.Stloc_1); generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, containerFieldBuilder); generator.Emit(OpCodes.Ldloc_1); - generator.EmitCall(OpCodes.Call, typeof(IocContainer).GetMethod(nameof(IocContainer.FactoryResolve), new[] { typeof(object?[]) })!.MakeGenericMethod(createMethod.ReturnType), null); + Type? asyncReturnType = createMethod.ReturnType.GetAsyncReturnType(); + if (asyncReturnType is not null) + generator.EmitCall(OpCodes.Call, typeof(IocContainer).GetMethod(nameof(IocContainer.FactoryResolveAsync), [typeof(object?[])])!.MakeGenericMethod(asyncReturnType), null); + else + generator.EmitCall(OpCodes.Call, typeof(IocContainer).GetMethod(nameof(IocContainer.FactoryResolve), [typeof(object?[])])!.MakeGenericMethod(createMethod.ReturnType), null); + generator.Emit(OpCodes.Castclass, createMethod.ReturnType); generator.Emit(OpCodes.Ret); } diff --git a/LightweightIocContainer/GenericMethodCaller.cs b/LightweightIocContainer/GenericMethodCaller.cs index dd400dd..1295c0a 100644 --- a/LightweightIocContainer/GenericMethodCaller.cs +++ b/LightweightIocContainer/GenericMethodCaller.cs @@ -40,6 +40,31 @@ internal static class GenericMethodCaller throw ex.GetBaseException(); } } + + public static async Task CallAsync(object caller, string functionName, Type genericParameter, BindingFlags bindingFlags, params object?[] parameters) + { + MethodInfo? method = caller.GetType().GetMethod(functionName, bindingFlags); + MethodInfo? genericMethod = method?.MakeGenericMethod(genericParameter); + + if (genericMethod == null) + throw new GenericMethodNotFoundException(functionName); + + try //exceptions thrown by methods called with invoke are wrapped into another exception, the exception thrown by the invoked method can be returned by `Exception.GetBaseException()` + { + object? result = genericMethod.Invoke(caller, parameters); + if (result is null) + return null; + + if (result is Task task) + return await task; + + return result; + } + catch (Exception ex) + { + throw ex.GetBaseException(); + } + } /// /// Call a private generic method without generic type parameters @@ -53,4 +78,7 @@ internal static class GenericMethodCaller /// Any thrown after invoking the generic method public static object? CallPrivate(object caller, string functionName, Type genericParameter, params object?[] parameters) => Call(caller, functionName, genericParameter, BindingFlags.NonPublic | BindingFlags.Instance, parameters); + + public static async Task CallPrivateAsync(object caller, string functionName, Type genericParameter, params object?[] parameters) => + await CallAsync(caller, functionName, genericParameter, BindingFlags.NonPublic | BindingFlags.Instance, parameters); } \ No newline at end of file diff --git a/LightweightIocContainer/Interfaces/IIocResolver.cs b/LightweightIocContainer/Interfaces/IIocResolver.cs index 5126f10..248de26 100644 --- a/LightweightIocContainer/Interfaces/IIocResolver.cs +++ b/LightweightIocContainer/Interfaces/IIocResolver.cs @@ -15,6 +15,13 @@ public interface IIocResolver : IDisposable /// The given /// An instance of the given T Resolve(); + + /// + /// Gets an instance of the given + /// + /// The given + /// An instance of the given + Task ResolveAsync(); /// /// Gets an instance of the given @@ -23,4 +30,12 @@ public interface IIocResolver : IDisposable /// The constructor arguments /// An instance of the given T Resolve(params object[] arguments); + + /// + /// Gets an instance of the given + /// + /// The given + /// The constructor arguments + /// An instance of the given + Task ResolveAsync(params object[] arguments); } \ No newline at end of file diff --git a/LightweightIocContainer/Interfaces/Registrations/Fluent/IOnCreate.cs b/LightweightIocContainer/Interfaces/Registrations/Fluent/IOnCreate.cs index 4051348..ed728d3 100644 --- a/LightweightIocContainer/Interfaces/Registrations/Fluent/IOnCreate.cs +++ b/LightweightIocContainer/Interfaces/Registrations/Fluent/IOnCreate.cs @@ -16,6 +16,7 @@ public interface IOnCreate /// Can be set in the by calling /// internal Action? OnCreateAction { get; } + internal Func? OnCreateActionAsync { get; } } /// @@ -31,4 +32,6 @@ public interface IOnCreate : IOnCreate where TImple /// The /// The current instance of this ITypedRegistration OnCreate(Action action); + + ITypedRegistration OnCreateAsync(Func action); } \ No newline at end of file diff --git a/LightweightIocContainer/IocContainer.cs b/LightweightIocContainer/IocContainer.cs index b4ac730..b3a53ed 100644 --- a/LightweightIocContainer/IocContainer.cs +++ b/LightweightIocContainer/IocContainer.cs @@ -104,6 +104,13 @@ public class IocContainer : IIocContainer, IIocResolver /// An instance of the given public virtual T Resolve() => ResolveInternal(null); + /// + /// Gets an instance of the given + /// + /// The given + /// An instance of the given + public Task ResolveAsync() => ResolveInternalAsync(null); + /// /// Gets an instance of the given /// @@ -112,6 +119,14 @@ public class IocContainer : IIocContainer, IIocResolver /// An instance of the given public T Resolve(params object[] arguments) => ResolveInternal(arguments); + /// + /// Gets an instance of the given + /// + /// The given + /// The constructor arguments + /// An instance of the given + public Task ResolveAsync(params object[] arguments) => ResolveInstanceAsync(arguments); + /// /// Gets an instance of the given for a factory /// @@ -119,6 +134,14 @@ public class IocContainer : IIocContainer, IIocResolver /// The constructor arguments /// An instance of the given public T FactoryResolve(params object?[] arguments) => ResolveInternal(arguments, null, true); + + /// + /// Gets an instance of the given for a factory + /// + /// The given + /// The constructor arguments + /// An instance of the given + public Task FactoryResolveAsync(params object?[] arguments) => ResolveInternalAsync(arguments, null, true); /// /// Gets an instance of a given registered @@ -139,6 +162,18 @@ public class IocContainer : IIocContainer, IIocResolver throw new Exception("Resolve Error"); } + + private async Task ResolveInternalAsync(object?[]? arguments, List? resolveStack = null, bool isFactoryResolve = false) + { + (bool success, object resolvedObject, Exception? exception) = TryResolve(arguments, resolveStack, isFactoryResolve); + if (success) + return await ResolveInstanceAsync(resolvedObject); + + if (exception is not null) + throw exception; + + throw new Exception("Resolve Error"); + } /// /// Tries to resolve the given with the given arguments @@ -254,7 +289,7 @@ public class IocContainer : IIocContainer, IIocResolver if (toBeResolvedPlaceholder.Parameters == null) return CreateInstance(toBeResolvedPlaceholder.ResolvedRegistration, null); - List parameters = new(); + List parameters = []; foreach (object? parameter in toBeResolvedPlaceholder.Parameters) { if (parameter != null) @@ -271,6 +306,32 @@ public class IocContainer : IIocContainer, IIocResolver return CreateInstance(toBeResolvedPlaceholder.ResolvedRegistration, parameters.ToArray()); } + private async Task ResolvePlaceholderAsync(InternalToBeResolvedPlaceholder toBeResolvedPlaceholder) + { + object? existingInstance = TryGetExistingInstance(toBeResolvedPlaceholder.ResolvedRegistration, toBeResolvedPlaceholder.Parameters); + if (existingInstance is T instance) + return instance; + + if (toBeResolvedPlaceholder.Parameters == null) + return await CreateInstanceAsync(toBeResolvedPlaceholder.ResolvedRegistration, null); + + List parameters = []; + foreach (object? parameter in toBeResolvedPlaceholder.Parameters) + { + if (parameter != null) + { + Type type = parameter is IInternalToBeResolvedPlaceholder internalToBeResolvedPlaceholder ? + internalToBeResolvedPlaceholder.ResolvedType : parameter.GetType(); + + parameters.Add(await ResolveInstanceNonGenericAsync(type, parameter)); + } + else + parameters.Add(parameter); + } + + return await CreateInstanceAsync(toBeResolvedPlaceholder.ResolvedRegistration, parameters.ToArray()); + } + /// /// Resolve the given object instance /// @@ -286,6 +347,15 @@ public class IocContainer : IIocContainer, IIocResolver InternalFactoryMethodPlaceholder factoryMethodPlaceholder => CreateInstance(factoryMethodPlaceholder.SingleTypeRegistration, null), _ => throw new InternalResolveException("Resolve returned wrong type.") }; + + private async Task ResolveInstanceAsync(object resolvedObject) => + resolvedObject switch + { + T instance => instance, + InternalToBeResolvedPlaceholder toBeResolvedPlaceholder => await ResolvePlaceholderAsync(toBeResolvedPlaceholder), + InternalFactoryMethodPlaceholder factoryMethodPlaceholder => await CreateInstanceAsync(factoryMethodPlaceholder.SingleTypeRegistration, null), + _ => throw new InternalResolveException("Resolve returned wrong type.") + }; /// /// Resolve the given object instance without generic arguments @@ -296,6 +366,9 @@ public class IocContainer : IIocContainer, IIocResolver /// Resolve returned wrong type private object? ResolveInstanceNonGeneric(Type type, object resolvedObject) => GenericMethodCaller.CallPrivate(this, nameof(ResolveInstance), type, resolvedObject); + + private Task ResolveInstanceNonGenericAsync(Type type, object resolvedObject) => + GenericMethodCaller.CallPrivateAsync(this, nameof(ResolveInstance), type, resolvedObject); /// /// Creates an instance of a given @@ -326,7 +399,16 @@ public class IocContainer : IIocContainer, IIocResolver _singletons.Add((GetType(registration), instance)); if (registration is IOnCreate onCreateRegistration) - onCreateRegistration.OnCreateAction?.Invoke(instance); //TODO: Allow async OnCreateAction? + onCreateRegistration.OnCreateAction?.Invoke(instance); + + return instance; + } + + private async Task CreateInstanceAsync(IRegistration registration, object?[]? arguments) + { + T instance = CreateInstance(registration, arguments); + if (registration is IOnCreate { OnCreateActionAsync: not null } onCreateRegistration) + await onCreateRegistration.OnCreateActionAsync.Invoke(instance); return instance; } diff --git a/LightweightIocContainer/LightweightIocContainer.xml b/LightweightIocContainer/LightweightIocContainer.xml index c5d2b02..c90c1b9 100644 --- a/LightweightIocContainer/LightweightIocContainer.xml +++ b/LightweightIocContainer/LightweightIocContainer.xml @@ -551,6 +551,13 @@ The given An instance of the given + + + Gets an instance of the given + + The given + An instance of the given + Gets an instance of the given @@ -559,6 +566,14 @@ The constructor arguments An instance of the given + + + Gets an instance of the given + + The given + The constructor arguments + An instance of the given + An that installs all s for its given @@ -975,6 +990,13 @@ The given An instance of the given + + + Gets an instance of the given + + The given + An instance of the given + Gets an instance of the given @@ -983,6 +1005,14 @@ The constructor arguments An instance of the given + + + Gets an instance of the given + + The given + The constructor arguments + An instance of the given + Gets an instance of the given for a factory @@ -991,6 +1021,14 @@ The constructor arguments An instance of the given + + + Gets an instance of the given for a factory + + The given + The constructor arguments + An instance of the given + Gets an instance of a given registered @@ -1856,12 +1894,24 @@ Can be set in the by calling + + + This is invoked when an instance of this type is created. + Can be set in the by calling + + This is invoked when an instance of this type is created. Can be set in the by calling + + + This is invoked when an instance of this type is created. + Can be set in the by calling + + Pass an that will be invoked when an instance of this type is created diff --git a/LightweightIocContainer/Registrations/RegistrationBase.cs b/LightweightIocContainer/Registrations/RegistrationBase.cs index 38364e0..e131689 100644 --- a/LightweightIocContainer/Registrations/RegistrationBase.cs +++ b/LightweightIocContainer/Registrations/RegistrationBase.cs @@ -176,7 +176,7 @@ internal abstract class RegistrationBase : IRegistrationBase, IWithFactoryIntern if (Factory == null) return; - if (Factory.CreateMethods.All(c => c.ReturnType != InterfaceType)) + if (Factory.CreateMethods.All(c => c.ReturnType != InterfaceType) && Factory.CreateMethods.All(c => c.ReturnType.GetAsyncReturnType() != InterfaceType)) throw new InvalidFactoryRegistrationException($"No create method that can create {InterfaceType}."); } diff --git a/LightweightIocContainer/Registrations/TypedRegistration.cs b/LightweightIocContainer/Registrations/TypedRegistration.cs index 4ec2fcd..ae5f32a 100644 --- a/LightweightIocContainer/Registrations/TypedRegistration.cs +++ b/LightweightIocContainer/Registrations/TypedRegistration.cs @@ -36,11 +36,23 @@ internal class TypedRegistration : RegistrationBase /// private Action? OnCreateAction { get; set; } + /// + /// This is invoked when an instance of this type is created. + /// Can be set in the by calling + /// + private Func? OnCreateActionAsync { get; set; } + /// /// This is invoked when an instance of this type is created. /// Can be set in the by calling /// Action? IOnCreate.OnCreateAction => OnCreateAction; + + /// + /// This is invoked when an instance of this type is created. + /// Can be set in the by calling + /// + Func? IOnCreate.OnCreateActionAsync => OnCreateActionAsync; /// /// Pass an that will be invoked when an instance of this type is created @@ -53,6 +65,12 @@ internal class TypedRegistration : RegistrationBase return this; } + public ITypedRegistration OnCreateAsync(Func action) + { + OnCreateActionAsync = a => action((TImplementation?) a); + return this; + } + /// /// Validate the for the and /// @@ -71,6 +89,12 @@ internal class TypedRegistration : RegistrationBase if (OnCreateAction != null && typedRegistration.OnCreateAction == null) return false; + + if (OnCreateActionAsync == null && typedRegistration.OnCreateActionAsync != null) + return false; + + if (OnCreateActionAsync != null && typedRegistration.OnCreateActionAsync == null) + return false; return ImplementationType == typedRegistration.ImplementationType; } diff --git a/LightweightIocContainer/TypeExtension.cs b/LightweightIocContainer/TypeExtension.cs index 2f3e411..46347be 100644 --- a/LightweightIocContainer/TypeExtension.cs +++ b/LightweightIocContainer/TypeExtension.cs @@ -12,4 +12,14 @@ internal static class TypeExtension /// The given /// The default value for the given public static object? GetDefault(this Type type) => type.IsValueType ? Activator.CreateInstance(type) : null; + public static Type? GetAsyncReturnType(this Type type) + { + if (!type.IsGenericType) + return null; + + if (type.GetGenericTypeDefinition() != typeof(Task<>)) + return null; + + return type.GenericTypeArguments[0]; + } } \ No newline at end of file diff --git a/Test.LightweightIocContainer/AsyncFactoryTest.cs b/Test.LightweightIocContainer/AsyncFactoryTest.cs new file mode 100644 index 0000000..c68aa82 --- /dev/null +++ b/Test.LightweightIocContainer/AsyncFactoryTest.cs @@ -0,0 +1,96 @@ +// Author: simon.gockner +// Created: 2024-11-25 +// Copyright(c) 2024 SimonG. All Rights Reserved. + +using LightweightIocContainer; +using NUnit.Framework; + +namespace Test.LightweightIocContainer; + +[TestFixture] +public class AsyncFactoryTest +{ + public interface ITest + { + bool IsInitialized { get; } + Task Initialize(); + } + + public class Test : ITest + { + public bool IsInitialized { get; private set; } + + public virtual async Task Initialize() + { + await Task.Delay(200); + IsInitialized = true; + } + } + + public class MultitonTest(int id) : Test + { + public int Id { get; } = id; + public override async Task Initialize() + { + if (IsInitialized) + throw new Exception(); + + await base.Initialize(); + } + } + + public interface ITestFactory + { + Task Create(); + } + + public interface IMultitonTestFactory + { + Task Create(int id); + } + + [Test] + public async Task TestAsyncFactoryResolve() + { + IocContainer container = new(); + container.Register(r => r.Add().WithFactory()); + + ITestFactory testFactory = container.Resolve(); + ITest test = await testFactory.Create(); + + Assert.IsInstanceOf(test); + } + + [Test] + public async Task TestAsyncFactoryResolveOnCreateCalled() + { + IocContainer container = new(); + container.Register(r => r.Add().OnCreateAsync(t => t.Initialize()).WithFactory()); + + ITestFactory testFactory = container.Resolve(); + ITest test = await testFactory.Create(); + + Assert.IsInstanceOf(test); + Assert.That(test.IsInitialized, Is.True); + } + + [Test] + public async Task TestAsyncMultitonFactoryResolveOnCreateCalledCorrectly() + { + IocContainer container = new(); + container.Register(r => r.AddMultiton().OnCreateAsync(t => t.Initialize()).WithFactory()); + + IMultitonTestFactory testFactory = container.Resolve(); + ITest test1 = await testFactory.Create(1); + ITest test2 = await testFactory.Create(2); + ITest anotherTest1 = await testFactory.Create(1); + + Assert.IsInstanceOf(test1); + Assert.That(test1.IsInitialized, Is.True); + + Assert.IsInstanceOf(test2); + Assert.That(test2.IsInitialized, Is.True); + + Assert.AreSame(test1, anotherTest1); + } +} \ No newline at end of file diff --git a/Test.LightweightIocContainer/OnCreateTest.cs b/Test.LightweightIocContainer/OnCreateTest.cs index 144b4a3..4fad5b3 100644 --- a/Test.LightweightIocContainer/OnCreateTest.cs +++ b/Test.LightweightIocContainer/OnCreateTest.cs @@ -21,6 +21,7 @@ public class OnCreateTest private class Test : ITest { public void DoSomething() => throw new Exception(); + public Task InitializeAsync() => throw new Exception(); } @@ -34,4 +35,15 @@ public class OnCreateTest Assert.Throws(() => testRegistration.OnCreateAction!(test)); } + + [Test] + public void TestOnCreateAsync() + { + RegistrationFactory registrationFactory = new(Substitute.For()); + ITypedRegistration testRegistration = registrationFactory.Register(Lifestyle.Transient).OnCreateAsync(t => t.InitializeAsync()); + + Test test = new(); + + Assert.Throws(() => testRegistration.OnCreateActionAsync!(test)); + } } \ No newline at end of file