Skip to content

Commit

Permalink
feat(lib): Improve handling of types with generic arguments #123
Browse files Browse the repository at this point in the history
  • Loading branch information
PerfectlyNormal committed Jan 13, 2025
1 parent 36202a1 commit 0f22d42
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 22 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

### Added

- Better support for generic types (#123)

## [0.16.0] - 2024-12-17

### Added
Expand Down
32 changes: 32 additions & 0 deletions TypeContractor.Tests/TypeScript/TypeScriptConverterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,26 @@ public void Handles_Nullable_Records_Inside_Other_Records()
second.IsNullable.Should().BeTrue();
}

[Fact]
public void Handles_Generics()
{
var result = Sut.Convert(typeof(ResponseWithOverrides));

result.Should().NotBeNull();
result.Properties.Should().NotBeNull();
result.Properties!.Should().HaveCount(2);
result.Properties!.First().DestinationType.Should().Be("Overridable<string>");
result.Properties!.Last().DestinationType.Should().Be("Overridable<boolean>");

Sut.CustomMappedTypes.Should().ContainSingle();
var overridableType = Sut.CustomMappedTypes.First().Value;
overridableType.Properties.Should().HaveCount(2);
overridableType.Properties!.First().DestinationName.Should().Be("value");
overridableType.Properties!.First().DestinationType.Should().Be("T");
overridableType.Properties!.Last().DestinationName.Should().Be("isOverridden");
overridableType.Properties!.Last().DestinationType.Should().Be("boolean");
}

#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
private record TopLevelRecord(string Name, SecondStoryRecord? SecondStoryRecord);
private record SecondStoryRecord(string Description, SomeOtherDeeplyNestedRecord? SomeOtherDeeplyNestedRecord);
Expand Down Expand Up @@ -466,6 +486,18 @@ private class TimeOnlyResponse
public TimeOnly MeetingTime { get; set; }
}

private class Overridable<T>
{
public T? Value { get; set; }
public bool IsOverridden { get; set; }
}

private class ResponseWithOverrides
{
public Overridable<string> Name { get; set; }
public Overridable<bool?> SomeBool { get; set; }
}

#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.

private MetadataLoadContext BuildMetadataLoadContext()
Expand Down
50 changes: 50 additions & 0 deletions TypeContractor.Tests/TypeScript/TypeScriptWriterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,42 @@ public void Can_Write_Simple_Types()
.And.Contain("someObject: any;");
}

[Fact]
public void Can_Write_Generic_Types()
{
// Arrange
var outputTypes = BuildOutputTypes(typeof(ResponseWithOverrides));

// Act
var responseResult = Sut.Write(outputTypes.First(x => x.Name == "ResponseWithOverrides"), outputTypes, true);
var overrideResult = Sut.Write(outputTypes.First(x => x.Name == "Overridable"), outputTypes, true);

// Assert
var responseFile = File.ReadAllLines(responseResult).Select(x => x.TrimStart());
responseFile.Should()
.NotBeEmpty()
.And.Contain("import { Overridable, OverridableSchema } from './Overridable';")

.And.Contain("export interface ResponseWithOverrides {")
.And.Contain("name: Overridable<string>;")
.And.Contain("someBool: Overridable<boolean>;")

.And.Contain("export const ResponseWithOverridesSchema = z.object({")
.And.Contain("name: OverridableSchema,")
.And.Contain("someBool: OverridableSchema,");

var overrideFile = File.ReadAllLines(overrideResult).Select(x => x.TrimStart());
overrideFile.Should()
.NotBeEmpty()
.And.Contain("export interface Overridable<T> {")
.And.Contain("value?: T;")
.And.Contain("isOverridden: boolean;")

.And.Contain("export const OverridableSchema = z.object({")
.And.Contain("value: z.any().nullable(),")
.And.Contain("isOverridden: z.boolean(),");
}

[Fact]
public void Handles_Dictionary_With_Complex_Values()
{
Expand Down Expand Up @@ -483,3 +519,17 @@ public void Dispose()
_outputDirectory.Delete(true);
}
}

#region Test input
public class Overridable<T>
{
public T? Value { get; set; }
public bool IsOverridden { get; set; }
}

public class ResponseWithOverrides
{
public Overridable<string> Name { get; set; }

Check warning on line 532 in TypeContractor.Tests/TypeScript/TypeScriptWriterTests.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable property 'Name' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the property as nullable.

Check warning on line 532 in TypeContractor.Tests/TypeScript/TypeScriptWriterTests.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable property 'Name' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the property as nullable.
public Overridable<bool?> SomeBool { get; set; }

Check warning on line 533 in TypeContractor.Tests/TypeScript/TypeScriptWriterTests.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable property 'SomeBool' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the property as nullable.

Check warning on line 533 in TypeContractor.Tests/TypeScript/TypeScriptWriterTests.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable property 'SomeBool' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the property as nullable.
}
#endregion
25 changes: 23 additions & 2 deletions TypeContractor/Output/DestinationType.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
namespace TypeContractor.Output;

public record DestinationType(string TypeName, string? FullName, string ImportType, bool IsBuiltin, bool IsArray, bool IsReadonly, bool IsNullable, Type? InnerType)
public record DestinationType(
string TypeName,
string? FullName,
string ImportType,
bool IsBuiltin,
bool IsArray,
bool IsReadonly,
bool IsNullable,
bool IsGeneric,
ICollection<DestinationType> GenericTypeArguments,
Type? SourceType,
Type? InnerType)
{
public DestinationType(string typeName, string? fullName, bool isBuiltin, bool isArray, bool isReadonly, bool isNullable, Type? innerType, string? importType = null) : this(typeName, fullName, importType ?? typeName, isBuiltin, isArray, isReadonly, isNullable, innerType)
public DestinationType(string typeName,
string? fullName,
bool isBuiltin,
bool isArray,
bool isReadonly,
bool isNullable,
bool isGeneric,
ICollection<DestinationType> genericTypeArguments,
Type? innerType,
Type? sourceType,
string? importType = null) : this(typeName, fullName, importType ?? typeName, isBuiltin, isArray, isReadonly, isNullable, isGeneric, genericTypeArguments, sourceType, innerType)
{
}

Expand Down
19 changes: 18 additions & 1 deletion TypeContractor/Output/OutputProperty.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
namespace TypeContractor.Output;

public class OutputProperty(string sourceName, Type sourceType, Type? innerSourceType, string destinationName, string destinationType, string importType, bool isBuiltin, bool isArray, bool isNullable, bool isReadonly)
public class OutputProperty(
string sourceName,
Type sourceType,
Type? innerSourceType,
string destinationName,
string destinationType,
string importType,
bool isBuiltin,
bool isArray,
bool isNullable,
bool isReadonly,
bool isGeneric,
ICollection<DestinationType> genericTypeArguments)
{
public string SourceName { get; set; } = sourceName;
public Type SourceType { get; set; } = sourceType;
Expand All @@ -12,6 +24,8 @@ public class OutputProperty(string sourceName, Type sourceType, Type? innerSourc
public bool IsArray { get; set; } = isArray;
public bool IsNullable { get; set; } = isNullable;
public bool IsReadonly { get; set; } = isReadonly;
public bool IsGeneric { get; set; } = isGeneric;
public ICollection<DestinationType> GenericTypeArguments { get; } = genericTypeArguments;
public ObsoleteInfo? Obsolete { get; set; }

/// <summary>
Expand All @@ -37,6 +51,8 @@ public override bool Equals(object? obj)
IsArray == property.IsArray &&
IsNullable == property.IsNullable &&
IsReadonly == property.IsReadonly &&
IsGeneric == property.IsGeneric &&
GenericTypeArguments.SequenceEqual(property.GenericTypeArguments) &&
EqualityComparer<ObsoleteInfo?>.Default.Equals(Obsolete, property.Obsolete);
}

Expand All @@ -53,6 +69,7 @@ public override int GetHashCode()
hash.Add(IsArray);
hash.Add(IsNullable);
hash.Add(IsReadonly);
hash.Add(IsGeneric);
hash.Add(Obsolete);
return hash.ToHashCode();
}
Expand Down
13 changes: 11 additions & 2 deletions TypeContractor/Output/OutputType.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
using System.Globalization;
using System.Globalization;
using System.Text;

namespace TypeContractor.Output;

public record OutputType(string Name, string FullName, string FileName, ContractedType ContractedType, bool IsEnum, ICollection<OutputProperty>? Properties, ICollection<OutputEnumMember>? EnumMembers)
public record OutputType(
string Name,
string FullName,
string FileName,
ContractedType ContractedType,
bool IsEnum,
bool IsGeneric,
ICollection<DestinationType> GenericTypeArguments,
ICollection<OutputProperty>? Properties,
ICollection<OutputEnumMember>? EnumMembers)
{
public override string ToString()
{
Expand Down
61 changes: 48 additions & 13 deletions TypeContractor/TypeScript/TypeScriptConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ public OutputType Convert(Type type, ContractedType? contractedType = null)
{
ArgumentNullException.ThrowIfNull(type);

var typeName = type.Name.Split('`').First();

return new(
type.Name,
typeName,
type.FullName!,
CasingHelpers.ToCasing(type.Name.Replace("_", ""), configuration.Casing),
contractedType ?? ContractedType.FromName(type.FullName!, type, configuration),
CasingHelpers.ToCasing(typeName.Replace("_", ""), configuration.Casing),
contractedType ?? ContractedType.FromName(type.FullName ?? typeName, type, configuration),
type.IsEnum,
type.IsGenericType,
type.IsGenericType ? ((TypeInfo)type).GenericTypeParameters.Select(x => GetDestinationType(x, [], false, TypeChecks.IsNullable(x))).ToList() : [],
type.IsEnum ? null : GetProperties(type).Distinct().ToList(),
type.IsEnum ? GetEnumProperties(type) : null
);
Expand Down Expand Up @@ -71,7 +75,19 @@ private List<OutputProperty> GetProperties(Type type)

var destinationName = GetDestinationName(property.Name);
var destinationType = GetDestinationType(property.PropertyType, property.CustomAttributes, isReadonly, TypeChecks.IsNullable(property.PropertyType));
var outputProperty = new OutputProperty(property.Name, property.PropertyType, destinationType.InnerType, destinationName, destinationType.TypeName, destinationType.ImportType, destinationType.IsBuiltin, destinationType.IsArray, TypeChecks.IsNullable(property), destinationType.IsReadonly);
var outputProperty = new OutputProperty(
property.Name,
property.PropertyType,
destinationType.InnerType,
destinationName,
destinationType.TypeName,
destinationType.ImportType,
destinationType.IsBuiltin,
destinationType.IsArray,
TypeChecks.IsNullable(property),
destinationType.IsReadonly,
destinationType.IsGeneric,
destinationType.GenericTypeArguments);

var obsolete = property.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == "System.ObsoleteAttribute");
outputProperty.Obsolete = obsolete is not null ? new ObsoleteInfo((string?)obsolete.ConstructorArguments.FirstOrDefault().Value) : null;
Expand All @@ -94,11 +110,14 @@ private List<OutputProperty> GetProperties(Type type)

public DestinationType GetDestinationType(in Type sourceType, IEnumerable<CustomAttributeData> customAttributes, bool isReadonly, bool isNullable)
{
if (configuration.TypeMaps.TryGetValue(sourceType.FullName!, out var destType))
return new DestinationType(destType.Replace("[]", string.Empty), sourceType.FullName, true, destType.Contains("[]"), isReadonly, isNullable || TypeChecks.IsNullable(sourceType), null);
if (!sourceType.IsGenericParameter && configuration.TypeMaps.TryGetValue(sourceType.FullName!, out var destType))
return new DestinationType(destType.Replace("[]", string.Empty), sourceType.FullName, true, destType.Contains("[]"), isReadonly, isNullable || TypeChecks.IsNullable(sourceType), false, [], null, sourceType);

if (CustomMappedTypes.TryGetValue(sourceType, out var customType))
return new DestinationType(customType.Name, customType.FullName, false, false, isReadonly, TypeChecks.IsNullable(sourceType), null);
return new DestinationType(customType.Name, customType.FullName, false, false, isReadonly, TypeChecks.IsNullable(sourceType), customType.IsGeneric, customType.GenericTypeArguments, null, customType.ContractedType.Type);

if (sourceType.IsGenericTypeParameter)
return new DestinationType(sourceType.Name, null, true, false, false, isNullable, true, [], null, sourceType, "");

if (TypeChecks.ImplementsIDictionary(sourceType))
{
Expand All @@ -108,15 +127,15 @@ public DestinationType GetDestinationType(in Type sourceType, IEnumerable<Custom

var isBuiltin = keyType.IsBuiltin && valueDestinationType.IsBuiltin;

return new DestinationType($"{{ [key: {keyType.TypeName}]: {valueDestinationType.FullTypeName} }}", valueDestinationType.FullName, isBuiltin, false, isReadonly, valueDestinationType.IsNullable, valueType, valueDestinationType.ImportType);
return new DestinationType($"{{ [key: {keyType.TypeName}]: {valueDestinationType.FullTypeName} }}", valueDestinationType.FullName, isBuiltin, false, isReadonly, valueDestinationType.IsNullable, valueDestinationType.IsGeneric, valueDestinationType.GenericTypeArguments, valueType, valueDestinationType.SourceType, valueDestinationType.ImportType);
}

if (TypeChecks.ImplementsIEnumerable(sourceType))
{
var innerType = TypeChecks.GetGenericType(sourceType);

var (TypeName, FullName, _, IsBuiltin, _, IsReadonly, IsNullable, _) = GetDestinationType(innerType, customAttributes, isReadonly, isNullable);
return new DestinationType(TypeName, FullName, IsBuiltin, true, IsReadonly, IsNullable, innerType);
var (TypeName, FullName, _, IsBuiltin, _, IsReadonly, IsNullable, IsGeneric, _, _, _) = GetDestinationType(innerType, customAttributes, isReadonly, isNullable);
return new DestinationType(TypeName, FullName, IsBuiltin, true, IsReadonly, IsNullable, IsGeneric, [], innerType, sourceType);
}

if (TypeChecks.IsValueTuple(sourceType))
Expand All @@ -128,21 +147,37 @@ public DestinationType GetDestinationType(in Type sourceType, IEnumerable<Custom
var argumentList = argumentDestinationTypes.Select((arg, idx) => $"item{idx + 1}: {arg.FullTypeName}");
var typeName = $"{{ {string.Join(", ", argumentList)} }}";

return new DestinationType(typeName, sourceType.FullName, isBuiltin, false, isReadonly, false, null);
return new DestinationType(typeName, sourceType.FullName, isBuiltin, false, isReadonly, false, false, [], null, sourceType);
}

if (TypeChecks.IsNullable(sourceType))
{
return GetDestinationType(sourceType.GenericTypeArguments.First(), customAttributes, isReadonly, true);
}

if (sourceType.IsGenericType && sourceType.GenericTypeArguments.Length > 0)
{
var genericType = sourceType.GetGenericTypeDefinition();
var genericOutputType = Convert(genericType);
CustomMappedTypes.TryAdd(genericType, genericOutputType);

var genericArguments = sourceType.GenericTypeArguments
.Select(x => GetDestinationType(x, customAttributes, isReadonly, TypeChecks.IsNullable(x)))
.ToList();

var importType = genericOutputType.Name.Split('`').First();
var typeName = importType + $"<{string.Join(", ", genericArguments.Select(x => x.TypeName))}>";

return new DestinationType(typeName, genericOutputType.FullName, false, false, isReadonly, isNullable, true, genericArguments, null, genericOutputType.ContractedType.Type, importType);
}

if (customAttributes.Any(x => x.AttributeType.FullName == "System.Runtime.CompilerServices.DynamicAttribute"))
return new DestinationType(DestinationTypes.Dynamic, null, true, false, isReadonly, true, null);
return new DestinationType(DestinationTypes.Dynamic, null, true, false, isReadonly, true, false, [], null, null);

// FIXME: Check if this is one of our types?
var outputType = Convert(sourceType);
CustomMappedTypes.Add(sourceType, outputType);
return new DestinationType(outputType.Name, outputType.FullName, false, false, isReadonly, isNullable || TypeChecks.IsNullable(sourceType), null);
return new DestinationType(outputType.Name, outputType.FullName, false, false, isReadonly, isNullable || TypeChecks.IsNullable(sourceType), outputType.IsGeneric, outputType.GenericTypeArguments, null, sourceType);

// throw new ArgumentException($"Unexpected type: {sourceType}");
}
Expand Down
29 changes: 27 additions & 2 deletions TypeContractor/TypeScript/TypeScriptWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,25 @@ private void BuildHeader()

private void BuildImports(OutputType type, IEnumerable<OutputType> allTypes, bool buildZodSchema)
{
var properties = type.Properties ?? Enumerable.Empty<OutputProperty>();
var properties = type.Properties ?? [];
var imports = properties
.Where(p => !p.IsBuiltin)
.DistinctBy(p => p.InnerSourceType ?? p.SourceType)
.ToList();

foreach (var property in properties)
{
if (!property.IsGeneric) continue;
if (property.GenericTypeArguments.Count == 0) continue;

foreach (var genArg in property.GenericTypeArguments)
{
if (genArg.IsBuiltin) continue;
if (genArg.InnerType is null && genArg.SourceType is null) continue;
imports.Add(new OutputProperty(genArg.TypeName, (genArg.InnerType ?? genArg.SourceType)!, null, "", genArg.TypeName, genArg.ImportType, false, genArg.IsArray, genArg.IsNullable, genArg.IsReadonly, genArg.IsGeneric, genArg.GenericTypeArguments));
}
}

if (buildZodSchema)
_builder.AppendLine(ZodSchemaWriter.LibraryImport);

Expand Down Expand Up @@ -117,7 +130,14 @@ private void BuildExport(OutputType type)
}
else
{
_builder.AppendLine($"export interface {type.Name} {{");
var genericPropertyTypes = type.IsGeneric
? type.GenericTypeArguments ?? []
: [];
var genericTypeArguments = genericPropertyTypes.Count > 0
? $"<{string.Join(", ", genericPropertyTypes.Select(x => x.TypeName))}>"
: "";

_builder.AppendLine($"export interface {type.Name}{genericTypeArguments} {{");
}

// Body
Expand Down Expand Up @@ -168,6 +188,11 @@ private static List<OutputType> GetImportedTypes(IEnumerable<OutputType> allType
return allTypes.Where(x => x.FullName == keyType.FullName || x.FullName == valueType.FullName).ToList();
}

if (import.IsGeneric && import.GenericTypeArguments.Count > 0)
return allTypes
.Where(x => x.FullName == $"{sourceType.Namespace}.{sourceType.Name}")
.ToList();

return allTypes.Where(x => x.FullName == sourceType.FullName).ToList();
}
}
Expand Down
Loading

0 comments on commit 0f22d42

Please sign in to comment.