Build type-safe finite state machines with higher-order states.
With Polystate, you can write 'state functions' that produce entirely new states, whose transitions are decided by a set of parameters. This enables composition: constructing states using other states, or even other state functions.
Download and add Polystate as a dependency by running the following command in your project root:
zig fetch --save git+https://github.com/sdzx-1/polystate.gitThen, retrieve the dependency in your build.zig:
const polystate = b.dependency("polystate", .{
.target = target,
.optimize = optimize,
});Finally, add the dependency's module to your module's imports:
exe_mod.addImport("polystate", polystate.module("root"));You should now be able to import Polystate in your module's code:
const ps = @import("polystate");Let's build a state machine that completes a simple task: capitalize all words in a string that contain an underscore.
Our state machine will contain three states: FindWord, CheckWord, and Capitalize:
FindWordfinds the start of a word.FindWordtransitions toCheckWordif it finds the start of a word.CheckWordchecks if an underscore exists in the word.CheckWordtransitions toCapitalizeif an underscore is found, or transitions back toFindWordif no underscore is found.Capitalizecapitalizes the word.Capitalizetransitions back toFindWordonce the word is capitalized.
Here's our state machine implemented with Polystate:
main.zig
const std = @import("std");
const ps = @import("polystate");
pub const FindWord = union(enum) {
to_check_word: CapsFsm(CheckWord),
exit: CapsFsm(ps.Exit),
no_transition: CapsFsm(FindWord),
pub fn handler(ctx: *Context) FindWord {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .no_transition;
},
else => {
ctx.word = ctx.string;
return .to_check_word;
},
}
}
};
pub const CheckWord = union(enum) {
to_find_word: CapsFsm(FindWord),
to_capitalize: CapsFsm(Capitalize),
exit: CapsFsm(ps.Exit),
no_transition: CapsFsm(CheckWord),
pub fn handler(ctx: *Context) CheckWord {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .to_find_word;
},
'_' => {
ctx.string = ctx.word;
return .to_capitalize;
},
else => {
ctx.string += 1;
return .no_transition;
},
}
}
};
pub const Capitalize = union(enum) {
to_find_word: CapsFsm(FindWord),
exit: CapsFsm(ps.Exit),
no_transition: CapsFsm(Capitalize),
pub fn handler(ctx: *Context) Capitalize {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .to_find_word;
},
else => {
ctx.string[0] = std.ascii.toUpper(ctx.string[0]);
ctx.string += 1;
return .no_transition;
},
}
}
};
pub const Context = struct {
string: [*:0]u8,
word: [*:0]u8,
pub fn init(string: [:0]u8) Context {
return .{
.string = string.ptr,
.word = string.ptr,
};
}
};
pub fn CapsFsm(comptime State: type) type {
return ps.FSM("Underscore Capitalizer", .not_suspendable, null, {}, State);
}
pub fn main() void {
const StartingFsmState = CapsFsm(FindWord);
const Runner = ps.Runner(true, StartingFsmState);
var string_backing =
\\capitalize_me
\\DontCapitalizeMe
\\ineedcaps_ _IAlsoNeedCaps idontneedcaps
\\_/\o_o/\_ <-- wide_eyed
.*;
const string: [:0]u8 = &string_backing;
var ctx: Context = .init(string);
const starting_state_id = Runner.idFromState(StartingFsmState.State);
std.debug.print("Without caps:\n{s}\n\n", .{string});
Runner.runHandler(starting_state_id, &ctx);
std.debug.print("With caps:\n{s}\n", .{string});
}As you can see, each of our states is represented by a tagged union. These unions have two main components: their fields and their handler function.
Rules for the fields:
- Each field represents one of the state's transitions.
- The type of a field describes the transition, primarily what the transitioned-to state will be.
- Field types must be generated by
ps.FSM, which wraps state union types and attaches additional information about the transition and its state machine. - For a single state machine's transitions,
ps.FSMmust always be given the same name and mode. In our case, we ensure this by wrappingps.FSMwithCapsFsm. InCapsFsm, the name is set to"Underscore Capitalizer"and the mode is set tonot_suspendable.
Rules for the handler function:
handlerexecutes the state's logic and determines which transition to take.handlertakes a context parameter (ctx), which points to mutable data that is shared across all states.handlerreturns a transition (one of the state's union fields).
Once we have defined the states of our state machine, we make a runner using ps.Runner. Just like our state's transition types, the starting state we pass into ps.Runner must be generated using ps.FSM, which we accomplish using our CapsFsm wrapper: CapsFsm(FindWord). Since our FSM's mode is set to not_suspendable, calling runHandler on our runner will run the state machine until completion (when the special ps.Exit state is reached).
runHandler also requires the 'state ID' of the state you want to start at. A runner provides both the StateId type and functions to convert between states and their ID. We use this to get the starting state ID: Runner.idFromState(StartingFsmState.State).
It may seem odd that we call idFromState with StartingFsmState.State instead of StartingFsmState, but this is because StartingFsmState is the wrapper type produced by ps.FSM, whereas StartingFsmState.State is the underlying state (FindWord). That's why we call it StartingFsmState instead of StartingState: the 'FSM' naming convention helps us remember that it's a wrapped state, and that we need to use its State declaration if we want the state it is wrapping.
In our previous example, our state machine's mode was not_suspendable. What if we set it to suspendable? Well, this would allow us to 'suspend' the execution of our state machine, run code outside of the state machine, and then resume the execution of our state machine.
However, suspendable adds an additional requirement to your state transitions: they must tell the state machine whether or not to suspend after transitioning.
This is our capitalization state machine, updated such that every time a word is chosen to be capitalized, we suspend execution and print the chosen word:
main.zig
const std = @import("std");
const ps = @import("polystate");
pub const FindWord = union(enum) {
to_check_word: CapsFsm(.current, CheckWord),
exit: CapsFsm(.current, ps.Exit),
no_transition: CapsFsm(.current, FindWord),
pub fn handler(ctx: *Context) FindWord {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .no_transition;
},
else => {
ctx.word = ctx.string;
return .to_check_word;
},
}
}
};
pub const CheckWord = union(enum) {
to_find_word: CapsFsm(.current, FindWord),
to_capitalize: CapsFsm(.next, Capitalize),
exit: CapsFsm(.current, ps.Exit),
no_transition: CapsFsm(.current, CheckWord),
pub fn handler(ctx: *Context) CheckWord {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .to_find_word;
},
'_' => {
ctx.string = ctx.word;
return .to_capitalize;
},
else => {
ctx.string += 1;
return .no_transition;
},
}
}
};
pub const Capitalize = union(enum) {
to_find_word: CapsFsm(.current, FindWord),
exit: CapsFsm(.current, ps.Exit),
no_transition: CapsFsm(.current, Capitalize),
pub fn handler(ctx: *Context) Capitalize {
switch (ctx.string[0]) {
0 => return .exit,
' ', '\t'...'\r' => {
ctx.string += 1;
return .to_find_word;
},
else => {
ctx.string[0] = std.ascii.toUpper(ctx.string[0]);
ctx.string += 1;
return .no_transition;
},
}
}
};
pub const Context = struct {
string: [*:0]u8,
word: [*:0]u8,
pub fn init(string: [:0]u8) Context {
return .{
.string = string.ptr,
.word = string.ptr,
};
}
};
pub fn CapsFsm(comptime method: ps.Method, comptime State: type) type {
return ps.FSM("Underscore Capitalizer", .suspendable, null, method, State);
}
pub fn main() void {
const StartingFsmState = CapsFsm(.current, FindWord);
const Runner = ps.Runner(true, StartingFsmState);
var string_backing =
\\capitalize_me
\\DontCapitalizeMe
\\ineedcaps_ _IAlsoNeedCaps idontneedcaps
\\_/\o_o/\_ <-- wide_eyed
.*;
const string: [:0]u8 = &string_backing;
var ctx: Context = .init(string);
std.debug.print("Without caps:\n{s}\n\n", .{string});
var state_id = Runner.idFromState(StartingFsmState.State);
while (Runner.runHandler(state_id, &ctx)) |new_state_id| {
state_id = new_state_id;
var word_len: usize = 0;
while (ctx.word[word_len] != 0 and !std.ascii.isWhitespace(ctx.word[word_len])) {
word_len += 1;
}
const word = ctx.word[0..word_len];
std.debug.print("capitalizing word: {s}\n", .{word});
}
std.debug.print("\nWith caps:\n{s}\n", .{string});
}We've updated our CapsFsm wrapper to take an additional parameter of type ps.Method, which has two possible values: current and next.
- If a transition has the method
current, the state machine will continue execution after transitioning. - If a transition has the method
next, the state machine will suspend execution after transitioning.
A transition's ps.Method basically answers the following question: "Should I set this new state as my current state and keep going (current), or save this new state for the next execution (next)?".
In addition to updating our state transitions with current or next, we also need to change how we use runHandler.
Before, since our state machine was not_suspendable, runHandler didn't return anything and only needed to be called once. Now, since our state machine is suspendable, runHandler only runs the state machine until it is suspended, and returns the ID of the state it was suspended on.
So, we now call runHandler in a loop, passing in the current state ID and using the result as the new state ID. We continue this until runHandler returns null, indicating that the state machine has completed (reached ps.Exit).
If you've read the previous sections where we cover the basics of Polystate, you may feel like it's a bit overkill to use a library instead of just implementing your FSM manually. After all, it can seem like Polystate does little more than provide a convenient framework for structuring state machines.
This changes when you start using higher-order states.
A higher-order state is a function that takes states as parameters and returns a new state, AKA a 'state function'. Since states are represented as types (specifically, tagged unions), a state function is no different than any other Zig generic: a type-returning function that takes types as parameters.
While being simple at their core, higher-order states allow endless ways to construct, compose, and re-use transition logic among your states. We will demonstrate this by expanding on our previous state machine, making it perform various new word-processing operations.
Our capitalization state machine was designed with one purpose: capitalize words with underscores. Sure you could tweak it, splicing in new states to add additional processing steps... only for it to decline into a spaghetti-like mess, becoming less decipherable with each new state. If we want to make our state machine truly extendable, we'll need to try another approach.
Let's make a new utility called Words that will provide us with several state functions, producing states that operate on a WordsContext. These state functions will allow us to construct word operations, expressing in a consistent manner both the word mutations and the conditions in which those mutations occur.
As a bonus, Words itself will be produced by a generic function, automatically specializing its state functions based on its parameters. This allows Words to be used in any state machine, and even lets you create multiple independent instances of Words in one state machine, each one having its own WordsContext!
Our new state machine will demonstrate the capabilities of Words, using two independent instances of Words to do the following:
- Capitalize all words in string 1 that contain an underscore or are palindromes.
- Reverse all words in string 2 that contain a vowel.
Here's the implementation:
main.zig
const std = @import("std");
const ps = @import("polystate");
pub fn Words(
comptime Fsm: fn (State: type) type,
comptime ParentContext: type,
comptime ctx_field: std.meta.FieldEnum(ParentContext),
) type {
return struct {
pub fn IterateWords(
comptime WordOperation: fn (Next: type) type,
comptime NoWordsLeft: type,
) type {
return union(enum) {
to_inner: Fsm(IterateWordsInner(WordOperation(@This()), NoWordsLeft)),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
ctx.head = ctx.word_end;
return .to_inner;
}
};
}
pub fn IterateWordsInner(
comptime FoundWord: type,
comptime NoWordsLeft: type,
) type {
return union(enum) {
to_find_word: Fsm(FindWord(FoundWord)),
to_no_words_left: Fsm(NoWordsLeft),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.word_end >= ctx.string.len) {
return .to_no_words_left;
} else {
return .to_find_word;
}
}
};
}
pub fn FindWord(comptime Next: type) type {
return union(enum) {
to_find_word_end: Fsm(FindWordEnd(Next)),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.head >= ctx.string.len) {
ctx.word_start = ctx.head;
ctx.word_end = ctx.head;
return .to_find_word_end;
}
switch (ctx.string[ctx.head]) {
' ', '\t'...'\r' => {
ctx.head += 1;
return .no_transition;
},
else => {
ctx.word_start = ctx.head;
ctx.word_end = ctx.head;
return .to_find_word_end;
},
}
}
};
}
pub fn FindWordEnd(comptime Next: type) type {
return union(enum) {
to_next: Fsm(Next),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.word_end >= ctx.string.len) {
return .to_next;
}
switch (ctx.string[ctx.word_end]) {
' ', '\t'...'\r' => return .to_next,
else => {
ctx.word_end += 1;
return .no_transition;
},
}
}
};
}
pub fn CharMutation(
comptime Next: type,
mutateChar: fn (char: u8) u8,
) type {
return union(enum) {
to_inner: Fsm(CharMutationInner(Next, mutateChar)),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
ctx.head = ctx.word_start;
return .to_inner;
}
};
}
pub fn CharMutationInner(
comptime Next: type,
mutateChar: fn (char: u8) u8,
) type {
return union(enum) {
to_next: Fsm(Next),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.head >= ctx.word_end) {
return .to_next;
} else {
ctx.string[ctx.head] = mutateChar(ctx.string[ctx.head]);
ctx.head += 1;
return .no_transition;
}
}
};
}
pub fn Reverse(comptime Next: type) type {
return union(enum) {
to_inner: Fsm(ReverseInner(Next)),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
ctx.tail = ctx.word_start;
ctx.head = ctx.word_end - 1;
return .to_inner;
}
};
}
pub fn ReverseInner(comptime Next: type) type {
return union(enum) {
to_next: Fsm(Next),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.tail >= ctx.head) {
return .to_next;
} else {
const temp = ctx.string[ctx.tail];
ctx.string[ctx.tail] = ctx.string[ctx.head];
ctx.string[ctx.head] = temp;
ctx.tail += 1;
ctx.head -= 1;
return .no_transition;
}
}
};
}
pub fn CharFilter(
comptime Pass: type,
comptime Fail: type,
comptime predicate: fn (char: u8) bool,
) type {
return union(enum) {
to_inner: Fsm(CharFilterInner(Pass, Fail, predicate)),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
ctx.head = ctx.word_start;
return .to_inner;
}
};
}
pub fn CharFilterInner(
comptime Pass: type,
comptime Fail: type,
comptime predicate: fn (char: u8) bool,
) type {
return union(enum) {
to_pass: Fsm(Pass),
to_fail: Fsm(Fail),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.head >= ctx.word_end) {
return .to_fail;
}
if (predicate(ctx.string[ctx.head])) {
return .to_pass;
} else {
ctx.head += 1;
return .no_transition;
}
}
};
}
pub fn PalindromeFilter(
comptime Pass: type,
comptime Fail: type,
) type {
return union(enum) {
to_inner: Fsm(PalindromeFilterInner(Pass, Fail)),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
ctx.tail = ctx.word_start;
ctx.head = ctx.word_end - 1;
return .to_inner;
}
};
}
pub fn PalindromeFilterInner(
comptime Pass: type,
comptime Fail: type,
) type {
return union(enum) {
to_pass: Fsm(Pass),
to_fail: Fsm(Fail),
no_transition: Fsm(@This()),
pub fn handler(parent_ctx: *ParentContext) @This() {
const ctx = ctxFromParent(parent_ctx);
if (ctx.tail >= ctx.head) {
return .to_pass;
} else if (ctx.string[ctx.tail] != ctx.string[ctx.head]) {
return .to_fail;
} else {
ctx.tail += 1;
ctx.head -= 1;
return .no_transition;
}
}
};
}
fn ctxFromParent(parent_ctx: *ParentContext) *WordsContext {
return &@field(parent_ctx, @tagName(ctx_field));
}
};
}
pub const WordsContext = struct {
string: []u8,
head: usize,
word_start: usize,
word_end: usize,
tail: usize,
pub fn init(string: []u8) WordsContext {
return .{
.string = string,
.head = 0,
.word_start = 0,
.word_end = 0,
.tail = 0,
};
}
};
pub const Context = struct {
string1_ctx: WordsContext,
string2_ctx: WordsContext,
pub fn init(string1: []u8, string2: []u8) Context {
return .{
.string1_ctx = .init(string1),
.string2_ctx = .init(string2),
};
}
};
pub fn CapsFsm(comptime State: type) type {
return ps.FSM("Word Processor", .not_suspendable, null, {}, State);
}
const string1_states = struct {
const W = Words(CapsFsm, Context, .string1_ctx);
fn isUnderscore(char: u8) bool {
return char == '_';
}
fn capitalize(char: u8) u8 {
return std.ascii.toUpper(char);
}
pub fn UnderscoreFilter(comptime Pass: type, comptime Fail: type) type {
return W.CharFilter(Pass, Fail, isUnderscore);
}
pub fn Capitalize(comptime Next: type) type {
return W.CharMutation(Next, capitalize);
}
pub fn UnderscoreOrPalindromeFilter(comptime Pass: type, comptime Fail: type) type {
return UnderscoreFilter(
Pass,
W.PalindromeFilter(Pass, Fail),
);
}
pub fn CapitalizeUnderscoreOrPalindromeWord(comptime Next: type) type {
return UnderscoreOrPalindromeFilter(
Capitalize(Next),
Next,
);
}
pub fn CapitalizeUnderscoreOrPalindromeWords(comptime Next: type) type {
return W.IterateWords(
CapitalizeUnderscoreOrPalindromeWord,
Next,
);
}
};
const string2_states = struct {
const W = Words(CapsFsm, Context, .string2_ctx);
fn isVowel(char: u8) bool {
return switch (char) {
'a', 'e', 'i', 'o', 'u', 'A', 'E', 'I', 'O', 'U' => true,
else => false,
};
}
pub fn VowelFilter(comptime Pass: type, comptime Fail: type) type {
return W.CharFilter(Pass, Fail, isVowel);
}
pub fn ReverseWordWithVowel(comptime Next: type) type {
return VowelFilter(
W.Reverse(Next),
Next,
);
}
pub fn ReverseVowelWords(comptime Next: type) type {
return W.IterateWords(
ReverseWordWithVowel,
Next,
);
}
};
pub const EnterFsmState = CapsFsm(
string1_states.CapitalizeUnderscoreOrPalindromeWords(
string2_states.ReverseVowelWords(ps.Exit),
),
);
pub fn main() !void {
const Runner = ps.Runner(true, EnterFsmState);
var string1_backing =
\\capitalize_me
\\DontCapitalizeMe
\\ineedcaps_ _IAlsoNeedCaps idontneedcaps
\\_/\o_o/\_ <-- wide_eyed
\\tacocat 123Hello--olleH321
.*;
var string2_backing =
\\apple gym cry
\\elephant pfft sphinx
\\amazing fly grr
.*;
var ctx: Context = .init(&string1_backing, &string2_backing);
const starting_state_id = Runner.idFromState(EnterFsmState.State);
std.debug.print("Before processing:\n", .{});
std.debug.print("String 1: {s}\n", .{string1_backing});
std.debug.print("String 2: {s}\n\n", .{string2_backing});
Runner.runHandler(starting_state_id, &ctx);
std.debug.print("After processing:\n", .{});
std.debug.print("String 1: {s}\n", .{string1_backing});
std.debug.print("String 2: {s}\n", .{string2_backing});
}
<< TODO: finish README >>