diff --git a/stl/inc/functional b/stl/inc/functional index 9c61372fcd..b8e87283b7 100644 --- a/stl/inc/functional +++ b/stl/inc/functional @@ -971,6 +971,11 @@ private: _Callable _Callee; }; +#if _HAS_CXX23 +template +class _Move_only_function_base; +#endif // _HAS_CXX23 + template class _Func_class : public _Arg_types<_Types...> { public: @@ -1097,6 +1102,10 @@ protected: #endif // _HAS_STATIC_RTTI private: +#if _HAS_CXX23 + friend _Move_only_function_base<_Ret, false, _Types...>; +#endif // _HAS_CXX23 + bool _Local() const noexcept { // test for locally stored copy of object return _Getimpl() == static_cast(&_Mystorage); } @@ -1370,6 +1379,8 @@ _NODISCARD bool operator!=(nullptr_t, const function<_Fty>& _Other) noexcept { #endif // !_HAS_CXX20 #if _HAS_CXX23 +enum class _Function_storage_mode { _Small, _Large }; + // _Move_only_function_data is defined as an array of pointers. // The first element is always a pointer to _Move_only_function_base::_Impl_t; it emulates a vtable pointer. // The other pointers are used as storage for a small functor; @@ -1392,16 +1403,15 @@ union alignas(max_align_t) _Move_only_function_data { return &_Data + _Buf_offset<_Fn>; } - template - _NODISCARD _Fn* _Small_fn_ptr() const noexcept { - // cast away const to avoid complication of const propagation to here; - // const correctness is still enforced by _Move_only_function_call specializations. - return static_cast<_Fn*>(const_cast<_Move_only_function_data*>(this)->_Buf_ptr<_Fn>()); - } - - template - _NODISCARD _Fn* _Large_fn_ptr() const noexcept { - return static_cast<_Fn*>(_Pointers[1]); + template + _NODISCARD _Fn* _Fn_ptr() const noexcept { + if constexpr (_Mode == _Function_storage_mode::_Small) { + // cast away const to avoid complication of const propagation to here; + // const correctness is still enforced by _Move_only_function_call specializations. + return static_cast<_Fn*>(const_cast<_Move_only_function_data*>(this)->_Buf_ptr<_Fn>()); + } else { + return static_cast<_Fn*>(_Pointers[1]); + } } void _Set_large_fn_ptr(void* const _Value) noexcept { @@ -1423,27 +1433,28 @@ template _STL_UNREACHABLE; // no return value available for "continue on error" } -template -_NODISCARD _Rx __stdcall _Function_inv_small(const _Move_only_function_data& _Self, _Types&&... _Args) noexcept(_Noex) { - if constexpr (is_void_v<_Rx>) { - (void) _STD invoke(static_cast<_VtInvQuals>(*_Self._Small_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...); - } else { - return _STD invoke(static_cast<_VtInvQuals>(*_Self._Small_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...); - } +template +[[noreturn]] _Rx __stdcall _Function_old_not_callable(const _Move_only_function_data&, _Types&&...) { + _Xbad_function_call(); } -template -_NODISCARD _Rx __stdcall _Function_inv_large(const _Move_only_function_data& _Self, _Types&&... _Args) noexcept(_Noex) { +template +_NODISCARD _Rx __stdcall _Function_inv(const _Move_only_function_data& _Self, _Types&&... _Args) noexcept(_Noex) { if constexpr (is_void_v<_Rx>) { - (void) _STD invoke(static_cast<_VtInvQuals>(*_Self._Large_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...); + (void) _STD invoke(static_cast<_VtInvQuals>(*_Self._Fn_ptr<_Vt, _Mode>()), _STD forward<_Types>(_Args)...); } else { - return _STD invoke(static_cast<_VtInvQuals>(*_Self._Large_fn_ptr<_Vt>()), _STD forward<_Types>(_Args)...); + return _STD invoke(static_cast<_VtInvQuals>(*_Self._Fn_ptr<_Vt, _Mode>()), _STD forward<_Types>(_Args)...); } } +template +_NODISCARD _Rx __stdcall _Function_inv_old(const _Move_only_function_data& _Self, _Types&&... _Args) { + return _Self._Fn_ptr<_Fn, _Mode>()->_Do_call(_STD forward<_Types>(_Args)...); +} + template void __stdcall _Function_move_small(_Move_only_function_data& _Self, _Move_only_function_data& _Src) noexcept { - const auto _Src_fn_ptr = _Src._Small_fn_ptr<_Vt>(); + const auto _Src_fn_ptr = _Src._Fn_ptr<_Vt, _Function_storage_mode::_Small>(); ::new (_Self._Buf_ptr<_Vt>()) _Vt(_STD move(*_Src_fn_ptr)); _Src_fn_ptr->~_Vt(); _Self._Impl = _Src._Impl; @@ -1458,28 +1469,57 @@ inline void __stdcall _Function_move_large(_Move_only_function_data& _Self, _Mov _CSTD memcpy(&_Self._Data, &_Src._Data, _Minimum_function_size); // Copy Impl* and functor data } +#ifdef _WIN64 +template +void __stdcall _Function_move_old_small(_Move_only_function_data& _Self, _Move_only_function_data& _Src) noexcept { + _Fn* const _Old_fn_impl = _Src._Fn_ptr<_Fn, _Function_storage_mode::_Small>(); + _Old_fn_impl->_Move(_Self._Buf_ptr()); + _Old_fn_impl->_Delete_this(false); + _Self._Impl = _Src._Impl; +} +#endif // ^^^ 64-bit ^^^ + template void __stdcall _Function_destroy_small(_Move_only_function_data& _Self) noexcept { - _Self._Small_fn_ptr<_Vt>()->~_Vt(); + _Self._Fn_ptr<_Vt, _Function_storage_mode::_Small>()->~_Vt(); } inline void __stdcall _Function_deallocate_large_default_aligned(_Move_only_function_data& _Self) noexcept { - ::operator delete(_Self._Large_fn_ptr()); + ::operator delete(_Self._Fn_ptr()); +} + +template +void __stdcall _Function_destroy_old_large(_Move_only_function_data& _Self) noexcept { + _Self._Fn_ptr<_Fn, _Function_storage_mode::_Large>()->_Delete_this(true); +} + +#ifdef _WIN64 +template +void __stdcall _Function_destroy_old_small(_Move_only_function_data& _Self) noexcept { + _Self._Fn_ptr<_Fn, _Function_storage_mode::_Small>()->_Delete_this(false); +} +#else // ^^^ 64-bit / 32-bit vvv +template +void __stdcall _Function_destroy_old_small_as_large(_Move_only_function_data& _Self) noexcept { + _Fn* const _Old_fn_impl = _Self._Fn_ptr<_Fn, _Function_storage_mode::_Large>(); + _Old_fn_impl->_Delete_this(false); + ::operator delete(static_cast(_Old_fn_impl)); } +#endif // ^^^ 32-bit ^^^ template void __stdcall _Function_deallocate_large_overaligned(_Move_only_function_data& _Self) noexcept { _STL_INTERNAL_STATIC_ASSERT(_Align > __STDCPP_DEFAULT_NEW_ALIGNMENT__); #ifdef __cpp_aligned_new - ::operator delete(_Self._Large_fn_ptr(), align_val_t{_Align}); + ::operator delete(_Self._Fn_ptr(), align_val_t{_Align}); #else // ^^^ defined(__cpp_aligned_new) / !defined(__cpp_aligned_new) vvv - ::operator delete(_Self._Large_fn_ptr()); + ::operator delete(_Self._Fn_ptr()); #endif // ^^^ !defined(__cpp_aligned_new) ^^^ } template void __stdcall _Function_destroy_large(_Move_only_function_data& _Self) noexcept { - const auto _Pfn = _Self._Large_fn_ptr<_Vt>(); + const auto _Pfn = _Self._Fn_ptr<_Vt, _Function_storage_mode::_Large>(); _Pfn->~_Vt(); #ifdef __cpp_aligned_new if constexpr (alignof(_Vt) > __STDCPP_DEFAULT_NEW_ALIGNMENT__) { @@ -1555,6 +1595,17 @@ public: void(__stdcall* _Destroy)(_Move_only_function_data&) _NOEXCEPT_FNPTR; }; + enum class _Impl_kind { + _Usual = 0, + _Old_fn_null = 1, + _Old_fn_large = 2, +#ifdef _WIN64 + _Old_fn_small = 3, +#else // ^^^ 64-bit / 32-bit vvv + _Old_fn_small_as_large = 4, +#endif // ^^^ 32-bit ^^^ + }; + _Move_only_function_data _Data; _Move_only_function_base() noexcept = default; // leaves fields uninitialized @@ -1568,9 +1619,40 @@ public: _Data._Impl = nullptr; } + template + void _Construct_with_old_fn(_Fn&& _Func) { + const auto _Old_fn_impl = _Func._Getimpl(); + if (_Old_fn_impl == nullptr) { + _Data._Impl = _Create_impl_ptr<_Impl_kind::_Old_fn_null, _Vt, void>(); + } else if (_Func._Local()) { +#ifdef _WIN64 + _STL_INTERNAL_STATIC_ASSERT(alignof(max_align_t) == alignof(void*)); + // 64-bit target, can put small function into small move_only_function directly + _Data._Impl = _Create_impl_ptr<_Impl_kind::_Old_fn_small, _Vt, void>(); + _Old_fn_impl->_Move(_Data._Buf_ptr()); + _Func._Tidy(); +#else // ^^^ 64-bit / 32-bit vvv + _STL_INTERNAL_STATIC_ASSERT(alignof(max_align_t) > alignof(void*)); + // 32-bit target, cannot put small function into small move_only_function directly + // due to potentially not enough alignment. Allocate large function + void* _Where = ::operator new((_Small_object_num_ptrs - 1) * sizeof(void*)); + _Old_fn_impl->_Move(_Where); + _Func._Tidy(); + + _Data._Impl = _Create_impl_ptr<_Impl_kind::_Old_fn_small_as_large, _Vt, void>(); + _Data._Set_large_fn_ptr(_Where); +#endif // ^^^ 32-bit ^^^ + } else { + // Just take ownership of the inner impl pointer + _Data._Impl = _Create_impl_ptr<_Impl_kind::_Old_fn_large, _Vt, void>(); + _Data._Set_large_fn_ptr(_Old_fn_impl); + _Func._Set(nullptr); + } + } + template void _Construct_with_fn(_CTypes&&... _Args) { - _Data._Impl = _Create_impl_ptr<_Vt, _VtInvQuals>(); + _Data._Impl = _Create_impl_ptr<_Impl_kind::_Usual, _Vt, _VtInvQuals>(); if constexpr (_Large_function_engaged<_Vt>) { _Data._Set_large_fn_ptr(_STD _Function_new_large<_Vt>(_STD forward<_CTypes>(_Args)...)); } else { @@ -1661,11 +1743,36 @@ public: return _Ret ? _Ret : &_Null_move_only_function; } - template + template <_Impl_kind _Kind, class _Vt, class _VtInvQuals> _NODISCARD static constexpr _Impl_t _Create_impl() noexcept { _Impl_t _Impl{}; - if constexpr (_Large_function_engaged<_Vt>) { - _Impl._Invoke = _Function_inv_large<_Vt, _VtInvQuals, _Rx, _Noexcept, _Types...>; + if constexpr (_Kind != _Impl_kind::_Usual) { + _STL_INTERNAL_STATIC_ASSERT(!_Noexcept); + _STL_INTERNAL_STATIC_ASSERT(is_void_v<_VtInvQuals>); + using _Fn = remove_pointer_t()._Getimpl())>; + if constexpr (_Kind == _Impl_kind::_Old_fn_null) { + _Impl._Invoke = _Function_old_not_callable<_Rx, _Types...>; + _Impl._Move = nullptr; + _Impl._Destroy = nullptr; + } else if constexpr (_Kind == _Impl_kind::_Old_fn_large) { + _Impl._Invoke = _Function_inv_old<_Fn, _Function_storage_mode::_Large, _Rx, _Types...>; + _Impl._Move = nullptr; + _Impl._Destroy = _Function_destroy_old_large<_Fn>; + } else { +#ifdef _WIN64 + static_assert(_Kind == _Impl_kind::_Old_fn_small); + _Impl._Invoke = _Function_inv_old<_Fn, _Function_storage_mode::_Small, _Rx, _Types...>; + _Impl._Move = _Function_move_old_small<_Fn>; + _Impl._Destroy = _Function_destroy_old_small<_Fn>; +#else // ^^^ 64-bit / 32-bit vvv + static_assert(_Kind == _Impl_kind::_Old_fn_small_as_large); + _Impl._Invoke = _Function_inv_old<_Fn, _Function_storage_mode::_Large, _Rx, _Types...>; + _Impl._Move = nullptr; + _Impl._Destroy = _Function_destroy_old_small_as_large<_Fn>; +#endif // ^^^ 32-bit ^^^ + } + } else if constexpr (_Large_function_engaged<_Vt>) { + _Impl._Invoke = _Function_inv<_Vt, _VtInvQuals, _Function_storage_mode::_Large, _Rx, _Noexcept, _Types...>; _Impl._Move = nullptr; if constexpr (is_trivially_destructible_v<_Vt>) { @@ -1678,7 +1785,7 @@ public: _Impl._Destroy = _Function_destroy_large<_Vt>; } } else { - _Impl._Invoke = _Function_inv_small<_Vt, _VtInvQuals, _Rx, _Noexcept, _Types...>; + _Impl._Invoke = _Function_inv<_Vt, _VtInvQuals, _Function_storage_mode::_Small, _Rx, _Noexcept, _Types...>; if constexpr (is_trivially_copyable_v<_Vt> && is_trivially_destructible_v<_Vt>) { if constexpr ((_Function_small_copy_size<_Vt>) > _Minimum_function_size) { @@ -1699,9 +1806,9 @@ public: return _Impl; } - template + template <_Impl_kind _Kind, class _Vt, class _VtInvQuals> _NODISCARD static const _Impl_t* _Create_impl_ptr() noexcept { - static constexpr _Impl_t _Impl = _Create_impl<_Vt, _VtInvQuals>(); + static constexpr _Impl_t _Impl = _Create_impl<_Kind, _Vt, _VtInvQuals>(); return &_Impl; } }; @@ -1932,16 +2039,20 @@ public: using _Vt = decay_t<_Fn>; static_assert(is_constructible_v<_Vt, _Fn>, "_Vt should be constructible from _Fn. " "(N4950 [func.wrap.move.ctor]/6)"); - - if constexpr (is_member_pointer_v<_Vt> || is_pointer_v<_Vt> || _Is_specialization_v<_Vt, move_only_function>) { - if (_Callable == nullptr) { - this->_Reset_to_null(); - return; + if constexpr (is_same_v<_Vt, function<_Signature...>>) { + this->template _Construct_with_old_fn<_Vt>(_STD forward<_Fn>(_Callable)); + } else { + if constexpr (is_member_pointer_v<_Vt> || is_pointer_v<_Vt> + || _Is_specialization_v<_Vt, move_only_function>) { + if (_Callable == nullptr) { + this->_Reset_to_null(); + return; + } } - } - using _VtInvQuals = _Call::template _VtInvQuals<_Vt>; - this->template _Construct_with_fn<_Vt, _VtInvQuals>(_STD forward<_Fn>(_Callable)); + using _VtInvQuals = _Call::template _VtInvQuals<_Vt>; + this->template _Construct_with_fn<_Vt, _VtInvQuals>(_STD forward<_Fn>(_Callable)); + } } template diff --git a/tests/std/test.lst b/tests/std/test.lst index 3d13fe2431..d13c3dd3a4 100644 --- a/tests/std/test.lst +++ b/tests/std/test.lst @@ -269,6 +269,7 @@ tests\GH_005315_destructor_tombstones tests\GH_005402_string_with_volatile_range tests\GH_005421_vector_algorithms_integer_class_type_iterator tests\GH_005472_do_not_overlap +tests\GH_005504_avoid_function_call_wrapping tests\GH_005546_containers_size_type_cast tests\GH_005553_regex_character_translation tests\GH_005768_pow_accuracy diff --git a/tests/std/tests/GH_005504_avoid_function_call_wrapping/env.lst b/tests/std/tests/GH_005504_avoid_function_call_wrapping/env.lst new file mode 100644 index 0000000000..642f530ffa --- /dev/null +++ b/tests/std/tests/GH_005504_avoid_function_call_wrapping/env.lst @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +RUNALL_INCLUDE ..\usual_latest_matrix.lst diff --git a/tests/std/tests/GH_005504_avoid_function_call_wrapping/test.cpp b/tests/std/tests/GH_005504_avoid_function_call_wrapping/test.cpp new file mode 100644 index 0000000000..44ec403e22 --- /dev/null +++ b/tests/std/tests/GH_005504_avoid_function_call_wrapping/test.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +#pragma warning(disable : 4324) // 'large_callable': structure was padded due to alignment specifier +#pragma warning(disable : 28251) // Inconsistent annotation for 'new': this instance has no annotations. + +int alloc_count = 0; +int dealloc_count = 0; + +size_t adjust_alloc_size(const size_t size) { + return size != 0 ? size : 1; +} + +void* check_alloc(void* const result) { + if (!result) { + throw bad_alloc{}; + } + return result; +} + +void* operator new(const size_t size) { + ++alloc_count; + return check_alloc(malloc(adjust_alloc_size(size))); +} + +void operator delete(void* const mem) noexcept { + ++dealloc_count; + free(mem); +} + +void* operator new(const size_t size, const align_val_t al) { + ++alloc_count; + return check_alloc(_aligned_malloc(adjust_alloc_size(size), static_cast(al))); +} + +void operator delete(void* const mem, align_val_t) noexcept { + ++dealloc_count; + _aligned_free(mem); +} + +struct alloc_checker { + explicit alloc_checker(const int expected_delta_) : expected_delta(expected_delta_) {} + alloc_checker(const alloc_checker&) = delete; + alloc_checker& operator=(const alloc_checker&) = delete; + + ~alloc_checker() { + assert(alloc_count - before == expected_delta); + assert(alloc_count == dealloc_count); + } + + const int expected_delta; + const int before = alloc_count; +}; + +struct copy_counter { + copy_counter() = default; + copy_counter(const copy_counter& other) : count(other.count + 1) {} + + int count = 0; +}; + +using fn_type = int(copy_counter); + +struct small_callable { + const int context = 42; + + int operator()(const copy_counter& counter) { + assert(context == 42); + return counter.count; + } +}; + +struct alignas(128) large_callable { + const int context = 1729; + + int operator()(const copy_counter& counter) { + assert((reinterpret_cast(this) & 0x7f) == 0); + assert(context == 1729); + return counter.count; + } +}; + +template +void test_plain_call(const int expected_copies) { + Wrapper fn{Callable{}}; + assert(fn(copy_counter{}) == expected_copies); +} + +template +void test_wrapped_call(const int expected_copies) { + InnerWrapper inner{Callable{}}; + OuterWrapper outer{move(inner)}; + assert(!inner); + assert(outer(copy_counter{}) == expected_copies); +} + +template +void check_call_null(Wrapper& wrapper, const bool throws) { + if (throws) { + try { + wrapper(copy_counter{}); + assert(false); // should not reach + } catch (const bad_function_call&) { + } + } else { + // UB that in our implementation tries to call doom function; we do not test that + } +} + +template +void test_plain_null(const bool throws) { + Wrapper fn{}; + assert(!fn); + check_call_null(fn, throws); +} + +template +void test_wrapped_null(const bool outer_is_null, const bool outer_throws) { + InnerWrapper inner{}; + OuterWrapper outer{move(inner)}; + assert(!inner); + assert(!outer == outer_is_null); + check_call_null(outer, outer_throws); +} + +int main() { + // Plain calls + alloc_checker{0}, test_plain_call, small_callable>(0); + alloc_checker{1}, test_plain_call, large_callable>(0); + alloc_checker{0}, test_plain_call, small_callable>(0); + alloc_checker{1}, test_plain_call, large_callable>(0); + + // Moves to the same + alloc_checker{0}, test_wrapped_call, function, small_callable>(0); + alloc_checker{1}, test_wrapped_call, function, large_callable>(0); + alloc_checker{0}, test_wrapped_call, move_only_function, small_callable>(0); + alloc_checker{1}, test_wrapped_call, move_only_function, large_callable>(0); + + // Moves from function to move_only_function +#ifdef _WIN64 + alloc_checker{0}, +#else + alloc_checker{1}, +#endif + test_wrapped_call, function, small_callable>(0); + alloc_checker{1}, test_wrapped_call, function, large_callable>(0); + + // nulls + alloc_checker{0}, test_plain_null>(true); + alloc_checker{0}, test_plain_null>(false); + + // wrapped nulls + alloc_checker{0}, test_wrapped_null, function>(true, true); + alloc_checker{0}, test_wrapped_null, move_only_function>(true, false); + alloc_checker{0}, test_wrapped_null, function>(false, true); +}