Skip to content

Commit 1fcb57e

Browse files
authored
[NFC] Refactor out subtyping discovery code (#6106)
This implements an idea I mentioned in the past, to extract the subtyping discovery code out of Unsubtyping so it could be reused elsewhere. Example possible uses: the validator could use to remove a lot of code, and also a future PR of mine will need it. Separately from those, I think this is a nice refactoring as it makes Unsubtyping much smaller. This just moves the code out and adds some C++ template elbow grease as needed.
1 parent bf76357 commit 1fcb57e

File tree

2 files changed

+358
-259
lines changed

2 files changed

+358
-259
lines changed

src/ir/subtype-exprs.h

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
/*
2+
* Copyright 2023 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef wasm_ir_subtype_exprs_h
18+
#define wasm_ir_subtype_exprs_h
19+
20+
#include "ir/branch-utils.h"
21+
#include "wasm-traversal.h"
22+
#include "wasm.h"
23+
24+
namespace wasm {
25+
26+
//
27+
// Analyze subtyping relationships between expressions. This must CRTP with a
28+
// class that implements:
29+
//
30+
// * noteSubType(A, B) indicating A must be a subtype of B
31+
// * noteCast(A, B) indicating A is cast to B
32+
//
33+
// There must be multiple versions of each of those, supporting A and B being
34+
// either a Type, which indicates a fixed type requirement, or an Expression*,
35+
// indicating a flexible requirement that depends on the type of that
36+
// expression. Specifically:
37+
//
38+
// * noteSubType(Type, Type) - A constraint not involving expressions at all,
39+
// for example, an element segment's type must be
40+
// a subtype of the corresponding table's.
41+
// * noteSubType(HeapType, HeapType) - Ditto, with heap types, for example in a
42+
// CallIndirect.
43+
// * noteSubType(Type, Expression) - A fixed type must be a subtype of an
44+
// expression's type, for example, in BrOn
45+
// (the declared sent type must be a subtype
46+
// of the block we branch to).
47+
// * noteSubType(Expression, Type) - An expression's type must be a subtype of
48+
// a fixed type, for example, a Call operand
49+
// must be a subtype of the signature's
50+
// param.
51+
// * noteSubType(Expression, Expression) - An expression's type must be a
52+
// subtype of anothers, for example,
53+
// a block and its last child.
54+
//
55+
// * noteCast(HeapType, HeapType) - A fixed type is cast to another, for
56+
// example, in a CallIndirect.
57+
// * noteCast(Expression, Type) - An expression's type is cast to a fixed type,
58+
// for example, in RefTest.
59+
// * noteCast(Expression, Expression) - An expression's type is cast to
60+
// another, for example, in RefCast.
61+
//
62+
// Note that noteCast(Type, Type) and noteCast(Type, Expression) never occur and
63+
// do not need to be implemented.
64+
//
65+
// The class must also inherit from ControlFlowWalker (for findBreakTarget).
66+
//
67+
68+
template<typename SubType>
69+
struct SubtypingDiscoverer : public OverriddenVisitor<SubType> {
70+
SubType* self() { return static_cast<SubType*>(this); }
71+
72+
void visitFunction(Function* func) {
73+
if (func->body) {
74+
self()->noteSubtype(func->body, func->getResults());
75+
}
76+
}
77+
void visitGlobal(Global* global) {
78+
if (global->init) {
79+
self()->noteSubtype(global->init, global->type);
80+
}
81+
}
82+
void visitElementSegment(ElementSegment* seg) {
83+
if (seg->offset) {
84+
self()->noteSubtype(seg->type,
85+
self()->getModule()->getTable(seg->table)->type);
86+
}
87+
for (auto init : seg->data) {
88+
self()->noteSubtype(init->type, seg->type);
89+
}
90+
}
91+
void visitNop(Nop* curr) {}
92+
void visitBlock(Block* curr) {
93+
if (!curr->list.empty()) {
94+
self()->noteSubtype(curr->list.back(), curr);
95+
}
96+
}
97+
void visitIf(If* curr) {
98+
if (curr->ifFalse) {
99+
self()->noteSubtype(curr->ifTrue, curr);
100+
self()->noteSubtype(curr->ifFalse, curr);
101+
}
102+
}
103+
void visitLoop(Loop* curr) { self()->noteSubtype(curr->body, curr); }
104+
void visitBreak(Break* curr) {
105+
if (curr->value) {
106+
self()->noteSubtype(curr->value, self()->findBreakTarget(curr->name));
107+
}
108+
}
109+
void visitSwitch(Switch* curr) {
110+
if (curr->value) {
111+
for (auto name : BranchUtils::getUniqueTargets(curr)) {
112+
self()->noteSubtype(curr->value, self()->findBreakTarget(name));
113+
}
114+
}
115+
}
116+
template<typename T> void handleCall(T* curr, Signature sig) {
117+
assert(curr->operands.size() == sig.params.size());
118+
for (size_t i = 0, size = sig.params.size(); i < size; ++i) {
119+
self()->noteSubtype(curr->operands[i], sig.params[i]);
120+
}
121+
if (curr->isReturn) {
122+
self()->noteSubtype(sig.results, self()->getFunction()->getResults());
123+
}
124+
}
125+
void visitCall(Call* curr) {
126+
handleCall(curr, self()->getModule()->getFunction(curr->target)->getSig());
127+
}
128+
void visitCallIndirect(CallIndirect* curr) {
129+
handleCall(curr, curr->heapType.getSignature());
130+
auto* table = self()->getModule()->getTable(curr->table);
131+
auto tableType = table->type.getHeapType();
132+
if (HeapType::isSubType(tableType, curr->heapType)) {
133+
// Unlike other casts, where cast targets are always subtypes of cast
134+
// sources, call_indirect target types may be supertypes of their source
135+
// table types. In this case, the cast will always succeed, but only if we
136+
// keep the types related.
137+
self()->noteSubtype(tableType, curr->heapType);
138+
} else if (HeapType::isSubType(curr->heapType, tableType)) {
139+
self()->noteCast(tableType, curr->heapType);
140+
} else {
141+
// The types are unrelated and the cast will fail. We can keep the types
142+
// unrelated.
143+
}
144+
}
145+
void visitLocalGet(LocalGet* curr) {}
146+
void visitLocalSet(LocalSet* curr) {
147+
self()->noteSubtype(curr->value,
148+
self()->getFunction()->getLocalType(curr->index));
149+
}
150+
void visitGlobalGet(GlobalGet* curr) {}
151+
void visitGlobalSet(GlobalSet* curr) {
152+
self()->noteSubtype(curr->value,
153+
self()->getModule()->getGlobal(curr->name)->type);
154+
}
155+
void visitLoad(Load* curr) {}
156+
void visitStore(Store* curr) {}
157+
void visitAtomicRMW(AtomicRMW* curr) {}
158+
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {}
159+
void visitAtomicWait(AtomicWait* curr) {}
160+
void visitAtomicNotify(AtomicNotify* curr) {}
161+
void visitAtomicFence(AtomicFence* curr) {}
162+
void visitSIMDExtract(SIMDExtract* curr) {}
163+
void visitSIMDReplace(SIMDReplace* curr) {}
164+
void visitSIMDShuffle(SIMDShuffle* curr) {}
165+
void visitSIMDTernary(SIMDTernary* curr) {}
166+
void visitSIMDShift(SIMDShift* curr) {}
167+
void visitSIMDLoad(SIMDLoad* curr) {}
168+
void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) {}
169+
void visitMemoryInit(MemoryInit* curr) {}
170+
void visitDataDrop(DataDrop* curr) {}
171+
void visitMemoryCopy(MemoryCopy* curr) {}
172+
void visitMemoryFill(MemoryFill* curr) {}
173+
void visitConst(Const* curr) {}
174+
void visitUnary(Unary* curr) {}
175+
void visitBinary(Binary* curr) {}
176+
void visitSelect(Select* curr) {
177+
self()->noteSubtype(curr->ifTrue, curr);
178+
self()->noteSubtype(curr->ifFalse, curr);
179+
}
180+
void visitDrop(Drop* curr) {}
181+
void visitReturn(Return* curr) {
182+
if (curr->value) {
183+
self()->noteSubtype(curr->value, self()->getFunction()->getResults());
184+
}
185+
}
186+
void visitMemorySize(MemorySize* curr) {}
187+
void visitMemoryGrow(MemoryGrow* curr) {}
188+
void visitUnreachable(Unreachable* curr) {}
189+
void visitPop(Pop* curr) {}
190+
void visitRefNull(RefNull* curr) {}
191+
void visitRefIsNull(RefIsNull* curr) {}
192+
void visitRefFunc(RefFunc* curr) {}
193+
void visitRefEq(RefEq* curr) {}
194+
void visitTableGet(TableGet* curr) {}
195+
void visitTableSet(TableSet* curr) {
196+
self()->noteSubtype(curr->value,
197+
self()->getModule()->getTable(curr->table)->type);
198+
}
199+
void visitTableSize(TableSize* curr) {}
200+
void visitTableGrow(TableGrow* curr) {}
201+
void visitTableFill(TableFill* curr) {
202+
self()->noteSubtype(curr->value,
203+
self()->getModule()->getTable(curr->table)->type);
204+
}
205+
void visitTableCopy(TableCopy* curr) {
206+
self()->noteSubtype(self()->getModule()->getTable(curr->sourceTable)->type,
207+
self()->getModule()->getTable(curr->destTable)->type);
208+
}
209+
void visitTry(Try* curr) {
210+
self()->noteSubtype(curr->body, curr);
211+
for (auto* body : curr->catchBodies) {
212+
self()->noteSubtype(body, curr);
213+
}
214+
}
215+
void visitThrow(Throw* curr) {
216+
Type params = self()->getModule()->getTag(curr->tag)->sig.params;
217+
assert(params.size() == curr->operands.size());
218+
for (size_t i = 0, size = curr->operands.size(); i < size; ++i) {
219+
self()->noteSubtype(curr->operands[i], params[i]);
220+
}
221+
}
222+
void visitRethrow(Rethrow* curr) {}
223+
void visitTupleMake(TupleMake* curr) {}
224+
void visitTupleExtract(TupleExtract* curr) {}
225+
void visitRefI31(RefI31* curr) {}
226+
void visitI31Get(I31Get* curr) {}
227+
void visitCallRef(CallRef* curr) {
228+
if (!curr->target->type.isSignature()) {
229+
return;
230+
}
231+
handleCall(curr, curr->target->type.getHeapType().getSignature());
232+
}
233+
void visitRefTest(RefTest* curr) {
234+
self()->noteCast(curr->ref, curr->castType);
235+
}
236+
void visitRefCast(RefCast* curr) { self()->noteCast(curr->ref, curr); }
237+
void visitBrOn(BrOn* curr) {
238+
if (curr->op == BrOnCast || curr->op == BrOnCastFail) {
239+
self()->noteCast(curr->ref, curr->castType);
240+
}
241+
self()->noteSubtype(curr->getSentType(),
242+
self()->findBreakTarget(curr->name));
243+
}
244+
void visitStructNew(StructNew* curr) {
245+
if (!curr->type.isStruct() || curr->isWithDefault()) {
246+
return;
247+
}
248+
const auto& fields = curr->type.getHeapType().getStruct().fields;
249+
assert(fields.size() == curr->operands.size());
250+
for (size_t i = 0, size = fields.size(); i < size; ++i) {
251+
self()->noteSubtype(curr->operands[i], fields[i].type);
252+
}
253+
}
254+
void visitStructGet(StructGet* curr) {}
255+
void visitStructSet(StructSet* curr) {
256+
if (!curr->ref->type.isStruct()) {
257+
return;
258+
}
259+
const auto& fields = curr->ref->type.getHeapType().getStruct().fields;
260+
self()->noteSubtype(curr->value, fields[curr->index].type);
261+
}
262+
void visitArrayNew(ArrayNew* curr) {
263+
if (!curr->type.isArray() || curr->isWithDefault()) {
264+
return;
265+
}
266+
auto array = curr->type.getHeapType().getArray();
267+
self()->noteSubtype(curr->init, array.element.type);
268+
}
269+
void visitArrayNewData(ArrayNewData* curr) {}
270+
void visitArrayNewElem(ArrayNewElem* curr) {
271+
if (!curr->type.isArray()) {
272+
return;
273+
}
274+
auto array = curr->type.getHeapType().getArray();
275+
auto* seg = self()->getModule()->getElementSegment(curr->segment);
276+
self()->noteSubtype(seg->type, array.element.type);
277+
}
278+
void visitArrayNewFixed(ArrayNewFixed* curr) {
279+
if (!curr->type.isArray()) {
280+
return;
281+
}
282+
auto array = curr->type.getHeapType().getArray();
283+
for (auto* value : curr->values) {
284+
self()->noteSubtype(value, array.element.type);
285+
}
286+
}
287+
void visitArrayGet(ArrayGet* curr) {}
288+
void visitArraySet(ArraySet* curr) {
289+
if (!curr->ref->type.isArray()) {
290+
return;
291+
}
292+
auto array = curr->ref->type.getHeapType().getArray();
293+
self()->noteSubtype(curr->value->type, array.element.type);
294+
}
295+
void visitArrayLen(ArrayLen* curr) {}
296+
void visitArrayCopy(ArrayCopy* curr) {
297+
if (!curr->srcRef->type.isArray() || !curr->destRef->type.isArray()) {
298+
return;
299+
}
300+
auto src = curr->srcRef->type.getHeapType().getArray();
301+
auto dest = curr->destRef->type.getHeapType().getArray();
302+
self()->noteSubtype(src.element.type, dest.element.type);
303+
}
304+
void visitArrayFill(ArrayFill* curr) {
305+
if (!curr->ref->type.isArray()) {
306+
return;
307+
}
308+
auto array = curr->ref->type.getHeapType().getArray();
309+
self()->noteSubtype(curr->value->type, array.element.type);
310+
}
311+
void visitArrayInitData(ArrayInitData* curr) {}
312+
void visitArrayInitElem(ArrayInitElem* curr) {
313+
if (!curr->ref->type.isArray()) {
314+
return;
315+
}
316+
auto array = curr->ref->type.getHeapType().getArray();
317+
auto* seg = self()->getModule()->getElementSegment(curr->segment);
318+
self()->noteSubtype(seg->type, array.element.type);
319+
}
320+
void visitRefAs(RefAs* curr) {}
321+
void visitStringNew(StringNew* curr) {}
322+
void visitStringConst(StringConst* curr) {}
323+
void visitStringMeasure(StringMeasure* curr) {}
324+
void visitStringEncode(StringEncode* curr) {}
325+
void visitStringConcat(StringConcat* curr) {}
326+
void visitStringEq(StringEq* curr) {}
327+
void visitStringAs(StringAs* curr) {}
328+
void visitStringWTF8Advance(StringWTF8Advance* curr) {}
329+
void visitStringWTF16Get(StringWTF16Get* curr) {}
330+
void visitStringIterNext(StringIterNext* curr) {}
331+
void visitStringIterMove(StringIterMove* curr) {}
332+
void visitStringSliceWTF(StringSliceWTF* curr) {}
333+
void visitStringSliceIter(StringSliceIter* curr) {}
334+
};
335+
336+
} // namespace wasm
337+
338+
#endif // #define wasm_ir_subtype_exprs_h

0 commit comments

Comments
 (0)