diff --git a/README.md b/README.md index 31b3100..581d564 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,9 @@ pnpm install @limechain/renderers-dart "args": [ "generated", { - "libraryName": "myProject", "deleteFolderBeforeRendering": true, - "formatCode": true + "formatCode": true, + "generateBorsh": true } ] } @@ -53,15 +53,13 @@ pnpm install @limechain/renderers-dart } ``` -An object can be passed as a second argument to further configure the renderer. See the [Options](#options) section below for more details. - ### 3.2. Run code generation ```sh pnpm codama run dart ``` -### 3.3. Run Dart Borsh code generation +### 3.3. (Only if `generateBorsh` is set to `false`) Run Dart Borsh code generation ```sh cd generated @@ -102,11 +100,10 @@ lib/ The `renderVisitor` accepts the following options. -| Name | Type | Default | Description | -| ----------------------------- | -------- | ------- | --------------------------------------------------------------- | -| `libraryName` | `string` | `'lib'` | The name of the generated Dart library. | -| `outputDirectory` | `string` | `'lib'` | The directory where generated files will be placed. | -| `deleteFolderBeforeRendering` | `bool` | `true` | Flag for deleting the output folder before generating it again. | -| `formatCode` | `bool` | `true` | Flag for formatting the Dart code after generation | +| Name | Type | Default | Description | +| ----------------------------- | ------ | ------- | --------------------------------------------------------------- | +| `deleteFolderBeforeRendering` | `bool` | `true` | Flag for deleting the output folder before generating it again. | +| `formatCode` | `bool` | `true` | Flag for formatting the Dart code after generation | +| `generateBorsh` | `bool` | `true` | Flag for running Borsh code generation after rendering |
diff --git a/src/fragments/accountPage.ts b/src/fragments/accountPage.ts index b28bf00..2cc21cc 100644 --- a/src/fragments/accountPage.ts +++ b/src/fragments/accountPage.ts @@ -22,12 +22,9 @@ export function getAccountPageFragment( return getStructAccountFragment(node, scope, className); } - const typeInfo = getTypeInfo(dataTypeNode, scope.nameApi); - const borshAnnotation = getBorshAnnotation(dataTypeNode, scope.nameApi); - const allImports = [ - 'package:borsh_annotation_extended/borsh_annotation_extended.dart', - ...typeInfo.imports, - ]; + const typeInfo = getTypeInfo(dataTypeNode, scope.nameApi, programNode.definedTypes); + const borshAnnotation = getBorshAnnotation(dataTypeNode, scope.nameApi, programNode.definedTypes); + const allImports = ['package:borsh_annotation_extended/borsh_annotation_extended.dart', ...typeInfo.imports]; const content = `part '${node.name}.g.dart'; @@ -68,21 +65,22 @@ function getStructAccountFragment( const dataTypeNode = resolveNestedTypeNode(node.data); const fields = isNode(dataTypeNode, 'structTypeNode') ? dataTypeNode.fields : []; + const programNode = findProgramNodeFromPath(scope.accountPath); + const programDefinedTypes = programNode?.definedTypes || []; + + const allImports = new Set(['package:borsh_annotation_extended/borsh_annotation_extended.dart']); const factoryParams = fields .map(field => { - const typeInfo = getTypeInfo(field.type, scope.nameApi); - const borshAnnotation = getBorshAnnotation(field.type, scope.nameApi); + const typeInfo = getTypeInfo(field.type, scope.nameApi, programDefinedTypes); + const borshAnnotation = getBorshAnnotation(field.type, scope.nameApi, programDefinedTypes); const fieldName = scope.nameApi.accountField(field.name); + + typeInfo.imports.forEach(imp => allImports.add(imp)); + return ` ${borshAnnotation} required ${typeInfo.dartType} ${fieldName},`; }) .join('\n'); - const allImports = new Set(['package:borsh_annotation_extended/borsh_annotation_extended.dart']); - fields.forEach(field => { - const typeInfo = getTypeInfo(field.type, scope.nameApi); - typeInfo.imports.forEach(imp => allImports.add(imp)); - }); - const content = `part '${node.name}.g.dart'; @BorshSerializable() diff --git a/src/fragments/enumType.ts b/src/fragments/enumType.ts new file mode 100644 index 0000000..fb5e040 --- /dev/null +++ b/src/fragments/enumType.ts @@ -0,0 +1,336 @@ +import { + camelCase, + DefinedTypeNode, + EnumStructVariantTypeNode, + EnumTupleVariantTypeNode, + EnumVariantTypeNode, + resolveNestedTypeNode, + StructFieldTypeNode, + TypeNode, +} from '@codama/nodes'; + +import { createFragment, Fragment, getBorshAnnotation, getTypeInfo, NameApi, RenderScope } from '../utils'; + +function collectVariantImports( + variant: EnumVariantTypeNode, + nameApi: NameApi, + scope: Pick, + allImports: Set, +): void { + if (variant.kind === 'enumStructVariantTypeNode') { + const resolvedStruct = resolveNestedTypeNode(variant.struct); + resolvedStruct.fields?.forEach((field: StructFieldTypeNode) => { + const typeInfo = getTypeInfo(field.type, nameApi, scope.definedTypes); + typeInfo.imports.forEach(imp => allImports.add(imp)); + }); + } else if (variant.kind === 'enumTupleVariantTypeNode') { + const resolvedTuple = resolveNestedTypeNode(variant.tuple); + resolvedTuple.items?.forEach((item: TypeNode) => { + const typeInfo = getTypeInfo(item, nameApi, scope.definedTypes); + typeInfo.imports.forEach(imp => allImports.add(imp)); + }); + } +} + +export function getEnumVariantFragment( + scope: Pick & { + variant: EnumVariantTypeNode; + variantName: string; + }, +): Fragment { + const { variantName, variant, nameApi } = scope; + + const allImports = new Set(['package:borsh_annotation_extended/borsh_annotation_extended.dart']); + collectVariantImports(variant, nameApi, scope, allImports); + + let content = ''; + + switch (variant.kind) { + case 'enumEmptyVariantTypeNode': + content = generateEmptyVariantAsStruct(variantName); + break; + case 'enumTupleVariantTypeNode': + content = generateTupleVariantAsStruct(variantName, variant, nameApi, scope); + break; + case 'enumStructVariantTypeNode': + content = generateStructVariantAsStruct(variantName, variant, nameApi, scope); + break; + } + + return createFragment(content, Array.from(allImports)); +} + +export function getEnumMainFragment( + scope: Pick & { + name: string; + node: DefinedTypeNode; + }, +): Fragment { + const { name, node, nameApi } = scope; + + if (node.type.kind !== 'enumTypeNode') { + throw new Error(`Expected enumTypeNode but got ${node.type.kind}`); + } + + const enumTypeNode = node.type; + const className = nameApi.definedType(camelCase(name)); + const variants = enumTypeNode.variants || []; + + const allImports = new Set(['package:borsh_annotation_extended/borsh_annotation_extended.dart']); + + variants.forEach(variant => { + const variantFileName = camelCase(variant.name); + allImports.add(`./${variantFileName}.dart`); + + collectVariantImports(variant, nameApi, scope, allImports); + }); + + const factoryConstructors = variants + .map((variant, index) => { + const variantName = nameApi.definedType(camelCase(variant.name)); + const methodName = camelCase(variant.name); + + switch (variant.kind) { + case 'enumEmptyVariantTypeNode': + return ` factory ${className}.${methodName}() { + return ${className}._(${variantName}(), ${index}); + }`; + case 'enumTupleVariantTypeNode': { + const resolvedTuple = resolveNestedTypeNode(variant.tuple); + if (!resolvedTuple.items) { + return ` factory ${className}.${methodName}() { + return ${className}._(${variantName}(), ${index}); + }`; + } + const tupleParams = resolvedTuple.items + .map((item: TypeNode, index: number) => { + const typeInfo = getTypeInfo(item, nameApi, scope.definedTypes); + return `${typeInfo.dartType} field${index}`; + }) + .join(', '); + const tupleArgs = resolvedTuple.items + .map((_: TypeNode, index: number) => `field${index}: field${index}`) + .join(', '); + return ` factory ${className}.${methodName}(${tupleParams}) { + return ${className}._(${variantName}(${tupleArgs}), ${index}); + }`; + } + case 'enumStructVariantTypeNode': { + const resolvedStruct = resolveNestedTypeNode(variant.struct); + if (!resolvedStruct.fields) { + return ` factory ${className}.${methodName}() { + return ${className}._(${variantName}(), ${index}); + }`; + } + const structParams = resolvedStruct.fields + .map((field: StructFieldTypeNode) => { + const typeInfo = getTypeInfo(field.type, nameApi, scope.definedTypes); + const fieldName = nameApi.accountField(field.name); + return `required ${typeInfo.dartType} ${fieldName}`; + }) + .join(', '); + const structArgs = resolvedStruct.fields + .map((field: StructFieldTypeNode) => { + const fieldName = nameApi.accountField(field.name); + return `${fieldName}: ${fieldName}`; + }) + .join(', '); + return ` factory ${className}.${methodName}({${structParams}}) { + return ${className}._(${variantName}(${structArgs}), ${index}); + }`; + } + default: + return ''; + } + }) + .filter(Boolean) + .join('\n\n'); + + const variantAnnotations = variants + .map(variant => { + const variantName = nameApi.definedType(camelCase(variant.name)); + return `${variantName}: B${variantName}()`; + }) + .join(', '); + + const content = `// Main enum class with factory constructors for each variant +class ${className} { + final dynamic variant; + final int discriminant; + + const ${className}._(this.variant, this.discriminant); + +${factoryConstructors} + + @override + String toString() { + return variant.toString(); + } + + Uint8List toBorsh() { + final writer = BinaryWriter(); + writer.writeU8(discriminant); + + final Uint8List variantBytes = variant.toBorsh(); + for (final byte in variantBytes) { + writer.writeU8(byte); + } + return writer.toArray(); + } +} + +// Borsh annotation generator for use in other structs +// Usage: @BEnum<${className}>({${variantAnnotations}}) required ${className} myEnum,`; + + return createFragment(content, Array.from(allImports)); +} + +function generateEmptyVariantAsStruct(variantName: string): string { + return `class ${variantName} { + const ${variantName}(); + + static ${variantName} fromBorsh(Uint8List _data) { + return ${variantName}(); + } + + Uint8List toBorsh() { + // Empty variant has no data to serialize + return Uint8List(0); + } + + @override + String toString([int indent = 0]) { + return '${variantName}()'; + } +} + +class B${variantName} implements BType<${variantName}> { + const B${variantName}(); + + @override + void write(BinaryWriter writer, ${variantName} value) { + // Correct - unit variant has no data to write + } + + @override + ${variantName} read(BinaryReader reader) { + // Correct - unit variant has no data to read + return ${variantName}(); + } +}`; +} + +function generateTupleVariantAsStruct( + variantName: string, + variant: EnumTupleVariantTypeNode, + nameApi: NameApi, + scope: Pick, +): string { + const resolvedTuple = resolveNestedTypeNode(variant.tuple); + if (!resolvedTuple.items) { + return generateEmptyVariantAsStruct(variantName); + } + + const params = resolvedTuple.items + .map((item: TypeNode, index: number) => { + const typeInfo = getTypeInfo(item, nameApi, scope.definedTypes); + const borshAnnotation = getBorshAnnotation(item, nameApi, scope.definedTypes); + return ` ${borshAnnotation} required ${typeInfo.dartType} field${index},`; + }) + .join('\n'); + + const toStringFields = resolvedTuple.items + .map((_: TypeNode, index: number) => `buffer.writeln(' field${index}: $field${index}');`) + .join('\n '); + + return `part '${camelCase(variantName)}.g.dart'; + +@BorshSerializable() +class ${variantName} with _$${variantName} { + factory ${variantName}({ +${params} + }) = _${variantName}; + + + const ${variantName}._(); + + static ${variantName} fromBorsh(Uint8List data) { + return _$${variantName}FromBorsh(data); + } + + + @override + String toString([int indent = 0]) { + final buffer = StringBuffer(); + buffer.writeln('${variantName}('); + ${toStringFields} + buffer.write(')'); + return buffer.toString(); + } +}`; +} + +function generateStructVariantAsStruct( + variantName: string, + variant: EnumStructVariantTypeNode, + nameApi: NameApi, + scope: Pick, +): string { + const resolvedStruct = resolveNestedTypeNode(variant.struct); + if (!resolvedStruct.fields) { + return generateEmptyVariantAsStruct(variantName); + } + + const params = resolvedStruct.fields + .map((field: StructFieldTypeNode) => { + const typeInfo = getTypeInfo(field.type, nameApi, scope.definedTypes); + const borshAnnotation = getBorshAnnotation(field.type, nameApi, scope.definedTypes); + const fieldName = nameApi.accountField(field.name); + return ` ${borshAnnotation} required ${typeInfo.dartType} ${fieldName},`; + }) + .join('\n'); + + const toStringFields = resolvedStruct.fields + .map((field: StructFieldTypeNode) => { + const fieldName = nameApi.accountField(field.name); + return `buffer.writeln(' ${fieldName}: $${fieldName}');`; + }) + .join('\n '); + + return `part '${camelCase(variantName)}.g.dart'; + +@BorshSerializable() +class ${variantName} with _$${variantName} { + factory ${variantName}({ +${params} + }) = _${variantName}; + + + const ${variantName}._(); + + static ${variantName} fromBorsh(Uint8List data) { + return _$${variantName}FromBorsh(data); + } + + + @override + String toString([int indent = 0]) { + final buffer = StringBuffer(); + buffer.writeln('${variantName}('); + ${toStringFields} + buffer.write(')'); + return buffer.toString(); + } +}`; +} + +export function getBEnumAnnotation(enumName: string, variants: EnumVariantTypeNode[], nameApi: NameApi): string { + const variantAnnotations = variants + .map(variant => { + const variantName = nameApi.definedType(camelCase(variant.name)); + return `${variantName}: B${variantName}()`; + }) + .join(', '); + + return `@BEnum<${enumName}>({${variantAnnotations}})`; +} diff --git a/src/fragments/index.ts b/src/fragments/index.ts index 98c0a18..98c47f0 100644 --- a/src/fragments/index.ts +++ b/src/fragments/index.ts @@ -1,5 +1,6 @@ export * from './accountPage'; export * from './structType'; +export * from './enumType'; export * from './instructionPage'; export * from './instructionData'; export * from './instructionFunction'; diff --git a/src/fragments/instructionData.ts b/src/fragments/instructionData.ts index f2b08a4..04d6ca8 100644 --- a/src/fragments/instructionData.ts +++ b/src/fragments/instructionData.ts @@ -1,5 +1,5 @@ import { InstructionNode, isNode, structTypeNodeFromInstructionArgumentNodes } from '@codama/nodes'; -import { getLastNodeFromPath, NodePath } from '@codama/visitors-core'; +import { findProgramNodeFromPath, getLastNodeFromPath, NodePath } from '@codama/visitors-core'; import { createFragment, Fragment, getBorshAnnotation, getTypeInfo, RenderScope } from '../utils'; @@ -17,18 +17,29 @@ export function getInstructionDataFragment( const instructionDataName = nameApi.instructionDataType(instructionNode.name); const structNode = structTypeNodeFromInstructionArgumentNodes(instructionNode.arguments); - const factoryParams = structNode.fields - .filter(field => field.name !== 'discriminator') // Exclude discriminator field + const programNode = findProgramNodeFromPath(instructionPath); + const programDefinedTypes = programNode?.definedTypes || []; + + const allImports = new Set([ + 'package:borsh_annotation_extended/borsh_annotation_extended.dart', + 'package:solana/encoder.dart', + ]); + + const nonDiscriminatorFields = structNode.fields.filter(field => field.name !== 'discriminator'); + + const factoryParams = nonDiscriminatorFields .map(field => { - const typeInfo = getTypeInfo(field.type, nameApi); - const borshAnnotation = getBorshAnnotation(field.type, nameApi); + const typeInfo = getTypeInfo(field.type, nameApi, programDefinedTypes); + const borshAnnotation = getBorshAnnotation(field.type, nameApi, programDefinedTypes); const fieldName = nameApi.instructionField(field.name); + + typeInfo.imports.forEach(imp => allImports.add(imp)); + return ` ${borshAnnotation} required ${typeInfo.dartType} ${fieldName},`; }) .join('\n'); - const validations = structNode.fields - .filter(field => field.name !== 'discriminator') + const validations = nonDiscriminatorFields .map(field => { const fieldName = nameApi.instructionField(field.name); @@ -55,16 +66,6 @@ export function getInstructionDataFragment( ? ` if (discriminator.length != 8) throw ArgumentError('discriminator must be exactly 8 bytes, got \${discriminator.length}');\n${validations}` : ` if (discriminator.length != 8) throw ArgumentError('discriminator must be exactly 8 bytes, got \${discriminator.length}');`; - const allImports = new Set([ - 'package:borsh_annotation_extended/borsh_annotation_extended.dart', - 'package:solana/solana.dart', - 'package:solana/encoder.dart', - ]); - structNode.fields.forEach(field => { - const typeInfo = getTypeInfo(field.type, nameApi); - typeInfo.imports.forEach(imp => allImports.add(imp)); - }); - const discriminatorBytes = (() => { const data = instructionNode.arguments.find(arg => arg.name === 'discriminator')?.defaultValue; return data && isNode(data, 'bytesValueNode') @@ -87,8 +88,7 @@ ${allParams} ${allValidations} return _${instructionDataName}( discriminator: discriminator, -${structNode.fields - .filter(field => field.name !== 'discriminator') +${nonDiscriminatorFields .map(field => { const fieldName = nameApi.instructionField(field.name); return ` ${fieldName}: ${fieldName},`; diff --git a/src/fragments/libraryIndex.ts b/src/fragments/libraryIndex.ts index 54964f0..72e0791 100644 --- a/src/fragments/libraryIndex.ts +++ b/src/fragments/libraryIndex.ts @@ -1,4 +1,4 @@ -import { getAllInstructionsWithSubs, getAllPrograms, RootNode } from '@codama/nodes'; +import { camelCase, getAllInstructionsWithSubs, getAllPrograms, RootNode } from '@codama/nodes'; import { Fragment } from '../utils'; @@ -7,19 +7,15 @@ export function getLibraryIndexFragment(scope: { rootNode: RootNode }): Fragment const programs = getAllPrograms(rootNode); const exports: string[] = []; - programs.forEach(program => { - // Export accounts program.accounts.forEach(account => { exports.push(`export 'accounts/${account.name}.dart';`); }); - // Export instructions const instructions = getAllInstructionsWithSubs(program, { leavesOnly: true }); instructions.forEach(instruction => { exports.push(`export 'instructions/${instruction.name}.dart';`); - - // Export PDAs from instruction accounts + instruction.accounts.forEach(account => { if (account.defaultValue?.kind === 'pdaValueNode') { exports.push(`export 'pdas/${account.name}.dart';`); @@ -27,19 +23,26 @@ export function getLibraryIndexFragment(scope: { rootNode: RootNode }): Fragment }); }); - // Export defined types program.definedTypes.forEach(type => { if (type.type.kind === 'structTypeNode') { exports.push(`export 'types/${type.name}.dart';`); + } else if (type.type.kind === 'enumTypeNode') { + const enumName = camelCase(type.name); + exports.push(`export 'types/${enumName}/${enumName}.dart';`); + + const enumType = type.type; + const variants = enumType.variants || []; + variants.forEach(variant => { + const variantName = camelCase(variant.name); + exports.push(`export 'types/${enumName}/${variantName}.dart';`); + }); } }); - // Export errors if (program.errors.length > 0) { exports.push(`export 'errors/${program.name}.dart';`); } - // Export programs if (program.name) { exports.push(`export 'programs/${program.name}.dart';`); } diff --git a/src/fragments/structType.ts b/src/fragments/structType.ts index 09b71c0..98b8075 100644 --- a/src/fragments/structType.ts +++ b/src/fragments/structType.ts @@ -3,7 +3,7 @@ import { camelCase, StructTypeNode } from '@codama/nodes'; import { createFragment, Fragment, getBorshAnnotation, getTypeInfo, RenderScope } from '../utils'; export function getStructTypeFragment( - scope: Pick & { + scope: Pick & { name: string; node: StructTypeNode; size: number | null; @@ -13,19 +13,15 @@ export function getStructTypeFragment( const className = nameApi.definedType(camelCase(name)); const fields = node.fields || []; - - // Collect all imports const allImports = new Set(['package:borsh_annotation_extended/borsh_annotation_extended.dart']); - fields.forEach(field => { - const typeInfo = getTypeInfo(field.type, scope.nameApi); - typeInfo.imports.forEach(imp => allImports.add(imp)); - }); - const factoryParams = fields .map(field => { - const typeInfo = getTypeInfo(field.type, scope.nameApi); - const borshAnnotation = getBorshAnnotation(field.type, scope.nameApi); + const typeInfo = getTypeInfo(field.type, nameApi, scope.definedTypes); + const borshAnnotation = getBorshAnnotation(field.type, nameApi, scope.definedTypes); const fieldName = nameApi.accountField(field.name); + + typeInfo.imports.forEach(imp => allImports.add(imp)); + return ` ${borshAnnotation} required ${typeInfo.dartType} ${fieldName},`; }) .join('\n'); @@ -37,7 +33,6 @@ class ${className} with _$${className} { factory ${className}({ ${factoryParams} }) = _${className}; - const ${className}._(); diff --git a/src/utils/fragment.ts b/src/utils/fragment.ts index bb16e09..ba3678d 100644 --- a/src/utils/fragment.ts +++ b/src/utils/fragment.ts @@ -2,14 +2,12 @@ import { BaseFragment } from '@codama/renderers-core'; export type Fragment = BaseFragment & { imports: Set; - libraryName?: string; }; export function createFragment(content: string, imports: string[] = []): Fragment { return { content, imports: new Set(imports), - libraryName: undefined, }; } @@ -37,14 +35,7 @@ export function getDocblockFragment(lines: string[]): Fragment | undefined { return createFragment(prefixedLines.join('\n')); } -export function getPageFragment( - page: Fragment, - options: { - libraryName?: string; - } = {}, -): Fragment { - const { libraryName } = options; - +export function getPageFragment(page: Fragment): Fragment { // Create header fragment const header = getDocblockFragment([ 'This code was AUTOGENERATED using the Codama library.', @@ -55,7 +46,7 @@ export function getPageFragment( ]); // Create library fragment - const library = libraryName ? createFragment(`library ${libraryName};`) : undefined; + const library = createFragment(`library lib;`); // Create imports fragment let imports: Fragment | undefined = undefined; diff --git a/src/utils/options.ts b/src/utils/options.ts index 4dbab9e..a7bc6e7 100644 --- a/src/utils/options.ts +++ b/src/utils/options.ts @@ -1,3 +1,5 @@ +import { DefinedTypeNode } from '@codama/nodes'; + import { NameApi } from './nameTransformers'; export type RenderOptions = GetRenderMapOptions & { @@ -7,13 +9,10 @@ export type RenderOptions = GetRenderMapOptions & { }; export type GetRenderMapOptions = { - libraryName?: string; nameTransformers?: Partial; - outputDirectory: string; }; export type RenderScope = { - libraryName: string; + definedTypes: DefinedTypeNode[]; nameApi: NameApi; - outputDirectory: string; }; diff --git a/src/utils/pda.ts b/src/utils/pda.ts index 10b3657..2381b93 100644 --- a/src/utils/pda.ts +++ b/src/utils/pda.ts @@ -57,10 +57,7 @@ export function createInlinePdaFile( nameApi: RenderScope['nameApi'], programPublicKey: string | undefined, programName: string | undefined, - asPage: ( - fragment: TFragment, - pageOptions?: { libraryName?: string }, - ) => TFragment, + asPage: (fragment: TFragment) => TFragment, ): Fragment | undefined { const functionName = `derive${accountName.charAt(0).toUpperCase() + accountName.slice(1)}Pda`; const seeds = generatePdaSeeds(pdaNode, pdaSeedValues, nameApi); @@ -71,7 +68,6 @@ export function createInlinePdaFile( programClassName = nameApi.programType(programName as CamelCaseString); } - pdaNode.seeds.forEach(seed => { if (isNode(seed, 'variablePdaSeedNode')) { const valueSeed = pdaSeedValues?.find((s: PdaSeedValueNode) => s.name === seed.name)?.value; @@ -88,8 +84,8 @@ export function createInlinePdaFile( programClassName && programClassName !== '' ? `Ed25519HDPublicKey.fromBase58(${programClassName}.programId)` : programPublicKey - ? `Ed25519HDPublicKey.fromBase58('${programPublicKey}')` - : 'PROGRAM_ID_HERE'; + ? `Ed25519HDPublicKey.fromBase58('${programPublicKey}')` + : 'PROGRAM_ID_HERE'; const content = `/// Returns the PDA address for ${accountName} Future ${functionName}(${parameterList}) async { diff --git a/src/utils/pubspec.ts b/src/utils/pubspec.ts index 7a8d6e1..c81a229 100644 --- a/src/utils/pubspec.ts +++ b/src/utils/pubspec.ts @@ -1,6 +1,6 @@ // Constants for Git-based dependencies const REPO_URL = 'https://github.com/vlady-kotsev/borsh_annotation_extended.git'; -const REPO_REF = '4fcc50dced4717257fc7cd4cf2ae5489ebfc1f48'; +const REPO_REF = 'b6b2c80d3b198fc2af9fc74d78fcdc86c23ed7cd'; export function generatePubspec( packageName: string, diff --git a/src/utils/types.ts b/src/utils/types.ts index 97ea308..bddbfb4 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -1,7 +1,12 @@ import { ArrayTypeNode, + camelCase, DefinedTypeLinkNode, + DefinedTypeNode, + EnumTypeNode, + EnumVariantTypeNode, FixedSizeTypeNode, + Node, NumberTypeNode, OptionTypeNode, resolveNestedTypeNode, @@ -9,11 +14,11 @@ import { TypeNode, } from '@codama/nodes'; +import { getBEnumAnnotation } from '../fragments/enumType'; import { NameApi } from './nameTransformers'; // Constants for hardcoded values const DEFAULT_PUBLIC_KEY = '11111111111111111111111111111111'; -const TYPES_IMPORT_PREFIX = '../types/'; const SOLANA_PUBLIC_KEY_SIZE = 32; export interface TypeInfo { @@ -23,7 +28,7 @@ export interface TypeInfo { serializationSize?: number; } -export function getTypeInfo(typeNode: TypeNode, nameApi: NameApi): TypeInfo { +export function getTypeInfo(typeNode: TypeNode, nameApi: NameApi, allDefinedTypes: DefinedTypeNode[]): TypeInfo { switch (typeNode.kind) { case 'numberTypeNode': return getNumberTypeInfo(typeNode); @@ -34,21 +39,20 @@ export function getTypeInfo(typeNode: TypeNode, nameApi: NameApi): TypeInfo { case 'bytesTypeNode': return getBytesTypeInfo(); case 'arrayTypeNode': - return getArrayTypeInfo(typeNode, nameApi); + return getArrayTypeInfo(typeNode, nameApi, allDefinedTypes); case 'optionTypeNode': - return getOptionTypeInfo(typeNode, nameApi); + return getOptionTypeInfo(typeNode, nameApi, allDefinedTypes); case 'publicKeyTypeNode': return getPublicKeyTypeInfo(); case 'fixedSizeTypeNode': - return getFixedSizeTypeInfo(typeNode, nameApi); + return getFixedSizeTypeInfo(typeNode, nameApi, allDefinedTypes); case 'solAmountTypeNode': return getAmountTypeInfo(); case 'definedTypeLinkNode': - return getDefinedTypeLinkTypeInfo(typeNode, nameApi.definedType(typeNode.name)); + return getDefinedTypeLinkTypeInfo(typeNode, nameApi.definedType(typeNode.name), allDefinedTypes); case 'sizePrefixTypeNode': return getSizePrefixTypeInfo(typeNode); default: - // For unsupported types, return a generic object type return { dartType: 'Object', defaultValue: 'Object()', @@ -115,9 +119,14 @@ function getBytesTypeInfo(): TypeInfo { }; } -function getArrayTypeInfo(node: ArrayTypeNode, nameApi: NameApi): TypeInfo { +function getArrayTypeInfo( + node: ArrayTypeNode, + nameApi: NameApi, + + allDefinedTypes: DefinedTypeNode[], +): TypeInfo { const resolvedType = resolveNestedTypeNode(node.item); - const innerTypeInfo = getTypeInfo(resolvedType, nameApi); + const innerTypeInfo = getTypeInfo(resolvedType, nameApi, allDefinedTypes); if (node.count && node.count.kind === 'fixedCountNode') { const size = node.count.value; @@ -136,9 +145,9 @@ function getArrayTypeInfo(node: ArrayTypeNode, nameApi: NameApi): TypeInfo { }; } -function getOptionTypeInfo(node: OptionTypeNode, nameApi: NameApi): TypeInfo { +function getOptionTypeInfo(node: OptionTypeNode, nameApi: NameApi, allDefinedTypes: DefinedTypeNode[]): TypeInfo { const resolvedType = resolveNestedTypeNode(node.item); - const innerTypeInfo = getTypeInfo(resolvedType, nameApi); + const innerTypeInfo = getTypeInfo(resolvedType, nameApi, allDefinedTypes); return { dartType: `${innerTypeInfo.dartType}?`, @@ -151,16 +160,14 @@ function getPublicKeyTypeInfo(): TypeInfo { return { dartType: 'Ed25519HDPublicKey', defaultValue: `Ed25519HDPublicKey.fromBase58("${DEFAULT_PUBLIC_KEY}")`, - imports: ['package:solana/solana.dart'], + imports: [], serializationSize: SOLANA_PUBLIC_KEY_SIZE, }; } -function getFixedSizeTypeInfo(node: FixedSizeTypeNode, nameApi: NameApi): TypeInfo { - // Resolve the nested type to get the actual inner type +function getFixedSizeTypeInfo(node: FixedSizeTypeNode, nameApi: NameApi, allDefinedTypes: DefinedTypeNode[]): TypeInfo { const resolvedType = resolveNestedTypeNode(node.type); - // Special case: fixed-size byte arrays should be Uint8List if (resolvedType.kind === 'bytesTypeNode') { return { dartType: 'Uint8List', @@ -170,9 +177,8 @@ function getFixedSizeTypeInfo(node: FixedSizeTypeNode, nameApi: NameApi): TypeIn }; } - const innerTypeInfo = getTypeInfo(resolvedType, nameApi); + const innerTypeInfo = getTypeInfo(resolvedType, nameApi, allDefinedTypes); - // For other fixed-size arrays, create List return { dartType: `List<${innerTypeInfo.dartType}>`, defaultValue: `List.filled(${node.size}, ${innerTypeInfo.defaultValue})`, @@ -190,7 +196,11 @@ function getAmountTypeInfo(): TypeInfo { }; } -export function getBorshAnnotation(typeNode: TypeNode, nameApi: NameApi): string { +function stripAnnotationPrefix(annotation: string): string { + return annotation.replace('@', ''); +} + +export function getBorshAnnotation(typeNode: Node, nameApi: NameApi, allDefinedTypes: DefinedTypeNode[]): string { switch (typeNode.kind) { case 'numberTypeNode': { const numberFormat = (typeNode as NumberTypeNode).format; @@ -232,41 +242,43 @@ export function getBorshAnnotation(typeNode: TypeNode, nameApi: NameApi): string case 'publicKeyTypeNode': return '@BPublicKey()'; case 'fixedSizeTypeNode': { - const fixedNode = typeNode as FixedSizeTypeNode; - const resolvedType = resolveNestedTypeNode(fixedNode.type); + const resolvedType = resolveNestedTypeNode(typeNode.type); if (resolvedType.kind === 'bytesTypeNode') { - return `@BFixedBytes(${fixedNode.size})`; + return `@BFixedBytes(${typeNode.size})`; } - const innerAnnotation = getBorshAnnotation(resolvedType, nameApi); - return `@BFixedArray(${fixedNode.size}, ${innerAnnotation.replace('@', '')})`; + const innerAnnotation = getBorshAnnotation(resolvedType, nameApi, allDefinedTypes); + return `@BFixedArray(${typeNode.size}, ${stripAnnotationPrefix(innerAnnotation)})`; } case 'arrayTypeNode': { - const arrayNode = typeNode as ArrayTypeNode; - const arrayInnerAnnotation = getBorshAnnotation(arrayNode.item, nameApi); + const arrayInnerAnnotation = getBorshAnnotation(typeNode.item, nameApi, allDefinedTypes); - if (arrayNode.count && arrayNode.count.kind === 'fixedCountNode') { - const size = arrayNode.count.value; - return `@BFixedArray(${size}, ${arrayInnerAnnotation.replace('@', '')})`; + if (typeNode.count && typeNode.count.kind === 'fixedCountNode') { + const size = typeNode.count.value; + return `@BFixedArray(${size}, ${stripAnnotationPrefix(arrayInnerAnnotation)})`; } - return `@BArray(${arrayInnerAnnotation.replace('@', '')})`; + return `@BArray(${stripAnnotationPrefix(arrayInnerAnnotation)})`; } case 'optionTypeNode': { - const optionNode = typeNode as OptionTypeNode; - const optionResolvedType = resolveNestedTypeNode(optionNode.item); - const optionInnerAnnotation = getBorshAnnotation(optionResolvedType, nameApi); - return `@BOption(${optionInnerAnnotation.replace('@', '')})`; + const optionResolvedType = resolveNestedTypeNode(typeNode.item); + const optionInnerAnnotation = getBorshAnnotation(optionResolvedType, nameApi, allDefinedTypes); + return `@BOption(${stripAnnotationPrefix(optionInnerAnnotation)})`; } case 'definedTypeLinkNode': { - const definedTypeNode = typeNode as DefinedTypeLinkNode; - const className = nameApi.definedType(definedTypeNode.name); + const className = nameApi.definedType(typeNode.name); + + const resolvedType = allDefinedTypes.find(dt => dt.name === typeNode.name); + if (resolvedType && resolvedType.type.kind === 'enumTypeNode') { + const variants = resolvedType.type.variants || []; + return getBEnumAnnotation(className, variants, nameApi); + } + return `@B${className}()`; } case 'sizePrefixTypeNode': { - const sizePrefixNode = typeNode as SizePrefixTypeNode; - const resolvedType = resolveNestedTypeNode(sizePrefixNode.type); + const resolvedType = resolveNestedTypeNode(typeNode.type); if (resolvedType.kind === 'bytesTypeNode') { return '@BBytes()'; @@ -279,22 +291,44 @@ export function getBorshAnnotation(typeNode: TypeNode, nameApi: NameApi): string } } -function getDefinedTypeLinkTypeInfo(node: DefinedTypeLinkNode, className: string): TypeInfo { - // Generate import path for the custom type - const importPath = `${TYPES_IMPORT_PREFIX}${node.name}.dart`; +function getDefinedTypeLinkTypeInfo( + node: DefinedTypeLinkNode, + className: string, + allDefinedTypes: DefinedTypeNode[], +): TypeInfo { + const imports: string[] = []; + const libName = 'lib'; + const typeDefinition = allDefinedTypes.find(dt => dt.name === node.name); + if (typeDefinition && typeDefinition.type.kind === 'enumTypeNode') { + const enumName = camelCase(node.name); + const enumType = typeDefinition.type as EnumTypeNode; + const variants = enumType.variants || []; + + imports.push(`package:${libName}/types/${enumName}/${enumName}.dart`); + + variants.forEach((variant: EnumVariantTypeNode) => { + const variantFileName = camelCase(variant.name); + imports.push(`package:${libName}/types/${enumName}/${variantFileName}.dart`); + }); + + return { + dartType: className, + defaultValue: `${className}()`, + imports, + }; + } + const importPath = `package:${libName}/types/${node.name}.dart`; return { dartType: className, - defaultValue: `${className}()`, // Assumes a default constructor + defaultValue: `${className}()`, imports: [importPath], }; } function getSizePrefixTypeInfo(node: SizePrefixTypeNode): TypeInfo { - // Resolve the nested type to see what we're prefixing const resolvedType = resolveNestedTypeNode(node.type); - // If it's bytes, return Uint8List if (resolvedType.kind === 'bytesTypeNode') { return { dartType: 'Uint8List', @@ -303,7 +337,6 @@ function getSizePrefixTypeInfo(node: SizePrefixTypeNode): TypeInfo { }; } - // For other size-prefixed types, default to String (like prefixed strings) return { dartType: 'String', defaultValue: "''", diff --git a/src/visitors/getRenderMapVisitor.ts b/src/visitors/getRenderMapVisitor.ts index 11566f0..8a70d0c 100644 --- a/src/visitors/getRenderMapVisitor.ts +++ b/src/visitors/getRenderMapVisitor.ts @@ -1,4 +1,12 @@ -import { camelCase, getAllInstructionsWithSubs, getAllPrograms, isNode } from '@codama/nodes'; +import { + camelCase, + DefinedTypeNode, + EnumTypeNode, + EnumVariantTypeNode, + getAllInstructionsWithSubs, + getAllPrograms, + isNode, +} from '@codama/nodes'; import { createRenderMap, mergeRenderMaps } from '@codama/renderers-core'; import { extendVisitor, @@ -15,6 +23,8 @@ import { import { getAccountPageFragment, + getEnumMainFragment, + getEnumVariantFragment, getErrorPageFragment, getInstructionPageFragment, getLibraryIndexFragment, @@ -27,24 +37,107 @@ import { createInlinePdaFile } from '../utils/pda'; export function getRenderMapVisitor(options: GetRenderMapOptions) { const linkables = new LinkableDictionary(); const stack = new NodeStack(); - - const byteSizeVisitor = getByteSizeVisitor(linkables, { stack }); const libraryName = 'lib'; - const outputDirectory = options.outputDirectory; + const byteSizeVisitor = getByteSizeVisitor(linkables, { stack }); + + const getProgramDefinedTypes = (): DefinedTypeNode[] => { + try { + const programNode = findProgramNodeFromPath(stack.getPath('definedTypeNode')); + return programNode?.definedTypes || []; + } catch { + return []; + } + }; - // Create the complete render scope const renderScope: RenderScope = { - libraryName, + definedTypes: getProgramDefinedTypes(), nameApi: getNameApi(options.nameTransformers), - outputDirectory, }; - const asPage = ( - fragment: TFragment, - pageOptions: { libraryName?: string } = {}, - ): TFragment => { + const asPage = (fragment: TFragment): TFragment => { if (!fragment) return undefined as TFragment; - return getPageFragment(fragment, pageOptions) as TFragment; + return getPageFragment(fragment) as TFragment; + }; + + const visitEnumType = (enumNode: DefinedTypeNode, programDefinedTypes: DefinedTypeNode[]) => { + const enumType = enumNode.type as EnumTypeNode; + const variants = enumType.variants || []; + const enumName = camelCase(enumNode.name); + const enumRenderScope = { ...renderScope, definedTypes: programDefinedTypes }; + + const variantRenderMaps = variants.map((variant: EnumVariantTypeNode) => { + switch (variant.kind) { + case 'enumEmptyVariantTypeNode': + return visitEnumEmptyVariantType(variant, enumName, enumRenderScope); + case 'enumStructVariantTypeNode': + return visitEnumStructVariantType(variant, enumName, enumRenderScope); + case 'enumTupleVariantTypeNode': + return visitEnumTupleVariantType(variant, enumName, enumRenderScope); + default: + return createRenderMap(); + } + }); + + const mainEnumRenderMap = createRenderMap( + `${libraryName}/types/${enumName}/${enumName}.dart`, + asPage( + getEnumMainFragment({ + ...enumRenderScope, + name: enumNode.name, + node: enumNode, + }), + ), + ); + + return mergeRenderMaps([mainEnumRenderMap, ...variantRenderMaps]); + }; + + const visitEnumEmptyVariantType = (variant: EnumVariantTypeNode, enumName: string, scope: RenderScope) => { + const variantName = renderScope.nameApi.definedType(camelCase(variant.name)); + const variantFileName = camelCase(variant.name); + + return createRenderMap( + `${libraryName}/types/${enumName}/${variantFileName}.dart`, + asPage( + getEnumVariantFragment({ + ...scope, + variant, + variantName, + }), + ), + ); + }; + + const visitEnumStructVariantType = (variant: EnumVariantTypeNode, enumName: string, scope: RenderScope) => { + const variantName = renderScope.nameApi.definedType(camelCase(variant.name)); + const variantFileName = camelCase(variant.name); + + return createRenderMap( + `${libraryName}/types/${enumName}/${variantFileName}.dart`, + asPage( + getEnumVariantFragment({ + ...scope, + variant, + variantName, + }), + ), + ); + }; + + const visitEnumTupleVariantType = (variant: EnumVariantTypeNode, enumName: string, scope: RenderScope) => { + const variantName = renderScope.nameApi.definedType(camelCase(variant.name)); + const variantFileName = camelCase(variant.name); + + return createRenderMap( + `${libraryName}/types/${enumName}/${variantFileName}.dart`, + asPage( + getEnumVariantFragment({ + ...scope, + variant, + variantName, + }), + ), + ); }; return pipe( @@ -68,18 +161,23 @@ export function getRenderMapVisitor(options: GetRenderMapOptions) { }, visitDefinedType(node) { + const programDefinedTypes = getProgramDefinedTypes(); + if (node.type.kind === 'structTypeNode') { return createRenderMap( `${libraryName}/types/${camelCase(node.name)}.dart`, asPage( getStructTypeFragment({ ...renderScope, + definedTypes: programDefinedTypes, name: node.name, node: node.type, size: visit(node, byteSizeVisitor), }), ), ); + } else if (node.type.kind === 'enumTypeNode') { + return visitEnumType(node, programDefinedTypes); } return createRenderMap(); }, @@ -134,14 +232,16 @@ export function getRenderMapVisitor(options: GetRenderMapOptions) { }, visitProgram(node, { self }) { + const programRenderScope = { ...renderScope, definedTypes: node.definedTypes }; + return mergeRenderMaps([ createRenderMap({ [`${libraryName}/programs/${camelCase(node.name)}.dart`]: asPage( - getProgramPageFragment({ ...renderScope, programNode: node }), + getProgramPageFragment({ ...programRenderScope, programNode: node }), ), [`${libraryName}/errors/${camelCase(node.name)}.dart`]: node.errors.length > 0 - ? asPage(getErrorPageFragment({ ...renderScope, programNode: node })) + ? asPage(getErrorPageFragment({ ...programRenderScope, programNode: node })) : undefined, }), ...node.pdas.map(p => visit(p, self)), diff --git a/src/visitors/renderVisitor.ts b/src/visitors/renderVisitor.ts index 6c12604..d6270ce 100644 --- a/src/visitors/renderVisitor.ts +++ b/src/visitors/renderVisitor.ts @@ -20,7 +20,6 @@ export function renderVisitor(path: string, options: RenderOptions) { cwd: path, stdio: 'ignore', }); - console.log('Dart formatting completed successfully.'); } catch (error) { console.warn( `Warning: Failed to format Dart code. Make sure Dart SDK is installed and accessible.: ${error instanceof Error ? error.message : String(error)}`, @@ -42,10 +41,11 @@ export function renderVisitor(path: string, options: RenderOptions) { cwd: path, stdio: 'ignore', }); - } catch (error) { - console.log(error); + } catch { console.warn('Warning: Failed to run Dart commands. Make sure Dart SDK is installed.'); - console.warn(`You can manually run commands in ${path}: dart pub get && dart run build_runner build && dart fix --apply`); + console.warn( + `You can manually run commands in ${path}: dart pub get && dart run build_runner build && dart fix --apply`, + ); } } });