-
Notifications
You must be signed in to change notification settings - Fork 825
Add a simple tuple optimization pass #5937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 60 commits
eead1ca
28f682d
2d5008f
6b9e063
cf54d39
0a576b9
0330419
2cbffd9
28e1f46
b9789a5
9b776f0
2f45444
377a7b3
8e55f0c
554d6e2
52a98d7
ddd2993
9c43ac7
d80db80
b06d868
148d5c4
44b2dde
9b9eef0
8841eed
59ba402
4cfb2bb
180a98e
684a01e
ed20ca8
6d1089b
561a02a
f0a4abd
7a843b6
d7c6a09
8c90cd1
6ad3321
d5e67f8
cff93f2
f254771
73f65e1
293c686
a402d78
1ceb35a
794e411
0749fd8
6b7631b
4aad2b4
09bd7f7
29fbf3b
67de135
89cf065
3917cac
e5bd82b
0d34707
48af11d
60829ba
3b5ea6a
2ece399
7af5549
57c7c43
902e0d3
c538850
2640b2a
8fd65be
9dc8512
f736d2f
0230bed
a53d3f5
958ed9a
9cd8e92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,357 @@ | ||
| /* | ||
| * Copyright 2023 WebAssembly Community Group participants | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // | ||
| // Optimize away trivial tuples. When values are bundled together in a tuple, we | ||
| // are limited in how we can optimize then in the various local-related passes, | ||
| // like this: | ||
| // | ||
| // (local.set $tuple | ||
| // (tuple.make (A) (B) (C))) | ||
| // (use | ||
| // (tuple.extract 0 | ||
| // (local.get $tuple))) | ||
| // | ||
| // If there are no other uses, then we just need one of the three lanes. By | ||
| // lowing them to three separate locals, other passes can remove the other two. | ||
| // | ||
| // Specifically, this pass seeks out tuple locals that have these properties: | ||
| // | ||
| // * They are always written either a tuple.make or another tuple local with | ||
| // these properties. | ||
| // * They are always used either in tuple.extract or they are copied to another | ||
| // tuple local with these properties. | ||
| // | ||
| // The set of those tuple locals can be easily optimized into individual locals, | ||
| // as the tuple does not "escape" into, say, a return value. | ||
| // | ||
| // TODO: Blocks etc. might be handled here, but it's not clear if we want to: | ||
| // there are situations where multivalue leads to smaller code using | ||
| // those constructs. Atm this pass should only remove things that are | ||
| // definitely worth lowering. | ||
| // | ||
|
|
||
| #include <pass.h> | ||
| #include <support/unique_deferring_queue.h> | ||
| #include <wasm-builder.h> | ||
| #include <wasm.h> | ||
|
|
||
| namespace wasm { | ||
|
|
||
| struct TupleOptimization : public WalkerPass<PostWalker<TupleOptimization>> { | ||
| bool isFunctionParallel() override { return true; } | ||
|
|
||
| std::unique_ptr<Pass> create() override { | ||
| return std::make_unique<TupleOptimization>(); | ||
| } | ||
|
|
||
| // Track the number of uses for each tuple local. We consider a use as a | ||
| // local.get, a set, or a tee. A tee counts as two uses (since it both sets | ||
| // and gets, and so we must see that it is both used and uses properly). | ||
| std::vector<Index> uses; | ||
|
|
||
| // Tracks which tuple local uses are valid, that is, follow the properties | ||
| // above. If we have more uses than valid uses then we must have an invalid | ||
| // one, and the local cannot be optimized. | ||
| std::vector<Index> validUses; | ||
|
|
||
| // When one tuple local copies the value of another, we need to track the | ||
| // index that was copied, as if the source ends up bad then the target is bad | ||
| // as well. | ||
| // | ||
| // This is a symmetrical map, that is, we consider copies to work both ways: | ||
| // | ||
| // x \in copiedIndexed[y] <==> y \in copiedIndexed[x] | ||
| // | ||
| std::vector<std::unordered_set<Index>> copiedIndexes; | ||
|
|
||
| void doWalkFunction(Function* func) { | ||
| // If tuples are not enabled, or there are no tuple locals, then there is no | ||
| // work to do. | ||
| if (!getModule()->features.hasMultivalue()) { | ||
| return; | ||
| } | ||
| bool hasTuple = false; | ||
| for (auto var : func->vars) { | ||
| if (var.isTuple()) { | ||
| hasTuple = true; | ||
| break; | ||
| } | ||
| } | ||
| if (!hasTuple) { | ||
| return; | ||
| } | ||
|
|
||
| // Prepare global data structures before we collect info. | ||
| auto numLocals = func->getNumLocals(); | ||
| uses.resize(numLocals); | ||
| validUses.resize(numLocals); | ||
| copiedIndexes.resize(numLocals); | ||
|
|
||
| // Walk the code to collect info. | ||
| super::doWalkFunction(func); | ||
|
|
||
| // Analyze and optimize. | ||
| optimize(func); | ||
| } | ||
|
|
||
| void visitLocalGet(LocalGet* curr) { | ||
| if (curr->type.isTuple()) { | ||
| uses[curr->index]++; | ||
| } | ||
| } | ||
|
|
||
| void visitLocalSet(LocalSet* curr) { | ||
| if (getFunction()->getLocalType(curr->index).isTuple()) { | ||
| // See comment above about tees (we consider their set and get each a | ||
| // separate use). | ||
| uses[curr->index] += curr->isTee() ? 2 : 1; | ||
| auto* value = curr->value; | ||
|
|
||
| // We need the input to the local to be another such local (from a tee, or | ||
| // a get), or a tuple.make. | ||
| if (auto* set = value->dynCast<LocalSet>()) { | ||
| assert(set->isTee()); | ||
| validUses[set->index]++; | ||
| validUses[curr->index]++; | ||
| copiedIndexes[set->index].insert(curr->index); | ||
| copiedIndexes[curr->index].insert(set->index); | ||
| } else if (auto* get = value->dynCast<LocalGet>()) { | ||
| validUses[get->index]++; | ||
| validUses[curr->index]++; | ||
| copiedIndexes[get->index].insert(curr->index); | ||
| copiedIndexes[curr->index].insert(get->index); | ||
| } else if (value->is<TupleMake>()) { | ||
| validUses[curr->index]++; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void visitTupleExtract(TupleExtract* curr) { | ||
| // We need the input to be a local, either from a tee or a get. | ||
| if (auto* set = curr->tuple->dynCast<LocalSet>()) { | ||
| validUses[set->index]++; | ||
| } else if (auto* get = curr->tuple->dynCast<LocalGet>()) { | ||
| validUses[get->index]++; | ||
| } | ||
| } | ||
|
|
||
| void optimize(Function* func) { | ||
| auto numLocals = func->getNumLocals(); | ||
|
|
||
| // Find the set of bad indexes. We add each such candidate to a worklist | ||
| // that we will then flow to find all those corrupted. | ||
| std::vector<bool> bad(numLocals); | ||
| UniqueDeferredQueue<Index> work; | ||
|
|
||
| for (Index i = 0; i < uses.size(); i++) { | ||
| assert(validUses[i] <= uses[i]); | ||
| if (uses[i] > 0 && validUses[i] < uses[i]) { | ||
| // This is a bad tuple. | ||
| work.push(i); | ||
| } | ||
| } | ||
|
|
||
| // Flow badness forward. | ||
| while (!work.empty()) { | ||
| auto i = work.pop(); | ||
| if (bad[i]) { | ||
| continue; | ||
| } | ||
| bad[i] = true; | ||
| for (auto target : copiedIndexes[i]) { | ||
| work.push(target); | ||
| } | ||
| } | ||
|
|
||
| // Good indexes we can optimize are tuple locals with uses that are not bad. | ||
| std::vector<bool> good(numLocals); | ||
| bool hasGood = false; | ||
| for (Index i = 0; i < uses.size(); i++) { | ||
| if (uses[i] > 0 && !bad[i]) { | ||
| good[i] = true; | ||
| hasGood = true; | ||
| } | ||
| } | ||
|
|
||
| if (!hasGood) { | ||
| return; | ||
| } | ||
|
|
||
| // We found things to optimize! Create new non-tuple locals for their | ||
| // contents, and then rewrite the code to use those according to the | ||
| // mapping from tuple locals to normal ones. The mapping maps a tuple local | ||
| // to the base index used for its contents: an index and several others | ||
| // right after it, depending on the tuple size. | ||
| std::unordered_map<Index, Index> tupleToNewBaseMap; | ||
| for (Index i = 0; i < good.size(); i++) { | ||
| if (good[i]) { | ||
|
||
| auto newBase = func->getNumLocals(); | ||
| tupleToNewBaseMap[i] = newBase; | ||
| Index lastNewIndex = 0; | ||
| for (auto t : func->getLocalType(i)) { | ||
| Index newIndex = Builder::addVar(func, t); | ||
| if (lastNewIndex == 0) { | ||
| // This is the first new local we added (0 is an impossible value, | ||
| // since tuple locals exist, hence index 0 was already taken), so it | ||
| // must be equal to the base. | ||
| assert(newIndex == newBase); | ||
| } else { | ||
| // This must be right after the former. | ||
| assert(newIndex == lastNewIndex + 1); | ||
| } | ||
| lastNewIndex = newIndex; | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| MapApplier mapApplier(tupleToNewBaseMap); | ||
| mapApplier.walkFunctionInModule(func, getModule()); | ||
| } | ||
|
|
||
| struct MapApplier : public PostWalker<MapApplier> { | ||
| std::unordered_map<Index, Index>& tupleToNewBaseMap; | ||
|
|
||
| MapApplier(std::unordered_map<Index, Index>& tupleToNewBaseMap) | ||
| : tupleToNewBaseMap(tupleToNewBaseMap) {} | ||
|
|
||
| // Gets the new base index if there is one, or 0 if not (0 is an impossible | ||
| // value for a new index, as local index 0 was taken before, as tuple | ||
| // locals existed). | ||
| Index getNewBaseIndex(Index i) { | ||
| auto iter = tupleToNewBaseMap.find(i); | ||
| if (iter == tupleToNewBaseMap.end()) { | ||
| return 0; | ||
| } | ||
| return iter->second; | ||
| } | ||
|
|
||
| // Given a local.get or local.set, return the new base index for the local | ||
| // index used there. Returns 0 (an impossible value, see above) otherwise. | ||
| Index getSetOrGetBaseIndex(Expression* setOrGet) { | ||
| Index index; | ||
| if (auto* set = setOrGet->dynCast<LocalSet>()) { | ||
| index = set->index; | ||
| } else if (auto* get = setOrGet->dynCast<LocalGet>()) { | ||
| index = get->index; | ||
| } else { | ||
| return 0; | ||
| } | ||
|
|
||
| return getNewBaseIndex(index); | ||
| } | ||
|
|
||
| // Replacing a local.tee requires some care, since we might have | ||
| // | ||
| // (local.set | ||
| // (local.tee | ||
| // .. | ||
| // | ||
| // We replace the local.tee with a block of sets of the new non-tuple | ||
| // locals, and the outer set must then (1) keep those around and also (2) | ||
| // identify the local that was tee'd, so we know what to get (which has been | ||
| // replaced by the block). To make that simple keep a map of the things that | ||
| // replaced tees. | ||
| std::unordered_map<Expression*, LocalSet*> teeReplacements; | ||
|
||
|
|
||
| void visitLocalSet(LocalSet* curr) { | ||
| auto replace = [&](Expression* replacement) { | ||
| if (curr->isTee()) { | ||
| teeReplacements[replacement] = curr; | ||
| } | ||
| replaceCurrent(replacement); | ||
| }; | ||
|
|
||
| if (auto targetBase = getNewBaseIndex(curr->index)) { | ||
| Builder builder(*getModule()); | ||
| auto type = getFunction()->getLocalType(curr->index); | ||
|
|
||
| auto* value = curr->value; | ||
| if (auto* make = value->dynCast<TupleMake>()) { | ||
| // Write each of the tuple.make fields into the proper local. | ||
| std::vector<Expression*> sets; | ||
| for (Index i = 0; i < type.size(); i++) { | ||
| auto* value = make->operands[i]; | ||
| sets.push_back(builder.makeLocalSet(targetBase + i, value)); | ||
| } | ||
| replace(builder.makeBlock(sets)); | ||
| return; | ||
| } | ||
|
|
||
| std::vector<Expression*> contents; | ||
|
|
||
| auto iter = teeReplacements.find(value); | ||
| if (iter != teeReplacements.end()) { | ||
| // The input to us was a tee that has been replaced. The actual value | ||
| // we read from (the tee) can be found in teeReplacements. Also, we | ||
| // need to keep around the replacement of the tee. | ||
| contents.push_back(value); | ||
| value = iter->second; | ||
| } | ||
|
|
||
| // This is a copy of a tuple local into another. Copy all the fields | ||
| // between them. | ||
| Index sourceBase = getSetOrGetBaseIndex(value); | ||
|
|
||
| // The target is being optimized, so the source must be as well, or else | ||
| // we were confused earlier and the target should not be. | ||
| assert(sourceBase); | ||
|
|
||
| for (Index i = 0; i < type.size(); i++) { | ||
| auto* get = builder.makeLocalGet(sourceBase + i, type[i]); | ||
| contents.push_back(builder.makeLocalSet(targetBase + i, get)); | ||
| } | ||
| replace(builder.makeBlock(contents)); | ||
| } | ||
| } | ||
|
|
||
| void visitTupleExtract(TupleExtract* curr) { | ||
| auto* value = curr->tuple; | ||
| Expression* extraContents = nullptr; | ||
|
|
||
| auto iter = teeReplacements.find(value); | ||
| if (iter != teeReplacements.end()) { | ||
| // The input to us was a tee that has been replaced. Handle it as in | ||
| // visitLocalSet. | ||
| extraContents = value; | ||
| value = iter->second; | ||
| } | ||
|
|
||
| auto type = value->type; | ||
| if (type == Type::unreachable) { | ||
| return; | ||
| } | ||
|
|
||
| Index sourceBase = getSetOrGetBaseIndex(value); | ||
| if (!sourceBase) { | ||
| return; | ||
| } | ||
|
|
||
| Builder builder(*getModule()); | ||
| auto i = curr->index; | ||
| auto* get = builder.makeLocalGet(sourceBase + i, type[i]); | ||
| if (extraContents) { | ||
| replaceCurrent(builder.makeSequence(extraContents, get)); | ||
| } else { | ||
| replaceCurrent(get); | ||
| } | ||
| } | ||
| }; | ||
| }; | ||
|
|
||
| Pass* createTupleOptimizationPass() { return new TupleOptimization(); } | ||
|
|
||
| } // namespace wasm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -480,6 +480,9 @@ void PassRegistry::registerPasses() { | |
| registerPass("trap-mode-js", | ||
| "replace trapping operations with js semantics", | ||
| createTrapModeJS); | ||
| registerPass("tuple-optimization", | ||
| "optimize trivial tuples away", | ||
| createTupleOptimizationPass); | ||
| registerPass("type-merging", | ||
| "merge types to their supertypes where possible", | ||
| createTypeMergingPass); | ||
|
|
@@ -558,6 +561,9 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { | |
| if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) { | ||
| addIfNoDWARFIssues("code-pushing"); | ||
| } | ||
| if (wasm->features.hasMultivalue()) { | ||
| addIfNoDWARFIssues("tuple-optimization"); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we putting this here because it is just before all the local optimizations?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sort of, and also after at least one
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will also be useful to do this after inlining, if that's not already the case.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, already the case: |
||
| // don't create if/block return values yet, as coalesce can remove copies that | ||
| // that could inhibit | ||
| addIfNoDWARFIssues("simplify-locals-nostructure"); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of setting
copiedIndexestwice, how about maintaining the invariant that the smaller index is always the key and setting it just once?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a set of tuples
(x, y)that we can store by flipping them. We are given an index and need to find all related indexes to it - that is, we know one ofx, yand want to know the other (and it is also a set of others, but that's separate).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so if we need to store the edges (0, 1), (1, 2), (0, 3), the current code would have:
I'm suggesting that instead we construct this mapping:
From the comment mentioning that this is a bidirectional mapping, this is what I would have expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm missing something. Say I have your mapping, and I get the index "2". I want to find the other indexes I need to mark as bad. How do I do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You would iterate through the map and for each bad key, you would add all the corresponding values to the work list and for each bad value, you would add the corresponding key to the work list. This may not be simpler overall!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, but that's
O(map size)for each index we realize is bad? It'sO(num related indexes)in the code atm. So I worry changing this would be a regression, though it would save some memory otoh. For now though I think this is good enough.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm