diff --git a/packages/renderers-dart/src/getRenderMapVisitor.ts b/packages/renderers-dart/src/getRenderMapVisitor.ts index 0db204fb0..6a06ec7cb 100644 --- a/packages/renderers-dart/src/getRenderMapVisitor.ts +++ b/packages/renderers-dart/src/getRenderMapVisitor.ts @@ -1,14 +1,14 @@ import { + camelCase, getAllAccounts, getAllDefinedTypes, getAllInstructionsWithSubs, getAllPdas, getAllPrograms, pascalCase, - snakeCase, resolveNestedTypeNode, + snakeCase, structTypeNodeFromInstructionArgumentNodes, - camelCase, } from '@codama/nodes'; import { RenderMap } from '@codama/renderers-core'; import { @@ -25,7 +25,7 @@ import { } from '@codama/visitors-core'; import { join } from 'path'; -import { getTypeManifestVisitor, TypeManifest } from './getTypeManifestVisitor'; +import TypeManifest, { getTypeManifestVisitor } from './getTypeManifestVisitor'; import { ImportMap } from './ImportMap'; import { extractDiscriminatorBytes, getImportFromFactory, LinkOverrides, render } from './utils'; @@ -35,30 +35,72 @@ export type GetRenderMapOptions = { renderParentInstructions?: boolean; }; -function extractFieldsFromTypeManifest(typeManifest: TypeManifest): { name: string, type: string, field: string }[] { +// ' final Int64List i64_array;\n' + +// ' final Int8List /* length: 2 */ fixed_i8;\n' + +function extractFieldsFromTypeManifest(typeManifest: TypeManifest): { + baseType: string + field: string, + name: string, + nesting: number, + type: string, +}[] { return typeManifest.type .split('\n') .map((line) => { + // That handles lines like: final Uint8List fieldName; and extracts the type and name in order to be used from borsh readers/writers const match = line.trim().match(/^final\s+([\w<>, ?]+)\s+(\w+);$/); if (match && match[2] !== 'discriminator') { + const isOptional = /\?$/.test(match[1]); + const rawType = match[1].replace(/\?$/, '').trim(); + + // Count nesting depth of List<> + let nesting = 0; + let inner = rawType; + const listRegex = /^List<(.+)>$/; + + while (listRegex.test(inner)) { + nesting += 1; + inner = inner.replace(listRegex, '$1').trim(); + } + return { + baseType: inner, + field: line, name: match[2], + nesting, + optional: isOptional, type: match[1].replace(/\?$/, ''), - field: line, }; } return null; }) - .filter((entry): entry is { name: string; type: string; field: string } => entry !== null); + .filter((entry): entry is { + baseType: string; + field: string; + name: string; + nesting: number; + optional: boolean; + type: string; + } => entry !== null); +} + +// Given a type like List or Option or Set, it returns SomeStruct +function getBaseType(type: string): string { + const match = type.match(/^(?:List|Set|Option)<([\w\d_]+)>$/); + return match ? match[1] : type; +} + +/** + * Returns a set of all struct type names defined in the given program node. + * Used for distinguishing user-defined struct types during code generation. + */ +function getAllDefinedTypesInNode(programNode: any): Set { + if (!programNode) return new Set(); + return new Set(getAllDefinedTypes(programNode).map(typeNode => pascalCase(typeNode.name))); } export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor { const linkables = new LinkableDictionary(); const stack = new NodeStack(); @@ -88,12 +130,22 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< visitAccount(node) { const typeManifest = visit(node, typeManifestVisitor) as TypeManifest; const { imports } = typeManifest; - + imports.add('dartTypedData', new Set(['Uint8List', 'ByteData'])); imports.add('package:collection/collection.dart', new Set(['ListEquality'])); imports.add('package:solana/dto.dart', new Set(['AccountResult', 'BinaryAccountData'])); imports.add('package:solana/solana.dart', new Set(['RpcClient', 'Ed25519HDPublicKey'])); imports.add('../shared.dart', new Set(['BinaryReader', 'BinaryWriter', 'AccountNotFoundError'])); + + // Find the current program node for this account using the stack path. + // This allows us to get all user-defined struct types for codegen and type checks. + const programNode = findProgramNodeFromPath(stack.getPath('accountNode')); + const structTypeNames = getAllDefinedTypesInNode(programNode); + + const fields = extractFieldsFromTypeManifest(typeManifest).map(field => ({ + ...field, + isStruct: structTypeNames.has(getBaseType(field.type)) + })); return new RenderMap().add( `accounts/${snakeCase(node.name)}.dart`, @@ -102,12 +154,12 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< ...node, discriminator: extractDiscriminatorBytes(node) }, + fields: fields, + getBaseType, // <-- Pass the function that is used to find if i have a collection imports: imports .remove(`generatedAccounts::${pascalCase(node.name)}`, [pascalCase(node.name)]) .toString(dependencyMap), typeManifest, - fields: extractFieldsFromTypeManifest(typeManifest), - }), ); }, @@ -116,10 +168,24 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< const typeManifest = visit(node, typeManifestVisitor) as TypeManifest; const imports = new ImportMap().mergeWithManifest(typeManifest); + imports.add('../shared.dart', new Set(['BinaryReader', 'BinaryWriter', 'AccountNotFoundError'])); + imports.add('dartTypedData', new Set(['Uint8List', 'ByteData'])); + + // This allows us to later distinguish between Object Data types and other types like string, int, arrays, etc + // So that when we generate the serialization logic we know which types needs special handling + const programNode = findProgramNodeFromPath(stack.getPath('definedTypeNode')); + const structTypeNames = getAllDefinedTypesInNode(programNode); + + const fields = extractFieldsFromTypeManifest(typeManifest).map(field => ({ + ...field, + isStruct: structTypeNames.has(getBaseType(field.type)) + })); + return new RenderMap().add( `types/${snakeCase(node.name)}.dart`, renderTemplate('definedTypesPage.njk', { definedType: node, + fields: fields, imports: imports.remove(`generatedTypes::${pascalCase(node.name)}`, [pascalCase(node.name)]).toString(dependencyMap), typeManifest, }), @@ -161,21 +227,21 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< imports.mergeWith(argManifest.imports); const rt = resolveNestedTypeNode(a.type); return { - name: camelCase(a.name), dartType: argManifest.type, + name: camelCase(a.name), resolvedType: rt, }; }); const context = { + args, imports: importsString, instruction: { ...node, discriminator: extractDiscriminatorBytes(node), }, - args, - typeManifest: typeManifest || { nestedStructs: [] }, program: { name: pascalCase(programNode.name || '') }, + typeManifest: typeManifest || { nestedStructs: [] }, }; return new RenderMap().add( @@ -237,12 +303,12 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< definedTypesToExport.length > 0; const ctx = { - programsToExport, - pdasToExport, accountsToExport, - instructionsToExport, definedTypesToExport, hasAnythingToExport, + instructionsToExport, + pdasToExport, + programsToExport, root: node, }; @@ -282,4 +348,4 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor< (v) => recordNodeStackVisitor(v, stack), (v) => recordLinkablesOnFirstVisitVisitor(v, linkables), ); -} +} \ No newline at end of file diff --git a/packages/renderers-dart/src/getTypeManifestVisitor.ts b/packages/renderers-dart/src/getTypeManifestVisitor.ts index 5477b570d..050087246 100644 --- a/packages/renderers-dart/src/getTypeManifestVisitor.ts +++ b/packages/renderers-dart/src/getTypeManifestVisitor.ts @@ -1,24 +1,27 @@ import { - REGISTERED_TYPE_NODE_KINDS, - REGISTERED_VALUE_NODE_KINDS, + ArrayTypeNode, isNode, + parseDocs, pascalCase, + REGISTERED_TYPE_NODE_KINDS, + REGISTERED_VALUE_NODE_KINDS, snakeCase, structTypeNodeFromInstructionArgumentNodes, - parseDocs } from '@codama/nodes'; import { extendVisitor, mergeVisitor, pipe, visit } from '@codama/visitors-core'; -import { dartDocblock } from './utils'; +import { getDartTypedArrayType } from './fragments/dartTypedArray'; import { ImportMap } from './ImportMap'; +import { dartDocblock } from './utils'; - -export type TypeManifest = { +type TypeManifest = { imports: ImportMap; - type: string; nestedStructs: string[]; + type: string; }; +export default TypeManifest; + export type GetImportFromFunction = (node: any) => string; export type TypeManifestOptions = { @@ -27,6 +30,10 @@ export type TypeManifestOptions = { parentName?: string | null; }; +export const structManifestMap: Record = { + +}; + export function getTypeManifestVisitor(options: TypeManifestOptions) { const { getImportFrom } = options; let parentName: string | null = options.parentName ?? null; @@ -35,11 +42,11 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { return pipe( mergeVisitor( - (): TypeManifest => ({ imports: new ImportMap(), type: '', nestedStructs: [] }), + (): TypeManifest => ({ imports: new ImportMap(), nestedStructs: [], type: '' }), (_: any, values: any[]) => ({ imports: new ImportMap().mergeWith(...values.map((v: any) => v.imports)), - type: values.map((v: any) => v.type).join('\n'), nestedStructs: values.flatMap((v: any) => v.nestedStructs), + type: values.map((v: any) => v.type).join('\n'), }), { keys: [ @@ -54,46 +61,58 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { ), (v: any) => extendVisitor(v, { - visitAccount(account: any, { self }: { self: any }) { + visitAccount(account: any, { self }: { self: any }): TypeManifest { parentName = pascalCase(account.name); const manifest = visit(account.data, self) as TypeManifest; parentName = null; return manifest; }, - visitArrayType(arrayType: any, { self }: { self: any }) { - const childManifest = visit(arrayType.item, self) as TypeManifest; - - if (isNode(arrayType.count, 'fixedCountNode')) { - return { - ...childManifest, - type: `List<${childManifest.type}> /* length: ${arrayType.count.value} */`, - }; + visitArrayType(arrayType: ArrayTypeNode , { self }: { self: any}): TypeManifest { + /* + ArrayTypeNode structure: + https://github.com/codama-idl/codama/blob/main/packages/nodes/docs/typeNodes/ArrayTypeNode.md + */ + // eslint-disable-next-line @typescript-eslint/no-unnecessary-type-assertion, @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-argument + const childManifest = visit(arrayType.item, self) as TypeManifest; // Item + + // console.log('===========Array Type ', arrayType); + const typedArrayManifest = getDartTypedArrayType(arrayType.item, childManifest); + if (typedArrayManifest) { + if (isNode(arrayType.count, 'fixedCountNode')) { + // Fixed-size typed array handler + return { + ...typedArrayManifest, + // type: `${typedArrayManifest.type} /* length: ${arrayType.count.value} */`, + type: `${typedArrayManifest.type}`, + }; + } + return typedArrayManifest; } return { ...childManifest, type: `List<${childManifest.type}>`, - }; + } }, - visitBooleanType(_booleanType: any) { + visitBooleanType(_booleanType: any): TypeManifest { return { imports: new ImportMap(), - type: 'bool', nestedStructs: [], + type: 'bool', }; }, - visitBytesType() { + visitBytesType(): TypeManifest { return { imports: new ImportMap().add('dart:typed_data', ['Uint8List']), - type: 'Uint8List', nestedStructs: [], + type: 'Uint8List', }; }, - visitDefinedType(definedType: any, { self }: { self: any }) { + visitDefinedType(definedType: any, { self }: { self: any }): TypeManifest { parentName = pascalCase(definedType.name); const manifest = visit(definedType.type, self) as TypeManifest; parentName = null; @@ -101,16 +120,18 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { const renderedType = isNode(definedType.type, ['enumTypeNode', 'structTypeNode']) ? manifest.type : `typedef ${pascalCase(definedType.name)} = ${manifest.type};`; - - return { ...manifest, type: renderedType }; + + return { ...manifest, type: renderedType}; }, - visitDefinedTypeLink(node: any) { - const pascal = pascalCase(node.name); + visitDefinedTypeLink(node: any): TypeManifest { + const snake_case = snakeCase(node.name); // This is the correct way to name files in Dart + const pascal_case = pascalCase(node.name); // This is the correct way to name types in Dart + // Example: types/simple_struct.dart -> SimpleStruct (is the actual type name) const importFrom = getImportFrom(node); return { - imports: new ImportMap().add(`${importFrom}::${pascal}`, [pascal]), - type: pascal, + imports: new ImportMap().add(`../${importFrom}/${snake_case}.dart`, [snake_case]), + type: pascal_case, nestedStructs: [], }; }, @@ -120,26 +141,26 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { return { imports: new ImportMap(), nestedStructs: [], - type: `${name},`, + type: `class ${name} extends ${parentName} {}` }; }, visitEnumStructVariantType(enumStructVariantType: any, { self }: { self: any }) { const name = pascalCase(enumStructVariantType.name); const originalParentName = parentName; - - // Use a default name if no parent name is available - const variantParentName = originalParentName || 'AnonymousEnum'; - + inlineStruct = true; - parentName = pascalCase(variantParentName) + name; + // Sets the name of the parent to the variant name only + parentName = name; const typeManifest = visit(enumStructVariantType.struct, self) as TypeManifest; inlineStruct = false; + // Set the Parent name back to the original parent name(Enum class name) parentName = originalParentName; - + return { ...typeManifest, - type: `${name} ${typeManifest.type},`, + type: `class ${name} extends ${parentName} + ${typeManifest.type}`, }; }, @@ -154,9 +175,17 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { const childManifest = visit(enumTupleVariantType.tuple, self) as TypeManifest; parentName = originalParentName; + const tupleTypes = childManifest.type.replace(/[()]/g, '').split(',').map(s => s.trim()); + const fields = tupleTypes.map((type, i) => `final ${type} value${i};`).join('\n'); + const constructorArgs = tupleTypes.map((_, i) => `this.value${i}`).join(', '); + return { ...childManifest, - type: `${name}${childManifest.type},`, + type: `class ${name} extends ${parentName} { + ${fields} + + ${name}(${constructorArgs}); + }`, }; }, @@ -165,9 +194,7 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { // Use a default name if no parent name is available const enumName = originalParentName || 'AnonymousEnum'; - const variants = enumType.variants.map((variant: any) => - visit(variant, self) as TypeManifest - ); + const variants = enumType.variants.map((variant: any) => visit(variant, self) as TypeManifest); const variantNames = variants.map((variant: any) => variant.type).join('\n'); const mergedManifest = { imports: new ImportMap().mergeWith(...variants.map((v: any) => v.imports)), @@ -176,9 +203,8 @@ export function getTypeManifestVisitor(options: TypeManifestOptions) { return { ...mergedManifest, - type: `enum ${pascalCase(enumName)} { -${variantNames} -}`, + type: `abstract class ${pascalCase(enumName)} {} + ${variantNames}`, }; }, @@ -219,8 +245,8 @@ ${variantNames} case 'i32': return { imports: new ImportMap(), - type: 'int', nestedStructs: [], + type: 'int', }; case 'u64': case 'i64': @@ -228,14 +254,14 @@ ${variantNames} case 'i128': return { imports: new ImportMap(), - type: 'BigInt', nestedStructs: [], + type: 'BigInt', }; case 'shortU16': return { imports: new ImportMap(), - type: 'int', nestedStructs: [], + type: 'int', }; default: throw new Error(`Unknown number format: ${numberType.format}`); @@ -254,8 +280,8 @@ ${variantNames} visitPublicKeyType() { return { imports: new ImportMap().add('package:solana/solana.dart', ['Ed25519HDPublicKey']), - type: 'Ed25519HDPublicKey', nestedStructs: [], + type: 'Ed25519HDPublicKey', }; }, @@ -275,8 +301,9 @@ ${variantNames} visitStringType() { return { imports: new ImportMap(), - type: 'String', nestedStructs: [], + type: 'String', + }; }, @@ -284,7 +311,6 @@ ${variantNames} const originalParentName = parentName; const originalInlineStruct = inlineStruct; const originalNestedStruct = nestedStruct; - const fieldParentName = originalParentName || 'AnonymousStruct'; parentName = pascalCase(fieldParentName) + pascalCase(structFieldType.name); @@ -302,10 +328,10 @@ ${variantNames} return { ...fieldManifest, - type: fieldManifest.type, - nestedStructs: fieldManifest.nestedStructs, + field: `${docblock} final ${fieldManifest.type} ${fieldName};`, imports: fieldManifest.imports, - field: `${docblock} final ${fieldManifest.type} ${fieldName};`, + nestedStructs: fieldManifest.nestedStructs, + type: fieldManifest.type, }; }, @@ -313,7 +339,10 @@ ${variantNames} const originalParentName = parentName; // Use a default name if no parent name is available const structName = originalParentName || 'AnonymousStruct'; - + + // In Dart, every variable must be initialized, either via constructor or with a default value. + // eslint-disable-next-line @typescript-eslint/no-unsafe-call + const classConstrutor = ` ${pascalCase(structName)}({\n${structType.fields.map((field: any) => ` required this.${snakeCase(field.name)},`).join('\n')}\n });\n`; const fields = structType.fields.map((field: any) => visit(field, self) as TypeManifest ); @@ -326,11 +355,14 @@ ${variantNames} if (nestedStruct) { return { ...mergedManifest, + isStruct: true, nestedStructs: [ ...mergedManifest.nestedStructs, `class ${pascalCase(structName)} { -${fieldTypes} -}`, + ${fieldTypes} + + ${classConstrutor} + }`, ], type: pascalCase(structName), }; @@ -339,15 +371,20 @@ ${fieldTypes} if (inlineStruct) { return { ...mergedManifest, type: `{ -${fieldTypes} -}` }; + ${fieldTypes} + + ${classConstrutor} + }` + }; } return { ...mergedManifest, type: `class ${pascalCase(structName)} { -${fieldTypes} -}`, + ${fieldTypes} + + ${classConstrutor} + }`, }; }, @@ -368,8 +405,8 @@ ${fieldTypes} visitDateTimeType() { return { imports: new ImportMap(), - type: 'DateTime', nestedStructs: [], + type: 'DateTime', }; }, }), diff --git a/packages/renderers-dart/src/utils/linkOverrides.ts b/packages/renderers-dart/src/utils/linkOverrides.ts index cbdf63025..c6b6152d6 100644 --- a/packages/renderers-dart/src/utils/linkOverrides.ts +++ b/packages/renderers-dart/src/utils/linkOverrides.ts @@ -43,7 +43,7 @@ export function getImportFromFactory(overrides: LinkOverrides): GetImportFromFun case 'accountLinkNode': return linkOverrides.accounts[node.name] ?? (fallback ?? 'generated_accounts'); case 'definedTypeLinkNode': - return linkOverrides.definedTypes[node.name] ?? (fallback ?? 'generated_types'); + return linkOverrides.definedTypes[node.name] ?? (fallback ?? 'types'); case 'instructionLinkNode': return linkOverrides.instructions[node.name] ?? (fallback ?? 'generated_instructions'); case 'pdaLinkNode': @@ -67,4 +67,4 @@ export function getImportFromFactory(overrides: LinkOverrides): GetImportFromFun }); } }; -} +} \ No newline at end of file