diff --git a/server/src/Tests.cpp b/server/src/Tests.cpp index 8103bb92b..5a03e0983 100644 --- a/server/src/Tests.cpp +++ b/server/src/Tests.cpp @@ -280,8 +280,10 @@ std::shared_ptr KTestObjectParser::structView(const std::vector size_t structEndOffset = offsetInBits + curStruct.size; size_t fieldIndex = 0; bool dirtyInitializedStruct = false; + bool isInitializedStruct = curStruct.subType == types::SubType::Struct; for (const auto &field: curStruct.fields) { bool dirtyInitializedField = false; + bool isInitializedField = true; size_t fieldLen = typesHandler.typeSize(field.type); size_t fieldStartOffset = offsetInBits + field.offset; size_t fieldEndOffset = fieldStartOffset + fieldLen; @@ -289,7 +291,7 @@ std::shared_ptr KTestObjectParser::structView(const std::vector prevFieldEndOffset = offsetInBits; } - auto dirtyCheck = [&](int i) { + auto dirtyCheck = [&](size_t i) { if (i >= byteArray.size()) { LOG_S(ERROR) << "Bad type size info: " << field.name << " index: " << fieldIndex; } else if (byteArray[i] == 0) { @@ -302,15 +304,16 @@ std::shared_ptr KTestObjectParser::structView(const std::vector if (prevFieldEndOffset < fieldStartOffset) { // check an alignment gap - for (int i = prevFieldEndOffset/8; i < fieldStartOffset/8; ++i) { + for (size_t i = prevFieldEndOffset / 8; i < fieldStartOffset / 8; ++i) { if (dirtyCheck(i)) { break; } } } - if (!dirtyInitializedField && curStruct.subType == types::SubType::Union) { - // check the rest of the union - for (int i = fieldEndOffset/8; i < structEndOffset/8; ++i) { + if (!dirtyInitializedField && (curStruct.subType == types::SubType::Union || + fieldIndex + 1 == curStruct.fields.size())) { + // check the rest of the union or the last field of the struct + for (size_t i = fieldEndOffset / 8; i < structEndOffset / 8; ++i) { if (dirtyCheck(i)) { break; } @@ -325,6 +328,7 @@ std::shared_ptr KTestObjectParser::structView(const std::vector PrinterUtils::getFieldAccess(name, field), objects, initReferences); dirtyInitializedField |= sv->isDirtyInit(); + isInitializedField = sv->isInitialized(); subViews.push_back(sv); } break; @@ -392,34 +396,38 @@ std::shared_ptr KTestObjectParser::structView(const std::vector throw NoSuchTypeException(message); } - if (!dirtyInitializedField && sizeOfFieldToInitUnion < fieldLen) { + if (!dirtyInitializedField && sizeOfFieldToInitUnion < fieldLen && + curStruct.subType == types::SubType::Union) { fieldIndexToInitUnion = fieldIndex; sizeOfFieldToInitUnion = fieldLen; - } else { - dirtyInitializedStruct = true; + isInitializedStruct = true; + dirtyInitializedStruct = false; + } + if (curStruct.subType == types::SubType::Struct) { + dirtyInitializedStruct |= dirtyInitializedField; + isInitializedStruct &= isInitializedField; } prevFieldEndOffset = fieldEndOffset; ++fieldIndex; } std::optional entryValue; - if (curStruct.subType == types::SubType::Union) { - if (fieldIndexToInitUnion == SIZE_MAX && !curStruct.name.empty()) { - // init by memory copy - entryValue = PrinterUtils::convertBytesToUnion( - curStruct.name, - arrayView(byteArray, lazyPointersArray, - types::Type::createSimpleTypeFromName("utbot_byte"), - curStruct.size, - offsetInBits, usage)->getEntryValue(nullptr)); - dirtyInitializedStruct = false; - } - if (fieldIndexToInitUnion != SIZE_MAX) { - dirtyInitializedStruct = false; - } + if (!isInitializedStruct && !curStruct.name.empty() && !anonymousField) { + // init by memory copy + entryValue = PrinterUtils::convertBytesToStruct( + curStruct.name, + arrayView(byteArray, lazyPointersArray, + types::Type::createSimpleTypeFromName("utbot_byte"), + curStruct.size, + offsetInBits, usage)->getEntryValue(nullptr)); + isInitializedStruct = true; + dirtyInitializedStruct = false; + } + if (!isInitializedStruct) { + dirtyInitializedStruct = false; } return std::make_shared(curStruct, subViews, entryValue, - anonymousField, dirtyInitializedStruct, fieldIndexToInitUnion); + anonymousField, isInitializedStruct, dirtyInitializedStruct, fieldIndexToInitUnion); } std::string KTestObjectParser::primitiveCharView(const types::Type &type, std::string value) { diff --git a/server/src/Tests.h b/server/src/Tests.h index c3abb4a9a..d23829a3b 100644 --- a/server/src/Tests.h +++ b/server/src/Tests.h @@ -238,15 +238,21 @@ namespace tests { std::vector> _subViews, std::optional _entryValue, bool _anonymous, + bool _isInit, bool _dirtyInit, size_t _fieldIndexToInitUnion) : AbstractValueView(std::move(_subViews)) , entryValue(std::move(_entryValue)) , structInfo(_structInfo) , anonymous(_anonymous) + , isInit(_isInit) , dirtyInit(_dirtyInit) , fieldIndexToInitUnion(_fieldIndexToInitUnion){} + bool isInitialized() const { + return isInit; + } + bool isDirtyInit() const { return dirtyInit; } @@ -317,6 +323,7 @@ namespace tests { std::optional entryValue; bool anonymous; + bool isInit; bool dirtyInit; size_t fieldIndexToInitUnion; }; diff --git a/server/src/utils/PrinterUtils.cpp b/server/src/utils/PrinterUtils.cpp index 4cdc41b63..78dcb3633 100644 --- a/server/src/utils/PrinterUtils.cpp +++ b/server/src/utils/PrinterUtils.cpp @@ -10,7 +10,7 @@ namespace PrinterUtils { std::string convertToBytesFunctionName(const std::string &typeName) { return StringUtils::stringFormat("from_bytes<%s>", typeName); } - std::string convertBytesToUnion(const std::string &typeName, const std::string &bytes) { + std::string convertBytesToStruct(const std::string &typeName, const std::string &bytes) { return StringUtils::stringFormat("%s(%s)", convertToBytesFunctionName(typeName), bytes); } diff --git a/server/src/utils/PrinterUtils.h b/server/src/utils/PrinterUtils.h index 94b4d70b5..26141c6b0 100644 --- a/server/src/utils/PrinterUtils.h +++ b/server/src/utils/PrinterUtils.h @@ -62,7 +62,7 @@ namespace PrinterUtils { std::string convertToBytesFunctionName(std::string const &typeName); - std::string convertBytesToUnion(const std::string &typeName, const std::string &bytes); + std::string convertBytesToStruct(const std::string &typeName, const std::string &bytes); std::string wrapperName(const std::string &declName, utbot::ProjectContext const &projectContext, diff --git a/server/src/visitors/AbstractValueViewVisitor.cpp b/server/src/visitors/AbstractValueViewVisitor.cpp index e8ef7653f..184ab1fe3 100644 --- a/server/src/visitors/AbstractValueViewVisitor.cpp +++ b/server/src/visitors/AbstractValueViewVisitor.cpp @@ -93,7 +93,7 @@ namespace visitor { auto subViews = view ? &view->getSubViews() : nullptr; bool oldFlag = inUnion; - inUnion = structInfo.subType == types::SubType::Union; + inUnion |= structInfo.subType == types::SubType::Union; for (int i = 0; i < structInfo.fields.size(); ++i) { auto const &field = structInfo.fields[i]; auto newName = PrinterUtils::getFieldAccess(name, field); @@ -101,6 +101,7 @@ namespace visitor { auto newAccess = PrinterUtils::getFieldAccess(access, field); visitAny(field.type, newName, newView, newAccess, depth + 1); } + inUnion = oldFlag; } void AbstractValueViewVisitor::visitEnum(const types::Type &type, diff --git a/server/src/visitors/VerboseAssertsReturnValueVisitor.cpp b/server/src/visitors/VerboseAssertsReturnValueVisitor.cpp index cff03d2a9..fe15a99a8 100644 --- a/server/src/visitors/VerboseAssertsReturnValueVisitor.cpp +++ b/server/src/visitors/VerboseAssertsReturnValueVisitor.cpp @@ -32,7 +32,7 @@ namespace visitor { auto signature = processExpect(type, gtestMacro, {PrinterUtils::fillVarName(access, PrinterUtils::EXPECTED), getDecorateActualVarName(access)}); signature = changeSignatureToNullCheck(signature, type, view, access); - printer->strFunctionCall(signature.name, signature.args); + printer->strFunctionCall(signature.name, signature.args, SCNL, std::nullopt, true, 0, std::nullopt, inUnion); } void VerboseAssertsReturnValueVisitor::visitPointer(const types::Type &type, diff --git a/server/test/framework/Server_Tests.cpp b/server/test/framework/Server_Tests.cpp index 53ebadd92..84cebd1f2 100644 --- a/server/test/framework/Server_Tests.cpp +++ b/server/test/framework/Server_Tests.cpp @@ -456,35 +456,6 @@ namespace { } } - void checkStructWithUnion_C(BaseTestGen &testGen) { - for (const auto &[methodName, methodDescription] : - testGen.tests.at(struct_with_union_c).methods) { - if (methodName == "struct_with_union_of_unnamed_type_as_return_type") { - checkTestCasePredicates( - methodDescription.testCases, - std::vector( - {[] (const tests::Tests::MethodTestCase& testCase) { - return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) < - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && - testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\x99', -2.530171e-98}}}"; - }, - [] (const tests::Tests::MethodTestCase& testCase) { - return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) == - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && - StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr), - "{from_bytes({"); - }, - [] (const tests::Tests::MethodTestCase& testCase) { - return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) > - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && - testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\0', -2.530171e-98}}}"; - } - }), - methodName); - } - } - } - void checkInnerBasicFunctions_C(BaseTestGen &testGen) { for (const auto &[methodName, methodDescription] : testGen.tests.at(inner_basic_functions_c).methods) { @@ -2187,8 +2158,7 @@ namespace { auto testGen = FileTestGen(*request, writer.get(), TESTMODE); Status status = Server::TestsGenServiceImpl::ProcessBaseTestRequest(testGen, writer.get()); ASSERT_TRUE(status.ok()) << status.error_message(); - EXPECT_GE(testUtils::getNumberOfTests(testGen.tests), 3); - checkStructWithUnion_C(testGen); + EXPECT_GE(testUtils::getNumberOfTests(testGen.tests), 6); fs::path testsDirPath = getTestFilePath("tests"); @@ -2214,7 +2184,7 @@ namespace { auto resultsMap = coverageGenerator.getTestResultMap(); auto tests = coverageGenerator.getTestsToLaunch(); - StatusCountMap expectedStatusCountMap{ { testsgen::TEST_PASSED, 3 } }; + StatusCountMap expectedStatusCountMap{ { testsgen::TEST_PASSED, 6 } }; testUtils::checkStatuses(resultsMap, tests); } diff --git a/server/test/framework/Syntax_Tests.cpp b/server/test/framework/Syntax_Tests.cpp index 10463774e..7a5a5850b 100644 --- a/server/test/framework/Syntax_Tests.cpp +++ b/server/test/framework/Syntax_Tests.cpp @@ -1882,7 +1882,7 @@ namespace { return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) == stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr), - "{from_bytes({"); + "{from_bytes({");; }, [] (const tests::Tests::MethodTestCase& testCase) { return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) > @@ -1903,18 +1903,18 @@ namespace { std::vector( {[] (const tests::Tests::MethodTestCase& testCase) { return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) < - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && + stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\x99', -2.530171e-98}}}"; }, [] (const tests::Tests::MethodTestCase& testCase) { return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) == - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && + stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && StringUtils::startsWith(testCase.returnValue.view->getEntryValue(nullptr), "{from_bytes({"); }, [] (const tests::Tests::MethodTestCase& testCase) { return stoi(testCase.paramValues[0].view->getEntryValue(nullptr)) > - stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && + stoi(testCase.paramValues[1].view->getEntryValue(nullptr)) && testCase.returnValue.view->getEntryValue(nullptr) == "{{{'\\0', -2.530171e-98}}}"; } }) diff --git a/server/test/suites/server/struct_with_union.c b/server/test/suites/server/struct_with_union.c index 74ca7342c..0c872e9c1 100644 --- a/server/test/suites/server/struct_with_union.c +++ b/server/test/suites/server/struct_with_union.c @@ -11,4 +11,17 @@ struct StructWithUnionOfUnnamedType struct_with_union_of_unnamed_type_as_return_ ans.un.ds.d = 1.0101; } return ans; +} + +struct StructWithAnonymousUnion struct_with_anonymous_union_as_return_type(int a, int b) { + struct StructWithAnonymousUnion ans; + if (a > b) { + ans.ptr = 0; + } else if (a < b) { + ans.x = 153; + } else { + ans.c = 'k'; + ans.d = 1.0101; + } + return ans; } \ No newline at end of file diff --git a/server/test/suites/server/struct_with_union.h b/server/test/suites/server/struct_with_union.h index 53669f4bf..7e5e87d69 100644 --- a/server/test/suites/server/struct_with_union.h +++ b/server/test/suites/server/struct_with_union.h @@ -12,6 +12,19 @@ struct StructWithUnionOfUnnamedType { } un; }; +struct StructWithAnonymousUnion { + union { + int x; + struct { + char c; + double d; + }; + long long *ptr; + }; +}; + struct StructWithUnionOfUnnamedType struct_with_union_of_unnamed_type_as_return_type(int a, int b); +struct StructWithAnonymousUnion struct_with_anonymous_union_as_return_type(int a, int b); + #endif // SIMPLE_TEST_PROJECT_STRUCT_WITH_UNION_H