diff --git a/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs b/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs index 4598053..09216c5 100644 --- a/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs +++ b/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs @@ -178,20 +178,10 @@ public class FactoryGenerator : IIncrementalGenerator if (member is not IMethodSymbol method) continue; - if (!method.ReturnsVoid) - { - if (method.ReturnType is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol) - { - namespaces.AddRange(namedTypeSymbol.TypeArguments.Select(GetNamespaceOfType)); - - if (method.ReturnType.Name != "Task") - namespaces.Add(GetNamespaceOfType(method.ReturnType)); - } - else - namespaces.Add(GetNamespaceOfType(method.ReturnType)); - } + if (!method.ReturnsVoid) + namespaces.AddRange(GetNamespacesOfType(method.ReturnType)); - namespaces.AddRange(method.Parameters.Select(p => GetNamespaceOfType(p.Type))); + namespaces.AddRange(method.Parameters.SelectMany(p => GetNamespacesOfType(p.Type))); } foreach (string @namespace in namespaces.Distinct().OfType().OrderBy(n => n)) @@ -270,9 +260,25 @@ public class FactoryGenerator : IIncrementalGenerator } private string? GetNamespaceOfType(ITypeSymbol s) => s.ContainingNamespace.IsGlobalNamespace ? null : s.ContainingNamespace.ToString(); + private List GetNamespacesOfType(ITypeSymbol typeSymbol) + { + List namespaces = []; + if (typeSymbol is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol) + { + namespaces.AddRange(namedTypeSymbol.TypeArguments.SelectMany(GetNamespacesOfType)); + + if (typeSymbol.Name != "Task") + namespaces.Add(GetNamespaceOfType(typeSymbol)); + } + else + namespaces.Add(GetNamespaceOfType(typeSymbol)); + + return namespaces; + } + private IEnumerable GetNamespacesOfTypes(ImmutableArray types) => types.OfType() - .Select(GetNamespaceOfType) + .SelectMany(GetNamespacesOfType) .OfType() .Distinct();