diff --git a/src/analysis/lattices/array.h b/src/analysis/lattices/array.h index 3426a3a77d2..7ac0273022b 100644 --- a/src/analysis/lattices/array.h +++ b/src/analysis/lattices/array.h @@ -27,6 +27,7 @@ namespace wasm::analysis { // A lattice whose elements are N-tuples of elements of L. Also written as L^N. +// N is supplied at compile time rather than run time like it is for Vector. template struct Array { using Element = std::array; diff --git a/src/analysis/lattices/vector.h b/src/analysis/lattices/vector.h new file mode 100644 index 00000000000..d13380868fd --- /dev/null +++ b/src/analysis/lattices/vector.h @@ -0,0 +1,138 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_analysis_lattices_vector_h +#define wasm_analysis_lattices_vector_h + +#include + +#include "../lattice.h" +#include "bool.h" +#include "flat.h" + +namespace wasm::analysis { + +// A lattice whose elements are N-tuples of elements of L. Also written as L^N. +// N is supplied at run time rather than compile time like it is for Array. +template struct Vector { + using Element = std::vector; + + L lattice; + const size_t size; + + Vector(L&& lattice, size_t size) : lattice(std::move(lattice)), size(size) {} + + Element getBottom() const noexcept { + return Element(size, lattice.getBottom()); + } + + Element getTop() const noexcept +#if __cplusplus >= 202002L + requires FullLattice +#endif + { + return Element(size, lattice.getTop()); + } + + // `a` <= `b` if their elements are pairwise <=, etc. Unless we determine + // that there is no relation, we must check all the elements. + LatticeComparison compare(const Element& a, const Element& b) const noexcept { + assert(a.size() == size); + assert(b.size() == size); + auto result = EQUAL; + for (size_t i = 0; i < size; ++i) { + switch (lattice.compare(a[i], b[i])) { + case NO_RELATION: + return NO_RELATION; + case EQUAL: + continue; + case LESS: + if (result == GREATER) { + // Cannot be both less and greater. + return NO_RELATION; + } + result = LESS; + continue; + case GREATER: + if (result == LESS) { + // Cannot be both greater and less. + return NO_RELATION; + } + result = GREATER; + continue; + } + } + return result; + } + + // Pairwise join on the elements. + bool join(Element& joinee, const Element& joiner) const noexcept { + assert(joinee.size() == size); + assert(joiner.size() == size); + bool result = false; + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + // The vector specialization does not expose references to the + // individual bools because they might be in a bitmap, so we need a + // workaround. + bool e = joinee[i]; + if (lattice.join(e, joiner[i])) { + joinee[i] = e; + result = true; + } + } else { + result |= lattice.join(joinee[i], joiner[i]); + } + } + + return result; + } + + // Pairwise meet on the elements. + bool meet(Element& meetee, const Element& meeter) const noexcept +#if __cplusplus >= 202002L + requires FullLattice +#endif + { + assert(meetee.size() == size); + assert(meeter.size() == size); + bool result = false; + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + // The vector specialization does not expose references to the + // individual bools because they might be in a bitmap, so we need a + // workaround. + bool e = meetee[i]; + if (lattice.meet(e, meeter[i])) { + meetee[i] = e; + result = true; + } + } else { + result |= lattice.meet(meetee[i], meeter[i]); + } + } + return result; + } +}; + +#if __cplusplus >= 202002L +static_assert(FullLattice>); +static_assert(Lattice>>); +#endif + +} // namespace wasm::analysis + +#endif // wasm_analysis_lattices_vector_h diff --git a/src/tools/wasm-fuzz-lattices.cpp b/src/tools/wasm-fuzz-lattices.cpp index f073aa72492..f583103d6b1 100644 --- a/src/tools/wasm-fuzz-lattices.cpp +++ b/src/tools/wasm-fuzz-lattices.cpp @@ -28,6 +28,7 @@ #include "analysis/lattices/inverted.h" #include "analysis/lattices/lift.h" #include "analysis/lattices/stack.h" +#include "analysis/lattices/vector.h" #include "analysis/liveness-transfer-function.h" #include "analysis/reaching-definitions-transfer-function.h" #include "analysis/transfer-function.h" @@ -151,33 +152,38 @@ static_assert(Lattice); using ArrayFullLattice = analysis::Array; using ArrayLattice = analysis::Array; -struct RandomFullLattice::L - : std::variant, ArrayFullLattice> { -}; +struct RandomFullLattice::L : std::variant, + ArrayFullLattice, + Vector> {}; struct RandomFullLattice::ElementImpl : std::variant::Element, - typename ArrayFullLattice::Element> {}; + typename ArrayFullLattice::Element, + typename Vector::Element> {}; struct RandomLattice::L : std::variant, Lift, - ArrayLattice> {}; + ArrayLattice, + Vector> {}; struct RandomLattice::ElementImpl : std::variant::Element, typename Lift::Element, - typename ArrayLattice::Element> {}; + typename ArrayLattice::Element, + typename Vector::Element> {}; RandomFullLattice::RandomFullLattice(Random& rand, size_t depth, std::optional maybePick) : rand(rand) { // TODO: Limit the depth once we get lattices with more fan-out. - uint32_t pick = maybePick ? *maybePick : rand.upTo(4); + uint32_t pick = maybePick ? *maybePick : rand.upTo(5); switch (pick) { case 0: lattice = std::make_unique(L{Bool{}}); @@ -193,30 +199,39 @@ RandomFullLattice::RandomFullLattice(Random& rand, lattice = std::make_unique( L{ArrayFullLattice{RandomFullLattice{rand, depth + 1}}}); return; + case 4: + lattice = std::make_unique( + L{Vector{RandomFullLattice{rand, depth + 1}, rand.upTo(4)}}); + return; } WASM_UNREACHABLE("unexpected pick"); } RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { // TODO: Limit the depth once we get lattices with more fan-out. - uint32_t pick = rand.upTo(7); + uint32_t pick = rand.upTo(9); switch (pick) { case 0: case 1: case 2: case 3: + case 4: lattice = std::make_unique(L{RandomFullLattice{rand, depth, pick}}); return; - case 4: + case 5: lattice = std::make_unique(L{Flat{}}); return; - case 5: + case 6: lattice = std::make_unique(L{Lift{RandomLattice{rand, depth + 1}}}); return; - case 6: + case 7: lattice = std::make_unique(L{ArrayLattice{RandomLattice{rand, depth + 1}}}); return; + case 8: + lattice = std::make_unique( + L{Vector{RandomLattice{rand, depth + 1}, rand.upTo(4)}}); + return; } WASM_UNREACHABLE("unexpected pick"); } @@ -235,6 +250,14 @@ RandomFullLattice::Element RandomFullLattice::makeElement() const noexcept { return ElementImpl{typename ArrayFullLattice::Element{ l->lattice.makeElement(), l->lattice.makeElement()}}; } + if (const auto* l = std::get_if>(lattice.get())) { + std::vector elem; + elem.reserve(l->size); + for (size_t i = 0; i < l->size; ++i) { + elem.push_back(l->lattice.makeElement()); + } + return ElementImpl{std::move(elem)}; + } WASM_UNREACHABLE("unexpected lattice"); } @@ -261,6 +284,14 @@ RandomLattice::Element RandomLattice::makeElement() const noexcept { return ElementImpl{typename ArrayLattice::Element{ l->lattice.makeElement(), l->lattice.makeElement()}}; } + if (const auto* l = std::get_if>(lattice.get())) { + std::vector elem; + elem.reserve(l->size); + for (size_t i = 0; i < l->size; ++i) { + elem.push_back(l->lattice.makeElement()); + } + return ElementImpl{std::move(elem)}; + } WASM_UNREACHABLE("unexpected lattice"); } @@ -293,6 +324,17 @@ void printFullElement(std::ostream& os, printFullElement(os, e->back(), depth + 1); indent(os, depth); os << "]\n"; + } else if (const auto* vec = + std::get_if::Element>( + &*elem)) { + os << "Vector[\n"; + for (const auto& e : *vec) { + printFullElement(os, e, depth + 1); + } + indent(os, depth); + os << "]\n"; + } else { + WASM_UNREACHABLE("unexpected element"); } } @@ -332,6 +374,16 @@ void printElement(std::ostream& os, printElement(os, e->back(), depth + 1); indent(os, depth); os << ")\n"; + } else if (const auto* vec = + std::get_if::Element>(&*elem)) { + os << "Vector[\n"; + for (const auto& e : *vec) { + printElement(os, e, depth + 1); + } + indent(os, depth); + os << "]\n"; + } else { + WASM_UNREACHABLE("unexpected element"); } } diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 50adffc7e4b..3a2bf8b0d75 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -20,6 +20,7 @@ #include "analysis/lattices/int.h" #include "analysis/lattices/inverted.h" #include "analysis/lattices/lift.h" +#include "analysis/lattices/vector.h" #include "gtest/gtest.h" using namespace wasm; @@ -483,3 +484,111 @@ TEST(ArrayLattice, Meet) { test(tt, tf, true, tf); test(tt, tt, false, tt); } + +TEST(VectorLattice, GetBottom) { + analysis::Vector vector{analysis::Bool{}, 2}; + EXPECT_EQ(vector.getBottom(), (std::vector{false, false})); +} + +TEST(VectorLattice, GetTop) { + analysis::Vector vector{analysis::Bool{}, 2}; + EXPECT_EQ(vector.getTop(), (std::vector{true, true})); +} + +TEST(VectorLattice, Compare) { + analysis::Vector vector{analysis::Bool{}, 2}; + std::vector ff{false, false}; + std::vector ft{false, true}; + std::vector tf{true, false}; + std::vector tt{true, true}; + + EXPECT_EQ(vector.compare(ff, ff), analysis::EQUAL); + EXPECT_EQ(vector.compare(ff, ft), analysis::LESS); + EXPECT_EQ(vector.compare(ff, tf), analysis::LESS); + EXPECT_EQ(vector.compare(ff, tt), analysis::LESS); + + EXPECT_EQ(vector.compare(ft, ff), analysis::GREATER); + EXPECT_EQ(vector.compare(ft, ft), analysis::EQUAL); + EXPECT_EQ(vector.compare(ft, tf), analysis::NO_RELATION); + EXPECT_EQ(vector.compare(ft, tt), analysis::LESS); + + EXPECT_EQ(vector.compare(tf, ff), analysis::GREATER); + EXPECT_EQ(vector.compare(tf, ft), analysis::NO_RELATION); + EXPECT_EQ(vector.compare(tf, tf), analysis::EQUAL); + EXPECT_EQ(vector.compare(tf, tt), analysis::LESS); + + EXPECT_EQ(vector.compare(tt, ff), analysis::GREATER); + EXPECT_EQ(vector.compare(tt, ft), analysis::GREATER); + EXPECT_EQ(vector.compare(tt, tf), analysis::GREATER); + EXPECT_EQ(vector.compare(tt, tt), analysis::EQUAL); +} + +TEST(VectorLattice, Join) { + analysis::Vector vector{analysis::Bool{}, 2}; + auto ff = []() { return std::vector{false, false}; }; + auto ft = []() { return std::vector{false, true}; }; + auto tf = []() { return std::vector{true, false}; }; + auto tt = []() { return std::vector{true, true}; }; + + auto test = + [&](auto& makeJoinee, auto& makeJoiner, bool modified, auto& makeExpected) { + auto joinee = makeJoinee(); + EXPECT_EQ(vector.join(joinee, makeJoiner()), modified); + EXPECT_EQ(joinee, makeExpected()); + }; + + test(ff, ff, false, ff); + test(ff, ft, true, ft); + test(ff, tf, true, tf); + test(ff, tt, true, tt); + + test(ft, ff, false, ft); + test(ft, ft, false, ft); + test(ft, tf, true, tt); + test(ft, tt, true, tt); + + test(tf, ff, false, tf); + test(tf, ft, true, tt); + test(tf, tf, false, tf); + test(tf, tt, true, tt); + + test(tt, ff, false, tt); + test(tt, ft, false, tt); + test(tt, tf, false, tt); + test(tt, tt, false, tt); +} + +TEST(VectorLattice, Meet) { + analysis::Vector vector{analysis::Bool{}, 2}; + auto ff = []() { return std::vector{false, false}; }; + auto ft = []() { return std::vector{false, true}; }; + auto tf = []() { return std::vector{true, false}; }; + auto tt = []() { return std::vector{true, true}; }; + + auto test = + [&](auto& makeMeetee, auto& makeMeeter, bool modified, auto& makeExpected) { + auto meetee = makeMeetee(); + EXPECT_EQ(vector.meet(meetee, makeMeeter()), modified); + EXPECT_EQ(meetee, makeExpected()); + }; + + test(ff, ff, false, ff); + test(ff, ft, false, ff); + test(ff, tf, false, ff); + test(ff, tt, false, ff); + + test(ft, ff, true, ff); + test(ft, ft, false, ft); + test(ft, tf, true, ff); + test(ft, tt, false, ft); + + test(tf, ff, true, ff); + test(tf, ft, true, ff); + test(tf, tf, false, tf); + test(tf, tt, false, tf); + + test(tt, ff, true, ff); + test(tt, ft, true, ft); + test(tt, tf, true, tf); + test(tt, tt, false, tt); +}