diff --git a/LightweightIocContainer/Factories/TypedFactory.cs b/LightweightIocContainer/Factories/TypedFactory.cs index b0c534b..8a63b01 100644 --- a/LightweightIocContainer/Factories/TypedFactory.cs +++ b/LightweightIocContainer/Factories/TypedFactory.cs @@ -75,6 +75,26 @@ public class TypedFactory : TypedFactoryBase, ITypedFactory< MethodBuilder methodBuilder = typeBuilder.DefineMethod(createMethod.Name, MethodAttributes.Public | MethodAttributes.Virtual, createMethod.ReturnType, (from arg in args select arg.ParameterType).ToArray()); + + if (createMethod.IsGenericMethod) + { + Type[] genericArguments = createMethod.GetGenericMethodDefinition().GetGenericArguments(); + string[] genericArgumentNames = genericArguments.Select(a => a.Name).ToArray(); + + GenericTypeParameterBuilder[] genericParameters = methodBuilder.DefineGenericParameters(genericArgumentNames); + + foreach (GenericTypeParameterBuilder genericParameter in genericParameters) + { + Type genericType = genericArguments.First(a => a.Name == genericParameter.Name); + genericParameter.SetGenericParameterAttributes(genericType.GenericParameterAttributes); + + genericParameter.SetInterfaceConstraints(genericType.GetGenericParameterConstraints().Where(c => c.IsInterface).ToArray()); + Type? typeConstraint = genericType.GetGenericParameterConstraints().FirstOrDefault(c => !c.IsInterface); + if (typeConstraint is not null) + genericParameter.SetBaseTypeConstraint(typeConstraint); + } + } + typeBuilder.DefineMethodOverride(methodBuilder, createMethod); ILGenerator generator = methodBuilder.GetILGenerator(); diff --git a/LightweightIocContainer/Registrations/OpenGenericRegistration.cs b/LightweightIocContainer/Registrations/OpenGenericRegistration.cs index d2800ba..37c85cb 100644 --- a/LightweightIocContainer/Registrations/OpenGenericRegistration.cs +++ b/LightweightIocContainer/Registrations/OpenGenericRegistration.cs @@ -42,7 +42,16 @@ internal class OpenGenericRegistration : RegistrationBase, IOpenGenericRegistrat base.Validate(); } + + protected override void ValidateFactory() + { + if (Factory == null) + return; + if (Factory.CreateMethods.All(c => c.ReturnType.Name != InterfaceType.Name)) + throw new InvalidFactoryRegistrationException($"No create method that can create {InterfaceType}."); + } + public override bool Equals(object? obj) => obj is OpenGenericRegistration openGenericRegistration && base.Equals(obj) && ImplementationType == openGenericRegistration.ImplementationType; diff --git a/Test.LightweightIocContainer/OpenGenericRegistrationTest.cs b/Test.LightweightIocContainer/OpenGenericRegistrationTest.cs index c9dcac0..c42ccbc 100644 --- a/Test.LightweightIocContainer/OpenGenericRegistrationTest.cs +++ b/Test.LightweightIocContainer/OpenGenericRegistrationTest.cs @@ -14,19 +14,36 @@ namespace Test.LightweightIocContainer; public class OpenGenericRegistrationTest { private IocContainer _iocContainer; + + [UsedImplicitly] + public interface IConstraint + { + + } + + [UsedImplicitly] + public class Constraint : IConstraint + { + + } [UsedImplicitly] [SuppressMessage("ReSharper", "UnusedTypeParameter")] - public interface ITest + public interface ITest where T : IConstraint, new() { } [UsedImplicitly] - public class Test : ITest + public class Test : ITest where T : IConstraint, new() { } + + public interface ITestFactory + { + ITest Create() where T : IConstraint, new(); + } [SetUp] public void SetUp() => _iocContainer = new IocContainer(); @@ -39,7 +56,7 @@ public class OpenGenericRegistrationTest { _iocContainer.Register(r => r.AddOpenGenerics(typeof(ITest<>), typeof(Test<>))); - ITest test = _iocContainer.Resolve>(); + ITest test = _iocContainer.Resolve>(); Assert.NotNull(test); } @@ -48,10 +65,10 @@ public class OpenGenericRegistrationTest { _iocContainer.Register(r => r.AddOpenGenerics(typeof(ITest<>), typeof(Test<>), Lifestyle.Singleton)); - ITest test = _iocContainer.Resolve>(); + ITest test = _iocContainer.Resolve>(); Assert.NotNull(test); - ITest secondTest = _iocContainer.Resolve>(); + ITest secondTest = _iocContainer.Resolve>(); Assert.NotNull(secondTest); Assert.AreEqual(test, secondTest); @@ -65,4 +82,13 @@ public class OpenGenericRegistrationTest [Test] public void TestRegisterNonOpenGenericTypeWithOpenGenericsFunctionThrowsException() => Assert.Throws(() => _iocContainer.Register(r => r.AddOpenGenerics(typeof(int), typeof(int)))); + + [Test] + public void TestRegisterFactoryOfOpenGenericType() + { + _iocContainer.Register(r => r.AddOpenGenerics(typeof(ITest<>), typeof(Test<>)).WithFactory()); + ITestFactory testFactory = _iocContainer.Resolve(); + ITest test = testFactory.Create(); + Assert.IsInstanceOf>(test); + } } \ No newline at end of file