diff --git a/src/network/connection.cpp b/src/network/connection.cpp index edada9f4..316526b9 100644 --- a/src/network/connection.cpp +++ b/src/network/connection.cpp @@ -23,10 +23,15 @@ #include "connection.h" +#include +#include +#include #include #include "GlobalVars.h" +#include "IPAddress.h" #include "logging/Logger.h" +#include "lwip/def.h" #include "packets.h" #define TIMEOUT 3000UL @@ -547,6 +552,20 @@ void Connection::searchForServer() { // receive incoming UDP packets [[maybe_unused]] int len = m_UDP.read(m_Packet, sizeof(m_Packet)); + if (mdnsResolver.isPacketMDNS(m_Packet)) { + auto mdnsResult = mdnsResolver.parseMDNSPacket(m_Packet); + if (!mdnsResult) { + continue; + } + + m_Logger.info( + "Found mDNS record for server with IP %s!", + mdnsResult->toString().c_str() + ); + connectTo(*mdnsResult, m_ServerPort); + return; + } + #ifdef DEBUG_NETWORK m_Logger.trace( "Received %d bytes from %s, port %d", @@ -566,16 +585,7 @@ void Connection::searchForServer() { continue; } - m_ServerHost = m_UDP.remoteIP(); - m_ServerPort = m_UDP.remotePort(); - m_LastPacketTimestamp = millis(); - m_Connected = true; - - m_FeatureFlagsRequestAttempts = 0; - m_ServerFeatures = ServerFeatures{}; - - statusManager.setStatus(SlimeVR::Status::SERVER_CONNECTING, false); - ledManager.off(); + connectTo(m_UDP.remoteIP(), m_UDP.remotePort()); m_Logger.debug( "Handshake successful, server is %s:%d", @@ -600,6 +610,19 @@ void Connection::searchForServer() { } } +void Connection::connectTo(const IPAddress& ip, uint16_t port) { + m_ServerHost = ip; + m_ServerPort = port; + m_LastPacketTimestamp = millis(); + m_Connected = true; + + m_FeatureFlagsRequestAttempts = 0; + m_ServerFeatures = ServerFeatures{}; + + statusManager.setStatus(SlimeVR::Status::SERVER_CONNECTING, false); + ledManager.off(); +} + void Connection::reset() { m_Connected = false; std::fill( @@ -625,6 +648,7 @@ void Connection::reset() { void Connection::update() { if (!m_Connected) { + mdnsResolver.searchForMDNS(); searchForServer(); return; } diff --git a/src/network/connection.h b/src/network/connection.h index dc7d26c1..fbcd0ff7 100644 --- a/src/network/connection.h +++ b/src/network/connection.h @@ -26,11 +26,14 @@ #include #include +#include #include #include "../configuration/SensorConfig.h" +#include "IPAddress.h" #include "featureflags.h" #include "globals.h" +#include "network/mdns.h" #include "packets.h" #include "quat.h" #include "sensors/sensor.h" @@ -59,6 +62,7 @@ class Connection { } void searchForServer(); + void connectTo(const IPAddress& ip, uint16_t port); void update(); void reset(); bool isConnected() const { return m_Connected; } @@ -240,6 +244,8 @@ class Connection { uint16_t m_BundlePacketInnerCount = 0; unsigned char m_Buf[8]; + + MDNSResolver mdnsResolver{m_UDP, m_Logger}; }; } // namespace SlimeVR::Network diff --git a/src/network/mdns.cpp b/src/network/mdns.cpp new file mode 100644 index 00000000..01dec0a7 --- /dev/null +++ b/src/network/mdns.cpp @@ -0,0 +1,160 @@ +#if ESP8266 +#include +#else +#include +#endif + +#include +#include +#include + +#include "IPAddress.h" +#include "WiFiUdp.h" +#include "logging/Logger.h" +#include "mdns.h" + +namespace SlimeVR::Network { + +MDNSResolver::MDNSResolver(WiFiUDP& udp, SlimeVR::Logging::Logger& logger) + : udp{udp} + , logger{logger} {} + +void MDNSResolver::searchForMDNS() { + if (millis() - lastMDNSQueryMillis + >= static_cast(MDNSSearchIntervalSeconds * 1000)) { + lastMDNSQueryMillis = millis(); + sendMDNSQuery(); + } +} + +bool MDNSResolver::isPacketMDNS(const uint8_t* buffer) { + const uint8_t packetHeader[] = {0x00, 0x00, 0x84, 0x00, 0x00, 0x01, 0x00, 0x01}; + + return memcmp(buffer, packetHeader, sizeof(packetHeader)) == 0; +} + +void MDNSResolver::sendMDNSQuery() { + logger.info("Searching for mDNS record"); + + uint8_t packet[64] = {0}; + uint16_t id = 0; + uint16_t flags = 0; + uint16_t questionCount = htons(1); + uint16_t answerCount = 0; + uint16_t authorityRRs = 0; + uint16_t additionalRRs = 0; + + memcpy(&packet[0], &id, sizeof(id)); + memcpy(&packet[2], &flags, sizeof(flags)); + memcpy(&packet[4], &questionCount, sizeof(questionCount)); + memcpy(&packet[6], &answerCount, sizeof(answerCount)); + memcpy(&packet[8], &authorityRRs, sizeof(authorityRRs)); + memcpy(&packet[10], &additionalRRs, sizeof(additionalRRs)); + + uint8_t* packetWrite = &packet[12]; + size_t hostNameLength = strlen(MDNSHostName); + *packetWrite = static_cast(hostNameLength); + packetWrite++; + memcpy(packetWrite, MDNSHostName, hostNameLength); + packetWrite += hostNameLength; + const char* tld = "local"; + size_t tldLength = strlen(tld); + *packetWrite = static_cast(tldLength); + packetWrite++; + memcpy(packetWrite, tld, tldLength); + packetWrite += tldLength; + *packetWrite = '\0'; + packetWrite++; + + uint16_t questionType = ntohs(1); // A record + uint16_t questionClass = ntohs(1); // IN class + memcpy(packetWrite, &questionType, sizeof(questionType)); + packetWrite += sizeof(questionType); + memcpy(packetWrite, &questionClass, sizeof(questionClass)); + packetWrite += sizeof(questionClass); + + IPAddress mdnsAddress{224, 0, 0, 251}; + const uint16_t mdnsPort = 5353; +#if ESP8266 + udp.beginPacketMulticast(mdnsAddress, mdnsPort, WiFi.localIP(), 255); +#else + udp.beginPacket(mdnsAddress, mdnsPort); +#endif + udp.write(packet, packetWrite - packet); + udp.endPacket(); +} + +std::optional MDNSResolver::parseMDNSPacket(const uint8_t* buffer) const { + const uint8_t* packetRead = buffer; + + auto readUint16 = [&]() { + uint16_t result = packetRead[0] << 8 | packetRead[1]; + packetRead += 2; + return result; + }; + + uint16_t transactionId = readUint16(); + uint16_t flags = readUint16(); + uint16_t questionCount = readUint16(); + uint16_t answerCount = readUint16(); + uint16_t authorityRRs = readUint16(); + uint16_t additionalRRs = readUint16(); + + if (transactionId != 0 || flags != 0x8400 || questionCount != 1 || answerCount != 1 + || authorityRRs != 0 || additionalRRs != 0) { + return {}; + } + + uint8_t hostNameSize = *packetRead; + packetRead++; + if (hostNameSize != strlen(MDNSHostName) + || memcmp(MDNSHostName, packetRead, hostNameSize) != 0) { + return {}; + } + packetRead += hostNameSize; + + uint8_t tldSize = *packetRead; + packetRead++; + if (tldSize != strlen("local") || memcmp("local", packetRead, tldSize) != 0) { + return {}; + } + packetRead += tldSize; + + if (*packetRead != '\0') { + return {}; + } + packetRead++; + + uint16_t recordType = readUint16(); + uint16_t recordClass = readUint16(); + if (recordType != 1 || recordClass != 1) { + return {}; + } + + uint8_t sectionLength = *packetRead; + while (sectionLength != 0) { + if ((sectionLength & 0xc0) == 0xc0) { + // Pointer to a previous section + packetRead++; + break; + } else { + packetRead += sectionLength + 1; + sectionLength = *packetRead; + } + } + packetRead++; + + // Skip record type, class and TTL + packetRead += 8; + + uint16_t ipLength = readUint16(); + if (ipLength != 4) { + return {}; + } + + return std::optional{ + IPAddress{packetRead[0], packetRead[1], packetRead[2], packetRead[3]} + }; +} + +} // namespace SlimeVR::Network diff --git a/src/network/mdns.h b/src/network/mdns.h new file mode 100644 index 00000000..4a96eb4b --- /dev/null +++ b/src/network/mdns.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include +#include + +#include "logging/Logger.h" + +namespace SlimeVR::Network { + +class MDNSResolver { +public: + explicit MDNSResolver(WiFiUDP& udp, SlimeVR::Logging::Logger& logger); + void searchForMDNS(); + static bool isPacketMDNS(const uint8_t* buffer); + std::optional parseMDNSPacket(const uint8_t* buffer) const; + +private: + constexpr static float MDNSSearchIntervalSeconds = 5; + const char* MDNSHostName = "slimevr-server"; + + void sendMDNSQuery(); + + WiFiUDP& udp; + SlimeVR::Logging::Logger& logger; + uint64_t lastMDNSQueryMillis = 0; +}; + +} // namespace SlimeVR::Network