diff --git a/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs b/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs index a0101bc..85c3214 100644 --- a/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs +++ b/LightweightIocContainer.FactoryGenerator/FactoryGenerator.cs @@ -169,6 +169,32 @@ public class FactoryGenerator : IIncrementalGenerator stringBuilder.AppendLine(); stringBuilder.AppendLine("using LightweightIocContainer;"); + + ImmutableArray members = typeSymbol.GetMembers(); + + List namespaces = []; + foreach (ISymbol? member in members) + { + if (member is not IMethodSymbol method) + continue; + + if (!method.ReturnsVoid) + { + if (method.ReturnType.Name == "Task") + { + if (method.ReturnType is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol) + namespaces.AddRange(namedTypeSymbol.TypeArguments.Select(GetNamespaceOfType)); + } + else + namespaces.Add(GetNamespaceOfType(method.ReturnType)); + } + + namespaces.AddRange(method.Parameters.Select(p => GetNamespaceOfType(p.Type))); + } + + foreach (string @namespace in namespaces.Distinct().OfType().OrderBy(n => n)) + stringBuilder.AppendLine($"using {@namespace};"); + stringBuilder.AppendLine(); if (typeNamespace is not null) @@ -180,7 +206,6 @@ public class FactoryGenerator : IIncrementalGenerator stringBuilder.AppendLine($"public class Generated{typeName}(IocContainer container) : {typeName}"); stringBuilder.AppendLine("{"); - ImmutableArray members = typeSymbol.GetMembers(); foreach (ISymbol? member in members) { if (member is not IMethodSymbol method) @@ -268,9 +293,10 @@ public class FactoryGenerator : IIncrementalGenerator return stringBuilder.ToString(); } + private string? GetNamespaceOfType(ITypeSymbol s) => s.ContainingNamespace.IsGlobalNamespace ? null : s.ContainingNamespace.ToString(); private IEnumerable GetNamespacesOfTypes(ImmutableArray types) => types.OfType() - .Select(s => s.ContainingNamespace.IsGlobalNamespace ? null : s.ContainingNamespace.ToString()) + .Select(GetNamespaceOfType) .OfType() .Distinct();