Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 88 additions & 22 deletions packages/renderers-dart/src/getRenderMapVisitor.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import {
camelCase,
getAllAccounts,
getAllDefinedTypes,
getAllInstructionsWithSubs,
getAllPdas,
getAllPrograms,
pascalCase,
snakeCase,
resolveNestedTypeNode,
snakeCase,
structTypeNodeFromInstructionArgumentNodes,
camelCase,
} from '@codama/nodes';
import { RenderMap } from '@codama/renderers-core';
import {
Expand All @@ -25,7 +25,7 @@ import {
} from '@codama/visitors-core';
import { join } from 'path';

import { getTypeManifestVisitor, TypeManifest } from './getTypeManifestVisitor';
import TypeManifest, { getTypeManifestVisitor } from './getTypeManifestVisitor';
import { ImportMap } from './ImportMap';
import { extractDiscriminatorBytes, getImportFromFactory, LinkOverrides, render } from './utils';

Expand All @@ -35,30 +35,72 @@ export type GetRenderMapOptions = {
renderParentInstructions?: boolean;
};

function extractFieldsFromTypeManifest(typeManifest: TypeManifest): { name: string, type: string, field: string }[] {
// ' final Int64List i64_array;\n' +
// ' final Int8List /* length: 2 */ fixed_i8;\n' +
function extractFieldsFromTypeManifest(typeManifest: TypeManifest): {
baseType: string
field: string,
name: string,
nesting: number,
type: string,
}[] {
return typeManifest.type
.split('\n')
.map((line) => {
// That handles lines like: final Uint8List fieldName; and extracts the type and name in order to be used from borsh readers/writers
const match = line.trim().match(/^final\s+([\w<>, ?]+)\s+(\w+);$/);
if (match && match[2] !== 'discriminator') {
const isOptional = /\?$/.test(match[1]);
const rawType = match[1].replace(/\?$/, '').trim();

// Count nesting depth of List<>
let nesting = 0;
let inner = rawType;
const listRegex = /^List<(.+)>$/;

while (listRegex.test(inner)) {
nesting += 1;
inner = inner.replace(listRegex, '$1').trim();
}

return {
baseType: inner,
field: line,
name: match[2],
nesting,
optional: isOptional,
type: match[1].replace(/\?$/, ''),
field: line,
};
}
return null;
})
.filter((entry): entry is { name: string; type: string; field: string } => entry !== null);
.filter((entry): entry is {
baseType: string;
field: string;
name: string;
nesting: number;
optional: boolean;
type: string;
} => entry !== null);
}

// Given a type like List<SomeStruct> or Option<SomeStruct> or Set<SomeStruct>, it returns SomeStruct
function getBaseType(type: string): string {
const match = type.match(/^(?:List|Set|Option)<([\w\d_]+)>$/);
return match ? match[1] : type;
}

/**
* Returns a set of all struct type names defined in the given program node.
* Used for distinguishing user-defined struct types during code generation.
*/
function getAllDefinedTypesInNode(programNode: any): Set<string> {
if (!programNode) return new Set();
return new Set(getAllDefinedTypes(programNode).map(typeNode => pascalCase(typeNode.name)));
}

export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<RenderMap,
| 'rootNode'
| 'programNode'
| 'pdaNode'
| 'instructionNode'
| 'accountNode'
| 'definedTypeNode'
'accountNode' | 'definedTypeNode' | 'instructionNode' | 'pdaNode' | 'programNode' | 'rootNode'
> {
const linkables = new LinkableDictionary();
const stack = new NodeStack();
Expand Down Expand Up @@ -88,12 +130,22 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
visitAccount(node) {
const typeManifest = visit(node, typeManifestVisitor) as TypeManifest;
const { imports } = typeManifest;

imports.add('dartTypedData', new Set(['Uint8List', 'ByteData']));
imports.add('package:collection/collection.dart', new Set(['ListEquality']));
imports.add('package:solana/dto.dart', new Set(['AccountResult', 'BinaryAccountData']));
imports.add('package:solana/solana.dart', new Set(['RpcClient', 'Ed25519HDPublicKey']));
imports.add('../shared.dart', new Set(['BinaryReader', 'BinaryWriter', 'AccountNotFoundError']));

// Find the current program node for this account using the stack path.
// This allows us to get all user-defined struct types for codegen and type checks.
const programNode = findProgramNodeFromPath(stack.getPath('accountNode'));
const structTypeNames = getAllDefinedTypesInNode(programNode);

const fields = extractFieldsFromTypeManifest(typeManifest).map(field => ({
...field,
isStruct: structTypeNames.has(getBaseType(field.type))
}));

return new RenderMap().add(
`accounts/${snakeCase(node.name)}.dart`,
Expand All @@ -102,12 +154,12 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
...node,
discriminator: extractDiscriminatorBytes(node)
},
fields: fields,
getBaseType, // <-- Pass the function that is used to find if i have a collection
imports: imports
.remove(`generatedAccounts::${pascalCase(node.name)}`, [pascalCase(node.name)])
.toString(dependencyMap),
typeManifest,
fields: extractFieldsFromTypeManifest(typeManifest),

}),
);
},
Expand All @@ -116,10 +168,24 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
const typeManifest = visit(node, typeManifestVisitor) as TypeManifest;
const imports = new ImportMap().mergeWithManifest(typeManifest);

imports.add('../shared.dart', new Set(['BinaryReader', 'BinaryWriter', 'AccountNotFoundError']));
imports.add('dartTypedData', new Set(['Uint8List', 'ByteData']));

// This allows us to later distinguish between Object Data types and other types like string, int, arrays, etc
// So that when we generate the serialization logic we know which types needs special handling
const programNode = findProgramNodeFromPath(stack.getPath('definedTypeNode'));
const structTypeNames = getAllDefinedTypesInNode(programNode);

const fields = extractFieldsFromTypeManifest(typeManifest).map(field => ({
...field,
isStruct: structTypeNames.has(getBaseType(field.type))
}));

return new RenderMap().add(
`types/${snakeCase(node.name)}.dart`,
renderTemplate('definedTypesPage.njk', {
definedType: node,
fields: fields,
imports: imports.remove(`generatedTypes::${pascalCase(node.name)}`, [pascalCase(node.name)]).toString(dependencyMap),
typeManifest,
}),
Expand Down Expand Up @@ -161,21 +227,21 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
imports.mergeWith(argManifest.imports);
const rt = resolveNestedTypeNode(a.type);
return {
name: camelCase(a.name),
dartType: argManifest.type,
name: camelCase(a.name),
resolvedType: rt,
};
});

const context = {
args,
imports: importsString,
instruction: {
...node,
discriminator: extractDiscriminatorBytes(node),
},
args,
typeManifest: typeManifest || { nestedStructs: [] },
program: { name: pascalCase(programNode.name || '') },
typeManifest: typeManifest || { nestedStructs: [] },
};

return new RenderMap().add(
Expand Down Expand Up @@ -237,12 +303,12 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
definedTypesToExport.length > 0;

const ctx = {
programsToExport,
pdasToExport,
accountsToExport,
instructionsToExport,
definedTypesToExport,
hasAnythingToExport,
instructionsToExport,
pdasToExport,
programsToExport,
root: node,
};

Expand Down Expand Up @@ -282,4 +348,4 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
(v) => recordNodeStackVisitor(v, stack),
(v) => recordLinkablesOnFirstVisitVisitor(v, linkables),
);
}
}
Loading